xinference 0.11.3__py3-none-any.whl → 0.12.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +143 -6
- xinference/client/restful/restful_client.py +144 -5
- xinference/constants.py +5 -0
- xinference/core/cache_tracker.py +48 -28
- xinference/core/model.py +160 -19
- xinference/core/scheduler.py +446 -0
- xinference/core/supervisor.py +99 -24
- xinference/core/worker.py +68 -2
- xinference/deploy/cmdline.py +86 -2
- xinference/deploy/test/test_cmdline.py +19 -10
- xinference/isolation.py +9 -2
- xinference/model/audio/__init__.py +14 -1
- xinference/model/audio/chattts.py +84 -0
- xinference/model/audio/core.py +22 -4
- xinference/model/audio/custom.py +6 -4
- xinference/model/audio/model_spec.json +20 -0
- xinference/model/audio/model_spec_modelscope.json +20 -0
- xinference/model/llm/__init__.py +38 -2
- xinference/model/llm/llm_family.json +509 -1
- xinference/model/llm/llm_family.py +86 -1
- xinference/model/llm/llm_family_csghub.json +66 -0
- xinference/model/llm/llm_family_modelscope.json +411 -2
- xinference/model/llm/pytorch/chatglm.py +20 -13
- xinference/model/llm/pytorch/cogvlm2.py +76 -17
- xinference/model/llm/pytorch/core.py +141 -6
- xinference/model/llm/pytorch/glm4v.py +268 -0
- xinference/model/llm/pytorch/minicpmv25.py +232 -0
- xinference/model/llm/pytorch/qwen_vl.py +1 -1
- xinference/model/llm/pytorch/utils.py +405 -8
- xinference/model/llm/utils.py +14 -13
- xinference/model/llm/vllm/core.py +16 -4
- xinference/model/utils.py +8 -2
- xinference/thirdparty/ChatTTS/__init__.py +1 -0
- xinference/thirdparty/ChatTTS/core.py +200 -0
- xinference/thirdparty/ChatTTS/experimental/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/experimental/llm.py +40 -0
- xinference/thirdparty/ChatTTS/infer/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/infer/api.py +125 -0
- xinference/thirdparty/ChatTTS/model/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/model/dvae.py +155 -0
- xinference/thirdparty/ChatTTS/model/gpt.py +265 -0
- xinference/thirdparty/ChatTTS/utils/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/utils/gpu_utils.py +23 -0
- xinference/thirdparty/ChatTTS/utils/infer_utils.py +141 -0
- xinference/thirdparty/ChatTTS/utils/io_utils.py +14 -0
- xinference/types.py +3 -0
- xinference/web/ui/build/asset-manifest.json +6 -6
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/css/main.074e2b31.css +2 -0
- xinference/web/ui/build/static/css/main.074e2b31.css.map +1 -0
- xinference/web/ui/build/static/js/main.a58ff436.js +3 -0
- xinference/web/ui/build/static/js/main.a58ff436.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/10262a281dec3bc2b185f4385ceb6846626f52d41cb4d46c7c649e719f979d4d.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/762a75a62daf3bec2cfc97ec8612798493fb34ef87087dcad6aad64ab7f14345.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/7f3bdb3a48fa00c046c8b185acd4da6f2e2940a20dbd77f9373d60de3fd6633e.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f2f73bfdc13b12b02c8cbc4769b0b8e6367e9b6d8331c322d94318491a0b3653.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f51bf63ddaa7afd125ef2254a105789333eecc1c94fdf5157a9b88ef7ad0a5bd.json +1 -0
- {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/METADATA +26 -9
- {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/RECORD +65 -47
- xinference/web/ui/build/static/css/main.54bca460.css +0 -2
- xinference/web/ui/build/static/css/main.54bca460.css.map +0 -1
- xinference/web/ui/build/static/js/main.551aa479.js +0 -3
- xinference/web/ui/build/static/js/main.551aa479.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/1e86938a0cdf706d21e99b21f5d868fa247c0c88b26807047e26dcdc4d9a9db3.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/1fa824d82b2af519de7700c594e50bde4bbca60d13bd3fabff576802e4070304.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/3e737bcdbcbc407ccd65b90e199ef0c3214b261e8e41dbf14d921384a717d9ee.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/46b6dd1f6d1109cd0e2455a0ea0be3e9bda1097cd4ebec9c4040070372671cfc.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/59ce49eae0f486af4c5034d4d2f9ca77c3ec3a32ecc560085caf5ef482b5f4c9.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/a6da6bc3d0d2191adebee87fb58ecebe82d071087bd2f7f3a9c7fdd2ada130f2.json +0 -1
- /xinference/web/ui/build/static/js/{main.551aa479.js.LICENSE.txt → main.a58ff436.js.LICENSE.txt} +0 -0
- {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/LICENSE +0 -0
- {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/WHEEL +0 -0
- {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/entry_points.txt +0 -0
- {xinference-0.11.3.dist-info → xinference-0.12.1.dist-info}/top_level.txt +0 -0
xinference/core/model.py
CHANGED
|
@@ -20,9 +20,14 @@ import os
|
|
|
20
20
|
import time
|
|
21
21
|
import types
|
|
22
22
|
import weakref
|
|
23
|
+
from asyncio.queues import Queue
|
|
24
|
+
from asyncio.tasks import wait_for
|
|
25
|
+
from concurrent.futures import Future as ConcurrentFuture
|
|
23
26
|
from typing import (
|
|
24
27
|
TYPE_CHECKING,
|
|
28
|
+
Any,
|
|
25
29
|
AsyncGenerator,
|
|
30
|
+
AsyncIterator,
|
|
26
31
|
Callable,
|
|
27
32
|
Dict,
|
|
28
33
|
Generator,
|
|
@@ -35,6 +40,8 @@ from typing import (
|
|
|
35
40
|
import sse_starlette.sse
|
|
36
41
|
import xoscar as xo
|
|
37
42
|
|
|
43
|
+
from ..constants import XINFERENCE_TRANSFORMERS_ENABLE_BATCHING
|
|
44
|
+
|
|
38
45
|
if TYPE_CHECKING:
|
|
39
46
|
from .worker import WorkerActor
|
|
40
47
|
from ..model.llm.core import LLM
|
|
@@ -125,6 +132,16 @@ class ModelActor(xo.StatelessActor):
|
|
|
125
132
|
from ..model.llm.pytorch.core import PytorchModel as LLMPytorchModel
|
|
126
133
|
from ..model.llm.vllm.core import VLLMModel as LLMVLLMModel
|
|
127
134
|
|
|
135
|
+
if self.allow_batching():
|
|
136
|
+
try:
|
|
137
|
+
assert self._scheduler_ref is not None
|
|
138
|
+
await xo.destroy_actor(self._scheduler_ref)
|
|
139
|
+
del self._scheduler_ref
|
|
140
|
+
except Exception as e:
|
|
141
|
+
logger.debug(
|
|
142
|
+
f"Destroy scheduler actor failed, address: {self.address}, error: {e}"
|
|
143
|
+
)
|
|
144
|
+
|
|
128
145
|
if (
|
|
129
146
|
isinstance(self._model, (LLMPytorchModel, LLMVLLMModel))
|
|
130
147
|
and self._model.model_spec.model_format == "pytorch"
|
|
@@ -181,9 +198,20 @@ class ModelActor(xo.StatelessActor):
|
|
|
181
198
|
}
|
|
182
199
|
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
|
183
200
|
|
|
201
|
+
self._scheduler_ref = None
|
|
202
|
+
|
|
184
203
|
async def __post_create__(self):
|
|
185
204
|
self._loop = asyncio.get_running_loop()
|
|
186
205
|
|
|
206
|
+
if self.allow_batching():
|
|
207
|
+
from .scheduler import SchedulerActor
|
|
208
|
+
|
|
209
|
+
self._scheduler_ref = await xo.create_actor(
|
|
210
|
+
SchedulerActor,
|
|
211
|
+
address=self.address,
|
|
212
|
+
uid=SchedulerActor.gen_uid(self.model_uid(), self._model.rep_id),
|
|
213
|
+
)
|
|
214
|
+
|
|
187
215
|
async def _record_completion_metrics(
|
|
188
216
|
self, duration, completion_tokens, prompt_tokens
|
|
189
217
|
):
|
|
@@ -235,8 +263,23 @@ class ModelActor(xo.StatelessActor):
|
|
|
235
263
|
|
|
236
264
|
return isinstance(self._model, VLLMModel)
|
|
237
265
|
|
|
238
|
-
def
|
|
266
|
+
def allow_batching(self) -> bool:
|
|
267
|
+
from ..model.llm.pytorch.core import PytorchChatModel, PytorchModel
|
|
268
|
+
|
|
269
|
+
return (
|
|
270
|
+
XINFERENCE_TRANSFORMERS_ENABLE_BATCHING
|
|
271
|
+
and isinstance(self._model, PytorchModel)
|
|
272
|
+
and self._model.__class__.__name__
|
|
273
|
+
in (PytorchChatModel.__name__, PytorchModel.__name__)
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
async def load(self):
|
|
239
277
|
self._model.load()
|
|
278
|
+
if self.allow_batching():
|
|
279
|
+
await self._scheduler_ref.set_model(self._model)
|
|
280
|
+
logger.debug(
|
|
281
|
+
f"Batching enabled for model: {self.model_uid()}, max_num_seqs: {self._model.get_max_num_seqs()}"
|
|
282
|
+
)
|
|
240
283
|
|
|
241
284
|
def model_uid(self):
|
|
242
285
|
return (
|
|
@@ -343,21 +386,88 @@ class ModelActor(xo.StatelessActor):
|
|
|
343
386
|
gen = self._to_json_async_gen(ret)
|
|
344
387
|
self._current_generator = weakref.ref(gen)
|
|
345
388
|
return gen
|
|
389
|
+
if isinstance(ret, bytes):
|
|
390
|
+
return ret
|
|
346
391
|
return await asyncio.to_thread(json_dumps, ret)
|
|
347
392
|
|
|
348
393
|
@log_async(logger=logger)
|
|
349
394
|
@request_limit
|
|
350
395
|
@xo.generator
|
|
351
396
|
async def generate(self, prompt: str, *args, **kwargs):
|
|
352
|
-
if
|
|
353
|
-
return await self.
|
|
354
|
-
|
|
397
|
+
if self.allow_batching():
|
|
398
|
+
return await self.handle_batching_request(
|
|
399
|
+
prompt, "generate", *args, **kwargs
|
|
355
400
|
)
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
self.
|
|
359
|
-
|
|
360
|
-
|
|
401
|
+
else:
|
|
402
|
+
if hasattr(self._model, "generate"):
|
|
403
|
+
return await self._call_wrapper(
|
|
404
|
+
self._model.generate, prompt, *args, **kwargs
|
|
405
|
+
)
|
|
406
|
+
if hasattr(self._model, "async_generate"):
|
|
407
|
+
return await self._call_wrapper(
|
|
408
|
+
self._model.async_generate, prompt, *args, **kwargs
|
|
409
|
+
)
|
|
410
|
+
raise AttributeError(f"Model {self._model.model_spec} is not for generate.")
|
|
411
|
+
|
|
412
|
+
@staticmethod
|
|
413
|
+
async def _queue_consumer(
|
|
414
|
+
queue: Queue, timeout: Optional[float] = None
|
|
415
|
+
) -> AsyncIterator[Any]:
|
|
416
|
+
from .scheduler import (
|
|
417
|
+
XINFERENCE_STREAMING_ABORT_FLAG,
|
|
418
|
+
XINFERENCE_STREAMING_DONE_FLAG,
|
|
419
|
+
XINFERENCE_STREAMING_ERROR_FLAG,
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
while True:
|
|
423
|
+
# TODO: timeout setting
|
|
424
|
+
res = await wait_for(queue.get(), timeout)
|
|
425
|
+
if res == XINFERENCE_STREAMING_DONE_FLAG:
|
|
426
|
+
break
|
|
427
|
+
elif res == XINFERENCE_STREAMING_ABORT_FLAG:
|
|
428
|
+
raise RuntimeError(
|
|
429
|
+
f"This request has been cancelled by another `abort_request` request."
|
|
430
|
+
)
|
|
431
|
+
elif isinstance(res, str) and res.startswith(
|
|
432
|
+
XINFERENCE_STREAMING_ERROR_FLAG
|
|
433
|
+
):
|
|
434
|
+
raise RuntimeError(res[len(XINFERENCE_STREAMING_ERROR_FLAG) :])
|
|
435
|
+
else:
|
|
436
|
+
yield res
|
|
437
|
+
|
|
438
|
+
@staticmethod
|
|
439
|
+
def _get_stream_from_args(ability: str, *args) -> bool:
|
|
440
|
+
if ability == "chat":
|
|
441
|
+
assert args[2] is None or isinstance(args[2], dict)
|
|
442
|
+
return False if args[2] is None else args[2].get("stream", False)
|
|
443
|
+
else:
|
|
444
|
+
assert args[0] is None or isinstance(args[0], dict)
|
|
445
|
+
return False if args[0] is None else args[0].get("stream", False)
|
|
446
|
+
|
|
447
|
+
async def handle_batching_request(self, prompt: str, ability: str, *args, **kwargs):
|
|
448
|
+
stream = self._get_stream_from_args(ability, *args)
|
|
449
|
+
assert self._scheduler_ref is not None
|
|
450
|
+
if stream:
|
|
451
|
+
assert self._scheduler_ref is not None
|
|
452
|
+
queue: Queue[Any] = Queue()
|
|
453
|
+
ret = self._queue_consumer(queue)
|
|
454
|
+
await self._scheduler_ref.add_request(prompt, queue, *args, **kwargs)
|
|
455
|
+
gen = self._to_json_async_gen(ret)
|
|
456
|
+
self._current_generator = weakref.ref(gen)
|
|
457
|
+
return gen
|
|
458
|
+
else:
|
|
459
|
+
from .scheduler import XINFERENCE_NON_STREAMING_ABORT_FLAG
|
|
460
|
+
|
|
461
|
+
assert self._loop is not None
|
|
462
|
+
future = ConcurrentFuture()
|
|
463
|
+
await self._scheduler_ref.add_request(prompt, future, *args, **kwargs)
|
|
464
|
+
fut = asyncio.wrap_future(future, loop=self._loop)
|
|
465
|
+
result = await fut
|
|
466
|
+
if result == XINFERENCE_NON_STREAMING_ABORT_FLAG:
|
|
467
|
+
raise RuntimeError(
|
|
468
|
+
f"This request has been cancelled by another `abort_request` request."
|
|
469
|
+
)
|
|
470
|
+
return await asyncio.to_thread(json_dumps, result)
|
|
361
471
|
|
|
362
472
|
@log_async(logger=logger)
|
|
363
473
|
@request_limit
|
|
@@ -366,17 +476,22 @@ class ModelActor(xo.StatelessActor):
|
|
|
366
476
|
start_time = time.time()
|
|
367
477
|
response = None
|
|
368
478
|
try:
|
|
369
|
-
if
|
|
370
|
-
|
|
371
|
-
|
|
479
|
+
if self.allow_batching():
|
|
480
|
+
return await self.handle_batching_request(
|
|
481
|
+
prompt, "chat", *args, **kwargs
|
|
372
482
|
)
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
483
|
+
else:
|
|
484
|
+
if hasattr(self._model, "chat"):
|
|
485
|
+
response = await self._call_wrapper(
|
|
486
|
+
self._model.chat, prompt, *args, **kwargs
|
|
487
|
+
)
|
|
488
|
+
return response
|
|
489
|
+
if hasattr(self._model, "async_chat"):
|
|
490
|
+
response = await self._call_wrapper(
|
|
491
|
+
self._model.async_chat, prompt, *args, **kwargs
|
|
492
|
+
)
|
|
493
|
+
return response
|
|
494
|
+
raise AttributeError(f"Model {self._model.model_spec} is not for chat.")
|
|
380
495
|
finally:
|
|
381
496
|
# For the non stream result.
|
|
382
497
|
record = None
|
|
@@ -395,6 +510,15 @@ class ModelActor(xo.StatelessActor):
|
|
|
395
510
|
prompt_tokens,
|
|
396
511
|
)
|
|
397
512
|
|
|
513
|
+
async def abort_request(self, request_id: str) -> str:
|
|
514
|
+
from .scheduler import AbortRequestMessage
|
|
515
|
+
|
|
516
|
+
if self.allow_batching():
|
|
517
|
+
if self._scheduler_ref is None:
|
|
518
|
+
return AbortRequestMessage.NOT_FOUND.name
|
|
519
|
+
return await self._scheduler_ref.abort_request(request_id)
|
|
520
|
+
return AbortRequestMessage.NO_OP.name
|
|
521
|
+
|
|
398
522
|
@log_async(logger=logger)
|
|
399
523
|
@request_limit
|
|
400
524
|
async def create_embedding(self, input: Union[str, List[str]], *args, **kwargs):
|
|
@@ -482,6 +606,23 @@ class ModelActor(xo.StatelessActor):
|
|
|
482
606
|
f"Model {self._model.model_spec} is not for creating translations."
|
|
483
607
|
)
|
|
484
608
|
|
|
609
|
+
@log_async(logger=logger)
|
|
610
|
+
@request_limit
|
|
611
|
+
async def speech(
|
|
612
|
+
self, input: str, voice: str, response_format: str = "mp3", speed: float = 1.0
|
|
613
|
+
):
|
|
614
|
+
if hasattr(self._model, "speech"):
|
|
615
|
+
return await self._call_wrapper(
|
|
616
|
+
self._model.speech,
|
|
617
|
+
input,
|
|
618
|
+
voice,
|
|
619
|
+
response_format,
|
|
620
|
+
speed,
|
|
621
|
+
)
|
|
622
|
+
raise AttributeError(
|
|
623
|
+
f"Model {self._model.model_spec} is not for creating speech."
|
|
624
|
+
)
|
|
625
|
+
|
|
485
626
|
@log_async(logger=logger)
|
|
486
627
|
@request_limit
|
|
487
628
|
async def text_to_image(
|
|
@@ -0,0 +1,446 @@
|
|
|
1
|
+
# Copyright 2022-2024 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import asyncio
|
|
16
|
+
import functools
|
|
17
|
+
import logging
|
|
18
|
+
import uuid
|
|
19
|
+
from collections import deque
|
|
20
|
+
from enum import Enum
|
|
21
|
+
from typing import List, Optional, Set
|
|
22
|
+
|
|
23
|
+
import xoscar as xo
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
XINFERENCE_BATCHING_CLEAN_CACHE_INTERVAL = 5
|
|
28
|
+
XINFERENCE_STREAMING_DONE_FLAG = "<XINFERENCE_STREAMING_DONE>"
|
|
29
|
+
XINFERENCE_STREAMING_ERROR_FLAG = "<XINFERENCE_STREAMING_ERROR>"
|
|
30
|
+
XINFERENCE_STREAMING_ABORT_FLAG = "<XINFERENCE_STREAMING_ABORT>"
|
|
31
|
+
XINFERENCE_NON_STREAMING_ABORT_FLAG = "<XINFERENCE_NON_STREAMING_ABORT>"
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class AbortRequestMessage(Enum):
|
|
35
|
+
NOT_FOUND = 1
|
|
36
|
+
DONE = 2
|
|
37
|
+
NO_OP = 3
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class InferenceRequest:
|
|
41
|
+
def __init__(self, prompt, future_or_queue, is_prefill, *args, **kwargs):
|
|
42
|
+
# original prompt
|
|
43
|
+
self._prompt = prompt
|
|
44
|
+
# full prompt that contains chat history and applies chat template
|
|
45
|
+
self._full_prompt = None
|
|
46
|
+
# whether the current request is in the prefill phase
|
|
47
|
+
self._is_prefill = is_prefill
|
|
48
|
+
# full prompt tokens
|
|
49
|
+
self._prompt_tokens = None
|
|
50
|
+
# all new generated tokens during decode phase
|
|
51
|
+
self._new_tokens = []
|
|
52
|
+
# kv_cache used in decode phase
|
|
53
|
+
self._kv_cache = None
|
|
54
|
+
# use passed args from upstream interface
|
|
55
|
+
self._inference_args = args
|
|
56
|
+
# use passed kwargs from upstream interface, basically not used for now
|
|
57
|
+
self._inference_kwargs = kwargs
|
|
58
|
+
# should this request be stopped
|
|
59
|
+
self._stopped = False
|
|
60
|
+
# finish reason. If this is set, self._stopped is True.
|
|
61
|
+
self._finish_reason = None
|
|
62
|
+
# should this request be aborted
|
|
63
|
+
# note that when this flag is True, assert self._stopped is True
|
|
64
|
+
self._aborted = False
|
|
65
|
+
# sanitized generate config
|
|
66
|
+
self._sanitized_generate_config = None
|
|
67
|
+
# Chunk id for results. In stream mode, all the chunk ids should be same.
|
|
68
|
+
self._stream_chunk_id = str(uuid.uuid4())
|
|
69
|
+
# Use in stream mode
|
|
70
|
+
self.last_output_length = 0
|
|
71
|
+
# inference results,
|
|
72
|
+
# it is a list type because when stream=True,
|
|
73
|
+
# self.completion contains all the results in a decode round.
|
|
74
|
+
self.completion = []
|
|
75
|
+
# The way upstream gets the returned results,
|
|
76
|
+
# when stream=True, it is an asyncio.Queue,
|
|
77
|
+
# and when stream=False, it is an asyncio future.
|
|
78
|
+
self.future_or_queue = future_or_queue
|
|
79
|
+
# Record error message when this request has error.
|
|
80
|
+
# Must set stopped=True when this field is set.
|
|
81
|
+
self.error_msg: Optional[str] = None
|
|
82
|
+
|
|
83
|
+
# check the integrity of args passed upstream
|
|
84
|
+
self._check_args()
|
|
85
|
+
|
|
86
|
+
def _check_args(self):
|
|
87
|
+
# chat
|
|
88
|
+
if len(self._inference_args) == 3:
|
|
89
|
+
# system prompt
|
|
90
|
+
assert self._inference_args[0] is None or isinstance(
|
|
91
|
+
self._inference_args[0], str
|
|
92
|
+
)
|
|
93
|
+
# chat history
|
|
94
|
+
assert self._inference_args[1] is None or isinstance(
|
|
95
|
+
self._inference_args[1], list
|
|
96
|
+
)
|
|
97
|
+
# generate config
|
|
98
|
+
assert self._inference_args[2] is None or isinstance(
|
|
99
|
+
self._inference_args[2], dict
|
|
100
|
+
)
|
|
101
|
+
else: # generate
|
|
102
|
+
assert len(self._inference_args) == 1
|
|
103
|
+
# generate config
|
|
104
|
+
assert self._inference_args[0] is None or isinstance(
|
|
105
|
+
self._inference_args[0], dict
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
@property
|
|
109
|
+
def prompt(self):
|
|
110
|
+
return self._prompt
|
|
111
|
+
|
|
112
|
+
@property
|
|
113
|
+
def system_prompt(self):
|
|
114
|
+
return self._inference_args[0]
|
|
115
|
+
|
|
116
|
+
@property
|
|
117
|
+
def chat_history(self):
|
|
118
|
+
return self._inference_args[1]
|
|
119
|
+
|
|
120
|
+
@property
|
|
121
|
+
def full_prompt(self):
|
|
122
|
+
return self._full_prompt
|
|
123
|
+
|
|
124
|
+
@full_prompt.setter
|
|
125
|
+
def full_prompt(self, value: str):
|
|
126
|
+
self._full_prompt = value
|
|
127
|
+
|
|
128
|
+
@property
|
|
129
|
+
def is_prefill(self):
|
|
130
|
+
return self._is_prefill
|
|
131
|
+
|
|
132
|
+
@is_prefill.setter
|
|
133
|
+
def is_prefill(self, value: bool):
|
|
134
|
+
self._is_prefill = value
|
|
135
|
+
|
|
136
|
+
@property
|
|
137
|
+
def prompt_tokens(self):
|
|
138
|
+
return self._prompt_tokens
|
|
139
|
+
|
|
140
|
+
@prompt_tokens.setter
|
|
141
|
+
def prompt_tokens(self, value: List[int]):
|
|
142
|
+
self._prompt_tokens = value
|
|
143
|
+
|
|
144
|
+
@property
|
|
145
|
+
def kv_cache(self):
|
|
146
|
+
return self._kv_cache
|
|
147
|
+
|
|
148
|
+
@kv_cache.setter
|
|
149
|
+
def kv_cache(self, value):
|
|
150
|
+
self._kv_cache = value
|
|
151
|
+
|
|
152
|
+
@property
|
|
153
|
+
def new_tokens(self):
|
|
154
|
+
return self._new_tokens
|
|
155
|
+
|
|
156
|
+
def append_new_token(self, token: int):
|
|
157
|
+
self._new_tokens.append(token)
|
|
158
|
+
|
|
159
|
+
@property
|
|
160
|
+
def generate_config(self):
|
|
161
|
+
return (
|
|
162
|
+
self._inference_args[2]
|
|
163
|
+
if len(self._inference_args) == 3
|
|
164
|
+
else self._inference_args[0]
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
@property
|
|
168
|
+
def sanitized_generate_config(self):
|
|
169
|
+
return self._sanitized_generate_config
|
|
170
|
+
|
|
171
|
+
@sanitized_generate_config.setter
|
|
172
|
+
def sanitized_generate_config(self, value: dict):
|
|
173
|
+
self._sanitized_generate_config = value
|
|
174
|
+
|
|
175
|
+
@property
|
|
176
|
+
def stopped(self):
|
|
177
|
+
return self._stopped
|
|
178
|
+
|
|
179
|
+
@stopped.setter
|
|
180
|
+
def stopped(self, value: bool):
|
|
181
|
+
self._stopped = value
|
|
182
|
+
|
|
183
|
+
@property
|
|
184
|
+
def finish_reason(self):
|
|
185
|
+
return self._finish_reason
|
|
186
|
+
|
|
187
|
+
@finish_reason.setter
|
|
188
|
+
def finish_reason(self, value: Optional[str]):
|
|
189
|
+
self._finish_reason = value
|
|
190
|
+
|
|
191
|
+
@property
|
|
192
|
+
def chunk_id(self):
|
|
193
|
+
return self._stream_chunk_id
|
|
194
|
+
|
|
195
|
+
@property
|
|
196
|
+
def stream(self) -> bool:
|
|
197
|
+
return (
|
|
198
|
+
False
|
|
199
|
+
if self.generate_config is None
|
|
200
|
+
else self.generate_config.get("stream", False)
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
@property
|
|
204
|
+
def stream_interval(self) -> int:
|
|
205
|
+
return self.sanitized_generate_config.get("stream_interval", 2)
|
|
206
|
+
|
|
207
|
+
@property
|
|
208
|
+
def include_usage(self) -> bool:
|
|
209
|
+
stream_options = self.sanitized_generate_config.get("stream_options", None)
|
|
210
|
+
include_usage = (
|
|
211
|
+
stream_options["include_usage"]
|
|
212
|
+
if isinstance(stream_options, dict)
|
|
213
|
+
else False
|
|
214
|
+
)
|
|
215
|
+
return include_usage
|
|
216
|
+
|
|
217
|
+
@property
|
|
218
|
+
def aborted(self) -> bool:
|
|
219
|
+
return self._aborted
|
|
220
|
+
|
|
221
|
+
@aborted.setter
|
|
222
|
+
def aborted(self, value: bool):
|
|
223
|
+
self._aborted = value
|
|
224
|
+
|
|
225
|
+
@property
|
|
226
|
+
def request_id(self) -> Optional[str]:
|
|
227
|
+
return (
|
|
228
|
+
None
|
|
229
|
+
if self.generate_config is None
|
|
230
|
+
else self.generate_config.get("request_id", None)
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
@functools.lru_cache
|
|
234
|
+
def get_generate_configs(self, eos_token_id: int):
|
|
235
|
+
from ..types import max_tokens_field
|
|
236
|
+
|
|
237
|
+
max_new_tokens = int(
|
|
238
|
+
self.sanitized_generate_config.get("max_tokens", max_tokens_field.default)
|
|
239
|
+
)
|
|
240
|
+
stream_interval = self.sanitized_generate_config.get("stream_interval", 2)
|
|
241
|
+
include_usage = self.include_usage
|
|
242
|
+
stop_str = self.sanitized_generate_config.get("stop", None)
|
|
243
|
+
stop_token_ids = (
|
|
244
|
+
self.sanitized_generate_config.get("stop_token_ids", None) or []
|
|
245
|
+
)
|
|
246
|
+
stop_token_ids = set(stop_token_ids)
|
|
247
|
+
stop_token_ids.add(eos_token_id)
|
|
248
|
+
temperature = float(self.sanitized_generate_config.get("temperature", 1.0))
|
|
249
|
+
repetition_penalty = float(
|
|
250
|
+
self.sanitized_generate_config.get("repetition_penalty", 1.0)
|
|
251
|
+
)
|
|
252
|
+
top_p = float(self.sanitized_generate_config.get("top_p", 1.0))
|
|
253
|
+
top_k = int(self.sanitized_generate_config.get("top_k", -1)) # -1 means disable
|
|
254
|
+
return (
|
|
255
|
+
max_new_tokens,
|
|
256
|
+
stream_interval,
|
|
257
|
+
include_usage,
|
|
258
|
+
stop_str,
|
|
259
|
+
stop_token_ids,
|
|
260
|
+
temperature,
|
|
261
|
+
repetition_penalty,
|
|
262
|
+
top_p,
|
|
263
|
+
top_k,
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def _get_valid_batch_kv_cache(data, skipped_indexes: Set[int]):
|
|
268
|
+
from transformers.cache_utils import DynamicCache
|
|
269
|
+
|
|
270
|
+
cache = DynamicCache.from_legacy_cache(data)
|
|
271
|
+
batch_size = cache.key_cache[0].shape[0]
|
|
272
|
+
batch_slices = [num for num in range(batch_size) if num not in skipped_indexes]
|
|
273
|
+
for idx in range(len(cache)):
|
|
274
|
+
cache.key_cache[idx] = cache.key_cache[idx][batch_slices, ::]
|
|
275
|
+
cache.value_cache[idx] = cache.value_cache[idx][batch_slices, ::]
|
|
276
|
+
return cache.to_legacy_cache()
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
class SchedulerActor(xo.StatelessActor):
|
|
280
|
+
@classmethod
|
|
281
|
+
def gen_uid(cls, model_uid: str, replica_id: str):
|
|
282
|
+
return f"{model_uid}-{replica_id}-scheduler-actor"
|
|
283
|
+
|
|
284
|
+
def __init__(self):
|
|
285
|
+
super().__init__()
|
|
286
|
+
self._waiting_queue: deque[InferenceRequest] = deque()
|
|
287
|
+
self._running_queue: deque[InferenceRequest] = deque()
|
|
288
|
+
self._model = None
|
|
289
|
+
self._id_to_req = {}
|
|
290
|
+
self._abort_req_ids: Set[str] = set()
|
|
291
|
+
self._isolation = None
|
|
292
|
+
|
|
293
|
+
async def __post_create__(self):
|
|
294
|
+
from ..isolation import Isolation
|
|
295
|
+
|
|
296
|
+
self._isolation = Isolation(
|
|
297
|
+
asyncio.new_event_loop(), threaded=True, daemon=True
|
|
298
|
+
)
|
|
299
|
+
self._isolation.start()
|
|
300
|
+
asyncio.run_coroutine_threadsafe(self.run(), loop=self._isolation.loop)
|
|
301
|
+
|
|
302
|
+
async def __pre_destroy__(self):
|
|
303
|
+
try:
|
|
304
|
+
assert self._isolation is not None
|
|
305
|
+
self._isolation.stop()
|
|
306
|
+
del self._isolation
|
|
307
|
+
except Exception as e:
|
|
308
|
+
logger.debug(
|
|
309
|
+
f"Destroy scheduler actor failed, address: {self.address}, error: {e}"
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
def set_model(self, model):
|
|
313
|
+
self._model = model
|
|
314
|
+
|
|
315
|
+
def get_max_num_seqs(self):
|
|
316
|
+
assert self._model is not None
|
|
317
|
+
return self._model.get_max_num_seqs()
|
|
318
|
+
|
|
319
|
+
def _check_request_aborted(self, req: InferenceRequest):
|
|
320
|
+
if req.request_id and req.request_id in self._abort_req_ids:
|
|
321
|
+
req.aborted = True
|
|
322
|
+
req.stopped = True
|
|
323
|
+
|
|
324
|
+
def _handle_request(self) -> Optional[List[InferenceRequest]]:
|
|
325
|
+
if self._model is None:
|
|
326
|
+
return None
|
|
327
|
+
max_num_seqs = self.get_max_num_seqs()
|
|
328
|
+
# currently, FCFS strategy
|
|
329
|
+
running_list: List[InferenceRequest] = []
|
|
330
|
+
while len(self._running_queue) > 0:
|
|
331
|
+
if len(running_list) == max_num_seqs:
|
|
332
|
+
break
|
|
333
|
+
req = self._running_queue.popleft()
|
|
334
|
+
self._check_request_aborted(req)
|
|
335
|
+
running_list.append(req)
|
|
336
|
+
|
|
337
|
+
waiting_list: List[InferenceRequest] = []
|
|
338
|
+
if len(running_list) < max_num_seqs:
|
|
339
|
+
while len(self._waiting_queue) > 0:
|
|
340
|
+
req = self._waiting_queue.popleft()
|
|
341
|
+
self._check_request_aborted(req)
|
|
342
|
+
waiting_list.append(req)
|
|
343
|
+
if len(running_list) + len(waiting_list) == max_num_seqs:
|
|
344
|
+
break
|
|
345
|
+
# must waiting_list in front
|
|
346
|
+
return waiting_list + running_list
|
|
347
|
+
|
|
348
|
+
@staticmethod
|
|
349
|
+
def _empty_cache():
|
|
350
|
+
from ..model.llm.pytorch.utils import empty_cache
|
|
351
|
+
|
|
352
|
+
empty_cache()
|
|
353
|
+
|
|
354
|
+
async def step(self):
|
|
355
|
+
req_list = self._handle_request()
|
|
356
|
+
if not req_list:
|
|
357
|
+
return
|
|
358
|
+
self._model.batch_inference(req_list)
|
|
359
|
+
|
|
360
|
+
stopped_batch_indexes = set()
|
|
361
|
+
for idx, r in enumerate(req_list):
|
|
362
|
+
if r.stream:
|
|
363
|
+
for completion in r.completion:
|
|
364
|
+
await r.future_or_queue.put(completion)
|
|
365
|
+
r.completion = []
|
|
366
|
+
|
|
367
|
+
if not r.stopped:
|
|
368
|
+
self._running_queue.append(r)
|
|
369
|
+
else:
|
|
370
|
+
if r.new_tokens:
|
|
371
|
+
stopped_batch_indexes.add(idx)
|
|
372
|
+
# set kv_cache to None for collection
|
|
373
|
+
r.kv_cache = None
|
|
374
|
+
rid = r.request_id
|
|
375
|
+
# clear data structure
|
|
376
|
+
if rid is not None:
|
|
377
|
+
self._id_to_req.pop(rid, None)
|
|
378
|
+
self._abort_req_ids.discard(rid)
|
|
379
|
+
|
|
380
|
+
if r.aborted: # stop due to abort
|
|
381
|
+
# handle abort result
|
|
382
|
+
if r.stream:
|
|
383
|
+
await r.future_or_queue.put(XINFERENCE_STREAMING_ABORT_FLAG)
|
|
384
|
+
else:
|
|
385
|
+
r.future_or_queue.set_result(
|
|
386
|
+
XINFERENCE_NON_STREAMING_ABORT_FLAG
|
|
387
|
+
)
|
|
388
|
+
else:
|
|
389
|
+
if r.error_msg is None: # normal stop
|
|
390
|
+
if not r.stream:
|
|
391
|
+
r.future_or_queue.set_result(r.completion[0])
|
|
392
|
+
else:
|
|
393
|
+
await r.future_or_queue.put(XINFERENCE_STREAMING_DONE_FLAG)
|
|
394
|
+
# Abnormal stop, currently indicates that the parameter check does not pass,
|
|
395
|
+
# and does not participate in the inference
|
|
396
|
+
else:
|
|
397
|
+
if not r.stream:
|
|
398
|
+
r.future_or_queue.set_exception(ValueError(r.error_msg))
|
|
399
|
+
else:
|
|
400
|
+
await r.future_or_queue.put(
|
|
401
|
+
XINFERENCE_STREAMING_ERROR_FLAG + r.error_msg
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
# Some requests have been completed. Batch size needs to be reduced for kv cache.
|
|
405
|
+
if stopped_batch_indexes and len(self._running_queue) > 0:
|
|
406
|
+
kv_cache = self._running_queue[0].kv_cache
|
|
407
|
+
reduced_kv_cache = _get_valid_batch_kv_cache(
|
|
408
|
+
kv_cache, stopped_batch_indexes
|
|
409
|
+
)
|
|
410
|
+
for r in self._running_queue:
|
|
411
|
+
r.kv_cache = reduced_kv_cache
|
|
412
|
+
|
|
413
|
+
self._empty_cache()
|
|
414
|
+
|
|
415
|
+
async def add_request(self, prompt: str, future_or_queue, *args, **kwargs):
|
|
416
|
+
req = InferenceRequest(prompt, future_or_queue, True, *args, **kwargs)
|
|
417
|
+
rid = req.request_id
|
|
418
|
+
if rid is not None:
|
|
419
|
+
if rid in self._id_to_req:
|
|
420
|
+
raise KeyError(f"Request id: {rid} has already existed!")
|
|
421
|
+
self._id_to_req[rid] = req
|
|
422
|
+
self._waiting_queue.append(req)
|
|
423
|
+
|
|
424
|
+
async def abort_request(self, req_id: str) -> str:
|
|
425
|
+
"""
|
|
426
|
+
Abort a request.
|
|
427
|
+
Abort a submitted request. If the request is finished or not found, this method will be a no-op.
|
|
428
|
+
"""
|
|
429
|
+
if req_id not in self._id_to_req:
|
|
430
|
+
logger.info(f"Request id: {req_id} not found. No-op for xinference.")
|
|
431
|
+
return AbortRequestMessage.NOT_FOUND.name
|
|
432
|
+
else:
|
|
433
|
+
self._abort_req_ids.add(req_id)
|
|
434
|
+
logger.info(f"Request id: {req_id} found to be aborted.")
|
|
435
|
+
return AbortRequestMessage.DONE.name
|
|
436
|
+
|
|
437
|
+
async def run(self):
|
|
438
|
+
try:
|
|
439
|
+
while True:
|
|
440
|
+
# wait 10ms
|
|
441
|
+
await asyncio.sleep(0.01)
|
|
442
|
+
await self.step()
|
|
443
|
+
except Exception as e:
|
|
444
|
+
logger.exception(
|
|
445
|
+
f"Scheduler actor uid: {self.uid}, address: {self.address} run with error: {e}"
|
|
446
|
+
)
|