xinference 0.16.2__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (60) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +62 -11
  3. xinference/client/restful/restful_client.py +8 -2
  4. xinference/conftest.py +0 -8
  5. xinference/constants.py +2 -0
  6. xinference/core/model.py +44 -5
  7. xinference/core/supervisor.py +13 -7
  8. xinference/core/utils.py +76 -12
  9. xinference/core/worker.py +5 -4
  10. xinference/deploy/cmdline.py +5 -0
  11. xinference/deploy/utils.py +7 -4
  12. xinference/model/audio/model_spec.json +2 -2
  13. xinference/model/image/stable_diffusion/core.py +5 -2
  14. xinference/model/llm/core.py +1 -3
  15. xinference/model/llm/llm_family.json +263 -4
  16. xinference/model/llm/llm_family_modelscope.json +302 -0
  17. xinference/model/llm/mlx/core.py +45 -2
  18. xinference/model/llm/vllm/core.py +2 -1
  19. xinference/model/rerank/core.py +11 -4
  20. xinference/thirdparty/fish_speech/fish_speech/conversation.py +254 -0
  21. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +2 -1
  22. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +2 -1
  23. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +2 -2
  24. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ko_KR.json +123 -0
  25. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +2 -1
  26. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +76 -11
  27. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +9 -9
  28. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +1 -1
  29. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +32 -1
  30. xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +2 -1
  31. xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +22 -0
  32. xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +1 -1
  33. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1 -1
  34. xinference/thirdparty/fish_speech/tools/api.py +578 -75
  35. xinference/thirdparty/fish_speech/tools/e2e_webui.py +232 -0
  36. xinference/thirdparty/fish_speech/tools/fish_e2e.py +298 -0
  37. xinference/thirdparty/fish_speech/tools/llama/generate.py +393 -9
  38. xinference/thirdparty/fish_speech/tools/msgpack_api.py +90 -29
  39. xinference/thirdparty/fish_speech/tools/post_api.py +37 -15
  40. xinference/thirdparty/fish_speech/tools/schema.py +187 -0
  41. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +7 -1
  42. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +2 -3
  43. xinference/thirdparty/fish_speech/tools/webui.py +138 -75
  44. {xinference-0.16.2.dist-info → xinference-1.0.0.dist-info}/METADATA +26 -3
  45. {xinference-0.16.2.dist-info → xinference-1.0.0.dist-info}/RECORD +49 -56
  46. {xinference-0.16.2.dist-info → xinference-1.0.0.dist-info}/WHEEL +1 -1
  47. xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
  48. xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
  49. xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
  50. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
  51. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
  52. xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
  53. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
  54. xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
  55. xinference/thirdparty/fish_speech/tools/commons.py +0 -35
  56. xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
  57. xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
  58. {xinference-0.16.2.dist-info → xinference-1.0.0.dist-info}/LICENSE +0 -0
  59. {xinference-0.16.2.dist-info → xinference-1.0.0.dist-info}/entry_points.txt +0 -0
  60. {xinference-0.16.2.dist-info → xinference-1.0.0.dist-info}/top_level.txt +0 -0
xinference/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2024-11-01T17:56:47+0800",
11
+ "date": "2024-11-15T17:33:11+0800",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "67e97ab485b539dc7a208825bee0504acc37044e",
15
- "version": "0.16.2"
14
+ "full-revisionid": "4c96475b8f90e354aa1b47856fda4db098b62b65",
15
+ "version": "1.0.0"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -52,10 +52,14 @@ from xoscar.utils import get_next_port
52
52
 
53
53
  from .._compat import BaseModel, Field
54
54
  from .._version import get_versions
55
- from ..constants import XINFERENCE_DEFAULT_ENDPOINT_PORT, XINFERENCE_DISABLE_METRICS
55
+ from ..constants import (
56
+ XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION,
57
+ XINFERENCE_DEFAULT_ENDPOINT_PORT,
58
+ XINFERENCE_DISABLE_METRICS,
59
+ )
56
60
  from ..core.event import Event, EventCollectorActor, EventType
57
61
  from ..core.supervisor import SupervisorActor
58
- from ..core.utils import json_dumps
62
+ from ..core.utils import CancelMixin, json_dumps
59
63
  from ..types import (
60
64
  ChatCompletion,
61
65
  Completion,
@@ -111,6 +115,7 @@ class RerankRequest(BaseModel):
111
115
  return_documents: Optional[bool] = False
112
116
  return_len: Optional[bool] = False
113
117
  max_chunks_per_doc: Optional[int] = None
118
+ kwargs: Optional[str] = None
114
119
 
115
120
 
116
121
  class TextToImageRequest(BaseModel):
@@ -206,7 +211,7 @@ class BuildGradioImageInterfaceRequest(BaseModel):
206
211
  model_ability: List[str]
207
212
 
208
213
 
209
- class RESTfulAPI:
214
+ class RESTfulAPI(CancelMixin):
210
215
  def __init__(
211
216
  self,
212
217
  supervisor_address: str,
@@ -1311,11 +1316,6 @@ class RESTfulAPI:
1311
1316
  payload = await request.json()
1312
1317
  body = RerankRequest.parse_obj(payload)
1313
1318
  model_uid = body.model
1314
- kwargs = {
1315
- key: value
1316
- for key, value in payload.items()
1317
- if key not in RerankRequest.__annotations__.keys()
1318
- }
1319
1319
 
1320
1320
  try:
1321
1321
  model = await (await self._get_supervisor_ref()).get_model(model_uid)
@@ -1329,6 +1329,10 @@ class RESTfulAPI:
1329
1329
  raise HTTPException(status_code=500, detail=str(e))
1330
1330
 
1331
1331
  try:
1332
+ if body.kwargs is not None:
1333
+ parsed_kwargs = json.loads(body.kwargs)
1334
+ else:
1335
+ parsed_kwargs = {}
1332
1336
  scores = await model.rerank(
1333
1337
  body.documents,
1334
1338
  body.query,
@@ -1336,7 +1340,7 @@ class RESTfulAPI:
1336
1340
  max_chunks_per_doc=body.max_chunks_per_doc,
1337
1341
  return_documents=body.return_documents,
1338
1342
  return_len=body.return_len,
1339
- **kwargs,
1343
+ **parsed_kwargs,
1340
1344
  )
1341
1345
  return Response(scores, media_type="application/json")
1342
1346
  except RuntimeError as re:
@@ -1531,8 +1535,11 @@ class RESTfulAPI:
1531
1535
  await self._report_error_event(model_uid, str(e))
1532
1536
  raise HTTPException(status_code=500, detail=str(e))
1533
1537
 
1538
+ request_id = None
1534
1539
  try:
1535
1540
  kwargs = json.loads(body.kwargs) if body.kwargs else {}
1541
+ request_id = kwargs.get("request_id")
1542
+ self._add_running_task(request_id)
1536
1543
  image_list = await model.text_to_image(
1537
1544
  prompt=body.prompt,
1538
1545
  n=body.n,
@@ -1541,6 +1548,11 @@ class RESTfulAPI:
1541
1548
  **kwargs,
1542
1549
  )
1543
1550
  return Response(content=image_list, media_type="application/json")
1551
+ except asyncio.CancelledError:
1552
+ err_str = f"The request has been cancelled: {request_id}"
1553
+ logger.error(err_str)
1554
+ await self._report_error_event(model_uid, err_str)
1555
+ raise HTTPException(status_code=409, detail=err_str)
1544
1556
  except RuntimeError as re:
1545
1557
  logger.error(re, exc_info=True)
1546
1558
  await self._report_error_event(model_uid, str(re))
@@ -1686,11 +1698,14 @@ class RESTfulAPI:
1686
1698
  await self._report_error_event(model_uid, str(e))
1687
1699
  raise HTTPException(status_code=500, detail=str(e))
1688
1700
 
1701
+ request_id = None
1689
1702
  try:
1690
1703
  if kwargs is not None:
1691
1704
  parsed_kwargs = json.loads(kwargs)
1692
1705
  else:
1693
1706
  parsed_kwargs = {}
1707
+ request_id = parsed_kwargs.get("request_id")
1708
+ self._add_running_task(request_id)
1694
1709
  image_list = await model_ref.image_to_image(
1695
1710
  image=Image.open(image.file),
1696
1711
  prompt=prompt,
@@ -1701,6 +1716,11 @@ class RESTfulAPI:
1701
1716
  **parsed_kwargs,
1702
1717
  )
1703
1718
  return Response(content=image_list, media_type="application/json")
1719
+ except asyncio.CancelledError:
1720
+ err_str = f"The request has been cancelled: {request_id}"
1721
+ logger.error(err_str)
1722
+ await self._report_error_event(model_uid, err_str)
1723
+ raise HTTPException(status_code=409, detail=err_str)
1704
1724
  except RuntimeError as re:
1705
1725
  logger.error(re, exc_info=True)
1706
1726
  await self._report_error_event(model_uid, str(re))
@@ -1734,11 +1754,14 @@ class RESTfulAPI:
1734
1754
  await self._report_error_event(model_uid, str(e))
1735
1755
  raise HTTPException(status_code=500, detail=str(e))
1736
1756
 
1757
+ request_id = None
1737
1758
  try:
1738
1759
  if kwargs is not None:
1739
1760
  parsed_kwargs = json.loads(kwargs)
1740
1761
  else:
1741
1762
  parsed_kwargs = {}
1763
+ request_id = parsed_kwargs.get("request_id")
1764
+ self._add_running_task(request_id)
1742
1765
  im = Image.open(image.file)
1743
1766
  mask_im = Image.open(mask_image.file)
1744
1767
  if not size:
@@ -1755,6 +1778,11 @@ class RESTfulAPI:
1755
1778
  **parsed_kwargs,
1756
1779
  )
1757
1780
  return Response(content=image_list, media_type="application/json")
1781
+ except asyncio.CancelledError:
1782
+ err_str = f"The request has been cancelled: {request_id}"
1783
+ logger.error(err_str)
1784
+ await self._report_error_event(model_uid, err_str)
1785
+ raise HTTPException(status_code=409, detail=err_str)
1758
1786
  except RuntimeError as re:
1759
1787
  logger.error(re, exc_info=True)
1760
1788
  await self._report_error_event(model_uid, str(re))
@@ -1782,17 +1810,25 @@ class RESTfulAPI:
1782
1810
  await self._report_error_event(model_uid, str(e))
1783
1811
  raise HTTPException(status_code=500, detail=str(e))
1784
1812
 
1813
+ request_id = None
1785
1814
  try:
1786
1815
  if kwargs is not None:
1787
1816
  parsed_kwargs = json.loads(kwargs)
1788
1817
  else:
1789
1818
  parsed_kwargs = {}
1819
+ request_id = parsed_kwargs.get("request_id")
1820
+ self._add_running_task(request_id)
1790
1821
  im = Image.open(image.file)
1791
1822
  text = await model_ref.ocr(
1792
1823
  image=im,
1793
1824
  **parsed_kwargs,
1794
1825
  )
1795
1826
  return Response(content=text, media_type="text/plain")
1827
+ except asyncio.CancelledError:
1828
+ err_str = f"The request has been cancelled: {request_id}"
1829
+ logger.error(err_str)
1830
+ await self._report_error_event(model_uid, err_str)
1831
+ raise HTTPException(status_code=409, detail=err_str)
1796
1832
  except RuntimeError as re:
1797
1833
  logger.error(re, exc_info=True)
1798
1834
  await self._report_error_event(model_uid, str(re))
@@ -2111,10 +2147,25 @@ class RESTfulAPI:
2111
2147
  logger.error(e, exc_info=True)
2112
2148
  raise HTTPException(status_code=500, detail=str(e))
2113
2149
 
2114
- async def abort_request(self, model_uid: str, request_id: str) -> JSONResponse:
2150
+ async def abort_request(
2151
+ self, request: Request, model_uid: str, request_id: str
2152
+ ) -> JSONResponse:
2115
2153
  try:
2154
+ payload = await request.json()
2155
+ block_duration = payload.get(
2156
+ "block_duration", XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION
2157
+ )
2158
+ logger.info(
2159
+ "Abort request with model uid: %s, request id: %s, block duration: %s",
2160
+ model_uid,
2161
+ request_id,
2162
+ block_duration,
2163
+ )
2116
2164
  supervisor_ref = await self._get_supervisor_ref()
2117
- res = await supervisor_ref.abort_request(model_uid, request_id)
2165
+ res = await supervisor_ref.abort_request(
2166
+ model_uid, request_id, block_duration
2167
+ )
2168
+ self._cancel_running_task(request_id, block_duration)
2118
2169
  return JSONResponse(content=res)
2119
2170
  except Exception as e:
2120
2171
  logger.error(e, exc_info=True)
@@ -174,6 +174,7 @@ class RESTfulRerankModelHandle(RESTfulModelHandle):
174
174
  "max_chunks_per_doc": max_chunks_per_doc,
175
175
  "return_documents": return_documents,
176
176
  "return_len": return_len,
177
+ "kwargs": json.dumps(kwargs),
177
178
  }
178
179
  request_body.update(kwargs)
179
180
  response = requests.post(url, json=request_body, headers=self.auth_headers)
@@ -1357,7 +1358,7 @@ class Client:
1357
1358
  response_data = response.json()
1358
1359
  return response_data
1359
1360
 
1360
- def abort_request(self, model_uid: str, request_id: str):
1361
+ def abort_request(self, model_uid: str, request_id: str, block_duration: int = 30):
1361
1362
  """
1362
1363
  Abort a request.
1363
1364
  Abort a submitted request. If the request is finished or not found, this method will be a no-op.
@@ -1369,13 +1370,18 @@ class Client:
1369
1370
  Model uid.
1370
1371
  request_id: str
1371
1372
  Request id.
1373
+ block_duration: int
1374
+ The duration to make the request id abort. If set to 0, the abort_request will be immediate, which may
1375
+ prevent it from taking effect if it arrives before the request operation.
1372
1376
  Returns
1373
1377
  -------
1374
1378
  Dict
1375
1379
  Return empty dict.
1376
1380
  """
1377
1381
  url = f"{self.base_url}/v1/models/{model_uid}/requests/{request_id}/abort"
1378
- response = requests.post(url, headers=self._headers)
1382
+ response = requests.post(
1383
+ url, headers=self._headers, json={"block_duration": block_duration}
1384
+ )
1379
1385
  if response.status_code != 200:
1380
1386
  raise RuntimeError(
1381
1387
  f"Failed to abort request, detail: {_get_error_string(response)}"
xinference/conftest.py CHANGED
@@ -58,10 +58,6 @@ TEST_LOGGING_CONF = {
58
58
  "propagate": False,
59
59
  }
60
60
  },
61
- "root": {
62
- "level": "WARN",
63
- "handlers": ["stream_handler"],
64
- },
65
61
  }
66
62
 
67
63
  TEST_LOG_FILE_PATH = get_log_file(f"test_{get_timestamp_ms()}")
@@ -102,10 +98,6 @@ TEST_FILE_LOGGING_CONF = {
102
98
  "propagate": False,
103
99
  }
104
100
  },
105
- "root": {
106
- "level": "WARN",
107
- "handlers": ["stream_handler", "file_handler"],
108
- },
109
101
  }
110
102
 
111
103
 
xinference/constants.py CHANGED
@@ -87,3 +87,5 @@ XINFERENCE_DOWNLOAD_MAX_ATTEMPTS = int(
87
87
  XINFERENCE_TEXT_TO_IMAGE_BATCHING_SIZE = os.environ.get(
88
88
  XINFERENCE_ENV_TEXT_TO_IMAGE_BATCHING_SIZE, None
89
89
  )
90
+ XINFERENCE_LAUNCH_MODEL_RETRY = 3
91
+ XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION = 30
xinference/core/model.py CHANGED
@@ -40,7 +40,11 @@ from typing import (
40
40
  import sse_starlette.sse
41
41
  import xoscar as xo
42
42
 
43
- from ..constants import XINFERENCE_TEXT_TO_IMAGE_BATCHING_SIZE
43
+ from ..constants import (
44
+ XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION,
45
+ XINFERENCE_LAUNCH_MODEL_RETRY,
46
+ XINFERENCE_TEXT_TO_IMAGE_BATCHING_SIZE,
47
+ )
44
48
 
45
49
  if TYPE_CHECKING:
46
50
  from .progress_tracker import ProgressTrackerActor
@@ -54,7 +58,7 @@ import logging
54
58
  logger = logging.getLogger(__name__)
55
59
 
56
60
  from ..device_utils import empty_cache
57
- from .utils import json_dumps, log_async
61
+ from .utils import CancelMixin, json_dumps, log_async
58
62
 
59
63
  try:
60
64
  from torch.cuda import OutOfMemoryError
@@ -133,7 +137,9 @@ def oom_check(fn):
133
137
  return _wrapper
134
138
 
135
139
 
136
- class ModelActor(xo.StatelessActor):
140
+ class ModelActor(xo.StatelessActor, CancelMixin):
141
+ _replica_model_uid: Optional[str]
142
+
137
143
  @classmethod
138
144
  def gen_uid(cls, model: "LLM"):
139
145
  return f"{model.__class__}-model-actor"
@@ -192,6 +198,7 @@ class ModelActor(xo.StatelessActor):
192
198
  supervisor_address: str,
193
199
  worker_address: str,
194
200
  model: "LLM",
201
+ replica_model_uid: str,
195
202
  model_description: Optional["ModelDescription"] = None,
196
203
  request_limits: Optional[int] = None,
197
204
  ):
@@ -203,6 +210,7 @@ class ModelActor(xo.StatelessActor):
203
210
 
204
211
  self._supervisor_address = supervisor_address
205
212
  self._worker_address = worker_address
213
+ self._replica_model_uid = replica_model_uid
206
214
  self._model = model
207
215
  self._model_description = (
208
216
  model_description.to_dict() if model_description else {}
@@ -257,6 +265,9 @@ class ModelActor(xo.StatelessActor):
257
265
  uid=FluxBatchSchedulerActor.gen_uid(self.model_uid()),
258
266
  )
259
267
 
268
+ def __repr__(self) -> str:
269
+ return f"ModelActor({self._replica_model_uid})"
270
+
260
271
  async def _record_completion_metrics(
261
272
  self, duration, completion_tokens, prompt_tokens
262
273
  ):
@@ -374,7 +385,28 @@ class ModelActor(xo.StatelessActor):
374
385
  return condition
375
386
 
376
387
  async def load(self):
377
- self._model.load()
388
+ try:
389
+ # Change process title for model
390
+ import setproctitle
391
+
392
+ setproctitle.setproctitle(f"Model: {self._replica_model_uid}")
393
+ except ImportError:
394
+ pass
395
+ i = 0
396
+ while True:
397
+ i += 1
398
+ try:
399
+ self._model.load()
400
+ break
401
+ except Exception as e:
402
+ if (
403
+ i < XINFERENCE_LAUNCH_MODEL_RETRY
404
+ and str(e).find("busy or unavailable") >= 0
405
+ ):
406
+ await asyncio.sleep(5)
407
+ logger.warning("Retry to load model {model_uid}: %d times", i)
408
+ continue
409
+ raise
378
410
  if self.allow_batching():
379
411
  await self._scheduler_ref.set_model(self._model)
380
412
  logger.debug(
@@ -385,6 +417,7 @@ class ModelActor(xo.StatelessActor):
385
417
  logger.debug(
386
418
  f"Batching enabled for model: {self.model_uid()}, max_num_images: {self._model.get_max_num_images_for_batching()}"
387
419
  )
420
+ logger.info(f"{self} loaded")
388
421
 
389
422
  def model_uid(self):
390
423
  return (
@@ -521,6 +554,7 @@ class ModelActor(xo.StatelessActor):
521
554
 
522
555
  @oom_check
523
556
  async def _call_wrapper(self, output_type: str, fn: Callable, *args, **kwargs):
557
+ self._add_running_task(kwargs.get("request_id"))
524
558
  if self._lock is None:
525
559
  if inspect.iscoroutinefunction(fn):
526
560
  ret = await fn(*args, **kwargs)
@@ -729,9 +763,14 @@ class ModelActor(xo.StatelessActor):
729
763
  prompt_tokens,
730
764
  )
731
765
 
732
- async def abort_request(self, request_id: str) -> str:
766
+ async def abort_request(
767
+ self,
768
+ request_id: str,
769
+ block_duration: int = XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION,
770
+ ) -> str:
733
771
  from .utils import AbortRequestMessage
734
772
 
773
+ self._cancel_running_task(request_id, block_duration)
735
774
  if self.allow_batching():
736
775
  if self._scheduler_ref is None:
737
776
  return AbortRequestMessage.NOT_FOUND.name
@@ -35,6 +35,7 @@ from typing import (
35
35
  import xoscar as xo
36
36
 
37
37
  from ..constants import (
38
+ XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION,
38
39
  XINFERENCE_DISABLE_HEALTH_CHECK,
39
40
  XINFERENCE_HEALTH_CHECK_FAILURE_THRESHOLD,
40
41
  XINFERENCE_HEALTH_CHECK_INTERVAL,
@@ -970,7 +971,7 @@ class SupervisorActor(xo.StatelessActor):
970
971
  raise ValueError(
971
972
  f"Model is already in the model list, uid: {_replica_model_uid}"
972
973
  )
973
- replica_gpu_idx = assign_replica_gpu(_replica_model_uid, gpu_idx)
974
+ replica_gpu_idx = assign_replica_gpu(_replica_model_uid, replica, gpu_idx)
974
975
  nonlocal model_type
975
976
 
976
977
  worker_ref = (
@@ -1084,7 +1085,7 @@ class SupervisorActor(xo.StatelessActor):
1084
1085
  dead_models,
1085
1086
  )
1086
1087
  for replica_model_uid in dead_models:
1087
- model_uid, _, _ = parse_replica_model_uid(replica_model_uid)
1088
+ model_uid, _ = parse_replica_model_uid(replica_model_uid)
1088
1089
  self._model_uid_to_replica_info.pop(model_uid, None)
1089
1090
  self._replica_model_uid_to_worker.pop(
1090
1091
  replica_model_uid, None
@@ -1137,7 +1138,7 @@ class SupervisorActor(xo.StatelessActor):
1137
1138
  raise ValueError(f"Model not found in the model list, uid: {model_uid}")
1138
1139
 
1139
1140
  replica_model_uid = build_replica_model_uid(
1140
- model_uid, replica_info.replica, next(replica_info.scheduler)
1141
+ model_uid, next(replica_info.scheduler)
1141
1142
  )
1142
1143
 
1143
1144
  worker_ref = self._replica_model_uid_to_worker.get(replica_model_uid, None)
@@ -1154,7 +1155,7 @@ class SupervisorActor(xo.StatelessActor):
1154
1155
  raise ValueError(f"Model not found in the model list, uid: {model_uid}")
1155
1156
  # Use rep id 0 to instead of next(replica_info.scheduler) to avoid
1156
1157
  # consuming the generator.
1157
- replica_model_uid = build_replica_model_uid(model_uid, replica_info.replica, 0)
1158
+ replica_model_uid = build_replica_model_uid(model_uid, 0)
1158
1159
  worker_ref = self._replica_model_uid_to_worker.get(replica_model_uid, None)
1159
1160
  if worker_ref is None:
1160
1161
  raise ValueError(
@@ -1213,7 +1214,12 @@ class SupervisorActor(xo.StatelessActor):
1213
1214
  return cached_models
1214
1215
 
1215
1216
  @log_async(logger=logger)
1216
- async def abort_request(self, model_uid: str, request_id: str) -> Dict:
1217
+ async def abort_request(
1218
+ self,
1219
+ model_uid: str,
1220
+ request_id: str,
1221
+ block_duration: int = XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION,
1222
+ ) -> Dict:
1217
1223
  from .scheduler import AbortRequestMessage
1218
1224
 
1219
1225
  res = {"msg": AbortRequestMessage.NO_OP.name}
@@ -1228,7 +1234,7 @@ class SupervisorActor(xo.StatelessActor):
1228
1234
  if worker_ref is None:
1229
1235
  continue
1230
1236
  model_ref = await worker_ref.get_model(model_uid=rep_mid)
1231
- result_info = await model_ref.abort_request(request_id)
1237
+ result_info = await model_ref.abort_request(request_id, block_duration)
1232
1238
  res["msg"] = result_info
1233
1239
  if result_info == AbortRequestMessage.DONE.name:
1234
1240
  break
@@ -1260,7 +1266,7 @@ class SupervisorActor(xo.StatelessActor):
1260
1266
  uids_to_remove.append(model_uid)
1261
1267
 
1262
1268
  for replica_model_uid in uids_to_remove:
1263
- model_uid, _, _ = parse_replica_model_uid(replica_model_uid)
1269
+ model_uid, _ = parse_replica_model_uid(replica_model_uid)
1264
1270
  self._model_uid_to_replica_info.pop(model_uid, None)
1265
1271
  self._replica_model_uid_to_worker.pop(replica_model_uid, None)
1266
1272
 
xinference/core/utils.py CHANGED
@@ -11,11 +11,13 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+ import asyncio
14
15
  import logging
15
16
  import os
16
17
  import random
17
18
  import string
18
19
  import uuid
20
+ import weakref
19
21
  from enum import Enum
20
22
  from typing import Dict, Generator, List, Optional, Tuple, Union
21
23
 
@@ -23,7 +25,10 @@ import orjson
23
25
  from pynvml import nvmlDeviceGetCount, nvmlInit, nvmlShutdown
24
26
 
25
27
  from .._compat import BaseModel
26
- from ..constants import XINFERENCE_LOG_ARG_MAX_LENGTH
28
+ from ..constants import (
29
+ XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION,
30
+ XINFERENCE_LOG_ARG_MAX_LENGTH,
31
+ )
27
32
 
28
33
  logger = logging.getLogger(__name__)
29
34
 
@@ -49,13 +54,20 @@ def log_async(
49
54
  ):
50
55
  import time
51
56
  from functools import wraps
57
+ from inspect import signature
52
58
 
53
59
  def decorator(func):
54
60
  func_name = func.__name__
61
+ sig = signature(func)
55
62
 
56
63
  @wraps(func)
57
64
  async def wrapped(*args, **kwargs):
58
- request_id_str = kwargs.get("request_id", "")
65
+ try:
66
+ bound_args = sig.bind_partial(*args, **kwargs)
67
+ arguments = bound_args.arguments
68
+ except TypeError:
69
+ arguments = {}
70
+ request_id_str = arguments.get("request_id", "")
59
71
  if not request_id_str:
60
72
  request_id_str = uuid.uuid1()
61
73
  if func_name == "text_to_image":
@@ -146,27 +158,26 @@ def iter_replica_model_uid(model_uid: str, replica: int) -> Generator[str, None,
146
158
  """
147
159
  replica = int(replica)
148
160
  for rep_id in range(replica):
149
- yield f"{model_uid}-{replica}-{rep_id}"
161
+ yield f"{model_uid}-{rep_id}"
150
162
 
151
163
 
152
- def build_replica_model_uid(model_uid: str, replica: int, rep_id: int) -> str:
164
+ def build_replica_model_uid(model_uid: str, rep_id: int) -> str:
153
165
  """
154
166
  Build a replica model uid.
155
167
  """
156
- return f"{model_uid}-{replica}-{rep_id}"
168
+ return f"{model_uid}-{rep_id}"
157
169
 
158
170
 
159
- def parse_replica_model_uid(replica_model_uid: str) -> Tuple[str, int, int]:
171
+ def parse_replica_model_uid(replica_model_uid: str) -> Tuple[str, int]:
160
172
  """
161
- Parse replica model uid to model uid, replica and rep id.
173
+ Parse replica model uid to model uid and rep id.
162
174
  """
163
175
  parts = replica_model_uid.split("-")
164
176
  if len(parts) == 1:
165
- return replica_model_uid, -1, -1
177
+ return replica_model_uid, -1
166
178
  rep_id = int(parts.pop())
167
- replica = int(parts.pop())
168
179
  model_uid = "-".join(parts)
169
- return model_uid, replica, rep_id
180
+ return model_uid, rep_id
170
181
 
171
182
 
172
183
  def is_valid_model_uid(model_uid: str) -> bool:
@@ -261,12 +272,65 @@ def get_nvidia_gpu_info() -> Dict:
261
272
 
262
273
 
263
274
  def assign_replica_gpu(
264
- _replica_model_uid: str, gpu_idx: Union[int, List[int]]
275
+ _replica_model_uid: str, replica: int, gpu_idx: Union[int, List[int]]
265
276
  ) -> List[int]:
266
- model_uid, replica, rep_id = parse_replica_model_uid(_replica_model_uid)
277
+ model_uid, rep_id = parse_replica_model_uid(_replica_model_uid)
267
278
  rep_id, replica = int(rep_id), int(replica)
268
279
  if isinstance(gpu_idx, int):
269
280
  gpu_idx = [gpu_idx]
270
281
  if isinstance(gpu_idx, list) and gpu_idx:
271
282
  return gpu_idx[rep_id::replica]
272
283
  return gpu_idx
284
+
285
+
286
+ class CancelMixin:
287
+ _CANCEL_TASK_NAME = "abort_block"
288
+
289
+ def __init__(self):
290
+ self._running_tasks: weakref.WeakValueDictionary[
291
+ str, asyncio.Task
292
+ ] = weakref.WeakValueDictionary()
293
+
294
+ def _add_running_task(self, request_id: Optional[str]):
295
+ """Add current asyncio task to the running task.
296
+ :param request_id: The corresponding request id.
297
+ """
298
+ if request_id is None:
299
+ return
300
+ running_task = self._running_tasks.get(request_id)
301
+ if running_task is not None:
302
+ if running_task.get_name() == self._CANCEL_TASK_NAME:
303
+ raise Exception(f"The request has been aborted: {request_id}")
304
+ raise Exception(f"Duplicate request id: {request_id}")
305
+ current_task = asyncio.current_task()
306
+ assert current_task is not None
307
+ self._running_tasks[request_id] = current_task
308
+
309
+ def _cancel_running_task(
310
+ self,
311
+ request_id: Optional[str],
312
+ block_duration: int = XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION,
313
+ ):
314
+ """Cancel the running asyncio task.
315
+ :param request_id: The request id to cancel.
316
+ :param block_duration: The duration seconds to ensure the request can't be executed.
317
+ """
318
+ if request_id is None:
319
+ return
320
+ running_task = self._running_tasks.pop(request_id, None)
321
+ if running_task is not None:
322
+ running_task.cancel()
323
+
324
+ async def block_task():
325
+ """This task is for blocking the request for a duration."""
326
+ try:
327
+ await asyncio.sleep(block_duration)
328
+ logger.info("Abort block end for request: %s", request_id)
329
+ except asyncio.CancelledError:
330
+ logger.info("Abort block is cancelled for request: %s", request_id)
331
+
332
+ if block_duration > 0:
333
+ logger.info("Abort block start for request: %s", request_id)
334
+ self._running_tasks[request_id] = asyncio.create_task(
335
+ block_task(), name=self._CANCEL_TASK_NAME
336
+ )
xinference/core/worker.py CHANGED
@@ -157,7 +157,7 @@ class WorkerActor(xo.StatelessActor):
157
157
  model_uid,
158
158
  recover_count - 1,
159
159
  )
160
- event_model_uid, _, __ = parse_replica_model_uid(model_uid)
160
+ event_model_uid, _ = parse_replica_model_uid(model_uid)
161
161
  try:
162
162
  if self._event_collector_ref is not None:
163
163
  await self._event_collector_ref.report_event(
@@ -377,7 +377,7 @@ class WorkerActor(xo.StatelessActor):
377
377
  return len(self._model_uid_to_model)
378
378
 
379
379
  async def is_model_vllm_backend(self, model_uid: str) -> bool:
380
- _model_uid, _, _ = parse_replica_model_uid(model_uid)
380
+ _model_uid, _ = parse_replica_model_uid(model_uid)
381
381
  supervisor_ref = await self.get_supervisor_ref()
382
382
  model_ref = await supervisor_ref.get_model(_model_uid)
383
383
  return await model_ref.is_vllm_backend()
@@ -800,7 +800,7 @@ class WorkerActor(xo.StatelessActor):
800
800
  launch_args.update(kwargs)
801
801
 
802
802
  try:
803
- origin_uid, _, _ = parse_replica_model_uid(model_uid)
803
+ origin_uid, _ = parse_replica_model_uid(model_uid)
804
804
  except Exception as e:
805
805
  logger.exception(e)
806
806
  raise
@@ -889,6 +889,7 @@ class WorkerActor(xo.StatelessActor):
889
889
  uid=model_uid,
890
890
  supervisor_address=self._supervisor_address,
891
891
  worker_address=self.address,
892
+ replica_model_uid=model_uid,
892
893
  model=model,
893
894
  model_description=model_description,
894
895
  request_limits=request_limits,
@@ -926,7 +927,7 @@ class WorkerActor(xo.StatelessActor):
926
927
  # Terminate model while its launching is not allow
927
928
  if model_uid in self._model_uid_launching_guard:
928
929
  raise ValueError(f"{model_uid} is launching")
929
- origin_uid, _, __ = parse_replica_model_uid(model_uid)
930
+ origin_uid, _ = parse_replica_model_uid(model_uid)
930
931
  try:
931
932
  _ = await self.get_supervisor_ref()
932
933
  if self._event_collector_ref is not None: