xinference 0.15.4__py3-none-any.whl → 0.16.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/__init__.py +0 -4
- xinference/_version.py +3 -3
- xinference/constants.py +4 -4
- xinference/core/model.py +89 -18
- xinference/core/scheduler.py +10 -7
- xinference/core/utils.py +9 -0
- xinference/deploy/supervisor.py +4 -0
- xinference/model/__init__.py +4 -0
- xinference/model/image/scheduler/__init__.py +13 -0
- xinference/model/image/scheduler/flux.py +533 -0
- xinference/model/image/stable_diffusion/core.py +6 -31
- xinference/model/image/utils.py +39 -3
- xinference/model/llm/__init__.py +2 -0
- xinference/model/llm/llm_family.json +169 -1
- xinference/model/llm/llm_family_modelscope.json +108 -0
- xinference/model/llm/transformers/chatglm.py +104 -0
- xinference/model/llm/transformers/core.py +37 -111
- xinference/model/llm/transformers/deepseek_v2.py +0 -226
- xinference/model/llm/transformers/internlm2.py +3 -95
- xinference/model/llm/transformers/opt.py +68 -0
- xinference/model/llm/transformers/utils.py +4 -284
- xinference/model/llm/utils.py +2 -2
- xinference/model/llm/vllm/core.py +16 -1
- xinference/utils.py +2 -3
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.e51a356d.js → main.f7da0140.js} +3 -3
- xinference/web/ui/build/static/js/main.f7da0140.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/331312668fa8bd3d7401818f4a25fa98135d7f61371cd6bfff78b18cf4fbdd92.json +1 -0
- {xinference-0.15.4.dist-info → xinference-0.16.0.dist-info}/METADATA +36 -4
- {xinference-0.15.4.dist-info → xinference-0.16.0.dist-info}/RECORD +36 -33
- xinference/web/ui/build/static/js/main.e51a356d.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/4385c1095eefbff0a8ec3b2964ba6e5a66a05ab31be721483ca2f43e2a91f6ff.json +0 -1
- /xinference/web/ui/build/static/js/{main.e51a356d.js.LICENSE.txt → main.f7da0140.js.LICENSE.txt} +0 -0
- {xinference-0.15.4.dist-info → xinference-0.16.0.dist-info}/LICENSE +0 -0
- {xinference-0.15.4.dist-info → xinference-0.16.0.dist-info}/WHEEL +0 -0
- {xinference-0.15.4.dist-info → xinference-0.16.0.dist-info}/entry_points.txt +0 -0
- {xinference-0.15.4.dist-info → xinference-0.16.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,533 @@
|
|
|
1
|
+
# Copyright 2022-2024 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import asyncio
|
|
15
|
+
import logging
|
|
16
|
+
import os
|
|
17
|
+
import re
|
|
18
|
+
import typing
|
|
19
|
+
from collections import deque
|
|
20
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
|
21
|
+
|
|
22
|
+
import numpy as np
|
|
23
|
+
import torch
|
|
24
|
+
import xoscar as xo
|
|
25
|
+
|
|
26
|
+
from ..utils import handle_image_result
|
|
27
|
+
|
|
28
|
+
if TYPE_CHECKING:
|
|
29
|
+
from ..stable_diffusion.core import DiffusionModel
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
logger = logging.getLogger(__name__)
|
|
33
|
+
DEFAULT_MAX_SEQUENCE_LENGTH = 512
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class Text2ImageRequest:
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
unique_id,
|
|
40
|
+
future,
|
|
41
|
+
prompt: str,
|
|
42
|
+
n: int,
|
|
43
|
+
size: str,
|
|
44
|
+
response_format: str,
|
|
45
|
+
*args,
|
|
46
|
+
**kwargs,
|
|
47
|
+
):
|
|
48
|
+
self._unique_id = unique_id
|
|
49
|
+
self.future = future
|
|
50
|
+
self._prompt = prompt
|
|
51
|
+
self._n = n
|
|
52
|
+
self._size = size
|
|
53
|
+
self._response_format = response_format
|
|
54
|
+
self._args = args
|
|
55
|
+
self._kwargs = kwargs
|
|
56
|
+
self._width = -1
|
|
57
|
+
self._height = -1
|
|
58
|
+
self._generate_kwargs: Dict[str, Any] = {}
|
|
59
|
+
self._set_width_and_height()
|
|
60
|
+
self.is_encode = True
|
|
61
|
+
self.scheduler = None
|
|
62
|
+
self.done_steps = 0
|
|
63
|
+
self.total_steps = 0
|
|
64
|
+
self.static_tensors: Dict[str, torch.Tensor] = {}
|
|
65
|
+
self.timesteps = None
|
|
66
|
+
self.dtype = None
|
|
67
|
+
self.output = None
|
|
68
|
+
self.error_msg: Optional[str] = None
|
|
69
|
+
self.aborted = False
|
|
70
|
+
|
|
71
|
+
def _set_width_and_height(self):
|
|
72
|
+
self._width, self._height = map(int, re.split(r"[^\d]+", self._size))
|
|
73
|
+
|
|
74
|
+
def set_generate_kwargs(self, generate_kwargs: Dict):
|
|
75
|
+
self._generate_kwargs = {k: v for k, v in generate_kwargs.items()}
|
|
76
|
+
|
|
77
|
+
@property
|
|
78
|
+
def prompt(self):
|
|
79
|
+
return self._prompt
|
|
80
|
+
|
|
81
|
+
@property
|
|
82
|
+
def n(self):
|
|
83
|
+
return self._n
|
|
84
|
+
|
|
85
|
+
@property
|
|
86
|
+
def size(self):
|
|
87
|
+
return self._size
|
|
88
|
+
|
|
89
|
+
@property
|
|
90
|
+
def response_format(self):
|
|
91
|
+
return self._response_format
|
|
92
|
+
|
|
93
|
+
@property
|
|
94
|
+
def kwargs(self):
|
|
95
|
+
return self._kwargs
|
|
96
|
+
|
|
97
|
+
@property
|
|
98
|
+
def width(self):
|
|
99
|
+
return self._width
|
|
100
|
+
|
|
101
|
+
@property
|
|
102
|
+
def height(self):
|
|
103
|
+
return self._height
|
|
104
|
+
|
|
105
|
+
@property
|
|
106
|
+
def generate_kwargs(self):
|
|
107
|
+
return self._generate_kwargs
|
|
108
|
+
|
|
109
|
+
@property
|
|
110
|
+
def request_id(self):
|
|
111
|
+
return self._unique_id
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class FluxBatchSchedulerActor(xo.StatelessActor):
|
|
115
|
+
@classmethod
|
|
116
|
+
def gen_uid(cls, model_uid: str):
|
|
117
|
+
return f"{model_uid}-scheduler-actor"
|
|
118
|
+
|
|
119
|
+
def __init__(self):
|
|
120
|
+
from ....device_utils import get_available_device
|
|
121
|
+
|
|
122
|
+
super().__init__()
|
|
123
|
+
self._waiting_queue: deque[Text2ImageRequest] = deque() # type: ignore
|
|
124
|
+
self._running_queue: deque[Text2ImageRequest] = deque() # type: ignore
|
|
125
|
+
self._model = None
|
|
126
|
+
self._available_device = get_available_device()
|
|
127
|
+
self._id_to_req: Dict[str, Text2ImageRequest] = {}
|
|
128
|
+
|
|
129
|
+
def set_model(self, model):
|
|
130
|
+
"""
|
|
131
|
+
Must use `set_model`. Otherwise, the model will be copied once.
|
|
132
|
+
"""
|
|
133
|
+
self._model = model
|
|
134
|
+
|
|
135
|
+
async def __post_create__(self):
|
|
136
|
+
from ....isolation import Isolation
|
|
137
|
+
|
|
138
|
+
self._isolation = Isolation(
|
|
139
|
+
asyncio.new_event_loop(), threaded=True, daemon=True
|
|
140
|
+
)
|
|
141
|
+
self._isolation.start()
|
|
142
|
+
asyncio.run_coroutine_threadsafe(self.run(), loop=self._isolation.loop)
|
|
143
|
+
|
|
144
|
+
async def __pre_destroy__(self):
|
|
145
|
+
try:
|
|
146
|
+
assert self._isolation is not None
|
|
147
|
+
self._isolation.stop()
|
|
148
|
+
del self._isolation
|
|
149
|
+
except Exception as e:
|
|
150
|
+
logger.debug(
|
|
151
|
+
f"Destroy scheduler actor failed, address: {self.address}, error: {e}"
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
async def add_request(self, unique_id: str, future, *args, **kwargs):
|
|
155
|
+
req = Text2ImageRequest(unique_id, future, *args, **kwargs)
|
|
156
|
+
rid = req.request_id
|
|
157
|
+
if rid is not None:
|
|
158
|
+
if rid in self._id_to_req:
|
|
159
|
+
raise KeyError(f"Request id: {rid} has already existed!")
|
|
160
|
+
self._id_to_req[rid] = req
|
|
161
|
+
self._waiting_queue.append(req)
|
|
162
|
+
|
|
163
|
+
async def abort_request(self, req_id: str) -> str:
|
|
164
|
+
"""
|
|
165
|
+
Abort a request.
|
|
166
|
+
Abort a submitted request. If the request is finished or not found, this method will be a no-op.
|
|
167
|
+
"""
|
|
168
|
+
from ....core.utils import AbortRequestMessage
|
|
169
|
+
|
|
170
|
+
if req_id not in self._id_to_req:
|
|
171
|
+
logger.info(f"Request id: {req_id} not found. No-op for xinference.")
|
|
172
|
+
return AbortRequestMessage.NOT_FOUND.name
|
|
173
|
+
else:
|
|
174
|
+
self._id_to_req[req_id].aborted = True
|
|
175
|
+
logger.info(f"Request id: {req_id} found to be aborted.")
|
|
176
|
+
return AbortRequestMessage.DONE.name
|
|
177
|
+
|
|
178
|
+
def _handle_request(
|
|
179
|
+
self,
|
|
180
|
+
) -> Optional[Tuple[List[Text2ImageRequest], List[Text2ImageRequest]]]:
|
|
181
|
+
"""
|
|
182
|
+
Every request may generate `n>=1` images.
|
|
183
|
+
Here we need to decide whether to wait or not based on the value of `n` of each request.
|
|
184
|
+
"""
|
|
185
|
+
if self._model is None:
|
|
186
|
+
return None
|
|
187
|
+
max_num_images = self._model.get_max_num_images_for_batching()
|
|
188
|
+
cur_num_images = 0
|
|
189
|
+
abort_list: List[Text2ImageRequest] = []
|
|
190
|
+
# currently, FCFS strategy
|
|
191
|
+
running_list: List[Text2ImageRequest] = []
|
|
192
|
+
while len(self._running_queue) > 0:
|
|
193
|
+
req = self._running_queue.popleft()
|
|
194
|
+
if req.aborted:
|
|
195
|
+
abort_list.append(req)
|
|
196
|
+
else:
|
|
197
|
+
running_list.append(req)
|
|
198
|
+
cur_num_images += req.n
|
|
199
|
+
|
|
200
|
+
# Remove all the aborted requests in the waiting queue
|
|
201
|
+
waiting_tmp_list: List[Text2ImageRequest] = []
|
|
202
|
+
while len(self._waiting_queue) > 0:
|
|
203
|
+
req = self._waiting_queue.popleft()
|
|
204
|
+
if req.aborted:
|
|
205
|
+
abort_list.append(req)
|
|
206
|
+
else:
|
|
207
|
+
waiting_tmp_list.append(req)
|
|
208
|
+
self._waiting_queue.extend(waiting_tmp_list)
|
|
209
|
+
|
|
210
|
+
waiting_list: List[Text2ImageRequest] = []
|
|
211
|
+
while len(self._waiting_queue) > 0:
|
|
212
|
+
req = self._waiting_queue[0]
|
|
213
|
+
if req.n + cur_num_images <= max_num_images:
|
|
214
|
+
waiting_list.append(self._waiting_queue.popleft())
|
|
215
|
+
cur_num_images += req.n
|
|
216
|
+
else:
|
|
217
|
+
logger.warning(
|
|
218
|
+
f"Current queue is full, with an upper limit of max_num_images: {max_num_images}. "
|
|
219
|
+
f"Requests will continue to wait."
|
|
220
|
+
)
|
|
221
|
+
break
|
|
222
|
+
|
|
223
|
+
return waiting_list + running_list, abort_list
|
|
224
|
+
|
|
225
|
+
@staticmethod
|
|
226
|
+
def _empty_cache():
|
|
227
|
+
from ....device_utils import empty_cache
|
|
228
|
+
|
|
229
|
+
empty_cache()
|
|
230
|
+
|
|
231
|
+
async def step(self):
|
|
232
|
+
res = self._handle_request()
|
|
233
|
+
if res is None:
|
|
234
|
+
return
|
|
235
|
+
req_list, abort_list = res
|
|
236
|
+
# handle abort
|
|
237
|
+
if abort_list:
|
|
238
|
+
for r in abort_list:
|
|
239
|
+
r.future.set_exception(
|
|
240
|
+
RuntimeError(
|
|
241
|
+
f"Request: {r.request_id} has been cancelled by another `abort_request` request."
|
|
242
|
+
)
|
|
243
|
+
)
|
|
244
|
+
self._id_to_req.pop(r.request_id, None)
|
|
245
|
+
if not req_list:
|
|
246
|
+
return
|
|
247
|
+
_batch_text_to_image(self._model, req_list, self._available_device)
|
|
248
|
+
# handle results
|
|
249
|
+
for r in req_list:
|
|
250
|
+
if r.error_msg is not None:
|
|
251
|
+
r.future.set_exception(ValueError(r.error_msg))
|
|
252
|
+
self._id_to_req.pop(r.request_id, None)
|
|
253
|
+
continue
|
|
254
|
+
if r.output is not None:
|
|
255
|
+
r.future.set_result(
|
|
256
|
+
handle_image_result(r.response_format, r.output.images)
|
|
257
|
+
)
|
|
258
|
+
self._id_to_req.pop(r.request_id, None)
|
|
259
|
+
else:
|
|
260
|
+
self._running_queue.append(r)
|
|
261
|
+
self._empty_cache()
|
|
262
|
+
|
|
263
|
+
async def run(self):
|
|
264
|
+
try:
|
|
265
|
+
while True:
|
|
266
|
+
# wait 10ms
|
|
267
|
+
await asyncio.sleep(0.01)
|
|
268
|
+
await self.step()
|
|
269
|
+
except Exception as e:
|
|
270
|
+
logger.exception(
|
|
271
|
+
f"Scheduler actor uid: {self.uid}, address: {self.address} run with error: {e}"
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
def _cat_tensors(infos: List[Dict]) -> Dict:
|
|
276
|
+
keys = infos[0].keys()
|
|
277
|
+
res = {}
|
|
278
|
+
for k in keys:
|
|
279
|
+
tmp = [info[k] for info in infos]
|
|
280
|
+
res[k] = torch.cat(tmp)
|
|
281
|
+
return res
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
@typing.no_type_check
|
|
285
|
+
@torch.inference_mode()
|
|
286
|
+
def _batch_text_to_image_internal(
|
|
287
|
+
model_cls: "DiffusionModel",
|
|
288
|
+
req_list: List[Text2ImageRequest],
|
|
289
|
+
available_device: str,
|
|
290
|
+
):
|
|
291
|
+
from diffusers.pipelines.flux.pipeline_flux import (
|
|
292
|
+
calculate_shift,
|
|
293
|
+
retrieve_timesteps,
|
|
294
|
+
)
|
|
295
|
+
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
|
|
296
|
+
from diffusers.schedulers.scheduling_flow_match_euler_discrete import (
|
|
297
|
+
FlowMatchEulerDiscreteScheduler,
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
device = model_cls._model._execution_device
|
|
301
|
+
height, width = req_list[0].height, req_list[0].width
|
|
302
|
+
cur_batch_max_sequence_length = [
|
|
303
|
+
r.generate_kwargs.get("max_sequence_length", DEFAULT_MAX_SEQUENCE_LENGTH)
|
|
304
|
+
for r in req_list
|
|
305
|
+
if not r.is_encode
|
|
306
|
+
]
|
|
307
|
+
for r in req_list:
|
|
308
|
+
if r.is_encode:
|
|
309
|
+
generate_kwargs = model_cls._model_spec.default_generate_config.copy()
|
|
310
|
+
generate_kwargs.update({k: v for k, v in r.kwargs.items() if v is not None})
|
|
311
|
+
model_cls._filter_kwargs(model_cls._model, generate_kwargs)
|
|
312
|
+
r.set_generate_kwargs(generate_kwargs)
|
|
313
|
+
|
|
314
|
+
# check max_sequence_length
|
|
315
|
+
max_sequence_length = r.generate_kwargs.get(
|
|
316
|
+
"max_sequence_length", DEFAULT_MAX_SEQUENCE_LENGTH
|
|
317
|
+
)
|
|
318
|
+
if (
|
|
319
|
+
cur_batch_max_sequence_length
|
|
320
|
+
and max_sequence_length != cur_batch_max_sequence_length[0]
|
|
321
|
+
):
|
|
322
|
+
r.is_encode = False
|
|
323
|
+
r.error_msg = (
|
|
324
|
+
f"The max_sequence_length of the current request: {max_sequence_length} is "
|
|
325
|
+
f"different from the setting in the running batch: {cur_batch_max_sequence_length[0]}, "
|
|
326
|
+
f"please be consistent."
|
|
327
|
+
)
|
|
328
|
+
continue
|
|
329
|
+
|
|
330
|
+
num_images_per_prompt = r.n
|
|
331
|
+
callback_on_step_end_tensor_inputs = r.generate_kwargs.get(
|
|
332
|
+
"callback_on_step_end_tensor_inputs", ["latents"]
|
|
333
|
+
)
|
|
334
|
+
num_inference_steps = r.generate_kwargs.get("num_inference_steps", 28)
|
|
335
|
+
guidance_scale = r.generate_kwargs.get("guidance_scale", 7.0)
|
|
336
|
+
generator = None
|
|
337
|
+
seed = r.generate_kwargs.get("seed", None)
|
|
338
|
+
if seed is not None:
|
|
339
|
+
generator = torch.Generator(device=available_device) # type: ignore
|
|
340
|
+
if seed != -1:
|
|
341
|
+
generator = generator.manual_seed(seed)
|
|
342
|
+
latents = None
|
|
343
|
+
timesteps = None
|
|
344
|
+
|
|
345
|
+
# Each request must build its own scheduler instance,
|
|
346
|
+
# otherwise the mixing of variables at `scheduler.STEP` will result in an error.
|
|
347
|
+
r.scheduler = FlowMatchEulerDiscreteScheduler(
|
|
348
|
+
model_cls._model.scheduler.config.num_train_timesteps,
|
|
349
|
+
model_cls._model.scheduler.config.shift,
|
|
350
|
+
model_cls._model.scheduler.config.use_dynamic_shifting,
|
|
351
|
+
model_cls._model.scheduler.config.base_shift,
|
|
352
|
+
model_cls._model.scheduler.config.max_shift,
|
|
353
|
+
model_cls._model.scheduler.config.base_image_seq_len,
|
|
354
|
+
model_cls._model.scheduler.config.max_image_seq_len,
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
# check inputs
|
|
358
|
+
model_cls._model.check_inputs(
|
|
359
|
+
r.prompt,
|
|
360
|
+
None,
|
|
361
|
+
height,
|
|
362
|
+
width,
|
|
363
|
+
prompt_embeds=None,
|
|
364
|
+
pooled_prompt_embeds=None,
|
|
365
|
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
|
366
|
+
max_sequence_length=max_sequence_length,
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
# handle prompt
|
|
370
|
+
(
|
|
371
|
+
prompt_embeds,
|
|
372
|
+
pooled_prompt_embeds,
|
|
373
|
+
text_ids,
|
|
374
|
+
) = model_cls._model.encode_prompt(
|
|
375
|
+
prompt=r.prompt,
|
|
376
|
+
prompt_2=None,
|
|
377
|
+
prompt_embeds=None,
|
|
378
|
+
pooled_prompt_embeds=None,
|
|
379
|
+
device=device,
|
|
380
|
+
num_images_per_prompt=num_images_per_prompt,
|
|
381
|
+
max_sequence_length=max_sequence_length,
|
|
382
|
+
lora_scale=None,
|
|
383
|
+
)
|
|
384
|
+
|
|
385
|
+
# Prepare latent variables
|
|
386
|
+
num_channels_latents = model_cls._model.transformer.config.in_channels // 4
|
|
387
|
+
latents, latent_image_ids = model_cls._model.prepare_latents(
|
|
388
|
+
num_images_per_prompt,
|
|
389
|
+
num_channels_latents,
|
|
390
|
+
height,
|
|
391
|
+
width,
|
|
392
|
+
prompt_embeds.dtype,
|
|
393
|
+
device,
|
|
394
|
+
generator,
|
|
395
|
+
latents,
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
# Prepare timesteps
|
|
399
|
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
|
400
|
+
image_seq_len = latents.shape[1]
|
|
401
|
+
|
|
402
|
+
mu = calculate_shift(
|
|
403
|
+
image_seq_len,
|
|
404
|
+
r.scheduler.config["base_image_seq_len"],
|
|
405
|
+
r.scheduler.config["max_image_seq_len"],
|
|
406
|
+
r.scheduler.config["base_shift"],
|
|
407
|
+
r.scheduler.config["max_shift"],
|
|
408
|
+
)
|
|
409
|
+
timesteps, num_inference_steps = retrieve_timesteps(
|
|
410
|
+
r.scheduler,
|
|
411
|
+
num_inference_steps,
|
|
412
|
+
device,
|
|
413
|
+
timesteps,
|
|
414
|
+
sigmas,
|
|
415
|
+
mu=mu,
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
# handle guidance
|
|
419
|
+
if model_cls._model.transformer.config.guidance_embeds:
|
|
420
|
+
guidance = torch.full(
|
|
421
|
+
[1], guidance_scale, device=device, dtype=torch.float32
|
|
422
|
+
)
|
|
423
|
+
guidance = guidance.expand(latents.shape[0])
|
|
424
|
+
else:
|
|
425
|
+
guidance = None
|
|
426
|
+
|
|
427
|
+
r.static_tensors["latents"] = latents
|
|
428
|
+
r.static_tensors["guidance"] = guidance
|
|
429
|
+
r.static_tensors["pooled_prompt_embeds"] = pooled_prompt_embeds
|
|
430
|
+
r.static_tensors["prompt_embeds"] = prompt_embeds
|
|
431
|
+
r.static_tensors["text_ids"] = text_ids
|
|
432
|
+
r.static_tensors["latent_image_ids"] = latent_image_ids
|
|
433
|
+
r.timesteps = timesteps
|
|
434
|
+
r.dtype = latents.dtype
|
|
435
|
+
r.total_steps = len(timesteps)
|
|
436
|
+
r.is_encode = False
|
|
437
|
+
|
|
438
|
+
running_req_list = [r for r in req_list if r.error_msg is None]
|
|
439
|
+
static_tensors = _cat_tensors([r.static_tensors for r in running_req_list])
|
|
440
|
+
|
|
441
|
+
# Do a step
|
|
442
|
+
timestep_tmp = []
|
|
443
|
+
for r in running_req_list:
|
|
444
|
+
timestep_tmp.append(r.timesteps[r.done_steps].expand(r.n).to(r.dtype))
|
|
445
|
+
r.done_steps += 1
|
|
446
|
+
timestep = torch.cat(timestep_tmp)
|
|
447
|
+
noise_pred = model_cls._model.transformer(
|
|
448
|
+
hidden_states=static_tensors["latents"],
|
|
449
|
+
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
|
450
|
+
timestep=timestep / 1000,
|
|
451
|
+
guidance=static_tensors["guidance"],
|
|
452
|
+
pooled_projections=static_tensors["pooled_prompt_embeds"],
|
|
453
|
+
encoder_hidden_states=static_tensors["prompt_embeds"],
|
|
454
|
+
txt_ids=static_tensors["text_ids"],
|
|
455
|
+
img_ids=static_tensors["latent_image_ids"],
|
|
456
|
+
joint_attention_kwargs=None,
|
|
457
|
+
return_dict=False,
|
|
458
|
+
)[0]
|
|
459
|
+
|
|
460
|
+
# update latents
|
|
461
|
+
start_idx = 0
|
|
462
|
+
for r in running_req_list:
|
|
463
|
+
n = r.n
|
|
464
|
+
# handle diffusion scheduler step
|
|
465
|
+
_noise_pred = noise_pred[start_idx : start_idx + n, ::]
|
|
466
|
+
_timestep = timestep[start_idx]
|
|
467
|
+
latents_out = r.scheduler.step(
|
|
468
|
+
_noise_pred, _timestep, r.static_tensors["latents"], return_dict=False
|
|
469
|
+
)[0]
|
|
470
|
+
r.static_tensors["latents"] = latents_out
|
|
471
|
+
start_idx += n
|
|
472
|
+
|
|
473
|
+
logger.info(
|
|
474
|
+
f"Request {r.request_id} has done {r.done_steps} / {r.total_steps} steps."
|
|
475
|
+
)
|
|
476
|
+
|
|
477
|
+
# process result
|
|
478
|
+
if r.done_steps == r.total_steps:
|
|
479
|
+
output_type = r.generate_kwargs.get("output_type", "pil")
|
|
480
|
+
_latents = r.static_tensors["latents"]
|
|
481
|
+
if output_type == "latent":
|
|
482
|
+
image = _latents
|
|
483
|
+
else:
|
|
484
|
+
_latents = model_cls._model._unpack_latents(
|
|
485
|
+
_latents, height, width, model_cls._model.vae_scale_factor
|
|
486
|
+
)
|
|
487
|
+
_latents = (
|
|
488
|
+
_latents / model_cls._model.vae.config.scaling_factor
|
|
489
|
+
) + model_cls._model.vae.config.shift_factor
|
|
490
|
+
image = model_cls._model.vae.decode(_latents, return_dict=False)[0]
|
|
491
|
+
image = model_cls._model.image_processor.postprocess(
|
|
492
|
+
image, output_type=output_type
|
|
493
|
+
)
|
|
494
|
+
|
|
495
|
+
is_padded = r.generate_kwargs.get("is_padded", None)
|
|
496
|
+
origin_size = r.generate_kwargs.get("origin_size", None)
|
|
497
|
+
|
|
498
|
+
if is_padded and origin_size:
|
|
499
|
+
new_images = []
|
|
500
|
+
x, y = origin_size
|
|
501
|
+
for img in image:
|
|
502
|
+
new_images.append(img.crop((0, 0, x, y)))
|
|
503
|
+
image = new_images
|
|
504
|
+
|
|
505
|
+
r.output = FluxPipelineOutput(images=image)
|
|
506
|
+
logger.info(
|
|
507
|
+
f"Request {r.request_id} has completed total {r.total_steps} steps."
|
|
508
|
+
)
|
|
509
|
+
|
|
510
|
+
|
|
511
|
+
def _batch_text_to_image(
|
|
512
|
+
model_cls: "DiffusionModel",
|
|
513
|
+
req_list: List[Text2ImageRequest],
|
|
514
|
+
available_device: str,
|
|
515
|
+
):
|
|
516
|
+
from ....core.model import OutOfMemoryError
|
|
517
|
+
|
|
518
|
+
try:
|
|
519
|
+
_batch_text_to_image_internal(model_cls, req_list, available_device)
|
|
520
|
+
except OutOfMemoryError:
|
|
521
|
+
logger.exception(
|
|
522
|
+
f"Batch text_to_image out of memory. "
|
|
523
|
+
f"Xinference will restart the model: {model_cls._model_uid}. "
|
|
524
|
+
f"Please be patient for a few moments."
|
|
525
|
+
)
|
|
526
|
+
# Just kill the process and let xinference auto-recover the model
|
|
527
|
+
os._exit(1)
|
|
528
|
+
except Exception as e:
|
|
529
|
+
logger.exception(f"Internal error for batch text_to_image: {e}.")
|
|
530
|
+
# If internal error happens, just skip all the requests in this batch.
|
|
531
|
+
# If not handle here, the client will hang.
|
|
532
|
+
for r in req_list:
|
|
533
|
+
r.error_msg = str(e)
|
|
@@ -12,31 +12,24 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
import base64
|
|
16
15
|
import contextlib
|
|
17
16
|
import gc
|
|
18
17
|
import inspect
|
|
19
18
|
import itertools
|
|
20
19
|
import logging
|
|
21
|
-
import os
|
|
22
20
|
import re
|
|
23
21
|
import sys
|
|
24
|
-
import time
|
|
25
|
-
import uuid
|
|
26
22
|
import warnings
|
|
27
|
-
from concurrent.futures import ThreadPoolExecutor
|
|
28
|
-
from functools import partial
|
|
29
|
-
from io import BytesIO
|
|
30
23
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
|
31
24
|
|
|
32
25
|
import PIL.Image
|
|
33
26
|
import torch
|
|
34
27
|
from PIL import ImageOps
|
|
35
28
|
|
|
36
|
-
from ....constants import XINFERENCE_IMAGE_DIR
|
|
37
29
|
from ....device_utils import get_available_device, move_model_to_available_device
|
|
38
|
-
from ....types import
|
|
30
|
+
from ....types import LoRA
|
|
39
31
|
from ..sdapi import SDAPIDiffusionModelMixin
|
|
32
|
+
from ..utils import handle_image_result
|
|
40
33
|
|
|
41
34
|
if TYPE_CHECKING:
|
|
42
35
|
from ....core.progress_tracker import Progressor
|
|
@@ -297,6 +290,9 @@ class DiffusionModel(SDAPIDiffusionModelMixin):
|
|
|
297
290
|
if self._kwargs.get("vae_tiling", False):
|
|
298
291
|
model.enable_vae_tiling()
|
|
299
292
|
|
|
293
|
+
def get_max_num_images_for_batching(self):
|
|
294
|
+
return self._kwargs.get("max_num_images", 16)
|
|
295
|
+
|
|
300
296
|
@staticmethod
|
|
301
297
|
def _get_scheduler(model: Any, sampler_name: str):
|
|
302
298
|
if not sampler_name or sampler_name == "default":
|
|
@@ -476,28 +472,7 @@ class DiffusionModel(SDAPIDiffusionModelMixin):
|
|
|
476
472
|
if return_images:
|
|
477
473
|
return images
|
|
478
474
|
|
|
479
|
-
|
|
480
|
-
os.makedirs(XINFERENCE_IMAGE_DIR, exist_ok=True)
|
|
481
|
-
image_list = []
|
|
482
|
-
with ThreadPoolExecutor() as executor:
|
|
483
|
-
for img in images:
|
|
484
|
-
path = os.path.join(XINFERENCE_IMAGE_DIR, uuid.uuid4().hex + ".jpg")
|
|
485
|
-
image_list.append(Image(url=path, b64_json=None))
|
|
486
|
-
executor.submit(img.save, path, "jpeg")
|
|
487
|
-
return ImageList(created=int(time.time()), data=image_list)
|
|
488
|
-
elif response_format == "b64_json":
|
|
489
|
-
|
|
490
|
-
def _gen_base64_image(_img):
|
|
491
|
-
buffered = BytesIO()
|
|
492
|
-
_img.save(buffered, format="jpeg")
|
|
493
|
-
return base64.b64encode(buffered.getvalue()).decode()
|
|
494
|
-
|
|
495
|
-
with ThreadPoolExecutor() as executor:
|
|
496
|
-
results = list(map(partial(executor.submit, _gen_base64_image), images)) # type: ignore
|
|
497
|
-
image_list = [Image(url=None, b64_json=s.result()) for s in results] # type: ignore
|
|
498
|
-
return ImageList(created=int(time.time()), data=image_list)
|
|
499
|
-
else:
|
|
500
|
-
raise ValueError(f"Unsupported response format: {response_format}")
|
|
475
|
+
return handle_image_result(response_format, images)
|
|
501
476
|
|
|
502
477
|
@classmethod
|
|
503
478
|
def _filter_kwargs(cls, model, kwargs: dict):
|
xinference/model/image/utils.py
CHANGED
|
@@ -11,16 +11,52 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
|
|
14
|
+
import base64
|
|
15
|
+
import os
|
|
16
|
+
import time
|
|
17
|
+
import uuid
|
|
18
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
19
|
+
from functools import partial
|
|
20
|
+
from io import BytesIO
|
|
21
|
+
from typing import TYPE_CHECKING, Optional
|
|
15
22
|
|
|
16
|
-
from
|
|
23
|
+
from ...constants import XINFERENCE_IMAGE_DIR
|
|
24
|
+
from ...types import Image, ImageList
|
|
25
|
+
|
|
26
|
+
if TYPE_CHECKING:
|
|
27
|
+
from .core import ImageModelFamilyV1
|
|
17
28
|
|
|
18
29
|
|
|
19
30
|
def get_model_version(
|
|
20
|
-
image_model: ImageModelFamilyV1, controlnet: Optional[ImageModelFamilyV1]
|
|
31
|
+
image_model: "ImageModelFamilyV1", controlnet: Optional["ImageModelFamilyV1"]
|
|
21
32
|
) -> str:
|
|
22
33
|
return (
|
|
23
34
|
image_model.model_name
|
|
24
35
|
if controlnet is None
|
|
25
36
|
else f"{image_model.model_name}--{controlnet.model_name}"
|
|
26
37
|
)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def handle_image_result(response_format: str, images) -> ImageList:
|
|
41
|
+
if response_format == "url":
|
|
42
|
+
os.makedirs(XINFERENCE_IMAGE_DIR, exist_ok=True)
|
|
43
|
+
image_list = []
|
|
44
|
+
with ThreadPoolExecutor() as executor:
|
|
45
|
+
for img in images:
|
|
46
|
+
path = os.path.join(XINFERENCE_IMAGE_DIR, uuid.uuid4().hex + ".jpg")
|
|
47
|
+
image_list.append(Image(url=path, b64_json=None))
|
|
48
|
+
executor.submit(img.save, path, "jpeg")
|
|
49
|
+
return ImageList(created=int(time.time()), data=image_list)
|
|
50
|
+
elif response_format == "b64_json":
|
|
51
|
+
|
|
52
|
+
def _gen_base64_image(_img):
|
|
53
|
+
buffered = BytesIO()
|
|
54
|
+
_img.save(buffered, format="jpeg")
|
|
55
|
+
return base64.b64encode(buffered.getvalue()).decode()
|
|
56
|
+
|
|
57
|
+
with ThreadPoolExecutor() as executor:
|
|
58
|
+
results = list(map(partial(executor.submit, _gen_base64_image), images)) # type: ignore
|
|
59
|
+
image_list = [Image(url=None, b64_json=s.result()) for s in results] # type: ignore
|
|
60
|
+
return ImageList(created=int(time.time()), data=image_list)
|
|
61
|
+
else:
|
|
62
|
+
raise ValueError(f"Unsupported response format: {response_format}")
|
xinference/model/llm/__init__.py
CHANGED
|
@@ -146,6 +146,7 @@ def _install():
|
|
|
146
146
|
from .transformers.internlm2 import Internlm2PytorchChatModel
|
|
147
147
|
from .transformers.minicpmv25 import MiniCPMV25Model
|
|
148
148
|
from .transformers.minicpmv26 import MiniCPMV26Model
|
|
149
|
+
from .transformers.opt import OptPytorchModel
|
|
149
150
|
from .transformers.qwen2_audio import Qwen2AudioChatModel
|
|
150
151
|
from .transformers.qwen2_vl import Qwen2VLChatModel
|
|
151
152
|
from .transformers.qwen_vl import QwenVLChatModel
|
|
@@ -190,6 +191,7 @@ def _install():
|
|
|
190
191
|
Glm4VModel,
|
|
191
192
|
DeepSeekV2PytorchModel,
|
|
192
193
|
DeepSeekV2PytorchChatModel,
|
|
194
|
+
OptPytorchModel,
|
|
193
195
|
]
|
|
194
196
|
)
|
|
195
197
|
if OmniLMMModel: # type: ignore
|