xinference 0.15.4__py3-none-any.whl → 0.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (38) hide show
  1. xinference/__init__.py +0 -4
  2. xinference/_version.py +3 -3
  3. xinference/constants.py +4 -4
  4. xinference/core/model.py +89 -18
  5. xinference/core/scheduler.py +10 -7
  6. xinference/core/utils.py +9 -0
  7. xinference/deploy/supervisor.py +4 -0
  8. xinference/model/__init__.py +4 -0
  9. xinference/model/image/scheduler/__init__.py +13 -0
  10. xinference/model/image/scheduler/flux.py +533 -0
  11. xinference/model/image/stable_diffusion/core.py +6 -31
  12. xinference/model/image/utils.py +39 -3
  13. xinference/model/llm/__init__.py +2 -0
  14. xinference/model/llm/llm_family.json +169 -1
  15. xinference/model/llm/llm_family_modelscope.json +108 -0
  16. xinference/model/llm/transformers/chatglm.py +104 -0
  17. xinference/model/llm/transformers/core.py +37 -111
  18. xinference/model/llm/transformers/deepseek_v2.py +0 -226
  19. xinference/model/llm/transformers/internlm2.py +3 -95
  20. xinference/model/llm/transformers/opt.py +68 -0
  21. xinference/model/llm/transformers/utils.py +4 -284
  22. xinference/model/llm/utils.py +2 -2
  23. xinference/model/llm/vllm/core.py +16 -1
  24. xinference/utils.py +2 -3
  25. xinference/web/ui/build/asset-manifest.json +3 -3
  26. xinference/web/ui/build/index.html +1 -1
  27. xinference/web/ui/build/static/js/{main.e51a356d.js → main.f7da0140.js} +3 -3
  28. xinference/web/ui/build/static/js/main.f7da0140.js.map +1 -0
  29. xinference/web/ui/node_modules/.cache/babel-loader/331312668fa8bd3d7401818f4a25fa98135d7f61371cd6bfff78b18cf4fbdd92.json +1 -0
  30. {xinference-0.15.4.dist-info → xinference-0.16.0.dist-info}/METADATA +36 -4
  31. {xinference-0.15.4.dist-info → xinference-0.16.0.dist-info}/RECORD +36 -33
  32. xinference/web/ui/build/static/js/main.e51a356d.js.map +0 -1
  33. xinference/web/ui/node_modules/.cache/babel-loader/4385c1095eefbff0a8ec3b2964ba6e5a66a05ab31be721483ca2f43e2a91f6ff.json +0 -1
  34. /xinference/web/ui/build/static/js/{main.e51a356d.js.LICENSE.txt → main.f7da0140.js.LICENSE.txt} +0 -0
  35. {xinference-0.15.4.dist-info → xinference-0.16.0.dist-info}/LICENSE +0 -0
  36. {xinference-0.15.4.dist-info → xinference-0.16.0.dist-info}/WHEEL +0 -0
  37. {xinference-0.15.4.dist-info → xinference-0.16.0.dist-info}/entry_points.txt +0 -0
  38. {xinference-0.15.4.dist-info → xinference-0.16.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,533 @@
1
+ # Copyright 2022-2024 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import asyncio
15
+ import logging
16
+ import os
17
+ import re
18
+ import typing
19
+ from collections import deque
20
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
21
+
22
+ import numpy as np
23
+ import torch
24
+ import xoscar as xo
25
+
26
+ from ..utils import handle_image_result
27
+
28
+ if TYPE_CHECKING:
29
+ from ..stable_diffusion.core import DiffusionModel
30
+
31
+
32
+ logger = logging.getLogger(__name__)
33
+ DEFAULT_MAX_SEQUENCE_LENGTH = 512
34
+
35
+
36
+ class Text2ImageRequest:
37
+ def __init__(
38
+ self,
39
+ unique_id,
40
+ future,
41
+ prompt: str,
42
+ n: int,
43
+ size: str,
44
+ response_format: str,
45
+ *args,
46
+ **kwargs,
47
+ ):
48
+ self._unique_id = unique_id
49
+ self.future = future
50
+ self._prompt = prompt
51
+ self._n = n
52
+ self._size = size
53
+ self._response_format = response_format
54
+ self._args = args
55
+ self._kwargs = kwargs
56
+ self._width = -1
57
+ self._height = -1
58
+ self._generate_kwargs: Dict[str, Any] = {}
59
+ self._set_width_and_height()
60
+ self.is_encode = True
61
+ self.scheduler = None
62
+ self.done_steps = 0
63
+ self.total_steps = 0
64
+ self.static_tensors: Dict[str, torch.Tensor] = {}
65
+ self.timesteps = None
66
+ self.dtype = None
67
+ self.output = None
68
+ self.error_msg: Optional[str] = None
69
+ self.aborted = False
70
+
71
+ def _set_width_and_height(self):
72
+ self._width, self._height = map(int, re.split(r"[^\d]+", self._size))
73
+
74
+ def set_generate_kwargs(self, generate_kwargs: Dict):
75
+ self._generate_kwargs = {k: v for k, v in generate_kwargs.items()}
76
+
77
+ @property
78
+ def prompt(self):
79
+ return self._prompt
80
+
81
+ @property
82
+ def n(self):
83
+ return self._n
84
+
85
+ @property
86
+ def size(self):
87
+ return self._size
88
+
89
+ @property
90
+ def response_format(self):
91
+ return self._response_format
92
+
93
+ @property
94
+ def kwargs(self):
95
+ return self._kwargs
96
+
97
+ @property
98
+ def width(self):
99
+ return self._width
100
+
101
+ @property
102
+ def height(self):
103
+ return self._height
104
+
105
+ @property
106
+ def generate_kwargs(self):
107
+ return self._generate_kwargs
108
+
109
+ @property
110
+ def request_id(self):
111
+ return self._unique_id
112
+
113
+
114
+ class FluxBatchSchedulerActor(xo.StatelessActor):
115
+ @classmethod
116
+ def gen_uid(cls, model_uid: str):
117
+ return f"{model_uid}-scheduler-actor"
118
+
119
+ def __init__(self):
120
+ from ....device_utils import get_available_device
121
+
122
+ super().__init__()
123
+ self._waiting_queue: deque[Text2ImageRequest] = deque() # type: ignore
124
+ self._running_queue: deque[Text2ImageRequest] = deque() # type: ignore
125
+ self._model = None
126
+ self._available_device = get_available_device()
127
+ self._id_to_req: Dict[str, Text2ImageRequest] = {}
128
+
129
+ def set_model(self, model):
130
+ """
131
+ Must use `set_model`. Otherwise, the model will be copied once.
132
+ """
133
+ self._model = model
134
+
135
+ async def __post_create__(self):
136
+ from ....isolation import Isolation
137
+
138
+ self._isolation = Isolation(
139
+ asyncio.new_event_loop(), threaded=True, daemon=True
140
+ )
141
+ self._isolation.start()
142
+ asyncio.run_coroutine_threadsafe(self.run(), loop=self._isolation.loop)
143
+
144
+ async def __pre_destroy__(self):
145
+ try:
146
+ assert self._isolation is not None
147
+ self._isolation.stop()
148
+ del self._isolation
149
+ except Exception as e:
150
+ logger.debug(
151
+ f"Destroy scheduler actor failed, address: {self.address}, error: {e}"
152
+ )
153
+
154
+ async def add_request(self, unique_id: str, future, *args, **kwargs):
155
+ req = Text2ImageRequest(unique_id, future, *args, **kwargs)
156
+ rid = req.request_id
157
+ if rid is not None:
158
+ if rid in self._id_to_req:
159
+ raise KeyError(f"Request id: {rid} has already existed!")
160
+ self._id_to_req[rid] = req
161
+ self._waiting_queue.append(req)
162
+
163
+ async def abort_request(self, req_id: str) -> str:
164
+ """
165
+ Abort a request.
166
+ Abort a submitted request. If the request is finished or not found, this method will be a no-op.
167
+ """
168
+ from ....core.utils import AbortRequestMessage
169
+
170
+ if req_id not in self._id_to_req:
171
+ logger.info(f"Request id: {req_id} not found. No-op for xinference.")
172
+ return AbortRequestMessage.NOT_FOUND.name
173
+ else:
174
+ self._id_to_req[req_id].aborted = True
175
+ logger.info(f"Request id: {req_id} found to be aborted.")
176
+ return AbortRequestMessage.DONE.name
177
+
178
+ def _handle_request(
179
+ self,
180
+ ) -> Optional[Tuple[List[Text2ImageRequest], List[Text2ImageRequest]]]:
181
+ """
182
+ Every request may generate `n>=1` images.
183
+ Here we need to decide whether to wait or not based on the value of `n` of each request.
184
+ """
185
+ if self._model is None:
186
+ return None
187
+ max_num_images = self._model.get_max_num_images_for_batching()
188
+ cur_num_images = 0
189
+ abort_list: List[Text2ImageRequest] = []
190
+ # currently, FCFS strategy
191
+ running_list: List[Text2ImageRequest] = []
192
+ while len(self._running_queue) > 0:
193
+ req = self._running_queue.popleft()
194
+ if req.aborted:
195
+ abort_list.append(req)
196
+ else:
197
+ running_list.append(req)
198
+ cur_num_images += req.n
199
+
200
+ # Remove all the aborted requests in the waiting queue
201
+ waiting_tmp_list: List[Text2ImageRequest] = []
202
+ while len(self._waiting_queue) > 0:
203
+ req = self._waiting_queue.popleft()
204
+ if req.aborted:
205
+ abort_list.append(req)
206
+ else:
207
+ waiting_tmp_list.append(req)
208
+ self._waiting_queue.extend(waiting_tmp_list)
209
+
210
+ waiting_list: List[Text2ImageRequest] = []
211
+ while len(self._waiting_queue) > 0:
212
+ req = self._waiting_queue[0]
213
+ if req.n + cur_num_images <= max_num_images:
214
+ waiting_list.append(self._waiting_queue.popleft())
215
+ cur_num_images += req.n
216
+ else:
217
+ logger.warning(
218
+ f"Current queue is full, with an upper limit of max_num_images: {max_num_images}. "
219
+ f"Requests will continue to wait."
220
+ )
221
+ break
222
+
223
+ return waiting_list + running_list, abort_list
224
+
225
+ @staticmethod
226
+ def _empty_cache():
227
+ from ....device_utils import empty_cache
228
+
229
+ empty_cache()
230
+
231
+ async def step(self):
232
+ res = self._handle_request()
233
+ if res is None:
234
+ return
235
+ req_list, abort_list = res
236
+ # handle abort
237
+ if abort_list:
238
+ for r in abort_list:
239
+ r.future.set_exception(
240
+ RuntimeError(
241
+ f"Request: {r.request_id} has been cancelled by another `abort_request` request."
242
+ )
243
+ )
244
+ self._id_to_req.pop(r.request_id, None)
245
+ if not req_list:
246
+ return
247
+ _batch_text_to_image(self._model, req_list, self._available_device)
248
+ # handle results
249
+ for r in req_list:
250
+ if r.error_msg is not None:
251
+ r.future.set_exception(ValueError(r.error_msg))
252
+ self._id_to_req.pop(r.request_id, None)
253
+ continue
254
+ if r.output is not None:
255
+ r.future.set_result(
256
+ handle_image_result(r.response_format, r.output.images)
257
+ )
258
+ self._id_to_req.pop(r.request_id, None)
259
+ else:
260
+ self._running_queue.append(r)
261
+ self._empty_cache()
262
+
263
+ async def run(self):
264
+ try:
265
+ while True:
266
+ # wait 10ms
267
+ await asyncio.sleep(0.01)
268
+ await self.step()
269
+ except Exception as e:
270
+ logger.exception(
271
+ f"Scheduler actor uid: {self.uid}, address: {self.address} run with error: {e}"
272
+ )
273
+
274
+
275
+ def _cat_tensors(infos: List[Dict]) -> Dict:
276
+ keys = infos[0].keys()
277
+ res = {}
278
+ for k in keys:
279
+ tmp = [info[k] for info in infos]
280
+ res[k] = torch.cat(tmp)
281
+ return res
282
+
283
+
284
+ @typing.no_type_check
285
+ @torch.inference_mode()
286
+ def _batch_text_to_image_internal(
287
+ model_cls: "DiffusionModel",
288
+ req_list: List[Text2ImageRequest],
289
+ available_device: str,
290
+ ):
291
+ from diffusers.pipelines.flux.pipeline_flux import (
292
+ calculate_shift,
293
+ retrieve_timesteps,
294
+ )
295
+ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
296
+ from diffusers.schedulers.scheduling_flow_match_euler_discrete import (
297
+ FlowMatchEulerDiscreteScheduler,
298
+ )
299
+
300
+ device = model_cls._model._execution_device
301
+ height, width = req_list[0].height, req_list[0].width
302
+ cur_batch_max_sequence_length = [
303
+ r.generate_kwargs.get("max_sequence_length", DEFAULT_MAX_SEQUENCE_LENGTH)
304
+ for r in req_list
305
+ if not r.is_encode
306
+ ]
307
+ for r in req_list:
308
+ if r.is_encode:
309
+ generate_kwargs = model_cls._model_spec.default_generate_config.copy()
310
+ generate_kwargs.update({k: v for k, v in r.kwargs.items() if v is not None})
311
+ model_cls._filter_kwargs(model_cls._model, generate_kwargs)
312
+ r.set_generate_kwargs(generate_kwargs)
313
+
314
+ # check max_sequence_length
315
+ max_sequence_length = r.generate_kwargs.get(
316
+ "max_sequence_length", DEFAULT_MAX_SEQUENCE_LENGTH
317
+ )
318
+ if (
319
+ cur_batch_max_sequence_length
320
+ and max_sequence_length != cur_batch_max_sequence_length[0]
321
+ ):
322
+ r.is_encode = False
323
+ r.error_msg = (
324
+ f"The max_sequence_length of the current request: {max_sequence_length} is "
325
+ f"different from the setting in the running batch: {cur_batch_max_sequence_length[0]}, "
326
+ f"please be consistent."
327
+ )
328
+ continue
329
+
330
+ num_images_per_prompt = r.n
331
+ callback_on_step_end_tensor_inputs = r.generate_kwargs.get(
332
+ "callback_on_step_end_tensor_inputs", ["latents"]
333
+ )
334
+ num_inference_steps = r.generate_kwargs.get("num_inference_steps", 28)
335
+ guidance_scale = r.generate_kwargs.get("guidance_scale", 7.0)
336
+ generator = None
337
+ seed = r.generate_kwargs.get("seed", None)
338
+ if seed is not None:
339
+ generator = torch.Generator(device=available_device) # type: ignore
340
+ if seed != -1:
341
+ generator = generator.manual_seed(seed)
342
+ latents = None
343
+ timesteps = None
344
+
345
+ # Each request must build its own scheduler instance,
346
+ # otherwise the mixing of variables at `scheduler.STEP` will result in an error.
347
+ r.scheduler = FlowMatchEulerDiscreteScheduler(
348
+ model_cls._model.scheduler.config.num_train_timesteps,
349
+ model_cls._model.scheduler.config.shift,
350
+ model_cls._model.scheduler.config.use_dynamic_shifting,
351
+ model_cls._model.scheduler.config.base_shift,
352
+ model_cls._model.scheduler.config.max_shift,
353
+ model_cls._model.scheduler.config.base_image_seq_len,
354
+ model_cls._model.scheduler.config.max_image_seq_len,
355
+ )
356
+
357
+ # check inputs
358
+ model_cls._model.check_inputs(
359
+ r.prompt,
360
+ None,
361
+ height,
362
+ width,
363
+ prompt_embeds=None,
364
+ pooled_prompt_embeds=None,
365
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
366
+ max_sequence_length=max_sequence_length,
367
+ )
368
+
369
+ # handle prompt
370
+ (
371
+ prompt_embeds,
372
+ pooled_prompt_embeds,
373
+ text_ids,
374
+ ) = model_cls._model.encode_prompt(
375
+ prompt=r.prompt,
376
+ prompt_2=None,
377
+ prompt_embeds=None,
378
+ pooled_prompt_embeds=None,
379
+ device=device,
380
+ num_images_per_prompt=num_images_per_prompt,
381
+ max_sequence_length=max_sequence_length,
382
+ lora_scale=None,
383
+ )
384
+
385
+ # Prepare latent variables
386
+ num_channels_latents = model_cls._model.transformer.config.in_channels // 4
387
+ latents, latent_image_ids = model_cls._model.prepare_latents(
388
+ num_images_per_prompt,
389
+ num_channels_latents,
390
+ height,
391
+ width,
392
+ prompt_embeds.dtype,
393
+ device,
394
+ generator,
395
+ latents,
396
+ )
397
+
398
+ # Prepare timesteps
399
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
400
+ image_seq_len = latents.shape[1]
401
+
402
+ mu = calculate_shift(
403
+ image_seq_len,
404
+ r.scheduler.config["base_image_seq_len"],
405
+ r.scheduler.config["max_image_seq_len"],
406
+ r.scheduler.config["base_shift"],
407
+ r.scheduler.config["max_shift"],
408
+ )
409
+ timesteps, num_inference_steps = retrieve_timesteps(
410
+ r.scheduler,
411
+ num_inference_steps,
412
+ device,
413
+ timesteps,
414
+ sigmas,
415
+ mu=mu,
416
+ )
417
+
418
+ # handle guidance
419
+ if model_cls._model.transformer.config.guidance_embeds:
420
+ guidance = torch.full(
421
+ [1], guidance_scale, device=device, dtype=torch.float32
422
+ )
423
+ guidance = guidance.expand(latents.shape[0])
424
+ else:
425
+ guidance = None
426
+
427
+ r.static_tensors["latents"] = latents
428
+ r.static_tensors["guidance"] = guidance
429
+ r.static_tensors["pooled_prompt_embeds"] = pooled_prompt_embeds
430
+ r.static_tensors["prompt_embeds"] = prompt_embeds
431
+ r.static_tensors["text_ids"] = text_ids
432
+ r.static_tensors["latent_image_ids"] = latent_image_ids
433
+ r.timesteps = timesteps
434
+ r.dtype = latents.dtype
435
+ r.total_steps = len(timesteps)
436
+ r.is_encode = False
437
+
438
+ running_req_list = [r for r in req_list if r.error_msg is None]
439
+ static_tensors = _cat_tensors([r.static_tensors for r in running_req_list])
440
+
441
+ # Do a step
442
+ timestep_tmp = []
443
+ for r in running_req_list:
444
+ timestep_tmp.append(r.timesteps[r.done_steps].expand(r.n).to(r.dtype))
445
+ r.done_steps += 1
446
+ timestep = torch.cat(timestep_tmp)
447
+ noise_pred = model_cls._model.transformer(
448
+ hidden_states=static_tensors["latents"],
449
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
450
+ timestep=timestep / 1000,
451
+ guidance=static_tensors["guidance"],
452
+ pooled_projections=static_tensors["pooled_prompt_embeds"],
453
+ encoder_hidden_states=static_tensors["prompt_embeds"],
454
+ txt_ids=static_tensors["text_ids"],
455
+ img_ids=static_tensors["latent_image_ids"],
456
+ joint_attention_kwargs=None,
457
+ return_dict=False,
458
+ )[0]
459
+
460
+ # update latents
461
+ start_idx = 0
462
+ for r in running_req_list:
463
+ n = r.n
464
+ # handle diffusion scheduler step
465
+ _noise_pred = noise_pred[start_idx : start_idx + n, ::]
466
+ _timestep = timestep[start_idx]
467
+ latents_out = r.scheduler.step(
468
+ _noise_pred, _timestep, r.static_tensors["latents"], return_dict=False
469
+ )[0]
470
+ r.static_tensors["latents"] = latents_out
471
+ start_idx += n
472
+
473
+ logger.info(
474
+ f"Request {r.request_id} has done {r.done_steps} / {r.total_steps} steps."
475
+ )
476
+
477
+ # process result
478
+ if r.done_steps == r.total_steps:
479
+ output_type = r.generate_kwargs.get("output_type", "pil")
480
+ _latents = r.static_tensors["latents"]
481
+ if output_type == "latent":
482
+ image = _latents
483
+ else:
484
+ _latents = model_cls._model._unpack_latents(
485
+ _latents, height, width, model_cls._model.vae_scale_factor
486
+ )
487
+ _latents = (
488
+ _latents / model_cls._model.vae.config.scaling_factor
489
+ ) + model_cls._model.vae.config.shift_factor
490
+ image = model_cls._model.vae.decode(_latents, return_dict=False)[0]
491
+ image = model_cls._model.image_processor.postprocess(
492
+ image, output_type=output_type
493
+ )
494
+
495
+ is_padded = r.generate_kwargs.get("is_padded", None)
496
+ origin_size = r.generate_kwargs.get("origin_size", None)
497
+
498
+ if is_padded and origin_size:
499
+ new_images = []
500
+ x, y = origin_size
501
+ for img in image:
502
+ new_images.append(img.crop((0, 0, x, y)))
503
+ image = new_images
504
+
505
+ r.output = FluxPipelineOutput(images=image)
506
+ logger.info(
507
+ f"Request {r.request_id} has completed total {r.total_steps} steps."
508
+ )
509
+
510
+
511
+ def _batch_text_to_image(
512
+ model_cls: "DiffusionModel",
513
+ req_list: List[Text2ImageRequest],
514
+ available_device: str,
515
+ ):
516
+ from ....core.model import OutOfMemoryError
517
+
518
+ try:
519
+ _batch_text_to_image_internal(model_cls, req_list, available_device)
520
+ except OutOfMemoryError:
521
+ logger.exception(
522
+ f"Batch text_to_image out of memory. "
523
+ f"Xinference will restart the model: {model_cls._model_uid}. "
524
+ f"Please be patient for a few moments."
525
+ )
526
+ # Just kill the process and let xinference auto-recover the model
527
+ os._exit(1)
528
+ except Exception as e:
529
+ logger.exception(f"Internal error for batch text_to_image: {e}.")
530
+ # If internal error happens, just skip all the requests in this batch.
531
+ # If not handle here, the client will hang.
532
+ for r in req_list:
533
+ r.error_msg = str(e)
@@ -12,31 +12,24 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- import base64
16
15
  import contextlib
17
16
  import gc
18
17
  import inspect
19
18
  import itertools
20
19
  import logging
21
- import os
22
20
  import re
23
21
  import sys
24
- import time
25
- import uuid
26
22
  import warnings
27
- from concurrent.futures import ThreadPoolExecutor
28
- from functools import partial
29
- from io import BytesIO
30
23
  from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
31
24
 
32
25
  import PIL.Image
33
26
  import torch
34
27
  from PIL import ImageOps
35
28
 
36
- from ....constants import XINFERENCE_IMAGE_DIR
37
29
  from ....device_utils import get_available_device, move_model_to_available_device
38
- from ....types import Image, ImageList, LoRA
30
+ from ....types import LoRA
39
31
  from ..sdapi import SDAPIDiffusionModelMixin
32
+ from ..utils import handle_image_result
40
33
 
41
34
  if TYPE_CHECKING:
42
35
  from ....core.progress_tracker import Progressor
@@ -297,6 +290,9 @@ class DiffusionModel(SDAPIDiffusionModelMixin):
297
290
  if self._kwargs.get("vae_tiling", False):
298
291
  model.enable_vae_tiling()
299
292
 
293
+ def get_max_num_images_for_batching(self):
294
+ return self._kwargs.get("max_num_images", 16)
295
+
300
296
  @staticmethod
301
297
  def _get_scheduler(model: Any, sampler_name: str):
302
298
  if not sampler_name or sampler_name == "default":
@@ -476,28 +472,7 @@ class DiffusionModel(SDAPIDiffusionModelMixin):
476
472
  if return_images:
477
473
  return images
478
474
 
479
- if response_format == "url":
480
- os.makedirs(XINFERENCE_IMAGE_DIR, exist_ok=True)
481
- image_list = []
482
- with ThreadPoolExecutor() as executor:
483
- for img in images:
484
- path = os.path.join(XINFERENCE_IMAGE_DIR, uuid.uuid4().hex + ".jpg")
485
- image_list.append(Image(url=path, b64_json=None))
486
- executor.submit(img.save, path, "jpeg")
487
- return ImageList(created=int(time.time()), data=image_list)
488
- elif response_format == "b64_json":
489
-
490
- def _gen_base64_image(_img):
491
- buffered = BytesIO()
492
- _img.save(buffered, format="jpeg")
493
- return base64.b64encode(buffered.getvalue()).decode()
494
-
495
- with ThreadPoolExecutor() as executor:
496
- results = list(map(partial(executor.submit, _gen_base64_image), images)) # type: ignore
497
- image_list = [Image(url=None, b64_json=s.result()) for s in results] # type: ignore
498
- return ImageList(created=int(time.time()), data=image_list)
499
- else:
500
- raise ValueError(f"Unsupported response format: {response_format}")
475
+ return handle_image_result(response_format, images)
501
476
 
502
477
  @classmethod
503
478
  def _filter_kwargs(cls, model, kwargs: dict):
@@ -11,16 +11,52 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- from typing import Optional
14
+ import base64
15
+ import os
16
+ import time
17
+ import uuid
18
+ from concurrent.futures import ThreadPoolExecutor
19
+ from functools import partial
20
+ from io import BytesIO
21
+ from typing import TYPE_CHECKING, Optional
15
22
 
16
- from .core import ImageModelFamilyV1
23
+ from ...constants import XINFERENCE_IMAGE_DIR
24
+ from ...types import Image, ImageList
25
+
26
+ if TYPE_CHECKING:
27
+ from .core import ImageModelFamilyV1
17
28
 
18
29
 
19
30
  def get_model_version(
20
- image_model: ImageModelFamilyV1, controlnet: Optional[ImageModelFamilyV1]
31
+ image_model: "ImageModelFamilyV1", controlnet: Optional["ImageModelFamilyV1"]
21
32
  ) -> str:
22
33
  return (
23
34
  image_model.model_name
24
35
  if controlnet is None
25
36
  else f"{image_model.model_name}--{controlnet.model_name}"
26
37
  )
38
+
39
+
40
+ def handle_image_result(response_format: str, images) -> ImageList:
41
+ if response_format == "url":
42
+ os.makedirs(XINFERENCE_IMAGE_DIR, exist_ok=True)
43
+ image_list = []
44
+ with ThreadPoolExecutor() as executor:
45
+ for img in images:
46
+ path = os.path.join(XINFERENCE_IMAGE_DIR, uuid.uuid4().hex + ".jpg")
47
+ image_list.append(Image(url=path, b64_json=None))
48
+ executor.submit(img.save, path, "jpeg")
49
+ return ImageList(created=int(time.time()), data=image_list)
50
+ elif response_format == "b64_json":
51
+
52
+ def _gen_base64_image(_img):
53
+ buffered = BytesIO()
54
+ _img.save(buffered, format="jpeg")
55
+ return base64.b64encode(buffered.getvalue()).decode()
56
+
57
+ with ThreadPoolExecutor() as executor:
58
+ results = list(map(partial(executor.submit, _gen_base64_image), images)) # type: ignore
59
+ image_list = [Image(url=None, b64_json=s.result()) for s in results] # type: ignore
60
+ return ImageList(created=int(time.time()), data=image_list)
61
+ else:
62
+ raise ValueError(f"Unsupported response format: {response_format}")
@@ -146,6 +146,7 @@ def _install():
146
146
  from .transformers.internlm2 import Internlm2PytorchChatModel
147
147
  from .transformers.minicpmv25 import MiniCPMV25Model
148
148
  from .transformers.minicpmv26 import MiniCPMV26Model
149
+ from .transformers.opt import OptPytorchModel
149
150
  from .transformers.qwen2_audio import Qwen2AudioChatModel
150
151
  from .transformers.qwen2_vl import Qwen2VLChatModel
151
152
  from .transformers.qwen_vl import QwenVLChatModel
@@ -190,6 +191,7 @@ def _install():
190
191
  Glm4VModel,
191
192
  DeepSeekV2PytorchModel,
192
193
  DeepSeekV2PytorchChatModel,
194
+ OptPytorchModel,
193
195
  ]
194
196
  )
195
197
  if OmniLMMModel: # type: ignore