xinference 0.14.0.post1__py3-none-any.whl → 0.14.1.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (50) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +54 -0
  3. xinference/client/handlers.py +0 -3
  4. xinference/client/restful/restful_client.py +51 -134
  5. xinference/constants.py +1 -0
  6. xinference/core/chat_interface.py +1 -4
  7. xinference/core/image_interface.py +33 -5
  8. xinference/core/model.py +28 -2
  9. xinference/core/supervisor.py +37 -0
  10. xinference/core/worker.py +128 -84
  11. xinference/deploy/cmdline.py +1 -4
  12. xinference/model/audio/core.py +11 -3
  13. xinference/model/audio/funasr.py +114 -0
  14. xinference/model/audio/model_spec.json +20 -0
  15. xinference/model/audio/model_spec_modelscope.json +21 -0
  16. xinference/model/audio/whisper.py +1 -1
  17. xinference/model/core.py +12 -0
  18. xinference/model/image/core.py +3 -4
  19. xinference/model/image/model_spec.json +41 -13
  20. xinference/model/image/model_spec_modelscope.json +30 -10
  21. xinference/model/image/stable_diffusion/core.py +53 -2
  22. xinference/model/llm/__init__.py +2 -0
  23. xinference/model/llm/llm_family.json +83 -1
  24. xinference/model/llm/llm_family_modelscope.json +85 -1
  25. xinference/model/llm/pytorch/core.py +1 -0
  26. xinference/model/llm/pytorch/minicpmv26.py +247 -0
  27. xinference/model/llm/sglang/core.py +72 -34
  28. xinference/model/llm/vllm/core.py +38 -0
  29. xinference/model/video/__init__.py +62 -0
  30. xinference/model/video/core.py +178 -0
  31. xinference/model/video/diffusers.py +180 -0
  32. xinference/model/video/model_spec.json +11 -0
  33. xinference/model/video/model_spec_modelscope.json +12 -0
  34. xinference/types.py +10 -24
  35. xinference/web/ui/build/asset-manifest.json +3 -3
  36. xinference/web/ui/build/index.html +1 -1
  37. xinference/web/ui/build/static/js/{main.ef2a203a.js → main.17ca0398.js} +3 -3
  38. xinference/web/ui/build/static/js/main.17ca0398.js.map +1 -0
  39. xinference/web/ui/node_modules/.cache/babel-loader/71684495d995c7e266eecc6a0ad8ea0284cc785f80abddf863789c57a6134969.json +1 -0
  40. xinference/web/ui/node_modules/.cache/babel-loader/80acd1edf31542ab1dcccfad02cb4b38f3325cff847a781fcce97500cfd6f878.json +1 -0
  41. {xinference-0.14.0.post1.dist-info → xinference-0.14.1.post1.dist-info}/METADATA +21 -15
  42. {xinference-0.14.0.post1.dist-info → xinference-0.14.1.post1.dist-info}/RECORD +47 -40
  43. xinference/web/ui/build/static/js/main.ef2a203a.js.map +0 -1
  44. xinference/web/ui/node_modules/.cache/babel-loader/2c63090c842376cdd368c3ded88a333ef40d94785747651343040a6f7872a223.json +0 -1
  45. xinference/web/ui/node_modules/.cache/babel-loader/70fa8c07463a5fe57c68bf92502910105a8f647371836fe8c3a7408246ca7ba0.json +0 -1
  46. /xinference/web/ui/build/static/js/{main.ef2a203a.js.LICENSE.txt → main.17ca0398.js.LICENSE.txt} +0 -0
  47. {xinference-0.14.0.post1.dist-info → xinference-0.14.1.post1.dist-info}/LICENSE +0 -0
  48. {xinference-0.14.0.post1.dist-info → xinference-0.14.1.post1.dist-info}/WHEEL +0 -0
  49. {xinference-0.14.0.post1.dist-info → xinference-0.14.1.post1.dist-info}/entry_points.txt +0 -0
  50. {xinference-0.14.0.post1.dist-info → xinference-0.14.1.post1.dist-info}/top_level.txt +0 -0
xinference/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2024-08-05T11:58:50+0800",
11
+ "date": "2024-08-12T12:36:32+0800",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "111299317120411f407b015b2b7dbf8402aa35c8",
15
- "version": "0.14.0.post1"
14
+ "full-revisionid": "9afee766a3c5cc53e6035490400a4291b78e72ff",
15
+ "version": "0.14.1.post1"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -65,6 +65,7 @@ from ..types import (
65
65
  CreateCompletion,
66
66
  ImageList,
67
67
  PeftModelConfig,
68
+ VideoList,
68
69
  max_tokens_field,
69
70
  )
70
71
  from .oauth2.auth_service import AuthService
@@ -123,6 +124,14 @@ class TextToImageRequest(BaseModel):
123
124
  user: Optional[str] = None
124
125
 
125
126
 
127
+ class TextToVideoRequest(BaseModel):
128
+ model: str
129
+ prompt: Union[str, List[str]] = Field(description="The input to embed.")
130
+ n: Optional[int] = 1
131
+ kwargs: Optional[str] = None
132
+ user: Optional[str] = None
133
+
134
+
126
135
  class SpeechRequest(BaseModel):
127
136
  model: str
128
137
  input: str
@@ -158,6 +167,7 @@ class BuildGradioImageInterfaceRequest(BaseModel):
158
167
  model_id: str
159
168
  controlnet: Union[None, List[Dict[str, Union[str, None]]]]
160
169
  model_revision: str
170
+ model_ability: List[str]
161
171
 
162
172
 
163
173
  class RESTfulAPI:
@@ -511,6 +521,17 @@ class RESTfulAPI:
511
521
  else None
512
522
  ),
513
523
  )
524
+ self._router.add_api_route(
525
+ "/v1/video/generations",
526
+ self.create_videos,
527
+ methods=["POST"],
528
+ response_model=VideoList,
529
+ dependencies=(
530
+ [Security(self._auth_service, scopes=["models:read"])]
531
+ if self.is_authenticated()
532
+ else None
533
+ ),
534
+ )
514
535
  self._router.add_api_route(
515
536
  "/v1/chat/completions",
516
537
  self.create_chat_completion,
@@ -1031,6 +1052,7 @@ class RESTfulAPI:
1031
1052
  model_revision=body.model_revision,
1032
1053
  controlnet=body.controlnet,
1033
1054
  access_token=access_token,
1055
+ model_ability=body.model_ability,
1034
1056
  ).build()
1035
1057
 
1036
1058
  gr.mount_gradio_app(self._app, interface, f"/{model_uid}")
@@ -1544,6 +1566,38 @@ class RESTfulAPI:
1544
1566
  await self._report_error_event(model_uid, str(e))
1545
1567
  raise HTTPException(status_code=500, detail=str(e))
1546
1568
 
1569
+ async def create_videos(self, request: Request) -> Response:
1570
+ body = TextToVideoRequest.parse_obj(await request.json())
1571
+ model_uid = body.model
1572
+ try:
1573
+ model = await (await self._get_supervisor_ref()).get_model(model_uid)
1574
+ except ValueError as ve:
1575
+ logger.error(str(ve), exc_info=True)
1576
+ await self._report_error_event(model_uid, str(ve))
1577
+ raise HTTPException(status_code=400, detail=str(ve))
1578
+ except Exception as e:
1579
+ logger.error(e, exc_info=True)
1580
+ await self._report_error_event(model_uid, str(e))
1581
+ raise HTTPException(status_code=500, detail=str(e))
1582
+
1583
+ try:
1584
+ kwargs = json.loads(body.kwargs) if body.kwargs else {}
1585
+ video_list = await model.text_to_video(
1586
+ prompt=body.prompt,
1587
+ n=body.n,
1588
+ **kwargs,
1589
+ )
1590
+ return Response(content=video_list, media_type="application/json")
1591
+ except RuntimeError as re:
1592
+ logger.error(re, exc_info=True)
1593
+ await self._report_error_event(model_uid, str(re))
1594
+ self.handle_request_limit_error(re)
1595
+ raise HTTPException(status_code=400, detail=str(re))
1596
+ except Exception as e:
1597
+ logger.error(e, exc_info=True)
1598
+ await self._report_error_event(model_uid, str(e))
1599
+ raise HTTPException(status_code=500, detail=str(e))
1600
+
1547
1601
  async def create_chat_completion(self, request: Request) -> Response:
1548
1602
  raw_body = await request.json()
1549
1603
  body = CreateChatCompletion.parse_obj(raw_body)
@@ -1,9 +1,6 @@
1
1
  from .restful.restful_client import ( # noqa: F401
2
2
  RESTfulAudioModelHandle as AudioModelHandle,
3
3
  )
4
- from .restful.restful_client import ( # noqa: F401
5
- RESTfulChatglmCppChatModelHandle as ChatglmCppChatModelHandle,
6
- )
7
4
  from .restful.restful_client import ( # noqa: F401
8
5
  RESTfulChatModelHandle as ChatModelHandle,
9
6
  )
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
  import json
15
15
  import typing
16
+ import warnings
16
17
  from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union
17
18
 
18
19
  import requests
@@ -24,13 +25,13 @@ if TYPE_CHECKING:
24
25
  ChatCompletion,
25
26
  ChatCompletionChunk,
26
27
  ChatCompletionMessage,
27
- ChatglmCppGenerateConfig,
28
28
  Completion,
29
29
  CompletionChunk,
30
30
  Embedding,
31
31
  ImageList,
32
32
  LlamaCppGenerateConfig,
33
33
  PytorchGenerateConfig,
34
+ VideoList,
34
35
  )
35
36
 
36
37
 
@@ -370,6 +371,44 @@ class RESTfulImageModelHandle(RESTfulModelHandle):
370
371
  return response_data
371
372
 
372
373
 
374
+ class RESTfulVideoModelHandle(RESTfulModelHandle):
375
+ def text_to_video(
376
+ self,
377
+ prompt: str,
378
+ n: int = 1,
379
+ **kwargs,
380
+ ) -> "VideoList":
381
+ """
382
+ Creates a video by the input text.
383
+
384
+ Parameters
385
+ ----------
386
+ prompt: `str` or `List[str]`
387
+ The prompt or prompts to guide video generation. If not defined, you need to pass `prompt_embeds`.
388
+ n: `int`, defaults to 1
389
+ The number of videos to generate per prompt. Must be between 1 and 10.
390
+ Returns
391
+ -------
392
+ VideoList
393
+ A list of video objects.
394
+ """
395
+ url = f"{self._base_url}/v1/video/generations"
396
+ request_body = {
397
+ "model": self._model_uid,
398
+ "prompt": prompt,
399
+ "n": n,
400
+ "kwargs": json.dumps(kwargs),
401
+ }
402
+ response = requests.post(url, json=request_body, headers=self.auth_headers)
403
+ if response.status_code != 200:
404
+ raise RuntimeError(
405
+ f"Failed to create the video, detail: {_get_error_string(response)}"
406
+ )
407
+
408
+ response_data = response.json()
409
+ return response_data
410
+
411
+
373
412
  class RESTfulGenerateModelHandle(RESTfulModelHandle):
374
413
  def generate(
375
414
  self,
@@ -470,81 +509,14 @@ class RESTfulChatModelHandle(RESTfulGenerateModelHandle):
470
509
  Report the failure to generate the chat from the server. Detailed information provided in error message.
471
510
 
472
511
  """
473
-
474
- url = f"{self._base_url}/v1/chat/completions"
475
-
476
- if chat_history is None:
477
- chat_history = []
478
-
479
- chat_history = handle_system_prompts(chat_history, system_prompt)
480
- chat_history.append({"role": "user", "content": prompt}) # type: ignore
481
-
482
- request_body: Dict[str, Any] = {
483
- "model": self._model_uid,
484
- "messages": chat_history,
485
- }
486
- if tools is not None:
487
- request_body["tools"] = tools
488
- if generate_config is not None:
489
- for key, value in generate_config.items():
490
- request_body[key] = value
491
-
492
- stream = bool(generate_config and generate_config.get("stream"))
493
- response = requests.post(
494
- url, json=request_body, stream=stream, headers=self.auth_headers
512
+ warnings.warn(
513
+ "The parameters `prompt`, `system_prompt` and `chat_history` will be deprecated in version v0.15.0, "
514
+ "and will be replaced by the parameter `messages`, "
515
+ "similar to the OpenAI API: https://platform.openai.com/docs/guides/chat-completions/getting-started",
516
+ category=DeprecationWarning,
517
+ stacklevel=2,
495
518
  )
496
519
 
497
- if response.status_code != 200:
498
- raise RuntimeError(
499
- f"Failed to generate chat completion, detail: {_get_error_string(response)}"
500
- )
501
-
502
- if stream:
503
- return streaming_response_iterator(response.iter_lines())
504
-
505
- response_data = response.json()
506
- return response_data
507
-
508
-
509
- class RESTfulChatglmCppChatModelHandle(RESTfulModelHandle):
510
- def chat(
511
- self,
512
- prompt: str,
513
- system_prompt: Optional[str] = None,
514
- chat_history: Optional[List["ChatCompletionMessage"]] = None,
515
- tools: Optional[List[Dict]] = None,
516
- generate_config: Optional["ChatglmCppGenerateConfig"] = None,
517
- ) -> Union["ChatCompletion", Iterator["ChatCompletionChunk"]]:
518
- """
519
- Given a list of messages comprising a conversation, the ChatGLM model will return a response via RESTful APIs.
520
-
521
- Parameters
522
- ----------
523
- prompt: str
524
- The user's input.
525
- system_prompt: Optional[str]
526
- The system context provide to Model prior to any chats.
527
- chat_history: Optional[List["ChatCompletionMessage"]]
528
- A list of messages comprising the conversation so far.
529
- tools: Optional[List[Dict]]
530
- A tool list.
531
- generate_config: Optional["ChatglmCppGenerateConfig"]
532
- Additional configuration for ChatGLM chat generation.
533
-
534
- Returns
535
- -------
536
- Union["ChatCompletion", Iterator["ChatCompletionChunk"]]
537
- Stream is a parameter in generate_config.
538
- When stream is set to True, the function will return Iterator["ChatCompletionChunk"].
539
- When stream is set to False, the function will return "ChatCompletion".
540
-
541
- Raises
542
- ------
543
- RuntimeError
544
- Report the failure to generate the chat from the server. Detailed information provided in error message.
545
-
546
- """
547
-
548
520
  url = f"{self._base_url}/v1/chat/completions"
549
521
 
550
522
  if chat_history is None:
@@ -580,60 +552,6 @@ class RESTfulChatglmCppChatModelHandle(RESTfulModelHandle):
580
552
  return response_data
581
553
 
582
554
 
583
- class RESTfulChatglmCppGenerateModelHandle(RESTfulChatglmCppChatModelHandle):
584
- def generate(
585
- self,
586
- prompt: str,
587
- generate_config: Optional["ChatglmCppGenerateConfig"] = None,
588
- ) -> Union["Completion", Iterator["CompletionChunk"]]:
589
- """
590
- Given a prompt, the ChatGLM model will generate a response via RESTful APIs.
591
-
592
- Parameters
593
- ----------
594
- prompt: str
595
- The user's input.
596
- generate_config: Optional["ChatglmCppGenerateConfig"]
597
- Additional configuration for ChatGLM chat generation.
598
-
599
- Returns
600
- -------
601
- Union["Completion", Iterator["CompletionChunk"]]
602
- Stream is a parameter in generate_config.
603
- When stream is set to True, the function will return Iterator["CompletionChunk"].
604
- When stream is set to False, the function will return "Completion".
605
-
606
- Raises
607
- ------
608
- RuntimeError
609
- Report the failure to generate the content from the server. Detailed information provided in error message.
610
-
611
- """
612
-
613
- url = f"{self._base_url}/v1/completions"
614
-
615
- request_body: Dict[str, Any] = {"model": self._model_uid, "prompt": prompt}
616
- if generate_config is not None:
617
- for key, value in generate_config.items():
618
- request_body[key] = value
619
-
620
- stream = bool(generate_config and generate_config.get("stream"))
621
-
622
- response = requests.post(
623
- url, json=request_body, stream=stream, headers=self.auth_headers
624
- )
625
- if response.status_code != 200:
626
- raise RuntimeError(
627
- f"Failed to generate completion, detail: {response.json()['detail']}"
628
- )
629
-
630
- if stream:
631
- return streaming_response_iterator(response.iter_lines())
632
-
633
- response_data = response.json()
634
- return response_data
635
-
636
-
637
555
  class RESTfulAudioModelHandle(RESTfulModelHandle):
638
556
  def transcriptions(
639
557
  self,
@@ -1090,7 +1008,6 @@ class Client:
1090
1008
  -------
1091
1009
  ModelHandle
1092
1010
  The corresponding Model Handler based on the Model specified in the uid:
1093
- - :obj:`xinference.client.handlers.ChatglmCppChatModelHandle` -> provide handle to ChatGLM Model
1094
1011
  - :obj:`xinference.client.handlers.GenerateModelHandle` -> provide handle to basic generate Model. e.g. Baichuan.
1095
1012
  - :obj:`xinference.client.handlers.ChatModelHandle` -> provide handle to chat Model. e.g. Baichuan-chat.
1096
1013
 
@@ -1111,11 +1028,7 @@ class Client:
1111
1028
  desc = response.json()
1112
1029
 
1113
1030
  if desc["model_type"] == "LLM":
1114
- if desc["model_format"] == "ggmlv3" and "chatglm" in desc["model_name"]:
1115
- return RESTfulChatglmCppGenerateModelHandle(
1116
- model_uid, self.base_url, auth_headers=self._headers
1117
- )
1118
- elif "chat" in desc["model_ability"]:
1031
+ if "chat" in desc["model_ability"]:
1119
1032
  return RESTfulChatModelHandle(
1120
1033
  model_uid, self.base_url, auth_headers=self._headers
1121
1034
  )
@@ -1141,6 +1054,10 @@ class Client:
1141
1054
  return RESTfulAudioModelHandle(
1142
1055
  model_uid, self.base_url, auth_headers=self._headers
1143
1056
  )
1057
+ elif desc["model_type"] == "video":
1058
+ return RESTfulVideoModelHandle(
1059
+ model_uid, self.base_url, auth_headers=self._headers
1060
+ )
1144
1061
  elif desc["model_type"] == "flexible":
1145
1062
  return RESTfulFlexibleModelHandle(
1146
1063
  model_uid, self.base_url, auth_headers=self._headers
xinference/constants.py CHANGED
@@ -47,6 +47,7 @@ XINFERENCE_TENSORIZER_DIR = os.path.join(XINFERENCE_HOME, "tensorizer")
47
47
  XINFERENCE_MODEL_DIR = os.path.join(XINFERENCE_HOME, "model")
48
48
  XINFERENCE_LOG_DIR = os.path.join(XINFERENCE_HOME, "logs")
49
49
  XINFERENCE_IMAGE_DIR = os.path.join(XINFERENCE_HOME, "image")
50
+ XINFERENCE_VIDEO_DIR = os.path.join(XINFERENCE_HOME, "video")
50
51
  XINFERENCE_AUTH_DIR = os.path.join(XINFERENCE_HOME, "auth")
51
52
  XINFERENCE_CSG_ENDPOINT = str(
52
53
  os.environ.get(XINFERENCE_ENV_CSG_ENDPOINT, "https://hub-stg.opencsg.com/")
@@ -24,7 +24,6 @@ from gradio.components import Markdown, Textbox
24
24
  from gradio.layouts import Accordion, Column, Row
25
25
 
26
26
  from ..client.restful.restful_client import (
27
- RESTfulChatglmCppChatModelHandle,
28
27
  RESTfulChatModelHandle,
29
28
  RESTfulGenerateModelHandle,
30
29
  )
@@ -116,9 +115,7 @@ class GradioInterface:
116
115
  client = RESTfulClient(self.endpoint)
117
116
  client._set_token(self._access_token)
118
117
  model = client.get_model(self.model_uid)
119
- assert isinstance(
120
- model, (RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle)
121
- )
118
+ assert isinstance(model, RESTfulChatModelHandle)
122
119
 
123
120
  response_content = ""
124
121
  for chunk in model.chat(
@@ -36,6 +36,7 @@ class ImageInterface:
36
36
  model_name: str,
37
37
  model_id: str,
38
38
  model_revision: str,
39
+ model_ability: List[str],
39
40
  controlnet: Union[None, List[Dict[str, Union[str, None]]]],
40
41
  access_token: Optional[str],
41
42
  ):
@@ -45,6 +46,7 @@ class ImageInterface:
45
46
  self.model_name = model_name
46
47
  self.model_id = model_id
47
48
  self.model_revision = model_revision
49
+ self.model_ability = model_ability
48
50
  self.controlnet = controlnet
49
51
  self.access_token = (
50
52
  access_token.replace("Bearer ", "") if access_token is not None else None
@@ -76,6 +78,7 @@ class ImageInterface:
76
78
  n: int,
77
79
  size_width: int,
78
80
  size_height: int,
81
+ num_inference_steps: int,
79
82
  negative_prompt: Optional[str] = None,
80
83
  ) -> PIL.Image.Image:
81
84
  from ..client import RESTfulClient
@@ -86,11 +89,15 @@ class ImageInterface:
86
89
  assert isinstance(model, RESTfulImageModelHandle)
87
90
 
88
91
  size = f"{int(size_width)}*{int(size_height)}"
92
+ num_inference_steps = (
93
+ None if num_inference_steps == -1 else num_inference_steps # type: ignore
94
+ )
89
95
 
90
96
  response = model.text_to_image(
91
97
  prompt=prompt,
92
98
  n=n,
93
99
  size=size,
100
+ num_inference_steps=num_inference_steps,
94
101
  negative_prompt=negative_prompt,
95
102
  response_format="b64_json",
96
103
  )
@@ -125,13 +132,23 @@ class ImageInterface:
125
132
  n = gr.Number(label="Number of Images", value=1)
126
133
  size_width = gr.Number(label="Width", value=1024)
127
134
  size_height = gr.Number(label="Height", value=1024)
135
+ num_inference_steps = gr.Number(
136
+ label="Inference Step Number", value=-1
137
+ )
128
138
 
129
139
  with gr.Column():
130
140
  image_output = gr.Gallery()
131
141
 
132
142
  generate_button.click(
133
143
  text_generate_image,
134
- inputs=[prompt, n, size_width, size_height, negative_prompt],
144
+ inputs=[
145
+ prompt,
146
+ n,
147
+ size_width,
148
+ size_height,
149
+ num_inference_steps,
150
+ negative_prompt,
151
+ ],
135
152
  outputs=image_output,
136
153
  )
137
154
 
@@ -145,6 +162,7 @@ class ImageInterface:
145
162
  n: int,
146
163
  size_width: int,
147
164
  size_height: int,
165
+ num_inference_steps: int,
148
166
  ) -> PIL.Image.Image:
149
167
  from ..client import RESTfulClient
150
168
 
@@ -157,6 +175,9 @@ class ImageInterface:
157
175
  size = f"{int(size_width)}*{int(size_height)}"
158
176
  else:
159
177
  size = None
178
+ num_inference_steps = (
179
+ None if num_inference_steps == -1 else num_inference_steps # type: ignore
180
+ )
160
181
 
161
182
  bio = io.BytesIO()
162
183
  image.save(bio, format="png")
@@ -168,6 +189,7 @@ class ImageInterface:
168
189
  image=bio.getvalue(),
169
190
  size=size,
170
191
  response_format="b64_json",
192
+ num_inference_steps=num_inference_steps,
171
193
  )
172
194
 
173
195
  images = []
@@ -200,6 +222,9 @@ class ImageInterface:
200
222
  n = gr.Number(label="Number of image", value=1)
201
223
  size_width = gr.Number(label="Width", value=-1)
202
224
  size_height = gr.Number(label="Height", value=-1)
225
+ num_inference_steps = gr.Number(
226
+ label="Inference Step Number", value=-1
227
+ )
203
228
 
204
229
  with gr.Row():
205
230
  with gr.Column(scale=1):
@@ -216,6 +241,7 @@ class ImageInterface:
216
241
  n,
217
242
  size_width,
218
243
  size_height,
244
+ num_inference_steps,
219
245
  ],
220
246
  outputs=output_gallery,
221
247
  )
@@ -247,9 +273,11 @@ class ImageInterface:
247
273
  </div>
248
274
  """
249
275
  )
250
- with gr.Tab("Text to Image"):
251
- self.text2image_interface()
252
- with gr.Tab("Image to Image"):
253
- self.image2image_interface()
276
+ if "text2image" in self.model_ability:
277
+ with gr.Tab("Text to Image"):
278
+ self.text2image_interface()
279
+ if "image2image" in self.model_ability:
280
+ with gr.Tab("Image to Image"):
281
+ self.image2image_interface()
254
282
 
255
283
  return app
xinference/core/model.py CHANGED
@@ -133,6 +133,7 @@ class ModelActor(xo.StatelessActor):
133
133
  async def __pre_destroy__(self):
134
134
  from ..model.embedding.core import EmbeddingModel
135
135
  from ..model.llm.pytorch.core import PytorchModel as LLMPytorchModel
136
+ from ..model.llm.sglang.core import SGLANGModel
136
137
  from ..model.llm.vllm.core import VLLMModel as LLMVLLMModel
137
138
 
138
139
  if self.allow_batching():
@@ -145,8 +146,11 @@ class ModelActor(xo.StatelessActor):
145
146
  f"Destroy scheduler actor failed, address: {self.address}, error: {e}"
146
147
  )
147
148
 
149
+ if hasattr(self._model, "stop") and callable(self._model.stop):
150
+ self._model.stop()
151
+
148
152
  if (
149
- isinstance(self._model, (LLMPytorchModel, LLMVLLMModel))
153
+ isinstance(self._model, (LLMPytorchModel, LLMVLLMModel, SGLANGModel))
150
154
  and self._model.model_spec.model_format == "pytorch"
151
155
  ) or isinstance(self._model, EmbeddingModel):
152
156
  try:
@@ -174,6 +178,7 @@ class ModelActor(xo.StatelessActor):
174
178
  ):
175
179
  super().__init__()
176
180
  from ..model.llm.pytorch.core import PytorchModel
181
+ from ..model.llm.sglang.core import SGLANGModel
177
182
  from ..model.llm.vllm.core import VLLMModel
178
183
 
179
184
  self._worker_address = worker_address
@@ -187,7 +192,7 @@ class ModelActor(xo.StatelessActor):
187
192
  self._current_generator = lambda: None
188
193
  self._lock = (
189
194
  None
190
- if isinstance(self._model, (PytorchModel, VLLMModel))
195
+ if isinstance(self._model, (PytorchModel, VLLMModel, SGLANGModel))
191
196
  else asyncio.locks.Lock()
192
197
  )
193
198
  self._worker_ref = None
@@ -771,6 +776,27 @@ class ModelActor(xo.StatelessActor):
771
776
  f"Model {self._model.model_spec} is not for flexible infer."
772
777
  )
773
778
 
779
+ @log_async(logger=logger)
780
+ @request_limit
781
+ async def text_to_video(
782
+ self,
783
+ prompt: str,
784
+ n: int = 1,
785
+ *args,
786
+ **kwargs,
787
+ ):
788
+ if hasattr(self._model, "text_to_video"):
789
+ return await self._call_wrapper_json(
790
+ self._model.text_to_video,
791
+ prompt,
792
+ n,
793
+ *args,
794
+ **kwargs,
795
+ )
796
+ raise AttributeError(
797
+ f"Model {self._model.model_spec} is not for creating video."
798
+ )
799
+
774
800
  async def record_metrics(self, name, op, kwargs):
775
801
  worker_ref = await self._get_worker_ref()
776
802
  await worker_ref.record_metrics(name, op, kwargs)
@@ -64,6 +64,7 @@ if TYPE_CHECKING:
64
64
  from ..model.image import ImageModelFamilyV1
65
65
  from ..model.llm import LLMFamilyV1
66
66
  from ..model.rerank import RerankModelSpec
67
+ from ..model.video import VideoModelFamilyV1
67
68
  from .worker import WorkerActor
68
69
 
69
70
 
@@ -484,6 +485,31 @@ class SupervisorActor(xo.StatelessActor):
484
485
  res["model_instance_count"] = instance_cnt
485
486
  return res
486
487
 
488
+ async def _to_video_model_reg(
489
+ self, model_family: "VideoModelFamilyV1", is_builtin: bool
490
+ ) -> Dict[str, Any]:
491
+ from ..model.video import get_cache_status
492
+
493
+ instance_cnt = await self.get_instance_count(model_family.model_name)
494
+ version_cnt = await self.get_model_version_count(model_family.model_name)
495
+
496
+ if self.is_local_deployment():
497
+ # TODO: does not work when the supervisor and worker are running on separate nodes.
498
+ cache_status = get_cache_status(model_family)
499
+ res = {
500
+ **model_family.dict(),
501
+ "cache_status": cache_status,
502
+ "is_builtin": is_builtin,
503
+ }
504
+ else:
505
+ res = {
506
+ **model_family.dict(),
507
+ "is_builtin": is_builtin,
508
+ }
509
+ res["model_version_count"] = version_cnt
510
+ res["model_instance_count"] = instance_cnt
511
+ return res
512
+
487
513
  async def _to_flexible_model_reg(
488
514
  self, model_spec: "FlexibleModelSpec", is_builtin: bool
489
515
  ) -> Dict[str, Any]:
@@ -602,6 +628,17 @@ class SupervisorActor(xo.StatelessActor):
602
628
  {"model_name": model_spec.model_name, "is_builtin": False}
603
629
  )
604
630
 
631
+ ret.sort(key=sort_helper)
632
+ return ret
633
+ elif model_type == "video":
634
+ from ..model.video import BUILTIN_VIDEO_MODELS
635
+
636
+ for model_name, family in BUILTIN_VIDEO_MODELS.items():
637
+ if detailed:
638
+ ret.append(await self._to_video_model_reg(family, is_builtin=True))
639
+ else:
640
+ ret.append({"model_name": model_name, "is_builtin": True})
641
+
605
642
  ret.sort(key=sort_helper)
606
643
  return ret
607
644
  elif model_type == "rerank":