xinference 0.11.2.post1__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +83 -8
- xinference/client/restful/restful_client.py +70 -0
- xinference/constants.py +8 -0
- xinference/core/__init__.py +0 -2
- xinference/core/cache_tracker.py +22 -1
- xinference/core/chat_interface.py +71 -10
- xinference/core/model.py +141 -12
- xinference/core/scheduler.py +428 -0
- xinference/core/supervisor.py +31 -3
- xinference/core/worker.py +8 -3
- xinference/isolation.py +9 -2
- xinference/model/audio/chattts.py +84 -0
- xinference/model/audio/core.py +10 -3
- xinference/model/audio/model_spec.json +20 -0
- xinference/model/llm/__init__.py +6 -0
- xinference/model/llm/llm_family.json +1063 -260
- xinference/model/llm/llm_family_modelscope.json +686 -13
- xinference/model/llm/pytorch/baichuan.py +2 -1
- xinference/model/llm/pytorch/chatglm.py +2 -1
- xinference/model/llm/pytorch/cogvlm2.py +316 -0
- xinference/model/llm/pytorch/core.py +92 -6
- xinference/model/llm/pytorch/glm4v.py +258 -0
- xinference/model/llm/pytorch/intern_vl.py +5 -10
- xinference/model/llm/pytorch/minicpmv25.py +232 -0
- xinference/model/llm/pytorch/utils.py +386 -2
- xinference/model/llm/vllm/core.py +7 -1
- xinference/thirdparty/ChatTTS/__init__.py +1 -0
- xinference/thirdparty/ChatTTS/core.py +200 -0
- xinference/types.py +3 -0
- {xinference-0.11.2.post1.dist-info → xinference-0.12.0.dist-info}/METADATA +28 -11
- {xinference-0.11.2.post1.dist-info → xinference-0.12.0.dist-info}/RECORD +36 -29
- {xinference-0.11.2.post1.dist-info → xinference-0.12.0.dist-info}/LICENSE +0 -0
- {xinference-0.11.2.post1.dist-info → xinference-0.12.0.dist-info}/WHEEL +0 -0
- {xinference-0.11.2.post1.dist-info → xinference-0.12.0.dist-info}/entry_points.txt +0 -0
- {xinference-0.11.2.post1.dist-info → xinference-0.12.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,428 @@
|
|
|
1
|
+
# Copyright 2022-2024 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import asyncio
|
|
16
|
+
import functools
|
|
17
|
+
import logging
|
|
18
|
+
from collections import deque
|
|
19
|
+
from enum import Enum
|
|
20
|
+
from typing import List, Optional, Set
|
|
21
|
+
|
|
22
|
+
import xoscar as xo
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
XINFERENCE_BATCHING_CLEAN_CACHE_INTERVAL = 5
|
|
27
|
+
XINFERENCE_STREAMING_DONE_FLAG = "<XINFERENCE_STREAMING_DONE>"
|
|
28
|
+
XINFERENCE_STREAMING_ERROR_FLAG = "<XINFERENCE_STREAMING_ERROR>"
|
|
29
|
+
XINFERENCE_STREAMING_ABORT_FLAG = "<XINFERENCE_STREAMING_ABORT>"
|
|
30
|
+
XINFERENCE_NON_STREAMING_ABORT_FLAG = "<XINFERENCE_NON_STREAMING_ABORT>"
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class AbortRequestMessage(Enum):
|
|
34
|
+
NOT_FOUND = 1
|
|
35
|
+
DONE = 2
|
|
36
|
+
NO_OP = 3
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class InferenceRequest:
|
|
40
|
+
def __init__(self, prompt, future_or_queue, is_prefill, *args, **kwargs):
|
|
41
|
+
# original prompt
|
|
42
|
+
self._prompt = prompt
|
|
43
|
+
# full prompt that contains chat history and applies chat template
|
|
44
|
+
self._full_prompt = None
|
|
45
|
+
# whether the current request is in the prefill phase
|
|
46
|
+
self._is_prefill = is_prefill
|
|
47
|
+
# full prompt tokens
|
|
48
|
+
self._prompt_tokens = None
|
|
49
|
+
# all new generated tokens during decode phase
|
|
50
|
+
self._new_tokens = []
|
|
51
|
+
# kv_cache used in decode phase
|
|
52
|
+
self._kv_cache = None
|
|
53
|
+
# use passed args from `chat` interface
|
|
54
|
+
self._inference_args = args
|
|
55
|
+
# use passed kwargs from `chat` interface, basically not used for now
|
|
56
|
+
self._inference_kwargs = kwargs
|
|
57
|
+
# should this request be stopped
|
|
58
|
+
self._stopped = False
|
|
59
|
+
# finish reason. If this is set, self._stopped is True.
|
|
60
|
+
self._finish_reason = None
|
|
61
|
+
# should this request be aborted
|
|
62
|
+
# note that when this flag is True, assert self._stopped is True
|
|
63
|
+
self._aborted = False
|
|
64
|
+
# sanitized generate config
|
|
65
|
+
self._sanitized_generate_config = None
|
|
66
|
+
# Use in stream mode
|
|
67
|
+
self.last_output_length = 0
|
|
68
|
+
# inference results,
|
|
69
|
+
# it is a list type because when stream=True,
|
|
70
|
+
# self.completion contains all the results in a decode round.
|
|
71
|
+
self.completion = []
|
|
72
|
+
# The way upstream gets the returned results,
|
|
73
|
+
# when stream=True, it is an asyncio.Queue,
|
|
74
|
+
# and when stream=False, it is an asyncio future.
|
|
75
|
+
self.future_or_queue = future_or_queue
|
|
76
|
+
# Record error message when this request has error.
|
|
77
|
+
# Must set stopped=True when this field is set.
|
|
78
|
+
self.error_msg: Optional[str] = None
|
|
79
|
+
|
|
80
|
+
# check the integrity of args passed upstream
|
|
81
|
+
self._check_args()
|
|
82
|
+
|
|
83
|
+
def _check_args(self):
|
|
84
|
+
assert len(self._inference_args) == 3
|
|
85
|
+
# system prompt
|
|
86
|
+
assert self._inference_args[0] is None or isinstance(
|
|
87
|
+
self._inference_args[0], str
|
|
88
|
+
)
|
|
89
|
+
# chat history
|
|
90
|
+
assert self._inference_args[1] is None or isinstance(
|
|
91
|
+
self._inference_args[1], list
|
|
92
|
+
)
|
|
93
|
+
# generate config
|
|
94
|
+
assert self._inference_args[2] is None or isinstance(
|
|
95
|
+
self._inference_args[2], dict
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
@property
|
|
99
|
+
def prompt(self):
|
|
100
|
+
return self._prompt
|
|
101
|
+
|
|
102
|
+
@property
|
|
103
|
+
def system_prompt(self):
|
|
104
|
+
return self._inference_args[0]
|
|
105
|
+
|
|
106
|
+
@property
|
|
107
|
+
def chat_history(self):
|
|
108
|
+
return self._inference_args[1]
|
|
109
|
+
|
|
110
|
+
@property
|
|
111
|
+
def full_prompt(self):
|
|
112
|
+
return self._full_prompt
|
|
113
|
+
|
|
114
|
+
@full_prompt.setter
|
|
115
|
+
def full_prompt(self, value: str):
|
|
116
|
+
self._full_prompt = value
|
|
117
|
+
|
|
118
|
+
@property
|
|
119
|
+
def is_prefill(self):
|
|
120
|
+
return self._is_prefill
|
|
121
|
+
|
|
122
|
+
@is_prefill.setter
|
|
123
|
+
def is_prefill(self, value: bool):
|
|
124
|
+
self._is_prefill = value
|
|
125
|
+
|
|
126
|
+
@property
|
|
127
|
+
def prompt_tokens(self):
|
|
128
|
+
return self._prompt_tokens
|
|
129
|
+
|
|
130
|
+
@prompt_tokens.setter
|
|
131
|
+
def prompt_tokens(self, value: List[int]):
|
|
132
|
+
self._prompt_tokens = value
|
|
133
|
+
|
|
134
|
+
@property
|
|
135
|
+
def kv_cache(self):
|
|
136
|
+
return self._kv_cache
|
|
137
|
+
|
|
138
|
+
@kv_cache.setter
|
|
139
|
+
def kv_cache(self, value):
|
|
140
|
+
self._kv_cache = value
|
|
141
|
+
|
|
142
|
+
@property
|
|
143
|
+
def new_tokens(self):
|
|
144
|
+
return self._new_tokens
|
|
145
|
+
|
|
146
|
+
def append_new_token(self, token: int):
|
|
147
|
+
self._new_tokens.append(token)
|
|
148
|
+
|
|
149
|
+
@property
|
|
150
|
+
def generate_config(self):
|
|
151
|
+
return self._inference_args[2]
|
|
152
|
+
|
|
153
|
+
@property
|
|
154
|
+
def sanitized_generate_config(self):
|
|
155
|
+
return self._sanitized_generate_config
|
|
156
|
+
|
|
157
|
+
@sanitized_generate_config.setter
|
|
158
|
+
def sanitized_generate_config(self, value: dict):
|
|
159
|
+
self._sanitized_generate_config = value
|
|
160
|
+
|
|
161
|
+
@property
|
|
162
|
+
def stopped(self):
|
|
163
|
+
return self._stopped
|
|
164
|
+
|
|
165
|
+
@stopped.setter
|
|
166
|
+
def stopped(self, value: bool):
|
|
167
|
+
self._stopped = value
|
|
168
|
+
|
|
169
|
+
@property
|
|
170
|
+
def finish_reason(self):
|
|
171
|
+
return self._finish_reason
|
|
172
|
+
|
|
173
|
+
@finish_reason.setter
|
|
174
|
+
def finish_reason(self, value: Optional[str]):
|
|
175
|
+
self._finish_reason = value
|
|
176
|
+
|
|
177
|
+
@property
|
|
178
|
+
def stream(self) -> bool:
|
|
179
|
+
return (
|
|
180
|
+
False
|
|
181
|
+
if self.generate_config is None
|
|
182
|
+
else self.generate_config.get("stream", False)
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
@property
|
|
186
|
+
def stream_interval(self) -> int:
|
|
187
|
+
return self.sanitized_generate_config.get("stream_interval", 2)
|
|
188
|
+
|
|
189
|
+
@property
|
|
190
|
+
def include_usage(self) -> bool:
|
|
191
|
+
stream_options = self.sanitized_generate_config.get("stream_options", None)
|
|
192
|
+
include_usage = (
|
|
193
|
+
stream_options["include_usage"]
|
|
194
|
+
if isinstance(stream_options, dict)
|
|
195
|
+
else False
|
|
196
|
+
)
|
|
197
|
+
return include_usage
|
|
198
|
+
|
|
199
|
+
@property
|
|
200
|
+
def aborted(self) -> bool:
|
|
201
|
+
return self._aborted
|
|
202
|
+
|
|
203
|
+
@aborted.setter
|
|
204
|
+
def aborted(self, value: bool):
|
|
205
|
+
self._aborted = value
|
|
206
|
+
|
|
207
|
+
@property
|
|
208
|
+
def request_id(self) -> Optional[str]:
|
|
209
|
+
return (
|
|
210
|
+
None
|
|
211
|
+
if self.generate_config is None
|
|
212
|
+
else self.generate_config.get("request_id", None)
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
@functools.lru_cache
|
|
216
|
+
def get_generate_configs(self, eos_token_id: int):
|
|
217
|
+
from ..types import max_tokens_field
|
|
218
|
+
|
|
219
|
+
max_new_tokens = int(
|
|
220
|
+
self.sanitized_generate_config.get("max_tokens", max_tokens_field.default)
|
|
221
|
+
)
|
|
222
|
+
stream_interval = self.sanitized_generate_config.get("stream_interval", 2)
|
|
223
|
+
include_usage = self.include_usage
|
|
224
|
+
stop_str = self.sanitized_generate_config.get("stop", None)
|
|
225
|
+
stop_token_ids = (
|
|
226
|
+
self.sanitized_generate_config.get("stop_token_ids", None) or []
|
|
227
|
+
)
|
|
228
|
+
stop_token_ids = set(stop_token_ids)
|
|
229
|
+
stop_token_ids.add(eos_token_id)
|
|
230
|
+
temperature = float(self.sanitized_generate_config.get("temperature", 1.0))
|
|
231
|
+
repetition_penalty = float(
|
|
232
|
+
self.sanitized_generate_config.get("repetition_penalty", 1.0)
|
|
233
|
+
)
|
|
234
|
+
top_p = float(self.sanitized_generate_config.get("top_p", 1.0))
|
|
235
|
+
top_k = int(self.sanitized_generate_config.get("top_k", -1)) # -1 means disable
|
|
236
|
+
return (
|
|
237
|
+
max_new_tokens,
|
|
238
|
+
stream_interval,
|
|
239
|
+
include_usage,
|
|
240
|
+
stop_str,
|
|
241
|
+
stop_token_ids,
|
|
242
|
+
temperature,
|
|
243
|
+
repetition_penalty,
|
|
244
|
+
top_p,
|
|
245
|
+
top_k,
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def _get_valid_batch_kv_cache(data, skipped_indexes: Set[int]):
|
|
250
|
+
from transformers.cache_utils import DynamicCache
|
|
251
|
+
|
|
252
|
+
cache = DynamicCache.from_legacy_cache(data)
|
|
253
|
+
batch_size = cache.key_cache[0].shape[0]
|
|
254
|
+
batch_slices = [num for num in range(batch_size) if num not in skipped_indexes]
|
|
255
|
+
for idx in range(len(cache)):
|
|
256
|
+
cache.key_cache[idx] = cache.key_cache[idx][batch_slices, ::]
|
|
257
|
+
cache.value_cache[idx] = cache.value_cache[idx][batch_slices, ::]
|
|
258
|
+
return cache.to_legacy_cache()
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
class SchedulerActor(xo.StatelessActor):
|
|
262
|
+
@classmethod
|
|
263
|
+
def gen_uid(cls, model_uid: str, replica_id: str):
|
|
264
|
+
return f"{model_uid}-{replica_id}-scheduler-actor"
|
|
265
|
+
|
|
266
|
+
def __init__(self):
|
|
267
|
+
super().__init__()
|
|
268
|
+
self._waiting_queue: deque[InferenceRequest] = deque()
|
|
269
|
+
self._running_queue: deque[InferenceRequest] = deque()
|
|
270
|
+
self._model = None
|
|
271
|
+
self._id_to_req = {}
|
|
272
|
+
self._abort_req_ids: Set[str] = set()
|
|
273
|
+
self._isolation = None
|
|
274
|
+
|
|
275
|
+
async def __post_create__(self):
|
|
276
|
+
from ..isolation import Isolation
|
|
277
|
+
|
|
278
|
+
self._isolation = Isolation(
|
|
279
|
+
asyncio.new_event_loop(), threaded=True, daemon=True
|
|
280
|
+
)
|
|
281
|
+
self._isolation.start()
|
|
282
|
+
asyncio.run_coroutine_threadsafe(self.run(), loop=self._isolation.loop)
|
|
283
|
+
|
|
284
|
+
async def __pre_destroy__(self):
|
|
285
|
+
try:
|
|
286
|
+
assert self._isolation is not None
|
|
287
|
+
self._isolation.stop()
|
|
288
|
+
del self._isolation
|
|
289
|
+
except Exception as e:
|
|
290
|
+
logger.debug(
|
|
291
|
+
f"Destroy scheduler actor failed, address: {self.address}, error: {e}"
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
def set_model(self, model):
|
|
295
|
+
self._model = model
|
|
296
|
+
|
|
297
|
+
def get_max_num_seqs(self):
|
|
298
|
+
assert self._model is not None
|
|
299
|
+
return self._model.get_max_num_seqs()
|
|
300
|
+
|
|
301
|
+
def _check_request_aborted(self, req: InferenceRequest):
|
|
302
|
+
if req.request_id and req.request_id in self._abort_req_ids:
|
|
303
|
+
req.aborted = True
|
|
304
|
+
req.stopped = True
|
|
305
|
+
|
|
306
|
+
def _handle_request(self) -> Optional[List[InferenceRequest]]:
|
|
307
|
+
if self._model is None:
|
|
308
|
+
return None
|
|
309
|
+
max_num_seqs = self.get_max_num_seqs()
|
|
310
|
+
# currently, FCFS strategy
|
|
311
|
+
running_list: List[InferenceRequest] = []
|
|
312
|
+
while len(self._running_queue) > 0:
|
|
313
|
+
if len(running_list) == max_num_seqs:
|
|
314
|
+
break
|
|
315
|
+
req = self._running_queue.popleft()
|
|
316
|
+
self._check_request_aborted(req)
|
|
317
|
+
running_list.append(req)
|
|
318
|
+
|
|
319
|
+
waiting_list: List[InferenceRequest] = []
|
|
320
|
+
if len(running_list) < max_num_seqs:
|
|
321
|
+
while len(self._waiting_queue) > 0:
|
|
322
|
+
req = self._waiting_queue.popleft()
|
|
323
|
+
self._check_request_aborted(req)
|
|
324
|
+
waiting_list.append(req)
|
|
325
|
+
if len(running_list) + len(waiting_list) == max_num_seqs:
|
|
326
|
+
break
|
|
327
|
+
# must waiting_list in front
|
|
328
|
+
return waiting_list + running_list
|
|
329
|
+
|
|
330
|
+
@staticmethod
|
|
331
|
+
def _empty_cache():
|
|
332
|
+
from ..model.llm.pytorch.utils import empty_cache
|
|
333
|
+
|
|
334
|
+
empty_cache()
|
|
335
|
+
|
|
336
|
+
async def step(self):
|
|
337
|
+
req_list = self._handle_request()
|
|
338
|
+
if not req_list:
|
|
339
|
+
return
|
|
340
|
+
self._model.batch_inference(req_list)
|
|
341
|
+
|
|
342
|
+
stopped_batch_indexes = set()
|
|
343
|
+
for idx, r in enumerate(req_list):
|
|
344
|
+
if r.stream:
|
|
345
|
+
for completion in r.completion:
|
|
346
|
+
await r.future_or_queue.put(completion)
|
|
347
|
+
r.completion = []
|
|
348
|
+
|
|
349
|
+
if not r.stopped:
|
|
350
|
+
self._running_queue.append(r)
|
|
351
|
+
else:
|
|
352
|
+
if r.new_tokens:
|
|
353
|
+
stopped_batch_indexes.add(idx)
|
|
354
|
+
# set kv_cache to None for collection
|
|
355
|
+
r.kv_cache = None
|
|
356
|
+
rid = r.request_id
|
|
357
|
+
# clear data structure
|
|
358
|
+
if rid is not None:
|
|
359
|
+
self._id_to_req.pop(rid, None)
|
|
360
|
+
self._abort_req_ids.discard(rid)
|
|
361
|
+
|
|
362
|
+
if r.aborted: # stop due to abort
|
|
363
|
+
# handle abort result
|
|
364
|
+
if r.stream:
|
|
365
|
+
await r.future_or_queue.put(XINFERENCE_STREAMING_ABORT_FLAG)
|
|
366
|
+
else:
|
|
367
|
+
r.future_or_queue.set_result(
|
|
368
|
+
XINFERENCE_NON_STREAMING_ABORT_FLAG
|
|
369
|
+
)
|
|
370
|
+
else:
|
|
371
|
+
if r.error_msg is None: # normal stop
|
|
372
|
+
if not r.stream:
|
|
373
|
+
r.future_or_queue.set_result(r.completion[0])
|
|
374
|
+
else:
|
|
375
|
+
await r.future_or_queue.put(XINFERENCE_STREAMING_DONE_FLAG)
|
|
376
|
+
# Abnormal stop, currently indicates that the parameter check does not pass,
|
|
377
|
+
# and does not participate in the inference
|
|
378
|
+
else:
|
|
379
|
+
if not r.stream:
|
|
380
|
+
r.future_or_queue.set_exception(ValueError(r.error_msg))
|
|
381
|
+
else:
|
|
382
|
+
await r.future_or_queue.put(
|
|
383
|
+
XINFERENCE_STREAMING_ERROR_FLAG + r.error_msg
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
# Some requests have been completed. Batch size needs to be reduced for kv cache.
|
|
387
|
+
if stopped_batch_indexes and len(self._running_queue) > 0:
|
|
388
|
+
kv_cache = self._running_queue[0].kv_cache
|
|
389
|
+
reduced_kv_cache = _get_valid_batch_kv_cache(
|
|
390
|
+
kv_cache, stopped_batch_indexes
|
|
391
|
+
)
|
|
392
|
+
for r in self._running_queue:
|
|
393
|
+
r.kv_cache = reduced_kv_cache
|
|
394
|
+
|
|
395
|
+
self._empty_cache()
|
|
396
|
+
|
|
397
|
+
async def add_request(self, prompt: str, future_or_queue, *args, **kwargs):
|
|
398
|
+
req = InferenceRequest(prompt, future_or_queue, True, *args, **kwargs)
|
|
399
|
+
rid = req.request_id
|
|
400
|
+
if rid is not None:
|
|
401
|
+
if rid in self._id_to_req:
|
|
402
|
+
raise KeyError(f"Request id: {rid} has already existed!")
|
|
403
|
+
self._id_to_req[rid] = req
|
|
404
|
+
self._waiting_queue.append(req)
|
|
405
|
+
|
|
406
|
+
async def abort_request(self, req_id: str) -> str:
|
|
407
|
+
"""
|
|
408
|
+
Abort a request.
|
|
409
|
+
Abort a submitted request. If the request is finished or not found, this method will be a no-op.
|
|
410
|
+
"""
|
|
411
|
+
if req_id not in self._id_to_req:
|
|
412
|
+
logger.info(f"Request id: {req_id} not found. No-op for xinference.")
|
|
413
|
+
return AbortRequestMessage.NOT_FOUND.name
|
|
414
|
+
else:
|
|
415
|
+
self._abort_req_ids.add(req_id)
|
|
416
|
+
logger.info(f"Request id: {req_id} found to be aborted.")
|
|
417
|
+
return AbortRequestMessage.DONE.name
|
|
418
|
+
|
|
419
|
+
async def run(self):
|
|
420
|
+
try:
|
|
421
|
+
while True:
|
|
422
|
+
# wait 10ms
|
|
423
|
+
await asyncio.sleep(0.01)
|
|
424
|
+
await self.step()
|
|
425
|
+
except Exception as e:
|
|
426
|
+
logger.exception(
|
|
427
|
+
f"Scheduler actor uid: {self.uid}, address: {self.address} run with error: {e}"
|
|
428
|
+
)
|
xinference/core/supervisor.py
CHANGED
|
@@ -28,7 +28,7 @@ from ..constants import (
|
|
|
28
28
|
XINFERENCE_HEALTH_CHECK_INTERVAL,
|
|
29
29
|
XINFERENCE_HEALTH_CHECK_TIMEOUT,
|
|
30
30
|
)
|
|
31
|
-
from ..core import ModelActor
|
|
31
|
+
from ..core.model import ModelActor
|
|
32
32
|
from ..core.status_guard import InstanceInfo, LaunchStatus
|
|
33
33
|
from ..types import PeftModelConfig
|
|
34
34
|
from .metrics import record_metrics
|
|
@@ -993,8 +993,9 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
993
993
|
"model_size_in_billions", None
|
|
994
994
|
)
|
|
995
995
|
quantizations = model_version.get("quantization", None)
|
|
996
|
-
|
|
997
|
-
|
|
996
|
+
actor_ip_address = model_version.get("actor_ip_address", None)
|
|
997
|
+
path = model_version.get("path", None)
|
|
998
|
+
real_path = model_version.get("real_path", None)
|
|
998
999
|
|
|
999
1000
|
cache_entry = {
|
|
1000
1001
|
"model_name": model_name,
|
|
@@ -1003,11 +1004,38 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
1003
1004
|
"quantizations": quantizations,
|
|
1004
1005
|
"path": path,
|
|
1005
1006
|
"Actor IP Address": actor_ip_address,
|
|
1007
|
+
"real_path": real_path,
|
|
1006
1008
|
}
|
|
1007
1009
|
|
|
1008
1010
|
cached_models.append(cache_entry)
|
|
1009
1011
|
return cached_models
|
|
1010
1012
|
|
|
1013
|
+
@log_async(logger=logger)
|
|
1014
|
+
async def abort_request(self, model_uid: str, request_id: str) -> Dict:
|
|
1015
|
+
from .scheduler import AbortRequestMessage
|
|
1016
|
+
|
|
1017
|
+
res = {"msg": AbortRequestMessage.NO_OP.name}
|
|
1018
|
+
replica_info = self._model_uid_to_replica_info.get(model_uid, None)
|
|
1019
|
+
if not replica_info:
|
|
1020
|
+
return res
|
|
1021
|
+
replica_cnt = replica_info.replica
|
|
1022
|
+
|
|
1023
|
+
# Query all replicas
|
|
1024
|
+
for rep_mid in iter_replica_model_uid(model_uid, replica_cnt):
|
|
1025
|
+
worker_ref = self._replica_model_uid_to_worker.get(rep_mid, None)
|
|
1026
|
+
if worker_ref is None:
|
|
1027
|
+
continue
|
|
1028
|
+
model_ref = await worker_ref.get_model(model_uid=rep_mid)
|
|
1029
|
+
result_info = await model_ref.abort_request(request_id)
|
|
1030
|
+
res["msg"] = result_info
|
|
1031
|
+
if result_info == AbortRequestMessage.DONE.name:
|
|
1032
|
+
break
|
|
1033
|
+
elif result_info == AbortRequestMessage.NOT_FOUND.name:
|
|
1034
|
+
logger.debug(f"Request id: {request_id} not found for model {rep_mid}")
|
|
1035
|
+
else:
|
|
1036
|
+
logger.debug(f"No-op for model {rep_mid}")
|
|
1037
|
+
return res
|
|
1038
|
+
|
|
1011
1039
|
@log_async(logger=logger)
|
|
1012
1040
|
async def add_worker(self, worker_address: str):
|
|
1013
1041
|
from .worker import WorkerActor
|
xinference/core/worker.py
CHANGED
|
@@ -30,9 +30,10 @@ from xoscar import MainActorPoolType
|
|
|
30
30
|
from ..constants import (
|
|
31
31
|
XINFERENCE_CACHE_DIR,
|
|
32
32
|
XINFERENCE_DISABLE_HEALTH_CHECK,
|
|
33
|
+
XINFERENCE_DISABLE_METRICS,
|
|
33
34
|
XINFERENCE_HEALTH_CHECK_INTERVAL,
|
|
34
35
|
)
|
|
35
|
-
from ..core import ModelActor
|
|
36
|
+
from ..core.model import ModelActor
|
|
36
37
|
from ..core.status_guard import LaunchStatus
|
|
37
38
|
from ..device_utils import get_available_device_env_name, gpu_count
|
|
38
39
|
from ..model.core import ModelDescription, create_model_instance
|
|
@@ -83,8 +84,12 @@ class WorkerActor(xo.StatelessActor):
|
|
|
83
84
|
self._model_uid_to_recover_count: Dict[str, Optional[int]] = {}
|
|
84
85
|
self._model_uid_to_launch_args: Dict[str, Dict] = {}
|
|
85
86
|
|
|
86
|
-
|
|
87
|
-
|
|
87
|
+
if XINFERENCE_DISABLE_METRICS:
|
|
88
|
+
logger.info(
|
|
89
|
+
"Worker metrics is disabled due to the environment XINFERENCE_DISABLE_METRICS=1"
|
|
90
|
+
)
|
|
91
|
+
elif metrics_exporter_host is not None or metrics_exporter_port is not None:
|
|
92
|
+
# metrics export server.
|
|
88
93
|
logger.info(
|
|
89
94
|
f"Starting metrics export server at {metrics_exporter_host}:{metrics_exporter_port}"
|
|
90
95
|
)
|
xinference/isolation.py
CHANGED
|
@@ -19,13 +19,19 @@ from typing import Any, Coroutine
|
|
|
19
19
|
|
|
20
20
|
class Isolation:
|
|
21
21
|
# TODO: better move isolation to xoscar.
|
|
22
|
-
def __init__(
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
loop: asyncio.AbstractEventLoop,
|
|
25
|
+
threaded: bool = True,
|
|
26
|
+
daemon: bool = True,
|
|
27
|
+
):
|
|
23
28
|
self._loop = loop
|
|
24
29
|
self._threaded = threaded
|
|
25
30
|
|
|
26
31
|
self._stopped = None
|
|
27
32
|
self._thread = None
|
|
28
33
|
self._thread_ident = None
|
|
34
|
+
self._daemon = daemon
|
|
29
35
|
|
|
30
36
|
def _run(self):
|
|
31
37
|
asyncio.set_event_loop(self._loop)
|
|
@@ -35,7 +41,8 @@ class Isolation:
|
|
|
35
41
|
def start(self):
|
|
36
42
|
if self._threaded:
|
|
37
43
|
self._thread = thread = threading.Thread(target=self._run)
|
|
38
|
-
|
|
44
|
+
if self._daemon:
|
|
45
|
+
thread.daemon = True
|
|
39
46
|
thread.start()
|
|
40
47
|
self._thread_ident = thread.ident
|
|
41
48
|
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import logging
|
|
15
|
+
from io import BytesIO
|
|
16
|
+
from typing import TYPE_CHECKING, Optional
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from .core import AudioModelFamilyV1
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class ChatTTSModel:
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
model_uid: str,
|
|
28
|
+
model_path: str,
|
|
29
|
+
model_spec: "AudioModelFamilyV1",
|
|
30
|
+
device: Optional[str] = None,
|
|
31
|
+
**kwargs,
|
|
32
|
+
):
|
|
33
|
+
self._model_uid = model_uid
|
|
34
|
+
self._model_path = model_path
|
|
35
|
+
self._model_spec = model_spec
|
|
36
|
+
self._device = device
|
|
37
|
+
self._model = None
|
|
38
|
+
self._kwargs = kwargs
|
|
39
|
+
|
|
40
|
+
def load(self):
|
|
41
|
+
import torch
|
|
42
|
+
|
|
43
|
+
from xinference.thirdparty import ChatTTS
|
|
44
|
+
|
|
45
|
+
torch._dynamo.config.cache_size_limit = 64
|
|
46
|
+
torch._dynamo.config.suppress_errors = True
|
|
47
|
+
torch.set_float32_matmul_precision("high")
|
|
48
|
+
self._model = ChatTTS.Chat()
|
|
49
|
+
self._model.load_models(
|
|
50
|
+
source="local", local_path=self._model_path, compile=True
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
def speech(
|
|
54
|
+
self, input: str, voice: str, response_format: str = "mp3", speed: float = 1.0
|
|
55
|
+
):
|
|
56
|
+
import numpy as np
|
|
57
|
+
import torch
|
|
58
|
+
import torchaudio
|
|
59
|
+
import xxhash
|
|
60
|
+
|
|
61
|
+
seed = xxhash.xxh32_intdigest(voice)
|
|
62
|
+
|
|
63
|
+
torch.manual_seed(seed)
|
|
64
|
+
np.random.seed(seed)
|
|
65
|
+
torch.cuda.manual_seed(seed)
|
|
66
|
+
torch.backends.cudnn.deterministic = True
|
|
67
|
+
torch.backends.cudnn.benchmark = False
|
|
68
|
+
|
|
69
|
+
assert self._model is not None
|
|
70
|
+
rnd_spk_emb = self._model.sample_random_speaker()
|
|
71
|
+
|
|
72
|
+
default = 5
|
|
73
|
+
infer_speed = int(default * speed)
|
|
74
|
+
params_infer_code = {"spk_emb": rnd_spk_emb, "prompt": f"[speed_{infer_speed}]"}
|
|
75
|
+
|
|
76
|
+
assert self._model is not None
|
|
77
|
+
wavs = self._model.infer([input], params_infer_code=params_infer_code)
|
|
78
|
+
|
|
79
|
+
# Save the generated audio
|
|
80
|
+
with BytesIO() as out:
|
|
81
|
+
torchaudio.save(
|
|
82
|
+
out, torch.from_numpy(wavs[0]), 24000, format=response_format
|
|
83
|
+
)
|
|
84
|
+
return out.getvalue()
|
xinference/model/audio/core.py
CHANGED
|
@@ -14,11 +14,12 @@
|
|
|
14
14
|
import logging
|
|
15
15
|
import os
|
|
16
16
|
from collections import defaultdict
|
|
17
|
-
from typing import Dict, List, Optional, Tuple
|
|
17
|
+
from typing import Dict, List, Optional, Tuple, Union
|
|
18
18
|
|
|
19
19
|
from ...constants import XINFERENCE_CACHE_DIR
|
|
20
20
|
from ..core import CacheableModelSpec, ModelDescription
|
|
21
21
|
from ..utils import valid_model_revision
|
|
22
|
+
from .chattts import ChatTTSModel
|
|
22
23
|
from .whisper import WhisperModel
|
|
23
24
|
|
|
24
25
|
MAX_ATTEMPTS = 3
|
|
@@ -130,10 +131,16 @@ def get_cache_status(
|
|
|
130
131
|
|
|
131
132
|
def create_audio_model_instance(
|
|
132
133
|
subpool_addr: str, devices: List[str], model_uid: str, model_name: str, **kwargs
|
|
133
|
-
) -> Tuple[WhisperModel, AudioModelDescription]:
|
|
134
|
+
) -> Tuple[Union[WhisperModel, ChatTTSModel], AudioModelDescription]:
|
|
134
135
|
model_spec = match_audio(model_name)
|
|
135
136
|
model_path = cache(model_spec)
|
|
136
|
-
model
|
|
137
|
+
model: Union[WhisperModel, ChatTTSModel]
|
|
138
|
+
if model_spec.model_family == "whisper":
|
|
139
|
+
model = WhisperModel(model_uid, model_path, model_spec, **kwargs)
|
|
140
|
+
elif model_spec.model_family == "ChatTTS":
|
|
141
|
+
model = ChatTTSModel(model_uid, model_path, model_spec, **kwargs)
|
|
142
|
+
else:
|
|
143
|
+
raise Exception(f"Unsupported audio model family: {model_spec.model_family}")
|
|
137
144
|
model_description = AudioModelDescription(
|
|
138
145
|
subpool_addr, devices, model_spec, model_path=model_path
|
|
139
146
|
)
|