xinference 0.15.3__py3-none-any.whl → 0.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (65) hide show
  1. xinference/__init__.py +0 -4
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +29 -2
  4. xinference/client/restful/restful_client.py +10 -0
  5. xinference/constants.py +7 -3
  6. xinference/core/image_interface.py +76 -23
  7. xinference/core/model.py +158 -46
  8. xinference/core/progress_tracker.py +187 -0
  9. xinference/core/scheduler.py +10 -7
  10. xinference/core/supervisor.py +11 -0
  11. xinference/core/utils.py +9 -0
  12. xinference/core/worker.py +1 -0
  13. xinference/deploy/supervisor.py +4 -0
  14. xinference/model/__init__.py +4 -0
  15. xinference/model/audio/chattts.py +2 -1
  16. xinference/model/audio/core.py +0 -2
  17. xinference/model/audio/model_spec.json +8 -0
  18. xinference/model/audio/model_spec_modelscope.json +9 -0
  19. xinference/model/image/core.py +6 -7
  20. xinference/model/image/scheduler/__init__.py +13 -0
  21. xinference/model/image/scheduler/flux.py +533 -0
  22. xinference/model/image/sdapi.py +35 -4
  23. xinference/model/image/stable_diffusion/core.py +215 -110
  24. xinference/model/image/utils.py +39 -3
  25. xinference/model/llm/__init__.py +2 -0
  26. xinference/model/llm/llm_family.json +185 -17
  27. xinference/model/llm/llm_family_modelscope.json +124 -12
  28. xinference/model/llm/transformers/chatglm.py +104 -0
  29. xinference/model/llm/transformers/cogvlm2.py +2 -1
  30. xinference/model/llm/transformers/cogvlm2_video.py +2 -0
  31. xinference/model/llm/transformers/core.py +43 -113
  32. xinference/model/llm/transformers/deepseek_v2.py +0 -226
  33. xinference/model/llm/transformers/deepseek_vl.py +2 -0
  34. xinference/model/llm/transformers/glm4v.py +2 -1
  35. xinference/model/llm/transformers/intern_vl.py +2 -0
  36. xinference/model/llm/transformers/internlm2.py +3 -95
  37. xinference/model/llm/transformers/minicpmv25.py +2 -0
  38. xinference/model/llm/transformers/minicpmv26.py +2 -0
  39. xinference/model/llm/transformers/omnilmm.py +2 -0
  40. xinference/model/llm/transformers/opt.py +68 -0
  41. xinference/model/llm/transformers/qwen2_audio.py +11 -4
  42. xinference/model/llm/transformers/qwen2_vl.py +2 -28
  43. xinference/model/llm/transformers/qwen_vl.py +2 -1
  44. xinference/model/llm/transformers/utils.py +36 -283
  45. xinference/model/llm/transformers/yi_vl.py +2 -0
  46. xinference/model/llm/utils.py +60 -16
  47. xinference/model/llm/vllm/core.py +68 -9
  48. xinference/model/llm/vllm/utils.py +0 -1
  49. xinference/model/utils.py +7 -4
  50. xinference/model/video/core.py +0 -2
  51. xinference/utils.py +2 -3
  52. xinference/web/ui/build/asset-manifest.json +3 -3
  53. xinference/web/ui/build/index.html +1 -1
  54. xinference/web/ui/build/static/js/{main.e51a356d.js → main.f7da0140.js} +3 -3
  55. xinference/web/ui/build/static/js/main.f7da0140.js.map +1 -0
  56. xinference/web/ui/node_modules/.cache/babel-loader/331312668fa8bd3d7401818f4a25fa98135d7f61371cd6bfff78b18cf4fbdd92.json +1 -0
  57. {xinference-0.15.3.dist-info → xinference-0.16.0.dist-info}/METADATA +38 -6
  58. {xinference-0.15.3.dist-info → xinference-0.16.0.dist-info}/RECORD +63 -59
  59. xinference/web/ui/build/static/js/main.e51a356d.js.map +0 -1
  60. xinference/web/ui/node_modules/.cache/babel-loader/4385c1095eefbff0a8ec3b2964ba6e5a66a05ab31be721483ca2f43e2a91f6ff.json +0 -1
  61. /xinference/web/ui/build/static/js/{main.e51a356d.js.LICENSE.txt → main.f7da0140.js.LICENSE.txt} +0 -0
  62. {xinference-0.15.3.dist-info → xinference-0.16.0.dist-info}/LICENSE +0 -0
  63. {xinference-0.15.3.dist-info → xinference-0.16.0.dist-info}/WHEEL +0 -0
  64. {xinference-0.15.3.dist-info → xinference-0.16.0.dist-info}/entry_points.txt +0 -0
  65. {xinference-0.15.3.dist-info → xinference-0.16.0.dist-info}/top_level.txt +0 -0
xinference/__init__.py CHANGED
@@ -26,13 +26,9 @@ except:
26
26
  def _install():
27
27
  from xoscar.backends.router import Router
28
28
 
29
- from .model import _install as install_model
30
-
31
29
  default_router = Router.get_instance_or_empty()
32
30
  Router.set_instance(default_router)
33
31
 
34
- install_model()
35
-
36
32
 
37
33
  _install()
38
34
  del _install
xinference/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2024-09-30T20:17:26+0800",
11
+ "date": "2024-10-18T12:49:02+0800",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "00a9ee15279a60a6d75393c4720d8da5cbbf5796",
15
- "version": "0.15.3"
14
+ "full-revisionid": "5f7dea44832a1c41f887b9a01377191894550057",
15
+ "version": "0.16.0"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -524,6 +524,16 @@ class RESTfulAPI:
524
524
  else None
525
525
  ),
526
526
  )
527
+ self._router.add_api_route(
528
+ "/v1/requests/{request_id}/progress",
529
+ self.get_progress,
530
+ methods=["get"],
531
+ dependencies=(
532
+ [Security(self._auth_service, scopes=["models:read"])]
533
+ if self.is_authenticated()
534
+ else None
535
+ ),
536
+ )
527
537
  self._router.add_api_route(
528
538
  "/v1/images/generations",
529
539
  self.create_images,
@@ -1486,6 +1496,17 @@ class RESTfulAPI:
1486
1496
  await self._report_error_event(model_uid, str(e))
1487
1497
  raise HTTPException(status_code=500, detail=str(e))
1488
1498
 
1499
+ async def get_progress(self, request_id: str) -> JSONResponse:
1500
+ try:
1501
+ supervisor_ref = await self._get_supervisor_ref()
1502
+ result = {"progress": await supervisor_ref.get_progress(request_id)}
1503
+ return JSONResponse(content=result)
1504
+ except KeyError as e:
1505
+ raise HTTPException(status_code=400, detail=str(e))
1506
+ except Exception as e:
1507
+ logger.error(e, exc_info=True)
1508
+ raise HTTPException(status_code=500, detail=str(e))
1509
+
1489
1510
  async def create_images(self, request: Request) -> Response:
1490
1511
  body = TextToImageRequest.parse_obj(await request.json())
1491
1512
  model_uid = body.model
@@ -1853,10 +1874,16 @@ class RESTfulAPI:
1853
1874
  await self._report_error_event(model_uid, str(e))
1854
1875
  raise HTTPException(status_code=500, detail=str(e))
1855
1876
 
1856
- from ..model.llm.utils import GLM4_TOOL_CALL_FAMILY, QWEN_TOOL_CALL_FAMILY
1877
+ from ..model.llm.utils import (
1878
+ GLM4_TOOL_CALL_FAMILY,
1879
+ LLAMA3_TOOL_CALL_FAMILY,
1880
+ QWEN_TOOL_CALL_FAMILY,
1881
+ )
1857
1882
 
1858
1883
  model_family = desc.get("model_family", "")
1859
- function_call_models = QWEN_TOOL_CALL_FAMILY + GLM4_TOOL_CALL_FAMILY
1884
+ function_call_models = (
1885
+ QWEN_TOOL_CALL_FAMILY + GLM4_TOOL_CALL_FAMILY + LLAMA3_TOOL_CALL_FAMILY
1886
+ )
1860
1887
 
1861
1888
  if model_family not in function_call_models:
1862
1889
  if body.tools:
@@ -1385,6 +1385,16 @@ class Client:
1385
1385
  response_json = response.json()
1386
1386
  return response_json
1387
1387
 
1388
+ def get_progress(self, request_id: str):
1389
+ url = f"{self.base_url}/v1/requests/{request_id}/progress"
1390
+ response = requests.get(url, headers=self._headers)
1391
+ if response.status_code != 200:
1392
+ raise RuntimeError(
1393
+ f"Failed to get progress, detail: {_get_error_string(response)}"
1394
+ )
1395
+ response_json = response.json()
1396
+ return response_json
1397
+
1388
1398
  def abort_cluster(self):
1389
1399
  url = f"{self.base_url}/v1/clusters"
1390
1400
  response = requests.delete(url, headers=self._headers)
xinference/constants.py CHANGED
@@ -27,7 +27,8 @@ XINFERENCE_ENV_HEALTH_CHECK_INTERVAL = "XINFERENCE_HEALTH_CHECK_INTERVAL"
27
27
  XINFERENCE_ENV_HEALTH_CHECK_TIMEOUT = "XINFERENCE_HEALTH_CHECK_TIMEOUT"
28
28
  XINFERENCE_ENV_DISABLE_HEALTH_CHECK = "XINFERENCE_DISABLE_HEALTH_CHECK"
29
29
  XINFERENCE_ENV_DISABLE_METRICS = "XINFERENCE_DISABLE_METRICS"
30
- XINFERENCE_ENV_TRANSFORMERS_ENABLE_BATCHING = "XINFERENCE_TRANSFORMERS_ENABLE_BATCHING"
30
+ XINFERENCE_ENV_DOWNLOAD_MAX_ATTEMPTS = "XINFERENCE_DOWNLOAD_MAX_ATTEMPTS"
31
+ XINFERENCE_ENV_TEXT_TO_IMAGE_BATCHING_SIZE = "XINFERENCE_TEXT_TO_IMAGE_BATCHING_SIZE"
31
32
 
32
33
 
33
34
  def get_xinference_home() -> str:
@@ -79,6 +80,9 @@ XINFERENCE_DISABLE_HEALTH_CHECK = bool(
79
80
  XINFERENCE_DISABLE_METRICS = bool(
80
81
  int(os.environ.get(XINFERENCE_ENV_DISABLE_METRICS, 0))
81
82
  )
82
- XINFERENCE_TRANSFORMERS_ENABLE_BATCHING = bool(
83
- int(os.environ.get(XINFERENCE_ENV_TRANSFORMERS_ENABLE_BATCHING, 0))
83
+ XINFERENCE_DOWNLOAD_MAX_ATTEMPTS = int(
84
+ os.environ.get(XINFERENCE_ENV_DOWNLOAD_MAX_ATTEMPTS, 3)
85
+ )
86
+ XINFERENCE_TEXT_TO_IMAGE_BATCHING_SIZE = os.environ.get(
87
+ XINFERENCE_ENV_TEXT_TO_IMAGE_BATCHING_SIZE, None
84
88
  )
@@ -16,6 +16,9 @@ import base64
16
16
  import io
17
17
  import logging
18
18
  import os
19
+ import threading
20
+ import time
21
+ import uuid
19
22
  from typing import Dict, List, Optional, Union
20
23
 
21
24
  import gradio as gr
@@ -84,6 +87,7 @@ class ImageInterface:
84
87
  num_inference_steps: int,
85
88
  negative_prompt: Optional[str] = None,
86
89
  sampler_name: Optional[str] = None,
90
+ progress=gr.Progress(),
87
91
  ) -> PIL.Image.Image:
88
92
  from ..client import RESTfulClient
89
93
 
@@ -99,19 +103,43 @@ class ImageInterface:
99
103
  )
100
104
  sampler_name = None if sampler_name == "default" else sampler_name
101
105
 
102
- response = model.text_to_image(
103
- prompt=prompt,
104
- n=n,
105
- size=size,
106
- num_inference_steps=num_inference_steps,
107
- guidance_scale=guidance_scale,
108
- negative_prompt=negative_prompt,
109
- sampler_name=sampler_name,
110
- response_format="b64_json",
111
- )
106
+ response = None
107
+ exc = None
108
+ request_id = str(uuid.uuid4())
109
+
110
+ def run_in_thread():
111
+ nonlocal exc, response
112
+ try:
113
+ response = model.text_to_image(
114
+ request_id=request_id,
115
+ prompt=prompt,
116
+ n=n,
117
+ size=size,
118
+ num_inference_steps=num_inference_steps,
119
+ guidance_scale=guidance_scale,
120
+ negative_prompt=negative_prompt,
121
+ sampler_name=sampler_name,
122
+ response_format="b64_json",
123
+ )
124
+ except Exception as e:
125
+ exc = e
126
+
127
+ t = threading.Thread(target=run_in_thread)
128
+ t.start()
129
+ while t.is_alive():
130
+ try:
131
+ cur_progress = client.get_progress(request_id)["progress"]
132
+ except (KeyError, RuntimeError):
133
+ cur_progress = 0.0
134
+
135
+ progress(cur_progress, desc="Generating images")
136
+ time.sleep(1)
137
+
138
+ if exc:
139
+ raise exc
112
140
 
113
141
  images = []
114
- for image_dict in response["data"]:
142
+ for image_dict in response["data"]: # type: ignore
115
143
  assert image_dict["b64_json"] is not None
116
144
  image_data = base64.b64decode(image_dict["b64_json"])
117
145
  image = PIL.Image.open(io.BytesIO(image_data))
@@ -184,6 +212,7 @@ class ImageInterface:
184
212
  num_inference_steps: int,
185
213
  padding_image_to_multiple: int,
186
214
  sampler_name: Optional[str] = None,
215
+ progress=gr.Progress(),
187
216
  ) -> PIL.Image.Image:
188
217
  from ..client import RESTfulClient
189
218
 
@@ -205,20 +234,44 @@ class ImageInterface:
205
234
  bio = io.BytesIO()
206
235
  image.save(bio, format="png")
207
236
 
208
- response = model.image_to_image(
209
- prompt=prompt,
210
- negative_prompt=negative_prompt,
211
- n=n,
212
- image=bio.getvalue(),
213
- size=size,
214
- response_format="b64_json",
215
- num_inference_steps=num_inference_steps,
216
- padding_image_to_multiple=padding_image_to_multiple,
217
- sampler_name=sampler_name,
218
- )
237
+ response = None
238
+ exc = None
239
+ request_id = str(uuid.uuid4())
240
+
241
+ def run_in_thread():
242
+ nonlocal exc, response
243
+ try:
244
+ response = model.image_to_image(
245
+ request_id=request_id,
246
+ prompt=prompt,
247
+ negative_prompt=negative_prompt,
248
+ n=n,
249
+ image=bio.getvalue(),
250
+ size=size,
251
+ response_format="b64_json",
252
+ num_inference_steps=num_inference_steps,
253
+ padding_image_to_multiple=padding_image_to_multiple,
254
+ sampler_name=sampler_name,
255
+ )
256
+ except Exception as e:
257
+ exc = e
258
+
259
+ t = threading.Thread(target=run_in_thread)
260
+ t.start()
261
+ while t.is_alive():
262
+ try:
263
+ cur_progress = client.get_progress(request_id)["progress"]
264
+ except (KeyError, RuntimeError):
265
+ cur_progress = 0.0
266
+
267
+ progress(cur_progress, desc="Generating images")
268
+ time.sleep(1)
269
+
270
+ if exc:
271
+ raise exc
219
272
 
220
273
  images = []
221
- for image_dict in response["data"]:
274
+ for image_dict in response["data"]: # type: ignore
222
275
  assert image_dict["b64_json"] is not None
223
276
  image_data = base64.b64decode(image_dict["b64_json"])
224
277
  image = PIL.Image.open(io.BytesIO(image_data))
xinference/core/model.py CHANGED
@@ -41,9 +41,10 @@ from typing import (
41
41
  import sse_starlette.sse
42
42
  import xoscar as xo
43
43
 
44
- from ..constants import XINFERENCE_TRANSFORMERS_ENABLE_BATCHING
44
+ from ..constants import XINFERENCE_TEXT_TO_IMAGE_BATCHING_SIZE
45
45
 
46
46
  if TYPE_CHECKING:
47
+ from .progress_tracker import ProgressTrackerActor
47
48
  from .worker import WorkerActor
48
49
  from ..model.llm.core import LLM
49
50
  from ..model.core import ModelDescription
@@ -73,6 +74,8 @@ XINFERENCE_BATCHING_ALLOWED_VISION_MODELS = [
73
74
  "MiniCPM-V-2.6",
74
75
  ]
75
76
 
77
+ XINFERENCE_TEXT_TO_IMAGE_BATCHING_ALLOWED_MODELS = ["FLUX.1-dev", "FLUX.1-schnell"]
78
+
76
79
 
77
80
  def request_limit(fn):
78
81
  """
@@ -152,6 +155,16 @@ class ModelActor(xo.StatelessActor):
152
155
  f"Destroy scheduler actor failed, address: {self.address}, error: {e}"
153
156
  )
154
157
 
158
+ if self.allow_batching_for_text_to_image():
159
+ try:
160
+ assert self._text_to_image_scheduler_ref is not None
161
+ await xo.destroy_actor(self._text_to_image_scheduler_ref)
162
+ del self._text_to_image_scheduler_ref
163
+ except Exception as e:
164
+ logger.debug(
165
+ f"Destroy text_to_image scheduler actor failed, address: {self.address}, error: {e}"
166
+ )
167
+
155
168
  if hasattr(self._model, "stop") and callable(self._model.stop):
156
169
  self._model.stop()
157
170
 
@@ -177,6 +190,7 @@ class ModelActor(xo.StatelessActor):
177
190
 
178
191
  def __init__(
179
192
  self,
193
+ supervisor_address: str,
180
194
  worker_address: str,
181
195
  model: "LLM",
182
196
  model_description: Optional["ModelDescription"] = None,
@@ -188,6 +202,7 @@ class ModelActor(xo.StatelessActor):
188
202
  from ..model.llm.transformers.core import PytorchModel
189
203
  from ..model.llm.vllm.core import VLLMModel
190
204
 
205
+ self._supervisor_address = supervisor_address
191
206
  self._worker_address = worker_address
192
207
  self._model = model
193
208
  self._model_description = (
@@ -205,6 +220,7 @@ class ModelActor(xo.StatelessActor):
205
220
  else asyncio.locks.Lock()
206
221
  )
207
222
  self._worker_ref = None
223
+ self._progress_tracker_ref = None
208
224
  self._serve_count = 0
209
225
  self._metrics_labels = {
210
226
  "type": self._model_description.get("model_type", "unknown"),
@@ -216,6 +232,7 @@ class ModelActor(xo.StatelessActor):
216
232
  self._loop: Optional[asyncio.AbstractEventLoop] = None
217
233
 
218
234
  self._scheduler_ref = None
235
+ self._text_to_image_scheduler_ref = None
219
236
 
220
237
  async def __post_create__(self):
221
238
  self._loop = asyncio.get_running_loop()
@@ -229,6 +246,15 @@ class ModelActor(xo.StatelessActor):
229
246
  uid=SchedulerActor.gen_uid(self.model_uid(), self._model.rep_id),
230
247
  )
231
248
 
249
+ if self.allow_batching_for_text_to_image():
250
+ from ..model.image.scheduler.flux import FluxBatchSchedulerActor
251
+
252
+ self._text_to_image_scheduler_ref = await xo.create_actor(
253
+ FluxBatchSchedulerActor,
254
+ address=self.address,
255
+ uid=FluxBatchSchedulerActor.gen_uid(self.model_uid()),
256
+ )
257
+
232
258
  async def _record_completion_metrics(
233
259
  self, duration, completion_tokens, prompt_tokens
234
260
  ):
@@ -275,6 +301,28 @@ class ModelActor(xo.StatelessActor):
275
301
  )
276
302
  return self._worker_ref
277
303
 
304
+ async def _get_progress_tracker_ref(
305
+ self,
306
+ ) -> xo.ActorRefType["ProgressTrackerActor"]:
307
+ from .progress_tracker import ProgressTrackerActor
308
+
309
+ if self._progress_tracker_ref is None:
310
+ self._progress_tracker_ref = await xo.actor_ref(
311
+ address=self._supervisor_address, uid=ProgressTrackerActor.default_uid()
312
+ )
313
+ return self._progress_tracker_ref
314
+
315
+ async def _get_progressor(self, request_id: str):
316
+ from .progress_tracker import Progressor
317
+
318
+ progressor = Progressor(
319
+ request_id,
320
+ await self._get_progress_tracker_ref(),
321
+ asyncio.get_running_loop(),
322
+ )
323
+ await progressor.start()
324
+ return progressor
325
+
278
326
  def is_vllm_backend(self) -> bool:
279
327
  from ..model.llm.vllm.core import VLLMModel
280
328
 
@@ -285,10 +333,8 @@ class ModelActor(xo.StatelessActor):
285
333
 
286
334
  model_ability = self._model_description.get("model_ability", [])
287
335
 
288
- condition = XINFERENCE_TRANSFORMERS_ENABLE_BATCHING and isinstance(
289
- self._model, PytorchModel
290
- )
291
- if condition and "vision" in model_ability:
336
+ condition = isinstance(self._model, PytorchModel)
337
+ if condition and ("vision" in model_ability or "audio" in model_ability):
292
338
  if (
293
339
  self._model.model_family.model_name
294
340
  in XINFERENCE_BATCHING_ALLOWED_VISION_MODELS
@@ -305,6 +351,26 @@ class ModelActor(xo.StatelessActor):
305
351
  return False
306
352
  return condition
307
353
 
354
+ def allow_batching_for_text_to_image(self) -> bool:
355
+ from ..model.image.stable_diffusion.core import DiffusionModel
356
+
357
+ condition = XINFERENCE_TEXT_TO_IMAGE_BATCHING_SIZE is not None and isinstance(
358
+ self._model, DiffusionModel
359
+ )
360
+
361
+ if condition:
362
+ model_name = self._model._model_spec.model_name # type: ignore
363
+ if model_name in XINFERENCE_TEXT_TO_IMAGE_BATCHING_ALLOWED_MODELS:
364
+ return True
365
+ else:
366
+ logger.warning(
367
+ f"Currently for image models with text_to_image ability, "
368
+ f"xinference only supports {', '.join(XINFERENCE_TEXT_TO_IMAGE_BATCHING_ALLOWED_MODELS)} for batching. "
369
+ f"Your model {model_name} is disqualified."
370
+ )
371
+ return False
372
+ return condition
373
+
308
374
  async def load(self):
309
375
  self._model.load()
310
376
  if self.allow_batching():
@@ -312,6 +378,11 @@ class ModelActor(xo.StatelessActor):
312
378
  logger.debug(
313
379
  f"Batching enabled for model: {self.model_uid()}, max_num_seqs: {self._model.get_max_num_seqs()}"
314
380
  )
381
+ if self.allow_batching_for_text_to_image():
382
+ await self._text_to_image_scheduler_ref.set_model(self._model)
383
+ logger.debug(
384
+ f"Batching enabled for model: {self.model_uid()}, max_num_images: {self._model.get_max_num_images_for_batching()}"
385
+ )
315
386
 
316
387
  def model_uid(self):
317
388
  return (
@@ -591,12 +662,16 @@ class ModelActor(xo.StatelessActor):
591
662
  )
592
663
 
593
664
  async def abort_request(self, request_id: str) -> str:
594
- from .scheduler import AbortRequestMessage
665
+ from .utils import AbortRequestMessage
595
666
 
596
667
  if self.allow_batching():
597
668
  if self._scheduler_ref is None:
598
669
  return AbortRequestMessage.NOT_FOUND.name
599
670
  return await self._scheduler_ref.abort_request(request_id)
671
+ elif self.allow_batching_for_text_to_image():
672
+ if self._text_to_image_scheduler_ref is None:
673
+ return AbortRequestMessage.NOT_FOUND.name
674
+ return await self._text_to_image_scheduler_ref.abort_request(request_id)
600
675
  return AbortRequestMessage.NO_OP.name
601
676
 
602
677
  @request_limit
@@ -721,6 +796,22 @@ class ModelActor(xo.StatelessActor):
721
796
  f"Model {self._model.model_spec} is not for creating speech."
722
797
  )
723
798
 
799
+ async def handle_image_batching_request(self, unique_id, *args, **kwargs):
800
+ size = args[2]
801
+ if XINFERENCE_TEXT_TO_IMAGE_BATCHING_SIZE != size:
802
+ raise RuntimeError(
803
+ f"The image size: {size} of text_to_image for batching "
804
+ f"must be the same as the environment variable: {XINFERENCE_TEXT_TO_IMAGE_BATCHING_SIZE} you set."
805
+ )
806
+ assert self._loop is not None
807
+ future = ConcurrentFuture()
808
+ await self._text_to_image_scheduler_ref.add_request(
809
+ unique_id, future, *args, **kwargs
810
+ )
811
+ fut = asyncio.wrap_future(future, loop=self._loop)
812
+ result = await fut
813
+ return await asyncio.to_thread(json_dumps, result)
814
+
724
815
  @request_limit
725
816
  @log_async(logger=logger)
726
817
  async def text_to_image(
@@ -732,17 +823,26 @@ class ModelActor(xo.StatelessActor):
732
823
  *args,
733
824
  **kwargs,
734
825
  ):
735
- kwargs.pop("request_id", None)
736
826
  if hasattr(self._model, "text_to_image"):
737
- return await self._call_wrapper_json(
738
- self._model.text_to_image,
739
- prompt,
740
- n,
741
- size,
742
- response_format,
743
- *args,
744
- **kwargs,
745
- )
827
+ if self.allow_batching_for_text_to_image():
828
+ unique_id = kwargs.pop("request_id", None)
829
+ return await self.handle_image_batching_request(
830
+ unique_id, prompt, n, size, response_format, *args, **kwargs
831
+ )
832
+ else:
833
+ progressor = kwargs["progressor"] = await self._get_progressor(
834
+ kwargs.pop("request_id", None)
835
+ )
836
+ with progressor:
837
+ return await self._call_wrapper_json(
838
+ self._model.text_to_image,
839
+ prompt,
840
+ n,
841
+ size,
842
+ response_format,
843
+ *args,
844
+ **kwargs,
845
+ )
746
846
  raise AttributeError(
747
847
  f"Model {self._model.model_spec} is not for creating image."
748
848
  )
@@ -753,12 +853,15 @@ class ModelActor(xo.StatelessActor):
753
853
  self,
754
854
  **kwargs,
755
855
  ):
756
- kwargs.pop("request_id", None)
757
856
  if hasattr(self._model, "txt2img"):
758
- return await self._call_wrapper_json(
759
- self._model.txt2img,
760
- **kwargs,
857
+ progressor = kwargs["progressor"] = await self._get_progressor(
858
+ kwargs.pop("request_id", None)
761
859
  )
860
+ with progressor:
861
+ return await self._call_wrapper_json(
862
+ self._model.txt2img,
863
+ **kwargs,
864
+ )
762
865
  raise AttributeError(f"Model {self._model.model_spec} is not for txt2img.")
763
866
 
764
867
  @log_async(
@@ -776,19 +879,22 @@ class ModelActor(xo.StatelessActor):
776
879
  *args,
777
880
  **kwargs,
778
881
  ):
779
- kwargs.pop("request_id", None)
780
882
  kwargs["negative_prompt"] = negative_prompt
781
883
  if hasattr(self._model, "image_to_image"):
782
- return await self._call_wrapper_json(
783
- self._model.image_to_image,
784
- image,
785
- prompt,
786
- n,
787
- size,
788
- response_format,
789
- *args,
790
- **kwargs,
884
+ progressor = kwargs["progressor"] = await self._get_progressor(
885
+ kwargs.pop("request_id", None)
791
886
  )
887
+ with progressor:
888
+ return await self._call_wrapper_json(
889
+ self._model.image_to_image,
890
+ image,
891
+ prompt,
892
+ n,
893
+ size,
894
+ response_format,
895
+ *args,
896
+ **kwargs,
897
+ )
792
898
  raise AttributeError(
793
899
  f"Model {self._model.model_spec} is not for creating image."
794
900
  )
@@ -799,12 +905,15 @@ class ModelActor(xo.StatelessActor):
799
905
  self,
800
906
  **kwargs,
801
907
  ):
802
- kwargs.pop("request_id", None)
803
908
  if hasattr(self._model, "img2img"):
804
- return await self._call_wrapper_json(
805
- self._model.img2img,
806
- **kwargs,
909
+ progressor = kwargs["progressor"] = await self._get_progressor(
910
+ kwargs.pop("request_id", None)
807
911
  )
912
+ with progressor:
913
+ return await self._call_wrapper_json(
914
+ self._model.img2img,
915
+ **kwargs,
916
+ )
808
917
  raise AttributeError(f"Model {self._model.model_spec} is not for img2img.")
809
918
 
810
919
  @log_async(
@@ -823,20 +932,23 @@ class ModelActor(xo.StatelessActor):
823
932
  *args,
824
933
  **kwargs,
825
934
  ):
826
- kwargs.pop("request_id", None)
935
+ kwargs["negative_prompt"] = negative_prompt
827
936
  if hasattr(self._model, "inpainting"):
828
- return await self._call_wrapper_json(
829
- self._model.inpainting,
830
- image,
831
- mask_image,
832
- prompt,
833
- negative_prompt,
834
- n,
835
- size,
836
- response_format,
837
- *args,
838
- **kwargs,
937
+ progressor = kwargs["progressor"] = await self._get_progressor(
938
+ kwargs.pop("request_id", None)
839
939
  )
940
+ with progressor:
941
+ return await self._call_wrapper_json(
942
+ self._model.inpainting,
943
+ image,
944
+ mask_image,
945
+ prompt,
946
+ n,
947
+ size,
948
+ response_format,
949
+ *args,
950
+ **kwargs,
951
+ )
840
952
  raise AttributeError(
841
953
  f"Model {self._model.model_spec} is not for creating image."
842
954
  )