xinference 0.15.4__py3-none-any.whl → 0.16.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (67) hide show
  1. xinference/__init__.py +0 -4
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +48 -0
  4. xinference/client/restful/restful_client.py +19 -0
  5. xinference/constants.py +4 -4
  6. xinference/core/chat_interface.py +5 -1
  7. xinference/core/image_interface.py +5 -1
  8. xinference/core/model.py +195 -34
  9. xinference/core/scheduler.py +10 -7
  10. xinference/core/utils.py +9 -0
  11. xinference/model/__init__.py +4 -0
  12. xinference/model/audio/chattts.py +25 -14
  13. xinference/model/audio/model_spec.json +1 -1
  14. xinference/model/audio/model_spec_modelscope.json +1 -1
  15. xinference/model/embedding/model_spec.json +1 -1
  16. xinference/model/image/core.py +59 -4
  17. xinference/model/image/model_spec.json +24 -3
  18. xinference/model/image/model_spec_modelscope.json +25 -3
  19. xinference/model/image/ocr/__init__.py +13 -0
  20. xinference/model/image/ocr/got_ocr2.py +76 -0
  21. xinference/model/image/scheduler/__init__.py +13 -0
  22. xinference/model/image/scheduler/flux.py +533 -0
  23. xinference/model/image/stable_diffusion/core.py +8 -34
  24. xinference/model/image/stable_diffusion/mlx.py +221 -0
  25. xinference/model/image/utils.py +39 -3
  26. xinference/model/llm/__init__.py +2 -0
  27. xinference/model/llm/llm_family.json +178 -1
  28. xinference/model/llm/llm_family_modelscope.json +119 -0
  29. xinference/model/llm/transformers/chatglm.py +104 -0
  30. xinference/model/llm/transformers/core.py +37 -111
  31. xinference/model/llm/transformers/deepseek_v2.py +0 -226
  32. xinference/model/llm/transformers/internlm2.py +3 -95
  33. xinference/model/llm/transformers/opt.py +68 -0
  34. xinference/model/llm/transformers/utils.py +4 -284
  35. xinference/model/llm/utils.py +2 -2
  36. xinference/model/llm/vllm/core.py +16 -1
  37. xinference/thirdparty/mlx/__init__.py +13 -0
  38. xinference/thirdparty/mlx/flux/__init__.py +15 -0
  39. xinference/thirdparty/mlx/flux/autoencoder.py +357 -0
  40. xinference/thirdparty/mlx/flux/clip.py +154 -0
  41. xinference/thirdparty/mlx/flux/datasets.py +75 -0
  42. xinference/thirdparty/mlx/flux/flux.py +247 -0
  43. xinference/thirdparty/mlx/flux/layers.py +302 -0
  44. xinference/thirdparty/mlx/flux/lora.py +76 -0
  45. xinference/thirdparty/mlx/flux/model.py +134 -0
  46. xinference/thirdparty/mlx/flux/sampler.py +56 -0
  47. xinference/thirdparty/mlx/flux/t5.py +244 -0
  48. xinference/thirdparty/mlx/flux/tokenizers.py +185 -0
  49. xinference/thirdparty/mlx/flux/trainer.py +98 -0
  50. xinference/thirdparty/mlx/flux/utils.py +179 -0
  51. xinference/utils.py +2 -3
  52. xinference/web/ui/build/asset-manifest.json +3 -3
  53. xinference/web/ui/build/index.html +1 -1
  54. xinference/web/ui/build/static/js/{main.e51a356d.js → main.b76aeeb7.js} +3 -3
  55. xinference/web/ui/build/static/js/main.b76aeeb7.js.map +1 -0
  56. xinference/web/ui/node_modules/.cache/babel-loader/32ea2c04cf0bba2761b4883d2c40cc259952c94d2d6bb774e510963ca37aac0a.json +1 -0
  57. xinference/web/ui/node_modules/.cache/babel-loader/331312668fa8bd3d7401818f4a25fa98135d7f61371cd6bfff78b18cf4fbdd92.json +1 -0
  58. {xinference-0.15.4.dist-info → xinference-0.16.1.dist-info}/METADATA +49 -10
  59. {xinference-0.15.4.dist-info → xinference-0.16.1.dist-info}/RECORD +64 -44
  60. xinference/web/ui/build/static/js/main.e51a356d.js.map +0 -1
  61. xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +0 -1
  62. xinference/web/ui/node_modules/.cache/babel-loader/4385c1095eefbff0a8ec3b2964ba6e5a66a05ab31be721483ca2f43e2a91f6ff.json +0 -1
  63. /xinference/web/ui/build/static/js/{main.e51a356d.js.LICENSE.txt → main.b76aeeb7.js.LICENSE.txt} +0 -0
  64. {xinference-0.15.4.dist-info → xinference-0.16.1.dist-info}/LICENSE +0 -0
  65. {xinference-0.15.4.dist-info → xinference-0.16.1.dist-info}/WHEEL +0 -0
  66. {xinference-0.15.4.dist-info → xinference-0.16.1.dist-info}/entry_points.txt +0 -0
  67. {xinference-0.15.4.dist-info → xinference-0.16.1.dist-info}/top_level.txt +0 -0
xinference/__init__.py CHANGED
@@ -26,13 +26,9 @@ except:
26
26
  def _install():
27
27
  from xoscar.backends.router import Router
28
28
 
29
- from .model import _install as install_model
30
-
31
29
  default_router = Router.get_instance_or_empty()
32
30
  Router.set_instance(default_router)
33
31
 
34
- install_model()
35
-
36
32
 
37
33
  _install()
38
34
  del _install
xinference/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2024-10-12T18:28:41+0800",
11
+ "date": "2024-10-25T12:51:06+0800",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "c0be11504c70f6c392cbdb67c86cf12153353f70",
15
- "version": "0.15.4"
14
+ "full-revisionid": "d4cd7b15104c16838e3c562cf2d33337e3d38897",
15
+ "version": "0.16.1"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -567,6 +567,16 @@ class RESTfulAPI:
567
567
  else None
568
568
  ),
569
569
  )
570
+ self._router.add_api_route(
571
+ "/v1/images/ocr",
572
+ self.create_ocr,
573
+ methods=["POST"],
574
+ dependencies=(
575
+ [Security(self._auth_service, scopes=["models:read"])]
576
+ if self.is_authenticated()
577
+ else None
578
+ ),
579
+ )
570
580
  # SD WebUI API
571
581
  self._router.add_api_route(
572
582
  "/sdapi/v1/options",
@@ -1754,6 +1764,44 @@ class RESTfulAPI:
1754
1764
  await self._report_error_event(model_uid, str(e))
1755
1765
  raise HTTPException(status_code=500, detail=str(e))
1756
1766
 
1767
+ async def create_ocr(
1768
+ self,
1769
+ model: str = Form(...),
1770
+ image: UploadFile = File(media_type="application/octet-stream"),
1771
+ kwargs: Optional[str] = Form(None),
1772
+ ) -> Response:
1773
+ model_uid = model
1774
+ try:
1775
+ model_ref = await (await self._get_supervisor_ref()).get_model(model_uid)
1776
+ except ValueError as ve:
1777
+ logger.error(str(ve), exc_info=True)
1778
+ await self._report_error_event(model_uid, str(ve))
1779
+ raise HTTPException(status_code=400, detail=str(ve))
1780
+ except Exception as e:
1781
+ logger.error(e, exc_info=True)
1782
+ await self._report_error_event(model_uid, str(e))
1783
+ raise HTTPException(status_code=500, detail=str(e))
1784
+
1785
+ try:
1786
+ if kwargs is not None:
1787
+ parsed_kwargs = json.loads(kwargs)
1788
+ else:
1789
+ parsed_kwargs = {}
1790
+ im = Image.open(image.file)
1791
+ text = await model_ref.ocr(
1792
+ image=im,
1793
+ **parsed_kwargs,
1794
+ )
1795
+ return Response(content=text, media_type="text/plain")
1796
+ except RuntimeError as re:
1797
+ logger.error(re, exc_info=True)
1798
+ await self._report_error_event(model_uid, str(re))
1799
+ raise HTTPException(status_code=400, detail=str(re))
1800
+ except Exception as e:
1801
+ logger.error(e, exc_info=True)
1802
+ await self._report_error_event(model_uid, str(e))
1803
+ raise HTTPException(status_code=500, detail=str(e))
1804
+
1757
1805
  async def create_flexible_infer(self, request: Request) -> Response:
1758
1806
  payload = await request.json()
1759
1807
 
@@ -369,6 +369,25 @@ class RESTfulImageModelHandle(RESTfulModelHandle):
369
369
  response_data = response.json()
370
370
  return response_data
371
371
 
372
+ def ocr(self, image: Union[str, bytes], **kwargs):
373
+ url = f"{self._base_url}/v1/images/ocr"
374
+ params = {
375
+ "model": self._model_uid,
376
+ "kwargs": json.dumps(kwargs),
377
+ }
378
+ files: List[Any] = []
379
+ for key, value in params.items():
380
+ files.append((key, (None, value)))
381
+ files.append(("image", ("image", image, "application/octet-stream")))
382
+ response = requests.post(url, files=files, headers=self.auth_headers)
383
+ if response.status_code != 200:
384
+ raise RuntimeError(
385
+ f"Failed to ocr the images, detail: {_get_error_string(response)}"
386
+ )
387
+
388
+ response_data = response.json()
389
+ return response_data
390
+
372
391
 
373
392
  class RESTfulVideoModelHandle(RESTfulModelHandle):
374
393
  def text_to_video(
xinference/constants.py CHANGED
@@ -27,8 +27,8 @@ XINFERENCE_ENV_HEALTH_CHECK_INTERVAL = "XINFERENCE_HEALTH_CHECK_INTERVAL"
27
27
  XINFERENCE_ENV_HEALTH_CHECK_TIMEOUT = "XINFERENCE_HEALTH_CHECK_TIMEOUT"
28
28
  XINFERENCE_ENV_DISABLE_HEALTH_CHECK = "XINFERENCE_DISABLE_HEALTH_CHECK"
29
29
  XINFERENCE_ENV_DISABLE_METRICS = "XINFERENCE_DISABLE_METRICS"
30
- XINFERENCE_ENV_TRANSFORMERS_ENABLE_BATCHING = "XINFERENCE_TRANSFORMERS_ENABLE_BATCHING"
31
30
  XINFERENCE_ENV_DOWNLOAD_MAX_ATTEMPTS = "XINFERENCE_DOWNLOAD_MAX_ATTEMPTS"
31
+ XINFERENCE_ENV_TEXT_TO_IMAGE_BATCHING_SIZE = "XINFERENCE_TEXT_TO_IMAGE_BATCHING_SIZE"
32
32
 
33
33
 
34
34
  def get_xinference_home() -> str:
@@ -80,9 +80,9 @@ XINFERENCE_DISABLE_HEALTH_CHECK = bool(
80
80
  XINFERENCE_DISABLE_METRICS = bool(
81
81
  int(os.environ.get(XINFERENCE_ENV_DISABLE_METRICS, 0))
82
82
  )
83
- XINFERENCE_TRANSFORMERS_ENABLE_BATCHING = bool(
84
- int(os.environ.get(XINFERENCE_ENV_TRANSFORMERS_ENABLE_BATCHING, 0))
85
- )
86
83
  XINFERENCE_DOWNLOAD_MAX_ATTEMPTS = int(
87
84
  os.environ.get(XINFERENCE_ENV_DOWNLOAD_MAX_ATTEMPTS, 3)
88
85
  )
86
+ XINFERENCE_TEXT_TO_IMAGE_BATCHING_SIZE = os.environ.get(
87
+ XINFERENCE_ENV_TEXT_TO_IMAGE_BATCHING_SIZE, None
88
+ )
@@ -74,7 +74,11 @@ class GradioInterface:
74
74
  # Gradio initiates the queue during a startup event, but since the app has already been
75
75
  # started, that event will not run, so manually invoke the startup events.
76
76
  # See: https://github.com/gradio-app/gradio/issues/5228
77
- interface.startup_events()
77
+ try:
78
+ interface.run_startup_events()
79
+ except AttributeError:
80
+ # compatibility
81
+ interface.startup_events()
78
82
  favicon_path = os.path.join(
79
83
  os.path.dirname(os.path.abspath(__file__)),
80
84
  os.path.pardir,
@@ -63,7 +63,11 @@ class ImageInterface:
63
63
  # Gradio initiates the queue during a startup event, but since the app has already been
64
64
  # started, that event will not run, so manually invoke the startup events.
65
65
  # See: https://github.com/gradio-app/gradio/issues/5228
66
- interface.startup_events()
66
+ try:
67
+ interface.run_startup_events()
68
+ except AttributeError:
69
+ # compatibility
70
+ interface.startup_events()
67
71
  favicon_path = os.path.join(
68
72
  os.path.dirname(os.path.abspath(__file__)),
69
73
  os.path.pardir,
xinference/core/model.py CHANGED
@@ -17,10 +17,10 @@ import functools
17
17
  import inspect
18
18
  import json
19
19
  import os
20
+ import queue
20
21
  import time
21
22
  import types
22
23
  import uuid
23
- import weakref
24
24
  from asyncio.queues import Queue
25
25
  from asyncio.tasks import wait_for
26
26
  from concurrent.futures import Future as ConcurrentFuture
@@ -32,7 +32,6 @@ from typing import (
32
32
  Callable,
33
33
  Dict,
34
34
  Generator,
35
- Iterator,
36
35
  List,
37
36
  Optional,
38
37
  Union,
@@ -41,7 +40,7 @@ from typing import (
41
40
  import sse_starlette.sse
42
41
  import xoscar as xo
43
42
 
44
- from ..constants import XINFERENCE_TRANSFORMERS_ENABLE_BATCHING
43
+ from ..constants import XINFERENCE_TEXT_TO_IMAGE_BATCHING_SIZE
45
44
 
46
45
  if TYPE_CHECKING:
47
46
  from .progress_tracker import ProgressTrackerActor
@@ -74,6 +73,8 @@ XINFERENCE_BATCHING_ALLOWED_VISION_MODELS = [
74
73
  "MiniCPM-V-2.6",
75
74
  ]
76
75
 
76
+ XINFERENCE_TEXT_TO_IMAGE_BATCHING_ALLOWED_MODELS = ["FLUX.1-dev", "FLUX.1-schnell"]
77
+
77
78
 
78
79
  def request_limit(fn):
79
80
  """
@@ -153,6 +154,16 @@ class ModelActor(xo.StatelessActor):
153
154
  f"Destroy scheduler actor failed, address: {self.address}, error: {e}"
154
155
  )
155
156
 
157
+ if self.allow_batching_for_text_to_image():
158
+ try:
159
+ assert self._text_to_image_scheduler_ref is not None
160
+ await xo.destroy_actor(self._text_to_image_scheduler_ref)
161
+ del self._text_to_image_scheduler_ref
162
+ except Exception as e:
163
+ logger.debug(
164
+ f"Destroy text_to_image scheduler actor failed, address: {self.address}, error: {e}"
165
+ )
166
+
156
167
  if hasattr(self._model, "stop") and callable(self._model.stop):
157
168
  self._model.stop()
158
169
 
@@ -197,9 +208,8 @@ class ModelActor(xo.StatelessActor):
197
208
  model_description.to_dict() if model_description else {}
198
209
  )
199
210
  self._request_limits = request_limits
200
-
201
- self._generators: Dict[str, Union[Iterator, AsyncGenerator]] = {}
202
- self._current_generator = lambda: None
211
+ self._pending_requests: asyncio.Queue = asyncio.Queue()
212
+ self._handle_pending_requests_task = None
203
213
  self._lock = (
204
214
  None
205
215
  if isinstance(
@@ -220,10 +230,15 @@ class ModelActor(xo.StatelessActor):
220
230
  self._loop: Optional[asyncio.AbstractEventLoop] = None
221
231
 
222
232
  self._scheduler_ref = None
233
+ self._text_to_image_scheduler_ref = None
223
234
 
224
235
  async def __post_create__(self):
225
236
  self._loop = asyncio.get_running_loop()
226
237
 
238
+ self._handle_pending_requests_task = asyncio.create_task(
239
+ self._handle_pending_requests()
240
+ )
241
+
227
242
  if self.allow_batching():
228
243
  from .scheduler import SchedulerActor
229
244
 
@@ -233,6 +248,15 @@ class ModelActor(xo.StatelessActor):
233
248
  uid=SchedulerActor.gen_uid(self.model_uid(), self._model.rep_id),
234
249
  )
235
250
 
251
+ if self.allow_batching_for_text_to_image():
252
+ from ..model.image.scheduler.flux import FluxBatchSchedulerActor
253
+
254
+ self._text_to_image_scheduler_ref = await xo.create_actor(
255
+ FluxBatchSchedulerActor,
256
+ address=self.address,
257
+ uid=FluxBatchSchedulerActor.gen_uid(self.model_uid()),
258
+ )
259
+
236
260
  async def _record_completion_metrics(
237
261
  self, duration, completion_tokens, prompt_tokens
238
262
  ):
@@ -311,10 +335,8 @@ class ModelActor(xo.StatelessActor):
311
335
 
312
336
  model_ability = self._model_description.get("model_ability", [])
313
337
 
314
- condition = XINFERENCE_TRANSFORMERS_ENABLE_BATCHING and isinstance(
315
- self._model, PytorchModel
316
- )
317
- if condition and "vision" in model_ability:
338
+ condition = isinstance(self._model, PytorchModel)
339
+ if condition and ("vision" in model_ability or "audio" in model_ability):
318
340
  if (
319
341
  self._model.model_family.model_name
320
342
  in XINFERENCE_BATCHING_ALLOWED_VISION_MODELS
@@ -331,6 +353,26 @@ class ModelActor(xo.StatelessActor):
331
353
  return False
332
354
  return condition
333
355
 
356
+ def allow_batching_for_text_to_image(self) -> bool:
357
+ from ..model.image.stable_diffusion.core import DiffusionModel
358
+
359
+ condition = XINFERENCE_TEXT_TO_IMAGE_BATCHING_SIZE is not None and isinstance(
360
+ self._model, DiffusionModel
361
+ )
362
+
363
+ if condition:
364
+ model_name = self._model._model_spec.model_name # type: ignore
365
+ if model_name in XINFERENCE_TEXT_TO_IMAGE_BATCHING_ALLOWED_MODELS:
366
+ return True
367
+ else:
368
+ logger.warning(
369
+ f"Currently for image models with text_to_image ability, "
370
+ f"xinference only supports {', '.join(XINFERENCE_TEXT_TO_IMAGE_BATCHING_ALLOWED_MODELS)} for batching. "
371
+ f"Your model {model_name} is disqualified."
372
+ )
373
+ return False
374
+ return condition
375
+
334
376
  async def load(self):
335
377
  self._model.load()
336
378
  if self.allow_batching():
@@ -338,6 +380,11 @@ class ModelActor(xo.StatelessActor):
338
380
  logger.debug(
339
381
  f"Batching enabled for model: {self.model_uid()}, max_num_seqs: {self._model.get_max_num_seqs()}"
340
382
  )
383
+ if self.allow_batching_for_text_to_image():
384
+ await self._text_to_image_scheduler_ref.set_model(self._model)
385
+ logger.debug(
386
+ f"Batching enabled for model: {self.model_uid()}, max_num_images: {self._model.get_max_num_images_for_batching()}"
387
+ )
341
388
 
342
389
  def model_uid(self):
343
390
  return (
@@ -429,6 +476,43 @@ class ModelActor(xo.StatelessActor):
429
476
  )
430
477
  await asyncio.gather(*coros)
431
478
 
479
+ async def _handle_pending_requests(self):
480
+ logger.info("Start requests handler.")
481
+ while True:
482
+ gen, stream_out, stop = await self._pending_requests.get()
483
+
484
+ async def _async_wrapper(_gen):
485
+ try:
486
+ # anext is only available for Python >= 3.10
487
+ return await _gen.__anext__() # noqa: F821
488
+ except StopAsyncIteration:
489
+ return stop
490
+
491
+ def _wrapper(_gen):
492
+ # Avoid issue: https://github.com/python/cpython/issues/112182
493
+ try:
494
+ return next(_gen)
495
+ except StopIteration:
496
+ return stop
497
+
498
+ while True:
499
+ try:
500
+ if inspect.isgenerator(gen):
501
+ r = await asyncio.to_thread(_wrapper, gen)
502
+ elif inspect.isasyncgen(gen):
503
+ r = await _async_wrapper(gen)
504
+ else:
505
+ raise Exception(
506
+ f"The generator {gen} should be a generator or an async generator, "
507
+ f"but a {type(gen)} is got."
508
+ )
509
+ stream_out.put_nowait(r)
510
+ if r is not stop:
511
+ continue
512
+ except Exception:
513
+ logger.exception("stream encountered an error.")
514
+ break
515
+
432
516
  async def _call_wrapper_json(self, fn: Callable, *args, **kwargs):
433
517
  return await self._call_wrapper("json", fn, *args, **kwargs)
434
518
 
@@ -442,6 +526,13 @@ class ModelActor(xo.StatelessActor):
442
526
  ret = await fn(*args, **kwargs)
443
527
  else:
444
528
  ret = await asyncio.to_thread(fn, *args, **kwargs)
529
+
530
+ if inspect.isgenerator(ret):
531
+ gen = self._to_generator(output_type, ret)
532
+ return gen
533
+ if inspect.isasyncgen(ret):
534
+ gen = self._to_async_gen(output_type, ret)
535
+ return gen
445
536
  else:
446
537
  async with self._lock:
447
538
  if inspect.iscoroutinefunction(fn):
@@ -449,17 +540,40 @@ class ModelActor(xo.StatelessActor):
449
540
  else:
450
541
  ret = await asyncio.to_thread(fn, *args, **kwargs)
451
542
 
452
- if self._lock is not None and self._current_generator():
453
- raise Exception("Parallel generation is not supported by llama-cpp-python.")
543
+ stream_out: Union[queue.Queue, asyncio.Queue]
544
+
545
+ if inspect.isgenerator(ret):
546
+ gen = self._to_generator(output_type, ret)
547
+ stream_out = queue.Queue()
548
+ stop = object()
549
+ self._pending_requests.put_nowait((gen, stream_out, stop))
550
+
551
+ def _stream_out_generator():
552
+ while True:
553
+ o = stream_out.get()
554
+ if o is stop:
555
+ break
556
+ else:
557
+ yield o
558
+
559
+ return _stream_out_generator()
560
+
561
+ if inspect.isasyncgen(ret):
562
+ gen = self._to_async_gen(output_type, ret)
563
+ stream_out = asyncio.Queue()
564
+ stop = object()
565
+ self._pending_requests.put_nowait((gen, stream_out, stop))
566
+
567
+ async def _stream_out_async_gen():
568
+ while True:
569
+ o = await stream_out.get()
570
+ if o is stop:
571
+ break
572
+ else:
573
+ yield o
574
+
575
+ return _stream_out_async_gen()
454
576
 
455
- if inspect.isgenerator(ret):
456
- gen = self._to_generator(output_type, ret)
457
- self._current_generator = weakref.ref(gen)
458
- return gen
459
- if inspect.isasyncgen(ret):
460
- gen = self._to_async_gen(output_type, ret)
461
- self._current_generator = weakref.ref(gen)
462
- return gen
463
577
  if output_type == "json":
464
578
  return await asyncio.to_thread(json_dumps, ret)
465
579
  else:
@@ -547,7 +661,6 @@ class ModelActor(xo.StatelessActor):
547
661
  prompt_or_messages, queue, call_ability, *args, **kwargs
548
662
  )
549
663
  gen = self._to_async_gen("json", ret)
550
- self._current_generator = weakref.ref(gen)
551
664
  return gen
552
665
  else:
553
666
  from .scheduler import XINFERENCE_NON_STREAMING_ABORT_FLAG
@@ -617,12 +730,16 @@ class ModelActor(xo.StatelessActor):
617
730
  )
618
731
 
619
732
  async def abort_request(self, request_id: str) -> str:
620
- from .scheduler import AbortRequestMessage
733
+ from .utils import AbortRequestMessage
621
734
 
622
735
  if self.allow_batching():
623
736
  if self._scheduler_ref is None:
624
737
  return AbortRequestMessage.NOT_FOUND.name
625
738
  return await self._scheduler_ref.abort_request(request_id)
739
+ elif self.allow_batching_for_text_to_image():
740
+ if self._text_to_image_scheduler_ref is None:
741
+ return AbortRequestMessage.NOT_FOUND.name
742
+ return await self._text_to_image_scheduler_ref.abort_request(request_id)
626
743
  return AbortRequestMessage.NO_OP.name
627
744
 
628
745
  @request_limit
@@ -747,6 +864,22 @@ class ModelActor(xo.StatelessActor):
747
864
  f"Model {self._model.model_spec} is not for creating speech."
748
865
  )
749
866
 
867
+ async def handle_image_batching_request(self, unique_id, *args, **kwargs):
868
+ size = args[2]
869
+ if XINFERENCE_TEXT_TO_IMAGE_BATCHING_SIZE != size:
870
+ raise RuntimeError(
871
+ f"The image size: {size} of text_to_image for batching "
872
+ f"must be the same as the environment variable: {XINFERENCE_TEXT_TO_IMAGE_BATCHING_SIZE} you set."
873
+ )
874
+ assert self._loop is not None
875
+ future = ConcurrentFuture()
876
+ await self._text_to_image_scheduler_ref.add_request(
877
+ unique_id, future, *args, **kwargs
878
+ )
879
+ fut = asyncio.wrap_future(future, loop=self._loop)
880
+ result = await fut
881
+ return await asyncio.to_thread(json_dumps, result)
882
+
750
883
  @request_limit
751
884
  @log_async(logger=logger)
752
885
  async def text_to_image(
@@ -759,19 +892,25 @@ class ModelActor(xo.StatelessActor):
759
892
  **kwargs,
760
893
  ):
761
894
  if hasattr(self._model, "text_to_image"):
762
- progressor = kwargs["progressor"] = await self._get_progressor(
763
- kwargs.pop("request_id", None)
764
- )
765
- with progressor:
766
- return await self._call_wrapper_json(
767
- self._model.text_to_image,
768
- prompt,
769
- n,
770
- size,
771
- response_format,
772
- *args,
773
- **kwargs,
895
+ if self.allow_batching_for_text_to_image():
896
+ unique_id = kwargs.pop("request_id", None)
897
+ return await self.handle_image_batching_request(
898
+ unique_id, prompt, n, size, response_format, *args, **kwargs
899
+ )
900
+ else:
901
+ progressor = kwargs["progressor"] = await self._get_progressor(
902
+ kwargs.pop("request_id", None)
774
903
  )
904
+ with progressor:
905
+ return await self._call_wrapper_json(
906
+ self._model.text_to_image,
907
+ prompt,
908
+ n,
909
+ size,
910
+ response_format,
911
+ *args,
912
+ **kwargs,
913
+ )
775
914
  raise AttributeError(
776
915
  f"Model {self._model.model_spec} is not for creating image."
777
916
  )
@@ -882,6 +1021,25 @@ class ModelActor(xo.StatelessActor):
882
1021
  f"Model {self._model.model_spec} is not for creating image."
883
1022
  )
884
1023
 
1024
+ @log_async(
1025
+ logger=logger,
1026
+ ignore_kwargs=["image"],
1027
+ )
1028
+ async def ocr(
1029
+ self,
1030
+ image: "PIL.Image",
1031
+ *args,
1032
+ **kwargs,
1033
+ ):
1034
+ if hasattr(self._model, "ocr"):
1035
+ return await self._call_wrapper_json(
1036
+ self._model.ocr,
1037
+ image,
1038
+ *args,
1039
+ **kwargs,
1040
+ )
1041
+ raise AttributeError(f"Model {self._model.model_spec} is not for ocr.")
1042
+
885
1043
  @request_limit
886
1044
  @log_async(logger=logger, ignore_kwargs=["image"])
887
1045
  async def infer(
@@ -923,3 +1081,6 @@ class ModelActor(xo.StatelessActor):
923
1081
  async def record_metrics(self, name, op, kwargs):
924
1082
  worker_ref = await self._get_worker_ref()
925
1083
  await worker_ref.record_metrics(name, op, kwargs)
1084
+
1085
+ async def get_pending_requests_count(self):
1086
+ return self._pending_requests.qsize()
@@ -17,11 +17,12 @@ import functools
17
17
  import logging
18
18
  import uuid
19
19
  from collections import deque
20
- from enum import Enum
21
20
  from typing import Dict, List, Optional, Set, Tuple, Union
22
21
 
23
22
  import xoscar as xo
24
23
 
24
+ from .utils import AbortRequestMessage
25
+
25
26
  logger = logging.getLogger(__name__)
26
27
 
27
28
  XINFERENCE_STREAMING_DONE_FLAG = "<XINFERENCE_STREAMING_DONE>"
@@ -30,12 +31,6 @@ XINFERENCE_STREAMING_ABORT_FLAG = "<XINFERENCE_STREAMING_ABORT>"
30
31
  XINFERENCE_NON_STREAMING_ABORT_FLAG = "<XINFERENCE_NON_STREAMING_ABORT>"
31
32
 
32
33
 
33
- class AbortRequestMessage(Enum):
34
- NOT_FOUND = 1
35
- DONE = 2
36
- NO_OP = 3
37
-
38
-
39
34
  class InferenceRequest:
40
35
  def __init__(
41
36
  self,
@@ -81,6 +76,10 @@ class InferenceRequest:
81
76
  self.padding_len = 0
82
77
  # Use in stream mode
83
78
  self.last_output_length = 0
79
+ # For tool call
80
+ self.tools = None
81
+ # Currently, for storing tool call streaming results.
82
+ self.outputs: List[str] = [] # type: ignore
84
83
  # inference results,
85
84
  # it is a list type because when stream=True,
86
85
  # self.completion contains all the results in a decode round.
@@ -112,6 +111,10 @@ class InferenceRequest:
112
111
  """
113
112
  return self._prompt
114
113
 
114
+ @prompt.setter
115
+ def prompt(self, value: str):
116
+ self._prompt = value
117
+
115
118
  @property
116
119
  def call_ability(self):
117
120
  return self._call_ability
xinference/core/utils.py CHANGED
@@ -16,6 +16,7 @@ import os
16
16
  import random
17
17
  import string
18
18
  import uuid
19
+ from enum import Enum
19
20
  from typing import Dict, Generator, List, Optional, Tuple, Union
20
21
 
21
22
  import orjson
@@ -27,6 +28,12 @@ from ..constants import XINFERENCE_LOG_ARG_MAX_LENGTH
27
28
  logger = logging.getLogger(__name__)
28
29
 
29
30
 
31
+ class AbortRequestMessage(Enum):
32
+ NOT_FOUND = 1
33
+ DONE = 2
34
+ NO_OP = 3
35
+
36
+
30
37
  def truncate_log_arg(arg) -> str:
31
38
  s = str(arg)
32
39
  if len(s) > XINFERENCE_LOG_ARG_MAX_LENGTH:
@@ -51,6 +58,8 @@ def log_async(
51
58
  request_id_str = kwargs.get("request_id", "")
52
59
  if not request_id_str:
53
60
  request_id_str = uuid.uuid1()
61
+ if func_name == "text_to_image":
62
+ kwargs["request_id"] = request_id_str
54
63
  request_id_str = f"[request {request_id_str}]"
55
64
  formatted_args = ",".join(map(truncate_log_arg, args))
56
65
  formatted_kwargs = ",".join(
@@ -29,3 +29,7 @@ def _install():
29
29
  image_install()
30
30
  rerank_install()
31
31
  video_install()
32
+
33
+
34
+ _install()
35
+ del _install