xinference 0.11.3__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

@@ -0,0 +1,428 @@
1
+ # Copyright 2022-2024 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import asyncio
16
+ import functools
17
+ import logging
18
+ from collections import deque
19
+ from enum import Enum
20
+ from typing import List, Optional, Set
21
+
22
+ import xoscar as xo
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+ XINFERENCE_BATCHING_CLEAN_CACHE_INTERVAL = 5
27
+ XINFERENCE_STREAMING_DONE_FLAG = "<XINFERENCE_STREAMING_DONE>"
28
+ XINFERENCE_STREAMING_ERROR_FLAG = "<XINFERENCE_STREAMING_ERROR>"
29
+ XINFERENCE_STREAMING_ABORT_FLAG = "<XINFERENCE_STREAMING_ABORT>"
30
+ XINFERENCE_NON_STREAMING_ABORT_FLAG = "<XINFERENCE_NON_STREAMING_ABORT>"
31
+
32
+
33
+ class AbortRequestMessage(Enum):
34
+ NOT_FOUND = 1
35
+ DONE = 2
36
+ NO_OP = 3
37
+
38
+
39
+ class InferenceRequest:
40
+ def __init__(self, prompt, future_or_queue, is_prefill, *args, **kwargs):
41
+ # original prompt
42
+ self._prompt = prompt
43
+ # full prompt that contains chat history and applies chat template
44
+ self._full_prompt = None
45
+ # whether the current request is in the prefill phase
46
+ self._is_prefill = is_prefill
47
+ # full prompt tokens
48
+ self._prompt_tokens = None
49
+ # all new generated tokens during decode phase
50
+ self._new_tokens = []
51
+ # kv_cache used in decode phase
52
+ self._kv_cache = None
53
+ # use passed args from `chat` interface
54
+ self._inference_args = args
55
+ # use passed kwargs from `chat` interface, basically not used for now
56
+ self._inference_kwargs = kwargs
57
+ # should this request be stopped
58
+ self._stopped = False
59
+ # finish reason. If this is set, self._stopped is True.
60
+ self._finish_reason = None
61
+ # should this request be aborted
62
+ # note that when this flag is True, assert self._stopped is True
63
+ self._aborted = False
64
+ # sanitized generate config
65
+ self._sanitized_generate_config = None
66
+ # Use in stream mode
67
+ self.last_output_length = 0
68
+ # inference results,
69
+ # it is a list type because when stream=True,
70
+ # self.completion contains all the results in a decode round.
71
+ self.completion = []
72
+ # The way upstream gets the returned results,
73
+ # when stream=True, it is an asyncio.Queue,
74
+ # and when stream=False, it is an asyncio future.
75
+ self.future_or_queue = future_or_queue
76
+ # Record error message when this request has error.
77
+ # Must set stopped=True when this field is set.
78
+ self.error_msg: Optional[str] = None
79
+
80
+ # check the integrity of args passed upstream
81
+ self._check_args()
82
+
83
+ def _check_args(self):
84
+ assert len(self._inference_args) == 3
85
+ # system prompt
86
+ assert self._inference_args[0] is None or isinstance(
87
+ self._inference_args[0], str
88
+ )
89
+ # chat history
90
+ assert self._inference_args[1] is None or isinstance(
91
+ self._inference_args[1], list
92
+ )
93
+ # generate config
94
+ assert self._inference_args[2] is None or isinstance(
95
+ self._inference_args[2], dict
96
+ )
97
+
98
+ @property
99
+ def prompt(self):
100
+ return self._prompt
101
+
102
+ @property
103
+ def system_prompt(self):
104
+ return self._inference_args[0]
105
+
106
+ @property
107
+ def chat_history(self):
108
+ return self._inference_args[1]
109
+
110
+ @property
111
+ def full_prompt(self):
112
+ return self._full_prompt
113
+
114
+ @full_prompt.setter
115
+ def full_prompt(self, value: str):
116
+ self._full_prompt = value
117
+
118
+ @property
119
+ def is_prefill(self):
120
+ return self._is_prefill
121
+
122
+ @is_prefill.setter
123
+ def is_prefill(self, value: bool):
124
+ self._is_prefill = value
125
+
126
+ @property
127
+ def prompt_tokens(self):
128
+ return self._prompt_tokens
129
+
130
+ @prompt_tokens.setter
131
+ def prompt_tokens(self, value: List[int]):
132
+ self._prompt_tokens = value
133
+
134
+ @property
135
+ def kv_cache(self):
136
+ return self._kv_cache
137
+
138
+ @kv_cache.setter
139
+ def kv_cache(self, value):
140
+ self._kv_cache = value
141
+
142
+ @property
143
+ def new_tokens(self):
144
+ return self._new_tokens
145
+
146
+ def append_new_token(self, token: int):
147
+ self._new_tokens.append(token)
148
+
149
+ @property
150
+ def generate_config(self):
151
+ return self._inference_args[2]
152
+
153
+ @property
154
+ def sanitized_generate_config(self):
155
+ return self._sanitized_generate_config
156
+
157
+ @sanitized_generate_config.setter
158
+ def sanitized_generate_config(self, value: dict):
159
+ self._sanitized_generate_config = value
160
+
161
+ @property
162
+ def stopped(self):
163
+ return self._stopped
164
+
165
+ @stopped.setter
166
+ def stopped(self, value: bool):
167
+ self._stopped = value
168
+
169
+ @property
170
+ def finish_reason(self):
171
+ return self._finish_reason
172
+
173
+ @finish_reason.setter
174
+ def finish_reason(self, value: Optional[str]):
175
+ self._finish_reason = value
176
+
177
+ @property
178
+ def stream(self) -> bool:
179
+ return (
180
+ False
181
+ if self.generate_config is None
182
+ else self.generate_config.get("stream", False)
183
+ )
184
+
185
+ @property
186
+ def stream_interval(self) -> int:
187
+ return self.sanitized_generate_config.get("stream_interval", 2)
188
+
189
+ @property
190
+ def include_usage(self) -> bool:
191
+ stream_options = self.sanitized_generate_config.get("stream_options", None)
192
+ include_usage = (
193
+ stream_options["include_usage"]
194
+ if isinstance(stream_options, dict)
195
+ else False
196
+ )
197
+ return include_usage
198
+
199
+ @property
200
+ def aborted(self) -> bool:
201
+ return self._aborted
202
+
203
+ @aborted.setter
204
+ def aborted(self, value: bool):
205
+ self._aborted = value
206
+
207
+ @property
208
+ def request_id(self) -> Optional[str]:
209
+ return (
210
+ None
211
+ if self.generate_config is None
212
+ else self.generate_config.get("request_id", None)
213
+ )
214
+
215
+ @functools.lru_cache
216
+ def get_generate_configs(self, eos_token_id: int):
217
+ from ..types import max_tokens_field
218
+
219
+ max_new_tokens = int(
220
+ self.sanitized_generate_config.get("max_tokens", max_tokens_field.default)
221
+ )
222
+ stream_interval = self.sanitized_generate_config.get("stream_interval", 2)
223
+ include_usage = self.include_usage
224
+ stop_str = self.sanitized_generate_config.get("stop", None)
225
+ stop_token_ids = (
226
+ self.sanitized_generate_config.get("stop_token_ids", None) or []
227
+ )
228
+ stop_token_ids = set(stop_token_ids)
229
+ stop_token_ids.add(eos_token_id)
230
+ temperature = float(self.sanitized_generate_config.get("temperature", 1.0))
231
+ repetition_penalty = float(
232
+ self.sanitized_generate_config.get("repetition_penalty", 1.0)
233
+ )
234
+ top_p = float(self.sanitized_generate_config.get("top_p", 1.0))
235
+ top_k = int(self.sanitized_generate_config.get("top_k", -1)) # -1 means disable
236
+ return (
237
+ max_new_tokens,
238
+ stream_interval,
239
+ include_usage,
240
+ stop_str,
241
+ stop_token_ids,
242
+ temperature,
243
+ repetition_penalty,
244
+ top_p,
245
+ top_k,
246
+ )
247
+
248
+
249
+ def _get_valid_batch_kv_cache(data, skipped_indexes: Set[int]):
250
+ from transformers.cache_utils import DynamicCache
251
+
252
+ cache = DynamicCache.from_legacy_cache(data)
253
+ batch_size = cache.key_cache[0].shape[0]
254
+ batch_slices = [num for num in range(batch_size) if num not in skipped_indexes]
255
+ for idx in range(len(cache)):
256
+ cache.key_cache[idx] = cache.key_cache[idx][batch_slices, ::]
257
+ cache.value_cache[idx] = cache.value_cache[idx][batch_slices, ::]
258
+ return cache.to_legacy_cache()
259
+
260
+
261
+ class SchedulerActor(xo.StatelessActor):
262
+ @classmethod
263
+ def gen_uid(cls, model_uid: str, replica_id: str):
264
+ return f"{model_uid}-{replica_id}-scheduler-actor"
265
+
266
+ def __init__(self):
267
+ super().__init__()
268
+ self._waiting_queue: deque[InferenceRequest] = deque()
269
+ self._running_queue: deque[InferenceRequest] = deque()
270
+ self._model = None
271
+ self._id_to_req = {}
272
+ self._abort_req_ids: Set[str] = set()
273
+ self._isolation = None
274
+
275
+ async def __post_create__(self):
276
+ from ..isolation import Isolation
277
+
278
+ self._isolation = Isolation(
279
+ asyncio.new_event_loop(), threaded=True, daemon=True
280
+ )
281
+ self._isolation.start()
282
+ asyncio.run_coroutine_threadsafe(self.run(), loop=self._isolation.loop)
283
+
284
+ async def __pre_destroy__(self):
285
+ try:
286
+ assert self._isolation is not None
287
+ self._isolation.stop()
288
+ del self._isolation
289
+ except Exception as e:
290
+ logger.debug(
291
+ f"Destroy scheduler actor failed, address: {self.address}, error: {e}"
292
+ )
293
+
294
+ def set_model(self, model):
295
+ self._model = model
296
+
297
+ def get_max_num_seqs(self):
298
+ assert self._model is not None
299
+ return self._model.get_max_num_seqs()
300
+
301
+ def _check_request_aborted(self, req: InferenceRequest):
302
+ if req.request_id and req.request_id in self._abort_req_ids:
303
+ req.aborted = True
304
+ req.stopped = True
305
+
306
+ def _handle_request(self) -> Optional[List[InferenceRequest]]:
307
+ if self._model is None:
308
+ return None
309
+ max_num_seqs = self.get_max_num_seqs()
310
+ # currently, FCFS strategy
311
+ running_list: List[InferenceRequest] = []
312
+ while len(self._running_queue) > 0:
313
+ if len(running_list) == max_num_seqs:
314
+ break
315
+ req = self._running_queue.popleft()
316
+ self._check_request_aborted(req)
317
+ running_list.append(req)
318
+
319
+ waiting_list: List[InferenceRequest] = []
320
+ if len(running_list) < max_num_seqs:
321
+ while len(self._waiting_queue) > 0:
322
+ req = self._waiting_queue.popleft()
323
+ self._check_request_aborted(req)
324
+ waiting_list.append(req)
325
+ if len(running_list) + len(waiting_list) == max_num_seqs:
326
+ break
327
+ # must waiting_list in front
328
+ return waiting_list + running_list
329
+
330
+ @staticmethod
331
+ def _empty_cache():
332
+ from ..model.llm.pytorch.utils import empty_cache
333
+
334
+ empty_cache()
335
+
336
+ async def step(self):
337
+ req_list = self._handle_request()
338
+ if not req_list:
339
+ return
340
+ self._model.batch_inference(req_list)
341
+
342
+ stopped_batch_indexes = set()
343
+ for idx, r in enumerate(req_list):
344
+ if r.stream:
345
+ for completion in r.completion:
346
+ await r.future_or_queue.put(completion)
347
+ r.completion = []
348
+
349
+ if not r.stopped:
350
+ self._running_queue.append(r)
351
+ else:
352
+ if r.new_tokens:
353
+ stopped_batch_indexes.add(idx)
354
+ # set kv_cache to None for collection
355
+ r.kv_cache = None
356
+ rid = r.request_id
357
+ # clear data structure
358
+ if rid is not None:
359
+ self._id_to_req.pop(rid, None)
360
+ self._abort_req_ids.discard(rid)
361
+
362
+ if r.aborted: # stop due to abort
363
+ # handle abort result
364
+ if r.stream:
365
+ await r.future_or_queue.put(XINFERENCE_STREAMING_ABORT_FLAG)
366
+ else:
367
+ r.future_or_queue.set_result(
368
+ XINFERENCE_NON_STREAMING_ABORT_FLAG
369
+ )
370
+ else:
371
+ if r.error_msg is None: # normal stop
372
+ if not r.stream:
373
+ r.future_or_queue.set_result(r.completion[0])
374
+ else:
375
+ await r.future_or_queue.put(XINFERENCE_STREAMING_DONE_FLAG)
376
+ # Abnormal stop, currently indicates that the parameter check does not pass,
377
+ # and does not participate in the inference
378
+ else:
379
+ if not r.stream:
380
+ r.future_or_queue.set_exception(ValueError(r.error_msg))
381
+ else:
382
+ await r.future_or_queue.put(
383
+ XINFERENCE_STREAMING_ERROR_FLAG + r.error_msg
384
+ )
385
+
386
+ # Some requests have been completed. Batch size needs to be reduced for kv cache.
387
+ if stopped_batch_indexes and len(self._running_queue) > 0:
388
+ kv_cache = self._running_queue[0].kv_cache
389
+ reduced_kv_cache = _get_valid_batch_kv_cache(
390
+ kv_cache, stopped_batch_indexes
391
+ )
392
+ for r in self._running_queue:
393
+ r.kv_cache = reduced_kv_cache
394
+
395
+ self._empty_cache()
396
+
397
+ async def add_request(self, prompt: str, future_or_queue, *args, **kwargs):
398
+ req = InferenceRequest(prompt, future_or_queue, True, *args, **kwargs)
399
+ rid = req.request_id
400
+ if rid is not None:
401
+ if rid in self._id_to_req:
402
+ raise KeyError(f"Request id: {rid} has already existed!")
403
+ self._id_to_req[rid] = req
404
+ self._waiting_queue.append(req)
405
+
406
+ async def abort_request(self, req_id: str) -> str:
407
+ """
408
+ Abort a request.
409
+ Abort a submitted request. If the request is finished or not found, this method will be a no-op.
410
+ """
411
+ if req_id not in self._id_to_req:
412
+ logger.info(f"Request id: {req_id} not found. No-op for xinference.")
413
+ return AbortRequestMessage.NOT_FOUND.name
414
+ else:
415
+ self._abort_req_ids.add(req_id)
416
+ logger.info(f"Request id: {req_id} found to be aborted.")
417
+ return AbortRequestMessage.DONE.name
418
+
419
+ async def run(self):
420
+ try:
421
+ while True:
422
+ # wait 10ms
423
+ await asyncio.sleep(0.01)
424
+ await self.step()
425
+ except Exception as e:
426
+ logger.exception(
427
+ f"Scheduler actor uid: {self.uid}, address: {self.address} run with error: {e}"
428
+ )
@@ -1010,6 +1010,32 @@ class SupervisorActor(xo.StatelessActor):
1010
1010
  cached_models.append(cache_entry)
1011
1011
  return cached_models
1012
1012
 
1013
+ @log_async(logger=logger)
1014
+ async def abort_request(self, model_uid: str, request_id: str) -> Dict:
1015
+ from .scheduler import AbortRequestMessage
1016
+
1017
+ res = {"msg": AbortRequestMessage.NO_OP.name}
1018
+ replica_info = self._model_uid_to_replica_info.get(model_uid, None)
1019
+ if not replica_info:
1020
+ return res
1021
+ replica_cnt = replica_info.replica
1022
+
1023
+ # Query all replicas
1024
+ for rep_mid in iter_replica_model_uid(model_uid, replica_cnt):
1025
+ worker_ref = self._replica_model_uid_to_worker.get(rep_mid, None)
1026
+ if worker_ref is None:
1027
+ continue
1028
+ model_ref = await worker_ref.get_model(model_uid=rep_mid)
1029
+ result_info = await model_ref.abort_request(request_id)
1030
+ res["msg"] = result_info
1031
+ if result_info == AbortRequestMessage.DONE.name:
1032
+ break
1033
+ elif result_info == AbortRequestMessage.NOT_FOUND.name:
1034
+ logger.debug(f"Request id: {request_id} not found for model {rep_mid}")
1035
+ else:
1036
+ logger.debug(f"No-op for model {rep_mid}")
1037
+ return res
1038
+
1013
1039
  @log_async(logger=logger)
1014
1040
  async def add_worker(self, worker_address: str):
1015
1041
  from .worker import WorkerActor
xinference/isolation.py CHANGED
@@ -19,13 +19,19 @@ from typing import Any, Coroutine
19
19
 
20
20
  class Isolation:
21
21
  # TODO: better move isolation to xoscar.
22
- def __init__(self, loop: asyncio.AbstractEventLoop, threaded: bool = True):
22
+ def __init__(
23
+ self,
24
+ loop: asyncio.AbstractEventLoop,
25
+ threaded: bool = True,
26
+ daemon: bool = True,
27
+ ):
23
28
  self._loop = loop
24
29
  self._threaded = threaded
25
30
 
26
31
  self._stopped = None
27
32
  self._thread = None
28
33
  self._thread_ident = None
34
+ self._daemon = daemon
29
35
 
30
36
  def _run(self):
31
37
  asyncio.set_event_loop(self._loop)
@@ -35,7 +41,8 @@ class Isolation:
35
41
  def start(self):
36
42
  if self._threaded:
37
43
  self._thread = thread = threading.Thread(target=self._run)
38
- thread.daemon = True
44
+ if self._daemon:
45
+ thread.daemon = True
39
46
  thread.start()
40
47
  self._thread_ident = thread.ident
41
48
 
@@ -0,0 +1,84 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ from io import BytesIO
16
+ from typing import TYPE_CHECKING, Optional
17
+
18
+ if TYPE_CHECKING:
19
+ from .core import AudioModelFamilyV1
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class ChatTTSModel:
25
+ def __init__(
26
+ self,
27
+ model_uid: str,
28
+ model_path: str,
29
+ model_spec: "AudioModelFamilyV1",
30
+ device: Optional[str] = None,
31
+ **kwargs,
32
+ ):
33
+ self._model_uid = model_uid
34
+ self._model_path = model_path
35
+ self._model_spec = model_spec
36
+ self._device = device
37
+ self._model = None
38
+ self._kwargs = kwargs
39
+
40
+ def load(self):
41
+ import torch
42
+
43
+ from xinference.thirdparty import ChatTTS
44
+
45
+ torch._dynamo.config.cache_size_limit = 64
46
+ torch._dynamo.config.suppress_errors = True
47
+ torch.set_float32_matmul_precision("high")
48
+ self._model = ChatTTS.Chat()
49
+ self._model.load_models(
50
+ source="local", local_path=self._model_path, compile=True
51
+ )
52
+
53
+ def speech(
54
+ self, input: str, voice: str, response_format: str = "mp3", speed: float = 1.0
55
+ ):
56
+ import numpy as np
57
+ import torch
58
+ import torchaudio
59
+ import xxhash
60
+
61
+ seed = xxhash.xxh32_intdigest(voice)
62
+
63
+ torch.manual_seed(seed)
64
+ np.random.seed(seed)
65
+ torch.cuda.manual_seed(seed)
66
+ torch.backends.cudnn.deterministic = True
67
+ torch.backends.cudnn.benchmark = False
68
+
69
+ assert self._model is not None
70
+ rnd_spk_emb = self._model.sample_random_speaker()
71
+
72
+ default = 5
73
+ infer_speed = int(default * speed)
74
+ params_infer_code = {"spk_emb": rnd_spk_emb, "prompt": f"[speed_{infer_speed}]"}
75
+
76
+ assert self._model is not None
77
+ wavs = self._model.infer([input], params_infer_code=params_infer_code)
78
+
79
+ # Save the generated audio
80
+ with BytesIO() as out:
81
+ torchaudio.save(
82
+ out, torch.from_numpy(wavs[0]), 24000, format=response_format
83
+ )
84
+ return out.getvalue()
@@ -14,11 +14,12 @@
14
14
  import logging
15
15
  import os
16
16
  from collections import defaultdict
17
- from typing import Dict, List, Optional, Tuple
17
+ from typing import Dict, List, Optional, Tuple, Union
18
18
 
19
19
  from ...constants import XINFERENCE_CACHE_DIR
20
20
  from ..core import CacheableModelSpec, ModelDescription
21
21
  from ..utils import valid_model_revision
22
+ from .chattts import ChatTTSModel
22
23
  from .whisper import WhisperModel
23
24
 
24
25
  MAX_ATTEMPTS = 3
@@ -130,10 +131,16 @@ def get_cache_status(
130
131
 
131
132
  def create_audio_model_instance(
132
133
  subpool_addr: str, devices: List[str], model_uid: str, model_name: str, **kwargs
133
- ) -> Tuple[WhisperModel, AudioModelDescription]:
134
+ ) -> Tuple[Union[WhisperModel, ChatTTSModel], AudioModelDescription]:
134
135
  model_spec = match_audio(model_name)
135
136
  model_path = cache(model_spec)
136
- model = WhisperModel(model_uid, model_path, model_spec, **kwargs)
137
+ model: Union[WhisperModel, ChatTTSModel]
138
+ if model_spec.model_family == "whisper":
139
+ model = WhisperModel(model_uid, model_path, model_spec, **kwargs)
140
+ elif model_spec.model_family == "ChatTTS":
141
+ model = ChatTTSModel(model_uid, model_path, model_spec, **kwargs)
142
+ else:
143
+ raise Exception(f"Unsupported audio model family: {model_spec.model_family}")
137
144
  model_description = AudioModelDescription(
138
145
  subpool_addr, devices, model_spec, model_path=model_path
139
146
  )