xinference 0.15.3__py3-none-any.whl → 0.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (65) hide show
  1. xinference/__init__.py +0 -4
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +29 -2
  4. xinference/client/restful/restful_client.py +10 -0
  5. xinference/constants.py +7 -3
  6. xinference/core/image_interface.py +76 -23
  7. xinference/core/model.py +158 -46
  8. xinference/core/progress_tracker.py +187 -0
  9. xinference/core/scheduler.py +10 -7
  10. xinference/core/supervisor.py +11 -0
  11. xinference/core/utils.py +9 -0
  12. xinference/core/worker.py +1 -0
  13. xinference/deploy/supervisor.py +4 -0
  14. xinference/model/__init__.py +4 -0
  15. xinference/model/audio/chattts.py +2 -1
  16. xinference/model/audio/core.py +0 -2
  17. xinference/model/audio/model_spec.json +8 -0
  18. xinference/model/audio/model_spec_modelscope.json +9 -0
  19. xinference/model/image/core.py +6 -7
  20. xinference/model/image/scheduler/__init__.py +13 -0
  21. xinference/model/image/scheduler/flux.py +533 -0
  22. xinference/model/image/sdapi.py +35 -4
  23. xinference/model/image/stable_diffusion/core.py +215 -110
  24. xinference/model/image/utils.py +39 -3
  25. xinference/model/llm/__init__.py +2 -0
  26. xinference/model/llm/llm_family.json +185 -17
  27. xinference/model/llm/llm_family_modelscope.json +124 -12
  28. xinference/model/llm/transformers/chatglm.py +104 -0
  29. xinference/model/llm/transformers/cogvlm2.py +2 -1
  30. xinference/model/llm/transformers/cogvlm2_video.py +2 -0
  31. xinference/model/llm/transformers/core.py +43 -113
  32. xinference/model/llm/transformers/deepseek_v2.py +0 -226
  33. xinference/model/llm/transformers/deepseek_vl.py +2 -0
  34. xinference/model/llm/transformers/glm4v.py +2 -1
  35. xinference/model/llm/transformers/intern_vl.py +2 -0
  36. xinference/model/llm/transformers/internlm2.py +3 -95
  37. xinference/model/llm/transformers/minicpmv25.py +2 -0
  38. xinference/model/llm/transformers/minicpmv26.py +2 -0
  39. xinference/model/llm/transformers/omnilmm.py +2 -0
  40. xinference/model/llm/transformers/opt.py +68 -0
  41. xinference/model/llm/transformers/qwen2_audio.py +11 -4
  42. xinference/model/llm/transformers/qwen2_vl.py +2 -28
  43. xinference/model/llm/transformers/qwen_vl.py +2 -1
  44. xinference/model/llm/transformers/utils.py +36 -283
  45. xinference/model/llm/transformers/yi_vl.py +2 -0
  46. xinference/model/llm/utils.py +60 -16
  47. xinference/model/llm/vllm/core.py +68 -9
  48. xinference/model/llm/vllm/utils.py +0 -1
  49. xinference/model/utils.py +7 -4
  50. xinference/model/video/core.py +0 -2
  51. xinference/utils.py +2 -3
  52. xinference/web/ui/build/asset-manifest.json +3 -3
  53. xinference/web/ui/build/index.html +1 -1
  54. xinference/web/ui/build/static/js/{main.e51a356d.js → main.f7da0140.js} +3 -3
  55. xinference/web/ui/build/static/js/main.f7da0140.js.map +1 -0
  56. xinference/web/ui/node_modules/.cache/babel-loader/331312668fa8bd3d7401818f4a25fa98135d7f61371cd6bfff78b18cf4fbdd92.json +1 -0
  57. {xinference-0.15.3.dist-info → xinference-0.16.0.dist-info}/METADATA +38 -6
  58. {xinference-0.15.3.dist-info → xinference-0.16.0.dist-info}/RECORD +63 -59
  59. xinference/web/ui/build/static/js/main.e51a356d.js.map +0 -1
  60. xinference/web/ui/node_modules/.cache/babel-loader/4385c1095eefbff0a8ec3b2964ba6e5a66a05ab31be721483ca2f43e2a91f6ff.json +0 -1
  61. /xinference/web/ui/build/static/js/{main.e51a356d.js.LICENSE.txt → main.f7da0140.js.LICENSE.txt} +0 -0
  62. {xinference-0.15.3.dist-info → xinference-0.16.0.dist-info}/LICENSE +0 -0
  63. {xinference-0.15.3.dist-info → xinference-0.16.0.dist-info}/WHEEL +0 -0
  64. {xinference-0.15.3.dist-info → xinference-0.16.0.dist-info}/entry_points.txt +0 -0
  65. {xinference-0.15.3.dist-info → xinference-0.16.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,533 @@
1
+ # Copyright 2022-2024 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import asyncio
15
+ import logging
16
+ import os
17
+ import re
18
+ import typing
19
+ from collections import deque
20
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
21
+
22
+ import numpy as np
23
+ import torch
24
+ import xoscar as xo
25
+
26
+ from ..utils import handle_image_result
27
+
28
+ if TYPE_CHECKING:
29
+ from ..stable_diffusion.core import DiffusionModel
30
+
31
+
32
+ logger = logging.getLogger(__name__)
33
+ DEFAULT_MAX_SEQUENCE_LENGTH = 512
34
+
35
+
36
+ class Text2ImageRequest:
37
+ def __init__(
38
+ self,
39
+ unique_id,
40
+ future,
41
+ prompt: str,
42
+ n: int,
43
+ size: str,
44
+ response_format: str,
45
+ *args,
46
+ **kwargs,
47
+ ):
48
+ self._unique_id = unique_id
49
+ self.future = future
50
+ self._prompt = prompt
51
+ self._n = n
52
+ self._size = size
53
+ self._response_format = response_format
54
+ self._args = args
55
+ self._kwargs = kwargs
56
+ self._width = -1
57
+ self._height = -1
58
+ self._generate_kwargs: Dict[str, Any] = {}
59
+ self._set_width_and_height()
60
+ self.is_encode = True
61
+ self.scheduler = None
62
+ self.done_steps = 0
63
+ self.total_steps = 0
64
+ self.static_tensors: Dict[str, torch.Tensor] = {}
65
+ self.timesteps = None
66
+ self.dtype = None
67
+ self.output = None
68
+ self.error_msg: Optional[str] = None
69
+ self.aborted = False
70
+
71
+ def _set_width_and_height(self):
72
+ self._width, self._height = map(int, re.split(r"[^\d]+", self._size))
73
+
74
+ def set_generate_kwargs(self, generate_kwargs: Dict):
75
+ self._generate_kwargs = {k: v for k, v in generate_kwargs.items()}
76
+
77
+ @property
78
+ def prompt(self):
79
+ return self._prompt
80
+
81
+ @property
82
+ def n(self):
83
+ return self._n
84
+
85
+ @property
86
+ def size(self):
87
+ return self._size
88
+
89
+ @property
90
+ def response_format(self):
91
+ return self._response_format
92
+
93
+ @property
94
+ def kwargs(self):
95
+ return self._kwargs
96
+
97
+ @property
98
+ def width(self):
99
+ return self._width
100
+
101
+ @property
102
+ def height(self):
103
+ return self._height
104
+
105
+ @property
106
+ def generate_kwargs(self):
107
+ return self._generate_kwargs
108
+
109
+ @property
110
+ def request_id(self):
111
+ return self._unique_id
112
+
113
+
114
+ class FluxBatchSchedulerActor(xo.StatelessActor):
115
+ @classmethod
116
+ def gen_uid(cls, model_uid: str):
117
+ return f"{model_uid}-scheduler-actor"
118
+
119
+ def __init__(self):
120
+ from ....device_utils import get_available_device
121
+
122
+ super().__init__()
123
+ self._waiting_queue: deque[Text2ImageRequest] = deque() # type: ignore
124
+ self._running_queue: deque[Text2ImageRequest] = deque() # type: ignore
125
+ self._model = None
126
+ self._available_device = get_available_device()
127
+ self._id_to_req: Dict[str, Text2ImageRequest] = {}
128
+
129
+ def set_model(self, model):
130
+ """
131
+ Must use `set_model`. Otherwise, the model will be copied once.
132
+ """
133
+ self._model = model
134
+
135
+ async def __post_create__(self):
136
+ from ....isolation import Isolation
137
+
138
+ self._isolation = Isolation(
139
+ asyncio.new_event_loop(), threaded=True, daemon=True
140
+ )
141
+ self._isolation.start()
142
+ asyncio.run_coroutine_threadsafe(self.run(), loop=self._isolation.loop)
143
+
144
+ async def __pre_destroy__(self):
145
+ try:
146
+ assert self._isolation is not None
147
+ self._isolation.stop()
148
+ del self._isolation
149
+ except Exception as e:
150
+ logger.debug(
151
+ f"Destroy scheduler actor failed, address: {self.address}, error: {e}"
152
+ )
153
+
154
+ async def add_request(self, unique_id: str, future, *args, **kwargs):
155
+ req = Text2ImageRequest(unique_id, future, *args, **kwargs)
156
+ rid = req.request_id
157
+ if rid is not None:
158
+ if rid in self._id_to_req:
159
+ raise KeyError(f"Request id: {rid} has already existed!")
160
+ self._id_to_req[rid] = req
161
+ self._waiting_queue.append(req)
162
+
163
+ async def abort_request(self, req_id: str) -> str:
164
+ """
165
+ Abort a request.
166
+ Abort a submitted request. If the request is finished or not found, this method will be a no-op.
167
+ """
168
+ from ....core.utils import AbortRequestMessage
169
+
170
+ if req_id not in self._id_to_req:
171
+ logger.info(f"Request id: {req_id} not found. No-op for xinference.")
172
+ return AbortRequestMessage.NOT_FOUND.name
173
+ else:
174
+ self._id_to_req[req_id].aborted = True
175
+ logger.info(f"Request id: {req_id} found to be aborted.")
176
+ return AbortRequestMessage.DONE.name
177
+
178
+ def _handle_request(
179
+ self,
180
+ ) -> Optional[Tuple[List[Text2ImageRequest], List[Text2ImageRequest]]]:
181
+ """
182
+ Every request may generate `n>=1` images.
183
+ Here we need to decide whether to wait or not based on the value of `n` of each request.
184
+ """
185
+ if self._model is None:
186
+ return None
187
+ max_num_images = self._model.get_max_num_images_for_batching()
188
+ cur_num_images = 0
189
+ abort_list: List[Text2ImageRequest] = []
190
+ # currently, FCFS strategy
191
+ running_list: List[Text2ImageRequest] = []
192
+ while len(self._running_queue) > 0:
193
+ req = self._running_queue.popleft()
194
+ if req.aborted:
195
+ abort_list.append(req)
196
+ else:
197
+ running_list.append(req)
198
+ cur_num_images += req.n
199
+
200
+ # Remove all the aborted requests in the waiting queue
201
+ waiting_tmp_list: List[Text2ImageRequest] = []
202
+ while len(self._waiting_queue) > 0:
203
+ req = self._waiting_queue.popleft()
204
+ if req.aborted:
205
+ abort_list.append(req)
206
+ else:
207
+ waiting_tmp_list.append(req)
208
+ self._waiting_queue.extend(waiting_tmp_list)
209
+
210
+ waiting_list: List[Text2ImageRequest] = []
211
+ while len(self._waiting_queue) > 0:
212
+ req = self._waiting_queue[0]
213
+ if req.n + cur_num_images <= max_num_images:
214
+ waiting_list.append(self._waiting_queue.popleft())
215
+ cur_num_images += req.n
216
+ else:
217
+ logger.warning(
218
+ f"Current queue is full, with an upper limit of max_num_images: {max_num_images}. "
219
+ f"Requests will continue to wait."
220
+ )
221
+ break
222
+
223
+ return waiting_list + running_list, abort_list
224
+
225
+ @staticmethod
226
+ def _empty_cache():
227
+ from ....device_utils import empty_cache
228
+
229
+ empty_cache()
230
+
231
+ async def step(self):
232
+ res = self._handle_request()
233
+ if res is None:
234
+ return
235
+ req_list, abort_list = res
236
+ # handle abort
237
+ if abort_list:
238
+ for r in abort_list:
239
+ r.future.set_exception(
240
+ RuntimeError(
241
+ f"Request: {r.request_id} has been cancelled by another `abort_request` request."
242
+ )
243
+ )
244
+ self._id_to_req.pop(r.request_id, None)
245
+ if not req_list:
246
+ return
247
+ _batch_text_to_image(self._model, req_list, self._available_device)
248
+ # handle results
249
+ for r in req_list:
250
+ if r.error_msg is not None:
251
+ r.future.set_exception(ValueError(r.error_msg))
252
+ self._id_to_req.pop(r.request_id, None)
253
+ continue
254
+ if r.output is not None:
255
+ r.future.set_result(
256
+ handle_image_result(r.response_format, r.output.images)
257
+ )
258
+ self._id_to_req.pop(r.request_id, None)
259
+ else:
260
+ self._running_queue.append(r)
261
+ self._empty_cache()
262
+
263
+ async def run(self):
264
+ try:
265
+ while True:
266
+ # wait 10ms
267
+ await asyncio.sleep(0.01)
268
+ await self.step()
269
+ except Exception as e:
270
+ logger.exception(
271
+ f"Scheduler actor uid: {self.uid}, address: {self.address} run with error: {e}"
272
+ )
273
+
274
+
275
+ def _cat_tensors(infos: List[Dict]) -> Dict:
276
+ keys = infos[0].keys()
277
+ res = {}
278
+ for k in keys:
279
+ tmp = [info[k] for info in infos]
280
+ res[k] = torch.cat(tmp)
281
+ return res
282
+
283
+
284
+ @typing.no_type_check
285
+ @torch.inference_mode()
286
+ def _batch_text_to_image_internal(
287
+ model_cls: "DiffusionModel",
288
+ req_list: List[Text2ImageRequest],
289
+ available_device: str,
290
+ ):
291
+ from diffusers.pipelines.flux.pipeline_flux import (
292
+ calculate_shift,
293
+ retrieve_timesteps,
294
+ )
295
+ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
296
+ from diffusers.schedulers.scheduling_flow_match_euler_discrete import (
297
+ FlowMatchEulerDiscreteScheduler,
298
+ )
299
+
300
+ device = model_cls._model._execution_device
301
+ height, width = req_list[0].height, req_list[0].width
302
+ cur_batch_max_sequence_length = [
303
+ r.generate_kwargs.get("max_sequence_length", DEFAULT_MAX_SEQUENCE_LENGTH)
304
+ for r in req_list
305
+ if not r.is_encode
306
+ ]
307
+ for r in req_list:
308
+ if r.is_encode:
309
+ generate_kwargs = model_cls._model_spec.default_generate_config.copy()
310
+ generate_kwargs.update({k: v for k, v in r.kwargs.items() if v is not None})
311
+ model_cls._filter_kwargs(model_cls._model, generate_kwargs)
312
+ r.set_generate_kwargs(generate_kwargs)
313
+
314
+ # check max_sequence_length
315
+ max_sequence_length = r.generate_kwargs.get(
316
+ "max_sequence_length", DEFAULT_MAX_SEQUENCE_LENGTH
317
+ )
318
+ if (
319
+ cur_batch_max_sequence_length
320
+ and max_sequence_length != cur_batch_max_sequence_length[0]
321
+ ):
322
+ r.is_encode = False
323
+ r.error_msg = (
324
+ f"The max_sequence_length of the current request: {max_sequence_length} is "
325
+ f"different from the setting in the running batch: {cur_batch_max_sequence_length[0]}, "
326
+ f"please be consistent."
327
+ )
328
+ continue
329
+
330
+ num_images_per_prompt = r.n
331
+ callback_on_step_end_tensor_inputs = r.generate_kwargs.get(
332
+ "callback_on_step_end_tensor_inputs", ["latents"]
333
+ )
334
+ num_inference_steps = r.generate_kwargs.get("num_inference_steps", 28)
335
+ guidance_scale = r.generate_kwargs.get("guidance_scale", 7.0)
336
+ generator = None
337
+ seed = r.generate_kwargs.get("seed", None)
338
+ if seed is not None:
339
+ generator = torch.Generator(device=available_device) # type: ignore
340
+ if seed != -1:
341
+ generator = generator.manual_seed(seed)
342
+ latents = None
343
+ timesteps = None
344
+
345
+ # Each request must build its own scheduler instance,
346
+ # otherwise the mixing of variables at `scheduler.STEP` will result in an error.
347
+ r.scheduler = FlowMatchEulerDiscreteScheduler(
348
+ model_cls._model.scheduler.config.num_train_timesteps,
349
+ model_cls._model.scheduler.config.shift,
350
+ model_cls._model.scheduler.config.use_dynamic_shifting,
351
+ model_cls._model.scheduler.config.base_shift,
352
+ model_cls._model.scheduler.config.max_shift,
353
+ model_cls._model.scheduler.config.base_image_seq_len,
354
+ model_cls._model.scheduler.config.max_image_seq_len,
355
+ )
356
+
357
+ # check inputs
358
+ model_cls._model.check_inputs(
359
+ r.prompt,
360
+ None,
361
+ height,
362
+ width,
363
+ prompt_embeds=None,
364
+ pooled_prompt_embeds=None,
365
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
366
+ max_sequence_length=max_sequence_length,
367
+ )
368
+
369
+ # handle prompt
370
+ (
371
+ prompt_embeds,
372
+ pooled_prompt_embeds,
373
+ text_ids,
374
+ ) = model_cls._model.encode_prompt(
375
+ prompt=r.prompt,
376
+ prompt_2=None,
377
+ prompt_embeds=None,
378
+ pooled_prompt_embeds=None,
379
+ device=device,
380
+ num_images_per_prompt=num_images_per_prompt,
381
+ max_sequence_length=max_sequence_length,
382
+ lora_scale=None,
383
+ )
384
+
385
+ # Prepare latent variables
386
+ num_channels_latents = model_cls._model.transformer.config.in_channels // 4
387
+ latents, latent_image_ids = model_cls._model.prepare_latents(
388
+ num_images_per_prompt,
389
+ num_channels_latents,
390
+ height,
391
+ width,
392
+ prompt_embeds.dtype,
393
+ device,
394
+ generator,
395
+ latents,
396
+ )
397
+
398
+ # Prepare timesteps
399
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
400
+ image_seq_len = latents.shape[1]
401
+
402
+ mu = calculate_shift(
403
+ image_seq_len,
404
+ r.scheduler.config["base_image_seq_len"],
405
+ r.scheduler.config["max_image_seq_len"],
406
+ r.scheduler.config["base_shift"],
407
+ r.scheduler.config["max_shift"],
408
+ )
409
+ timesteps, num_inference_steps = retrieve_timesteps(
410
+ r.scheduler,
411
+ num_inference_steps,
412
+ device,
413
+ timesteps,
414
+ sigmas,
415
+ mu=mu,
416
+ )
417
+
418
+ # handle guidance
419
+ if model_cls._model.transformer.config.guidance_embeds:
420
+ guidance = torch.full(
421
+ [1], guidance_scale, device=device, dtype=torch.float32
422
+ )
423
+ guidance = guidance.expand(latents.shape[0])
424
+ else:
425
+ guidance = None
426
+
427
+ r.static_tensors["latents"] = latents
428
+ r.static_tensors["guidance"] = guidance
429
+ r.static_tensors["pooled_prompt_embeds"] = pooled_prompt_embeds
430
+ r.static_tensors["prompt_embeds"] = prompt_embeds
431
+ r.static_tensors["text_ids"] = text_ids
432
+ r.static_tensors["latent_image_ids"] = latent_image_ids
433
+ r.timesteps = timesteps
434
+ r.dtype = latents.dtype
435
+ r.total_steps = len(timesteps)
436
+ r.is_encode = False
437
+
438
+ running_req_list = [r for r in req_list if r.error_msg is None]
439
+ static_tensors = _cat_tensors([r.static_tensors for r in running_req_list])
440
+
441
+ # Do a step
442
+ timestep_tmp = []
443
+ for r in running_req_list:
444
+ timestep_tmp.append(r.timesteps[r.done_steps].expand(r.n).to(r.dtype))
445
+ r.done_steps += 1
446
+ timestep = torch.cat(timestep_tmp)
447
+ noise_pred = model_cls._model.transformer(
448
+ hidden_states=static_tensors["latents"],
449
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
450
+ timestep=timestep / 1000,
451
+ guidance=static_tensors["guidance"],
452
+ pooled_projections=static_tensors["pooled_prompt_embeds"],
453
+ encoder_hidden_states=static_tensors["prompt_embeds"],
454
+ txt_ids=static_tensors["text_ids"],
455
+ img_ids=static_tensors["latent_image_ids"],
456
+ joint_attention_kwargs=None,
457
+ return_dict=False,
458
+ )[0]
459
+
460
+ # update latents
461
+ start_idx = 0
462
+ for r in running_req_list:
463
+ n = r.n
464
+ # handle diffusion scheduler step
465
+ _noise_pred = noise_pred[start_idx : start_idx + n, ::]
466
+ _timestep = timestep[start_idx]
467
+ latents_out = r.scheduler.step(
468
+ _noise_pred, _timestep, r.static_tensors["latents"], return_dict=False
469
+ )[0]
470
+ r.static_tensors["latents"] = latents_out
471
+ start_idx += n
472
+
473
+ logger.info(
474
+ f"Request {r.request_id} has done {r.done_steps} / {r.total_steps} steps."
475
+ )
476
+
477
+ # process result
478
+ if r.done_steps == r.total_steps:
479
+ output_type = r.generate_kwargs.get("output_type", "pil")
480
+ _latents = r.static_tensors["latents"]
481
+ if output_type == "latent":
482
+ image = _latents
483
+ else:
484
+ _latents = model_cls._model._unpack_latents(
485
+ _latents, height, width, model_cls._model.vae_scale_factor
486
+ )
487
+ _latents = (
488
+ _latents / model_cls._model.vae.config.scaling_factor
489
+ ) + model_cls._model.vae.config.shift_factor
490
+ image = model_cls._model.vae.decode(_latents, return_dict=False)[0]
491
+ image = model_cls._model.image_processor.postprocess(
492
+ image, output_type=output_type
493
+ )
494
+
495
+ is_padded = r.generate_kwargs.get("is_padded", None)
496
+ origin_size = r.generate_kwargs.get("origin_size", None)
497
+
498
+ if is_padded and origin_size:
499
+ new_images = []
500
+ x, y = origin_size
501
+ for img in image:
502
+ new_images.append(img.crop((0, 0, x, y)))
503
+ image = new_images
504
+
505
+ r.output = FluxPipelineOutput(images=image)
506
+ logger.info(
507
+ f"Request {r.request_id} has completed total {r.total_steps} steps."
508
+ )
509
+
510
+
511
+ def _batch_text_to_image(
512
+ model_cls: "DiffusionModel",
513
+ req_list: List[Text2ImageRequest],
514
+ available_device: str,
515
+ ):
516
+ from ....core.model import OutOfMemoryError
517
+
518
+ try:
519
+ _batch_text_to_image_internal(model_cls, req_list, available_device)
520
+ except OutOfMemoryError:
521
+ logger.exception(
522
+ f"Batch text_to_image out of memory. "
523
+ f"Xinference will restart the model: {model_cls._model_uid}. "
524
+ f"Please be patient for a few moments."
525
+ )
526
+ # Just kill the process and let xinference auto-recover the model
527
+ os._exit(1)
528
+ except Exception as e:
529
+ logger.exception(f"Internal error for batch text_to_image: {e}.")
530
+ # If internal error happens, just skip all the requests in this batch.
531
+ # If not handle here, the client will hang.
532
+ for r in req_list:
533
+ r.error_msg = str(e)
@@ -11,11 +11,12 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+
14
15
  import base64
15
16
  import io
16
17
  import warnings
17
18
 
18
- from PIL import Image
19
+ from PIL import Image, ImageOps
19
20
 
20
21
 
21
22
  class SDAPIToDiffusersConverter:
@@ -26,11 +27,12 @@ class SDAPIToDiffusersConverter:
26
27
  "width",
27
28
  "height",
28
29
  "sampler_name",
30
+ "progressor",
29
31
  }
30
32
  txt2img_arg_mapping = {
31
33
  "steps": "num_inference_steps",
32
34
  "cfg_scale": "guidance_scale",
33
- # "denoising_strength": "strength",
35
+ "denoising_strength": "strength",
34
36
  }
35
37
  img2img_identical_args = {
36
38
  "prompt",
@@ -39,12 +41,15 @@ class SDAPIToDiffusersConverter:
39
41
  "width",
40
42
  "height",
41
43
  "sampler_name",
44
+ "progressor",
42
45
  }
43
46
  img2img_arg_mapping = {
44
47
  "init_images": "image",
48
+ "mask": "mask_image",
45
49
  "steps": "num_inference_steps",
46
50
  "cfg_scale": "guidance_scale",
47
51
  "denoising_strength": "strength",
52
+ "inpaint_full_res_padding": "padding_mask_crop",
48
53
  }
49
54
 
50
55
  @staticmethod
@@ -121,12 +126,38 @@ class SDAPIDiffusionModelMixin:
121
126
 
122
127
  def img2img(self, **kwargs):
123
128
  init_images = kwargs.pop("init_images", [])
124
- kwargs["init_images"] = [self._decode_b64_img(i) for i in init_images]
129
+ kwargs["init_images"] = init_images = [
130
+ self._decode_b64_img(i) for i in init_images
131
+ ]
132
+ if len(init_images) == 1:
133
+ kwargs["init_images"] = init_images[0]
134
+ mask_image = kwargs.pop("mask", None)
135
+ if mask_image:
136
+ if kwargs.pop("inpainting_mask_invert"):
137
+ mask_image = ImageOps.invert(mask_image)
138
+
139
+ kwargs["mask"] = self._decode_b64_img(mask_image)
140
+
141
+ # process inpaint_full_res and inpaint_full_res_padding
142
+ if kwargs.pop("inpaint_full_res", None):
143
+ kwargs["inpaint_full_res_padding"] = kwargs.pop(
144
+ "inpaint_full_res_padding", 0
145
+ )
146
+ else:
147
+ # inpaint_full_res_padding is turned `into padding_mask_crop`
148
+ # in diffusers, if padding_mask_crop is passed, it will do inpaint_full_res
149
+ # so if not inpaint_full_rs, we need to pop this option
150
+ kwargs.pop("inpaint_full_res_padding", None)
151
+
125
152
  clip_skip = kwargs.get("override_settings", {}).get("clip_skip")
126
153
  converted_kwargs = self._check_kwargs("img2img", kwargs)
127
154
  if clip_skip:
128
155
  converted_kwargs["clip_skip"] = clip_skip
129
- result = self.image_to_image(response_format="b64_json", **converted_kwargs) # type: ignore
156
+
157
+ if not converted_kwargs.get("mask_image"):
158
+ result = self.image_to_image(response_format="b64_json", **converted_kwargs) # type: ignore
159
+ else:
160
+ result = self.inpainting(response_format="b64_json", **converted_kwargs) # type: ignore
130
161
 
131
162
  # convert to SD API result
132
163
  return {