xinference 0.11.3__py3-none-any.whl → 0.12.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (75) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +143 -6
  3. xinference/client/restful/restful_client.py +144 -5
  4. xinference/constants.py +5 -0
  5. xinference/core/cache_tracker.py +48 -28
  6. xinference/core/model.py +160 -19
  7. xinference/core/scheduler.py +446 -0
  8. xinference/core/supervisor.py +99 -24
  9. xinference/core/worker.py +68 -2
  10. xinference/deploy/cmdline.py +86 -2
  11. xinference/deploy/test/test_cmdline.py +19 -10
  12. xinference/isolation.py +9 -2
  13. xinference/model/audio/__init__.py +14 -1
  14. xinference/model/audio/chattts.py +84 -0
  15. xinference/model/audio/core.py +22 -4
  16. xinference/model/audio/custom.py +6 -4
  17. xinference/model/audio/model_spec.json +20 -0
  18. xinference/model/audio/model_spec_modelscope.json +20 -0
  19. xinference/model/llm/__init__.py +38 -2
  20. xinference/model/llm/llm_family.json +509 -1
  21. xinference/model/llm/llm_family.py +86 -1
  22. xinference/model/llm/llm_family_csghub.json +66 -0
  23. xinference/model/llm/llm_family_modelscope.json +411 -2
  24. xinference/model/llm/pytorch/chatglm.py +20 -13
  25. xinference/model/llm/pytorch/cogvlm2.py +76 -17
  26. xinference/model/llm/pytorch/core.py +141 -6
  27. xinference/model/llm/pytorch/glm4v.py +268 -0
  28. xinference/model/llm/pytorch/minicpmv25.py +232 -0
  29. xinference/model/llm/pytorch/qwen_vl.py +1 -1
  30. xinference/model/llm/pytorch/utils.py +405 -8
  31. xinference/model/llm/utils.py +14 -13
  32. xinference/model/llm/vllm/core.py +16 -4
  33. xinference/model/utils.py +8 -2
  34. xinference/thirdparty/ChatTTS/__init__.py +1 -0
  35. xinference/thirdparty/ChatTTS/core.py +200 -0
  36. xinference/thirdparty/ChatTTS/experimental/__init__.py +0 -0
  37. xinference/thirdparty/ChatTTS/experimental/llm.py +40 -0
  38. xinference/thirdparty/ChatTTS/infer/__init__.py +0 -0
  39. xinference/thirdparty/ChatTTS/infer/api.py +125 -0
  40. xinference/thirdparty/ChatTTS/model/__init__.py +0 -0
  41. xinference/thirdparty/ChatTTS/model/dvae.py +155 -0
  42. xinference/thirdparty/ChatTTS/model/gpt.py +265 -0
  43. xinference/thirdparty/ChatTTS/utils/__init__.py +0 -0
  44. xinference/thirdparty/ChatTTS/utils/gpu_utils.py +23 -0
  45. xinference/thirdparty/ChatTTS/utils/infer_utils.py +141 -0
  46. xinference/thirdparty/ChatTTS/utils/io_utils.py +14 -0
  47. xinference/types.py +3 -0
  48. xinference/web/ui/build/asset-manifest.json +6 -6
  49. xinference/web/ui/build/index.html +1 -1
  50. xinference/web/ui/build/static/css/main.074e2b31.css +2 -0
  51. xinference/web/ui/build/static/css/main.074e2b31.css.map +1 -0
  52. xinference/web/ui/build/static/js/main.a58ff436.js +3 -0
  53. xinference/web/ui/build/static/js/main.a58ff436.js.map +1 -0
  54. xinference/web/ui/node_modules/.cache/babel-loader/10262a281dec3bc2b185f4385ceb6846626f52d41cb4d46c7c649e719f979d4d.json +1 -0
  55. xinference/web/ui/node_modules/.cache/babel-loader/762a75a62daf3bec2cfc97ec8612798493fb34ef87087dcad6aad64ab7f14345.json +1 -0
  56. xinference/web/ui/node_modules/.cache/babel-loader/7f3bdb3a48fa00c046c8b185acd4da6f2e2940a20dbd77f9373d60de3fd6633e.json +1 -0
  57. xinference/web/ui/node_modules/.cache/babel-loader/f2f73bfdc13b12b02c8cbc4769b0b8e6367e9b6d8331c322d94318491a0b3653.json +1 -0
  58. xinference/web/ui/node_modules/.cache/babel-loader/f51bf63ddaa7afd125ef2254a105789333eecc1c94fdf5157a9b88ef7ad0a5bd.json +1 -0
  59. {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/METADATA +26 -9
  60. {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/RECORD +65 -47
  61. xinference/web/ui/build/static/css/main.54bca460.css +0 -2
  62. xinference/web/ui/build/static/css/main.54bca460.css.map +0 -1
  63. xinference/web/ui/build/static/js/main.551aa479.js +0 -3
  64. xinference/web/ui/build/static/js/main.551aa479.js.map +0 -1
  65. xinference/web/ui/node_modules/.cache/babel-loader/1e86938a0cdf706d21e99b21f5d868fa247c0c88b26807047e26dcdc4d9a9db3.json +0 -1
  66. xinference/web/ui/node_modules/.cache/babel-loader/1fa824d82b2af519de7700c594e50bde4bbca60d13bd3fabff576802e4070304.json +0 -1
  67. xinference/web/ui/node_modules/.cache/babel-loader/3e737bcdbcbc407ccd65b90e199ef0c3214b261e8e41dbf14d921384a717d9ee.json +0 -1
  68. xinference/web/ui/node_modules/.cache/babel-loader/46b6dd1f6d1109cd0e2455a0ea0be3e9bda1097cd4ebec9c4040070372671cfc.json +0 -1
  69. xinference/web/ui/node_modules/.cache/babel-loader/59ce49eae0f486af4c5034d4d2f9ca77c3ec3a32ecc560085caf5ef482b5f4c9.json +0 -1
  70. xinference/web/ui/node_modules/.cache/babel-loader/a6da6bc3d0d2191adebee87fb58ecebe82d071087bd2f7f3a9c7fdd2ada130f2.json +0 -1
  71. /xinference/web/ui/build/static/js/{main.551aa479.js.LICENSE.txt → main.a58ff436.js.LICENSE.txt} +0 -0
  72. {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/LICENSE +0 -0
  73. {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/WHEEL +0 -0
  74. {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/entry_points.txt +0 -0
  75. {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/top_level.txt +0 -0
xinference/core/model.py CHANGED
@@ -20,9 +20,14 @@ import os
20
20
  import time
21
21
  import types
22
22
  import weakref
23
+ from asyncio.queues import Queue
24
+ from asyncio.tasks import wait_for
25
+ from concurrent.futures import Future as ConcurrentFuture
23
26
  from typing import (
24
27
  TYPE_CHECKING,
28
+ Any,
25
29
  AsyncGenerator,
30
+ AsyncIterator,
26
31
  Callable,
27
32
  Dict,
28
33
  Generator,
@@ -35,6 +40,8 @@ from typing import (
35
40
  import sse_starlette.sse
36
41
  import xoscar as xo
37
42
 
43
+ from ..constants import XINFERENCE_TRANSFORMERS_ENABLE_BATCHING
44
+
38
45
  if TYPE_CHECKING:
39
46
  from .worker import WorkerActor
40
47
  from ..model.llm.core import LLM
@@ -125,6 +132,16 @@ class ModelActor(xo.StatelessActor):
125
132
  from ..model.llm.pytorch.core import PytorchModel as LLMPytorchModel
126
133
  from ..model.llm.vllm.core import VLLMModel as LLMVLLMModel
127
134
 
135
+ if self.allow_batching():
136
+ try:
137
+ assert self._scheduler_ref is not None
138
+ await xo.destroy_actor(self._scheduler_ref)
139
+ del self._scheduler_ref
140
+ except Exception as e:
141
+ logger.debug(
142
+ f"Destroy scheduler actor failed, address: {self.address}, error: {e}"
143
+ )
144
+
128
145
  if (
129
146
  isinstance(self._model, (LLMPytorchModel, LLMVLLMModel))
130
147
  and self._model.model_spec.model_format == "pytorch"
@@ -181,9 +198,20 @@ class ModelActor(xo.StatelessActor):
181
198
  }
182
199
  self._loop: Optional[asyncio.AbstractEventLoop] = None
183
200
 
201
+ self._scheduler_ref = None
202
+
184
203
  async def __post_create__(self):
185
204
  self._loop = asyncio.get_running_loop()
186
205
 
206
+ if self.allow_batching():
207
+ from .scheduler import SchedulerActor
208
+
209
+ self._scheduler_ref = await xo.create_actor(
210
+ SchedulerActor,
211
+ address=self.address,
212
+ uid=SchedulerActor.gen_uid(self.model_uid(), self._model.rep_id),
213
+ )
214
+
187
215
  async def _record_completion_metrics(
188
216
  self, duration, completion_tokens, prompt_tokens
189
217
  ):
@@ -235,8 +263,23 @@ class ModelActor(xo.StatelessActor):
235
263
 
236
264
  return isinstance(self._model, VLLMModel)
237
265
 
238
- def load(self):
266
+ def allow_batching(self) -> bool:
267
+ from ..model.llm.pytorch.core import PytorchChatModel, PytorchModel
268
+
269
+ return (
270
+ XINFERENCE_TRANSFORMERS_ENABLE_BATCHING
271
+ and isinstance(self._model, PytorchModel)
272
+ and self._model.__class__.__name__
273
+ in (PytorchChatModel.__name__, PytorchModel.__name__)
274
+ )
275
+
276
+ async def load(self):
239
277
  self._model.load()
278
+ if self.allow_batching():
279
+ await self._scheduler_ref.set_model(self._model)
280
+ logger.debug(
281
+ f"Batching enabled for model: {self.model_uid()}, max_num_seqs: {self._model.get_max_num_seqs()}"
282
+ )
240
283
 
241
284
  def model_uid(self):
242
285
  return (
@@ -343,21 +386,88 @@ class ModelActor(xo.StatelessActor):
343
386
  gen = self._to_json_async_gen(ret)
344
387
  self._current_generator = weakref.ref(gen)
345
388
  return gen
389
+ if isinstance(ret, bytes):
390
+ return ret
346
391
  return await asyncio.to_thread(json_dumps, ret)
347
392
 
348
393
  @log_async(logger=logger)
349
394
  @request_limit
350
395
  @xo.generator
351
396
  async def generate(self, prompt: str, *args, **kwargs):
352
- if hasattr(self._model, "generate"):
353
- return await self._call_wrapper(
354
- self._model.generate, prompt, *args, **kwargs
397
+ if self.allow_batching():
398
+ return await self.handle_batching_request(
399
+ prompt, "generate", *args, **kwargs
355
400
  )
356
- if hasattr(self._model, "async_generate"):
357
- return await self._call_wrapper(
358
- self._model.async_generate, prompt, *args, **kwargs
359
- )
360
- raise AttributeError(f"Model {self._model.model_spec} is not for generate.")
401
+ else:
402
+ if hasattr(self._model, "generate"):
403
+ return await self._call_wrapper(
404
+ self._model.generate, prompt, *args, **kwargs
405
+ )
406
+ if hasattr(self._model, "async_generate"):
407
+ return await self._call_wrapper(
408
+ self._model.async_generate, prompt, *args, **kwargs
409
+ )
410
+ raise AttributeError(f"Model {self._model.model_spec} is not for generate.")
411
+
412
+ @staticmethod
413
+ async def _queue_consumer(
414
+ queue: Queue, timeout: Optional[float] = None
415
+ ) -> AsyncIterator[Any]:
416
+ from .scheduler import (
417
+ XINFERENCE_STREAMING_ABORT_FLAG,
418
+ XINFERENCE_STREAMING_DONE_FLAG,
419
+ XINFERENCE_STREAMING_ERROR_FLAG,
420
+ )
421
+
422
+ while True:
423
+ # TODO: timeout setting
424
+ res = await wait_for(queue.get(), timeout)
425
+ if res == XINFERENCE_STREAMING_DONE_FLAG:
426
+ break
427
+ elif res == XINFERENCE_STREAMING_ABORT_FLAG:
428
+ raise RuntimeError(
429
+ f"This request has been cancelled by another `abort_request` request."
430
+ )
431
+ elif isinstance(res, str) and res.startswith(
432
+ XINFERENCE_STREAMING_ERROR_FLAG
433
+ ):
434
+ raise RuntimeError(res[len(XINFERENCE_STREAMING_ERROR_FLAG) :])
435
+ else:
436
+ yield res
437
+
438
+ @staticmethod
439
+ def _get_stream_from_args(ability: str, *args) -> bool:
440
+ if ability == "chat":
441
+ assert args[2] is None or isinstance(args[2], dict)
442
+ return False if args[2] is None else args[2].get("stream", False)
443
+ else:
444
+ assert args[0] is None or isinstance(args[0], dict)
445
+ return False if args[0] is None else args[0].get("stream", False)
446
+
447
+ async def handle_batching_request(self, prompt: str, ability: str, *args, **kwargs):
448
+ stream = self._get_stream_from_args(ability, *args)
449
+ assert self._scheduler_ref is not None
450
+ if stream:
451
+ assert self._scheduler_ref is not None
452
+ queue: Queue[Any] = Queue()
453
+ ret = self._queue_consumer(queue)
454
+ await self._scheduler_ref.add_request(prompt, queue, *args, **kwargs)
455
+ gen = self._to_json_async_gen(ret)
456
+ self._current_generator = weakref.ref(gen)
457
+ return gen
458
+ else:
459
+ from .scheduler import XINFERENCE_NON_STREAMING_ABORT_FLAG
460
+
461
+ assert self._loop is not None
462
+ future = ConcurrentFuture()
463
+ await self._scheduler_ref.add_request(prompt, future, *args, **kwargs)
464
+ fut = asyncio.wrap_future(future, loop=self._loop)
465
+ result = await fut
466
+ if result == XINFERENCE_NON_STREAMING_ABORT_FLAG:
467
+ raise RuntimeError(
468
+ f"This request has been cancelled by another `abort_request` request."
469
+ )
470
+ return await asyncio.to_thread(json_dumps, result)
361
471
 
362
472
  @log_async(logger=logger)
363
473
  @request_limit
@@ -366,17 +476,22 @@ class ModelActor(xo.StatelessActor):
366
476
  start_time = time.time()
367
477
  response = None
368
478
  try:
369
- if hasattr(self._model, "chat"):
370
- response = await self._call_wrapper(
371
- self._model.chat, prompt, *args, **kwargs
479
+ if self.allow_batching():
480
+ return await self.handle_batching_request(
481
+ prompt, "chat", *args, **kwargs
372
482
  )
373
- return response
374
- if hasattr(self._model, "async_chat"):
375
- response = await self._call_wrapper(
376
- self._model.async_chat, prompt, *args, **kwargs
377
- )
378
- return response
379
- raise AttributeError(f"Model {self._model.model_spec} is not for chat.")
483
+ else:
484
+ if hasattr(self._model, "chat"):
485
+ response = await self._call_wrapper(
486
+ self._model.chat, prompt, *args, **kwargs
487
+ )
488
+ return response
489
+ if hasattr(self._model, "async_chat"):
490
+ response = await self._call_wrapper(
491
+ self._model.async_chat, prompt, *args, **kwargs
492
+ )
493
+ return response
494
+ raise AttributeError(f"Model {self._model.model_spec} is not for chat.")
380
495
  finally:
381
496
  # For the non stream result.
382
497
  record = None
@@ -395,6 +510,15 @@ class ModelActor(xo.StatelessActor):
395
510
  prompt_tokens,
396
511
  )
397
512
 
513
+ async def abort_request(self, request_id: str) -> str:
514
+ from .scheduler import AbortRequestMessage
515
+
516
+ if self.allow_batching():
517
+ if self._scheduler_ref is None:
518
+ return AbortRequestMessage.NOT_FOUND.name
519
+ return await self._scheduler_ref.abort_request(request_id)
520
+ return AbortRequestMessage.NO_OP.name
521
+
398
522
  @log_async(logger=logger)
399
523
  @request_limit
400
524
  async def create_embedding(self, input: Union[str, List[str]], *args, **kwargs):
@@ -482,6 +606,23 @@ class ModelActor(xo.StatelessActor):
482
606
  f"Model {self._model.model_spec} is not for creating translations."
483
607
  )
484
608
 
609
+ @log_async(logger=logger)
610
+ @request_limit
611
+ async def speech(
612
+ self, input: str, voice: str, response_format: str = "mp3", speed: float = 1.0
613
+ ):
614
+ if hasattr(self._model, "speech"):
615
+ return await self._call_wrapper(
616
+ self._model.speech,
617
+ input,
618
+ voice,
619
+ response_format,
620
+ speed,
621
+ )
622
+ raise AttributeError(
623
+ f"Model {self._model.model_spec} is not for creating speech."
624
+ )
625
+
485
626
  @log_async(logger=logger)
486
627
  @request_limit
487
628
  async def text_to_image(
@@ -0,0 +1,446 @@
1
+ # Copyright 2022-2024 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import asyncio
16
+ import functools
17
+ import logging
18
+ import uuid
19
+ from collections import deque
20
+ from enum import Enum
21
+ from typing import List, Optional, Set
22
+
23
+ import xoscar as xo
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+ XINFERENCE_BATCHING_CLEAN_CACHE_INTERVAL = 5
28
+ XINFERENCE_STREAMING_DONE_FLAG = "<XINFERENCE_STREAMING_DONE>"
29
+ XINFERENCE_STREAMING_ERROR_FLAG = "<XINFERENCE_STREAMING_ERROR>"
30
+ XINFERENCE_STREAMING_ABORT_FLAG = "<XINFERENCE_STREAMING_ABORT>"
31
+ XINFERENCE_NON_STREAMING_ABORT_FLAG = "<XINFERENCE_NON_STREAMING_ABORT>"
32
+
33
+
34
+ class AbortRequestMessage(Enum):
35
+ NOT_FOUND = 1
36
+ DONE = 2
37
+ NO_OP = 3
38
+
39
+
40
+ class InferenceRequest:
41
+ def __init__(self, prompt, future_or_queue, is_prefill, *args, **kwargs):
42
+ # original prompt
43
+ self._prompt = prompt
44
+ # full prompt that contains chat history and applies chat template
45
+ self._full_prompt = None
46
+ # whether the current request is in the prefill phase
47
+ self._is_prefill = is_prefill
48
+ # full prompt tokens
49
+ self._prompt_tokens = None
50
+ # all new generated tokens during decode phase
51
+ self._new_tokens = []
52
+ # kv_cache used in decode phase
53
+ self._kv_cache = None
54
+ # use passed args from upstream interface
55
+ self._inference_args = args
56
+ # use passed kwargs from upstream interface, basically not used for now
57
+ self._inference_kwargs = kwargs
58
+ # should this request be stopped
59
+ self._stopped = False
60
+ # finish reason. If this is set, self._stopped is True.
61
+ self._finish_reason = None
62
+ # should this request be aborted
63
+ # note that when this flag is True, assert self._stopped is True
64
+ self._aborted = False
65
+ # sanitized generate config
66
+ self._sanitized_generate_config = None
67
+ # Chunk id for results. In stream mode, all the chunk ids should be same.
68
+ self._stream_chunk_id = str(uuid.uuid4())
69
+ # Use in stream mode
70
+ self.last_output_length = 0
71
+ # inference results,
72
+ # it is a list type because when stream=True,
73
+ # self.completion contains all the results in a decode round.
74
+ self.completion = []
75
+ # The way upstream gets the returned results,
76
+ # when stream=True, it is an asyncio.Queue,
77
+ # and when stream=False, it is an asyncio future.
78
+ self.future_or_queue = future_or_queue
79
+ # Record error message when this request has error.
80
+ # Must set stopped=True when this field is set.
81
+ self.error_msg: Optional[str] = None
82
+
83
+ # check the integrity of args passed upstream
84
+ self._check_args()
85
+
86
+ def _check_args(self):
87
+ # chat
88
+ if len(self._inference_args) == 3:
89
+ # system prompt
90
+ assert self._inference_args[0] is None or isinstance(
91
+ self._inference_args[0], str
92
+ )
93
+ # chat history
94
+ assert self._inference_args[1] is None or isinstance(
95
+ self._inference_args[1], list
96
+ )
97
+ # generate config
98
+ assert self._inference_args[2] is None or isinstance(
99
+ self._inference_args[2], dict
100
+ )
101
+ else: # generate
102
+ assert len(self._inference_args) == 1
103
+ # generate config
104
+ assert self._inference_args[0] is None or isinstance(
105
+ self._inference_args[0], dict
106
+ )
107
+
108
+ @property
109
+ def prompt(self):
110
+ return self._prompt
111
+
112
+ @property
113
+ def system_prompt(self):
114
+ return self._inference_args[0]
115
+
116
+ @property
117
+ def chat_history(self):
118
+ return self._inference_args[1]
119
+
120
+ @property
121
+ def full_prompt(self):
122
+ return self._full_prompt
123
+
124
+ @full_prompt.setter
125
+ def full_prompt(self, value: str):
126
+ self._full_prompt = value
127
+
128
+ @property
129
+ def is_prefill(self):
130
+ return self._is_prefill
131
+
132
+ @is_prefill.setter
133
+ def is_prefill(self, value: bool):
134
+ self._is_prefill = value
135
+
136
+ @property
137
+ def prompt_tokens(self):
138
+ return self._prompt_tokens
139
+
140
+ @prompt_tokens.setter
141
+ def prompt_tokens(self, value: List[int]):
142
+ self._prompt_tokens = value
143
+
144
+ @property
145
+ def kv_cache(self):
146
+ return self._kv_cache
147
+
148
+ @kv_cache.setter
149
+ def kv_cache(self, value):
150
+ self._kv_cache = value
151
+
152
+ @property
153
+ def new_tokens(self):
154
+ return self._new_tokens
155
+
156
+ def append_new_token(self, token: int):
157
+ self._new_tokens.append(token)
158
+
159
+ @property
160
+ def generate_config(self):
161
+ return (
162
+ self._inference_args[2]
163
+ if len(self._inference_args) == 3
164
+ else self._inference_args[0]
165
+ )
166
+
167
+ @property
168
+ def sanitized_generate_config(self):
169
+ return self._sanitized_generate_config
170
+
171
+ @sanitized_generate_config.setter
172
+ def sanitized_generate_config(self, value: dict):
173
+ self._sanitized_generate_config = value
174
+
175
+ @property
176
+ def stopped(self):
177
+ return self._stopped
178
+
179
+ @stopped.setter
180
+ def stopped(self, value: bool):
181
+ self._stopped = value
182
+
183
+ @property
184
+ def finish_reason(self):
185
+ return self._finish_reason
186
+
187
+ @finish_reason.setter
188
+ def finish_reason(self, value: Optional[str]):
189
+ self._finish_reason = value
190
+
191
+ @property
192
+ def chunk_id(self):
193
+ return self._stream_chunk_id
194
+
195
+ @property
196
+ def stream(self) -> bool:
197
+ return (
198
+ False
199
+ if self.generate_config is None
200
+ else self.generate_config.get("stream", False)
201
+ )
202
+
203
+ @property
204
+ def stream_interval(self) -> int:
205
+ return self.sanitized_generate_config.get("stream_interval", 2)
206
+
207
+ @property
208
+ def include_usage(self) -> bool:
209
+ stream_options = self.sanitized_generate_config.get("stream_options", None)
210
+ include_usage = (
211
+ stream_options["include_usage"]
212
+ if isinstance(stream_options, dict)
213
+ else False
214
+ )
215
+ return include_usage
216
+
217
+ @property
218
+ def aborted(self) -> bool:
219
+ return self._aborted
220
+
221
+ @aborted.setter
222
+ def aborted(self, value: bool):
223
+ self._aborted = value
224
+
225
+ @property
226
+ def request_id(self) -> Optional[str]:
227
+ return (
228
+ None
229
+ if self.generate_config is None
230
+ else self.generate_config.get("request_id", None)
231
+ )
232
+
233
+ @functools.lru_cache
234
+ def get_generate_configs(self, eos_token_id: int):
235
+ from ..types import max_tokens_field
236
+
237
+ max_new_tokens = int(
238
+ self.sanitized_generate_config.get("max_tokens", max_tokens_field.default)
239
+ )
240
+ stream_interval = self.sanitized_generate_config.get("stream_interval", 2)
241
+ include_usage = self.include_usage
242
+ stop_str = self.sanitized_generate_config.get("stop", None)
243
+ stop_token_ids = (
244
+ self.sanitized_generate_config.get("stop_token_ids", None) or []
245
+ )
246
+ stop_token_ids = set(stop_token_ids)
247
+ stop_token_ids.add(eos_token_id)
248
+ temperature = float(self.sanitized_generate_config.get("temperature", 1.0))
249
+ repetition_penalty = float(
250
+ self.sanitized_generate_config.get("repetition_penalty", 1.0)
251
+ )
252
+ top_p = float(self.sanitized_generate_config.get("top_p", 1.0))
253
+ top_k = int(self.sanitized_generate_config.get("top_k", -1)) # -1 means disable
254
+ return (
255
+ max_new_tokens,
256
+ stream_interval,
257
+ include_usage,
258
+ stop_str,
259
+ stop_token_ids,
260
+ temperature,
261
+ repetition_penalty,
262
+ top_p,
263
+ top_k,
264
+ )
265
+
266
+
267
+ def _get_valid_batch_kv_cache(data, skipped_indexes: Set[int]):
268
+ from transformers.cache_utils import DynamicCache
269
+
270
+ cache = DynamicCache.from_legacy_cache(data)
271
+ batch_size = cache.key_cache[0].shape[0]
272
+ batch_slices = [num for num in range(batch_size) if num not in skipped_indexes]
273
+ for idx in range(len(cache)):
274
+ cache.key_cache[idx] = cache.key_cache[idx][batch_slices, ::]
275
+ cache.value_cache[idx] = cache.value_cache[idx][batch_slices, ::]
276
+ return cache.to_legacy_cache()
277
+
278
+
279
+ class SchedulerActor(xo.StatelessActor):
280
+ @classmethod
281
+ def gen_uid(cls, model_uid: str, replica_id: str):
282
+ return f"{model_uid}-{replica_id}-scheduler-actor"
283
+
284
+ def __init__(self):
285
+ super().__init__()
286
+ self._waiting_queue: deque[InferenceRequest] = deque()
287
+ self._running_queue: deque[InferenceRequest] = deque()
288
+ self._model = None
289
+ self._id_to_req = {}
290
+ self._abort_req_ids: Set[str] = set()
291
+ self._isolation = None
292
+
293
+ async def __post_create__(self):
294
+ from ..isolation import Isolation
295
+
296
+ self._isolation = Isolation(
297
+ asyncio.new_event_loop(), threaded=True, daemon=True
298
+ )
299
+ self._isolation.start()
300
+ asyncio.run_coroutine_threadsafe(self.run(), loop=self._isolation.loop)
301
+
302
+ async def __pre_destroy__(self):
303
+ try:
304
+ assert self._isolation is not None
305
+ self._isolation.stop()
306
+ del self._isolation
307
+ except Exception as e:
308
+ logger.debug(
309
+ f"Destroy scheduler actor failed, address: {self.address}, error: {e}"
310
+ )
311
+
312
+ def set_model(self, model):
313
+ self._model = model
314
+
315
+ def get_max_num_seqs(self):
316
+ assert self._model is not None
317
+ return self._model.get_max_num_seqs()
318
+
319
+ def _check_request_aborted(self, req: InferenceRequest):
320
+ if req.request_id and req.request_id in self._abort_req_ids:
321
+ req.aborted = True
322
+ req.stopped = True
323
+
324
+ def _handle_request(self) -> Optional[List[InferenceRequest]]:
325
+ if self._model is None:
326
+ return None
327
+ max_num_seqs = self.get_max_num_seqs()
328
+ # currently, FCFS strategy
329
+ running_list: List[InferenceRequest] = []
330
+ while len(self._running_queue) > 0:
331
+ if len(running_list) == max_num_seqs:
332
+ break
333
+ req = self._running_queue.popleft()
334
+ self._check_request_aborted(req)
335
+ running_list.append(req)
336
+
337
+ waiting_list: List[InferenceRequest] = []
338
+ if len(running_list) < max_num_seqs:
339
+ while len(self._waiting_queue) > 0:
340
+ req = self._waiting_queue.popleft()
341
+ self._check_request_aborted(req)
342
+ waiting_list.append(req)
343
+ if len(running_list) + len(waiting_list) == max_num_seqs:
344
+ break
345
+ # must waiting_list in front
346
+ return waiting_list + running_list
347
+
348
+ @staticmethod
349
+ def _empty_cache():
350
+ from ..model.llm.pytorch.utils import empty_cache
351
+
352
+ empty_cache()
353
+
354
+ async def step(self):
355
+ req_list = self._handle_request()
356
+ if not req_list:
357
+ return
358
+ self._model.batch_inference(req_list)
359
+
360
+ stopped_batch_indexes = set()
361
+ for idx, r in enumerate(req_list):
362
+ if r.stream:
363
+ for completion in r.completion:
364
+ await r.future_or_queue.put(completion)
365
+ r.completion = []
366
+
367
+ if not r.stopped:
368
+ self._running_queue.append(r)
369
+ else:
370
+ if r.new_tokens:
371
+ stopped_batch_indexes.add(idx)
372
+ # set kv_cache to None for collection
373
+ r.kv_cache = None
374
+ rid = r.request_id
375
+ # clear data structure
376
+ if rid is not None:
377
+ self._id_to_req.pop(rid, None)
378
+ self._abort_req_ids.discard(rid)
379
+
380
+ if r.aborted: # stop due to abort
381
+ # handle abort result
382
+ if r.stream:
383
+ await r.future_or_queue.put(XINFERENCE_STREAMING_ABORT_FLAG)
384
+ else:
385
+ r.future_or_queue.set_result(
386
+ XINFERENCE_NON_STREAMING_ABORT_FLAG
387
+ )
388
+ else:
389
+ if r.error_msg is None: # normal stop
390
+ if not r.stream:
391
+ r.future_or_queue.set_result(r.completion[0])
392
+ else:
393
+ await r.future_or_queue.put(XINFERENCE_STREAMING_DONE_FLAG)
394
+ # Abnormal stop, currently indicates that the parameter check does not pass,
395
+ # and does not participate in the inference
396
+ else:
397
+ if not r.stream:
398
+ r.future_or_queue.set_exception(ValueError(r.error_msg))
399
+ else:
400
+ await r.future_or_queue.put(
401
+ XINFERENCE_STREAMING_ERROR_FLAG + r.error_msg
402
+ )
403
+
404
+ # Some requests have been completed. Batch size needs to be reduced for kv cache.
405
+ if stopped_batch_indexes and len(self._running_queue) > 0:
406
+ kv_cache = self._running_queue[0].kv_cache
407
+ reduced_kv_cache = _get_valid_batch_kv_cache(
408
+ kv_cache, stopped_batch_indexes
409
+ )
410
+ for r in self._running_queue:
411
+ r.kv_cache = reduced_kv_cache
412
+
413
+ self._empty_cache()
414
+
415
+ async def add_request(self, prompt: str, future_or_queue, *args, **kwargs):
416
+ req = InferenceRequest(prompt, future_or_queue, True, *args, **kwargs)
417
+ rid = req.request_id
418
+ if rid is not None:
419
+ if rid in self._id_to_req:
420
+ raise KeyError(f"Request id: {rid} has already existed!")
421
+ self._id_to_req[rid] = req
422
+ self._waiting_queue.append(req)
423
+
424
+ async def abort_request(self, req_id: str) -> str:
425
+ """
426
+ Abort a request.
427
+ Abort a submitted request. If the request is finished or not found, this method will be a no-op.
428
+ """
429
+ if req_id not in self._id_to_req:
430
+ logger.info(f"Request id: {req_id} not found. No-op for xinference.")
431
+ return AbortRequestMessage.NOT_FOUND.name
432
+ else:
433
+ self._abort_req_ids.add(req_id)
434
+ logger.info(f"Request id: {req_id} found to be aborted.")
435
+ return AbortRequestMessage.DONE.name
436
+
437
+ async def run(self):
438
+ try:
439
+ while True:
440
+ # wait 10ms
441
+ await asyncio.sleep(0.01)
442
+ await self.step()
443
+ except Exception as e:
444
+ logger.exception(
445
+ f"Scheduler actor uid: {self.uid}, address: {self.address} run with error: {e}"
446
+ )