xinference 0.9.2__py3-none-any.whl → 0.9.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +111 -13
- xinference/client/restful/restful_client.py +2 -1
- xinference/conftest.py +18 -15
- xinference/constants.py +2 -0
- xinference/core/image_interface.py +252 -0
- xinference/core/supervisor.py +3 -10
- xinference/deploy/cmdline.py +69 -4
- xinference/deploy/local.py +1 -1
- xinference/deploy/supervisor.py +1 -1
- xinference/model/image/__init__.py +13 -7
- xinference/model/image/core.py +17 -1
- xinference/model/llm/__init__.py +2 -0
- xinference/model/llm/ggml/llamacpp.py +1 -5
- xinference/model/llm/llm_family.json +98 -13
- xinference/model/llm/llm_family_modelscope.json +98 -7
- xinference/model/llm/pytorch/chatglm.py +2 -1
- xinference/model/llm/pytorch/internlm2.py +2 -1
- xinference/model/llm/sglang/__init__.py +13 -0
- xinference/model/llm/sglang/core.py +365 -0
- xinference/model/llm/utils.py +35 -12
- xinference/model/llm/vllm/core.py +17 -0
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.78829790.js → main.66b1c4fb.js} +3 -3
- xinference/web/ui/build/static/js/main.66b1c4fb.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/0bd70b1ecf307e2681318e864f4692305b6350c8683863007f4caf2f9ac33b6e.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/3e055de705e397e1d413d7f429589b1a98dd78ef378b97f0cdb462c5f2487d5e.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/60c4b98d8ea7479fb0c94cfd19c8128f17bd7e27a1e73e6dd9adf6e9d88d18eb.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/7e094845f611802b024b57439cbf911038169d06cdf6c34a72a7277f35aa71a4.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/b400cfc9db57fa6c70cd2bad055b73c5079fde0ed37974009d898083f6af8cd8.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/e1d9b2ae4e1248658704bc6bfc5d6160dcd1a9e771ea4ae8c1fed0aaddeedd29.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/fd4a8ae5d192331af1bedd1d2d70efcc569708ee6cc4cb479b225d059482aa81.json +1 -0
- {xinference-0.9.2.dist-info → xinference-0.9.4.dist-info}/METADATA +8 -5
- {xinference-0.9.2.dist-info → xinference-0.9.4.dist-info}/RECORD +40 -37
- {xinference-0.9.2.dist-info → xinference-0.9.4.dist-info}/WHEEL +1 -1
- xinference/web/ui/build/static/js/main.78829790.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/22858de5265f2d279fca9f2f54dfb147e4b2704200dfb5d2ad3ec9769417328f.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/30670751f55508ef3b861e13dd71b9e5a10d2561373357a12fc3831a0b77fd93.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/396f7ce6ae6900bfdb00e369ade8a05045dc1df025610057ff7436d9e58af81c.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/5282ee05e064b3a80bc991e9003ddef6a4958471d8f4fc65589dc64553365cdd.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/83beb31daa7169fb0057453d4f86411f1effd3e3f7af97472cbd22accbfc65bb.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/ddf597663270471b31251b2abb36e3fa093efe20489387d996f993d2c61be112.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/e8687f75d2adacd34852b71c41ca17203d6fb4c8999ea55325bb2939f9d9ea90.json +0 -1
- /xinference/web/ui/build/static/js/{main.78829790.js.LICENSE.txt → main.66b1c4fb.js.LICENSE.txt} +0 -0
- {xinference-0.9.2.dist-info → xinference-0.9.4.dist-info}/LICENSE +0 -0
- {xinference-0.9.2.dist-info → xinference-0.9.4.dist-info}/entry_points.txt +0 -0
- {xinference-0.9.2.dist-info → xinference-0.9.4.dist-info}/top_level.txt +0 -0
xinference/_version.py
CHANGED
|
@@ -8,11 +8,11 @@ import json
|
|
|
8
8
|
|
|
9
9
|
version_json = '''
|
|
10
10
|
{
|
|
11
|
-
"date": "2024-03-
|
|
11
|
+
"date": "2024-03-21T14:58:01+0800",
|
|
12
12
|
"dirty": false,
|
|
13
13
|
"error": null,
|
|
14
|
-
"full-revisionid": "
|
|
15
|
-
"version": "0.9.
|
|
14
|
+
"full-revisionid": "2c9465ade7f358d57d4bc087277882d896a8de15",
|
|
15
|
+
"version": "0.9.4"
|
|
16
16
|
}
|
|
17
17
|
''' # END VERSION_JSON
|
|
18
18
|
|
xinference/api/restful_api.py
CHANGED
|
@@ -22,7 +22,7 @@ import pprint
|
|
|
22
22
|
import sys
|
|
23
23
|
import time
|
|
24
24
|
import warnings
|
|
25
|
-
from typing import Any, List, Optional, Union
|
|
25
|
+
from typing import Any, Dict, List, Optional, Union
|
|
26
26
|
|
|
27
27
|
import gradio as gr
|
|
28
28
|
import xoscar as xo
|
|
@@ -59,6 +59,7 @@ from ..core.utils import json_dumps
|
|
|
59
59
|
from ..types import (
|
|
60
60
|
SPECIAL_TOOL_PROMPT,
|
|
61
61
|
ChatCompletion,
|
|
62
|
+
ChatCompletionMessage,
|
|
62
63
|
Completion,
|
|
63
64
|
CreateChatCompletion,
|
|
64
65
|
CreateCompletion,
|
|
@@ -135,6 +136,15 @@ class BuildGradioInterfaceRequest(BaseModel):
|
|
|
135
136
|
model_lang: List[str]
|
|
136
137
|
|
|
137
138
|
|
|
139
|
+
class BuildGradioImageInterfaceRequest(BaseModel):
|
|
140
|
+
model_type: str
|
|
141
|
+
model_name: str
|
|
142
|
+
model_family: str
|
|
143
|
+
model_id: str
|
|
144
|
+
controlnet: Union[None, List[Dict[str, Union[str, None]]]]
|
|
145
|
+
model_revision: str
|
|
146
|
+
|
|
147
|
+
|
|
138
148
|
class RESTfulAPI:
|
|
139
149
|
def __init__(
|
|
140
150
|
self,
|
|
@@ -246,6 +256,16 @@ class RESTfulAPI:
|
|
|
246
256
|
else None
|
|
247
257
|
),
|
|
248
258
|
)
|
|
259
|
+
self._router.add_api_route(
|
|
260
|
+
"/v1/ui/images/{model_uid}",
|
|
261
|
+
self.build_gradio_images_interface,
|
|
262
|
+
methods=["POST"],
|
|
263
|
+
dependencies=(
|
|
264
|
+
[Security(self._auth_service, scopes=["models:read"])]
|
|
265
|
+
if self.is_authenticated()
|
|
266
|
+
else None
|
|
267
|
+
),
|
|
268
|
+
)
|
|
249
269
|
self._router.add_api_route(
|
|
250
270
|
"/token", self.login_for_access_token, methods=["POST"]
|
|
251
271
|
)
|
|
@@ -584,8 +604,22 @@ class RESTfulAPI:
|
|
|
584
604
|
|
|
585
605
|
async def list_models(self) -> JSONResponse:
|
|
586
606
|
try:
|
|
587
|
-
|
|
588
|
-
|
|
607
|
+
models = await (await self._get_supervisor_ref()).list_models()
|
|
608
|
+
|
|
609
|
+
model_list = []
|
|
610
|
+
for model_id, model_info in models.items():
|
|
611
|
+
model_list.append(
|
|
612
|
+
{
|
|
613
|
+
"id": model_id,
|
|
614
|
+
"object": "model",
|
|
615
|
+
"created": 0,
|
|
616
|
+
"owned_by": "xinference",
|
|
617
|
+
**model_info,
|
|
618
|
+
}
|
|
619
|
+
)
|
|
620
|
+
response = {"object": "list", "data": model_list}
|
|
621
|
+
|
|
622
|
+
return JSONResponse(content=response)
|
|
589
623
|
except Exception as e:
|
|
590
624
|
logger.error(e, exc_info=True)
|
|
591
625
|
raise HTTPException(status_code=500, detail=str(e))
|
|
@@ -822,6 +856,56 @@ class RESTfulAPI:
|
|
|
822
856
|
|
|
823
857
|
return JSONResponse(content={"model_uid": model_uid})
|
|
824
858
|
|
|
859
|
+
async def build_gradio_images_interface(
|
|
860
|
+
self, model_uid: str, request: Request
|
|
861
|
+
) -> JSONResponse:
|
|
862
|
+
"""
|
|
863
|
+
Build a Gradio interface for image processing models.
|
|
864
|
+
"""
|
|
865
|
+
payload = await request.json()
|
|
866
|
+
body = BuildGradioImageInterfaceRequest.parse_obj(payload)
|
|
867
|
+
assert self._app is not None
|
|
868
|
+
assert body.model_type == "image"
|
|
869
|
+
|
|
870
|
+
# asyncio.Lock() behaves differently in 3.9 than 3.10+
|
|
871
|
+
# A event loop is required in 3.9 but not 3.10+
|
|
872
|
+
if sys.version_info < (3, 10):
|
|
873
|
+
try:
|
|
874
|
+
asyncio.get_event_loop()
|
|
875
|
+
except RuntimeError:
|
|
876
|
+
warnings.warn(
|
|
877
|
+
"asyncio.Lock() requires an event loop in Python 3.9"
|
|
878
|
+
+ "a placeholder event loop has been created"
|
|
879
|
+
)
|
|
880
|
+
asyncio.set_event_loop(asyncio.new_event_loop())
|
|
881
|
+
|
|
882
|
+
from ..core.image_interface import ImageInterface
|
|
883
|
+
|
|
884
|
+
try:
|
|
885
|
+
access_token = request.headers.get("Authorization")
|
|
886
|
+
internal_host = "localhost" if self._host == "0.0.0.0" else self._host
|
|
887
|
+
interface = ImageInterface(
|
|
888
|
+
endpoint=f"http://{internal_host}:{self._port}",
|
|
889
|
+
model_uid=model_uid,
|
|
890
|
+
model_family=body.model_family,
|
|
891
|
+
model_name=body.model_name,
|
|
892
|
+
model_id=body.model_id,
|
|
893
|
+
model_revision=body.model_revision,
|
|
894
|
+
controlnet=body.controlnet,
|
|
895
|
+
access_token=access_token,
|
|
896
|
+
).build()
|
|
897
|
+
|
|
898
|
+
gr.mount_gradio_app(self._app, interface, f"/{model_uid}")
|
|
899
|
+
except ValueError as ve:
|
|
900
|
+
logger.error(str(ve), exc_info=True)
|
|
901
|
+
raise HTTPException(status_code=400, detail=str(ve))
|
|
902
|
+
|
|
903
|
+
except Exception as e:
|
|
904
|
+
logger.error(e, exc_info=True)
|
|
905
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
906
|
+
|
|
907
|
+
return JSONResponse(content={"model_uid": model_uid})
|
|
908
|
+
|
|
825
909
|
async def terminate_model(self, model_uid: str) -> JSONResponse:
|
|
826
910
|
try:
|
|
827
911
|
assert self._app is not None
|
|
@@ -891,11 +975,17 @@ class RESTfulAPI:
|
|
|
891
975
|
self.handle_request_limit_error(re)
|
|
892
976
|
async for item in iterator:
|
|
893
977
|
yield item
|
|
978
|
+
except asyncio.CancelledError:
|
|
979
|
+
logger.info(
|
|
980
|
+
f"Disconnected from client (via refresh/close) {request.client} during generate."
|
|
981
|
+
)
|
|
982
|
+
return
|
|
894
983
|
except Exception as ex:
|
|
895
984
|
logger.exception("Completion stream got an error: %s", ex)
|
|
896
985
|
await self._report_error_event(model_uid, str(ex))
|
|
897
986
|
# https://github.com/openai/openai-python/blob/e0aafc6c1a45334ac889fe3e54957d309c3af93f/src/openai/_streaming.py#L107
|
|
898
987
|
yield dict(data=json.dumps({"error": str(ex)}))
|
|
988
|
+
return
|
|
899
989
|
|
|
900
990
|
return EventSourceResponse(stream_results())
|
|
901
991
|
else:
|
|
@@ -1169,25 +1259,21 @@ class RESTfulAPI:
|
|
|
1169
1259
|
status_code=400, detail="Invalid input. Please specify the prompt."
|
|
1170
1260
|
)
|
|
1171
1261
|
|
|
1172
|
-
system_messages = []
|
|
1262
|
+
system_messages: List["ChatCompletionMessage"] = []
|
|
1263
|
+
system_messages_contents = []
|
|
1173
1264
|
non_system_messages = []
|
|
1174
1265
|
for msg in messages:
|
|
1175
1266
|
assert (
|
|
1176
1267
|
msg.get("content") != SPECIAL_TOOL_PROMPT
|
|
1177
1268
|
), f"Invalid message content {SPECIAL_TOOL_PROMPT}"
|
|
1178
1269
|
if msg["role"] == "system":
|
|
1179
|
-
|
|
1270
|
+
system_messages_contents.append(msg["content"])
|
|
1180
1271
|
else:
|
|
1181
1272
|
non_system_messages.append(msg)
|
|
1273
|
+
system_messages.append(
|
|
1274
|
+
{"role": "system", "content": ". ".join(system_messages_contents)}
|
|
1275
|
+
)
|
|
1182
1276
|
|
|
1183
|
-
if len(system_messages) > 1:
|
|
1184
|
-
raise HTTPException(
|
|
1185
|
-
status_code=400, detail="Multiple system messages are not supported."
|
|
1186
|
-
)
|
|
1187
|
-
if len(system_messages) == 1 and messages[0]["role"] != "system":
|
|
1188
|
-
raise HTTPException(
|
|
1189
|
-
status_code=400, detail="System message should be the first one."
|
|
1190
|
-
)
|
|
1191
1277
|
assert non_system_messages
|
|
1192
1278
|
|
|
1193
1279
|
has_tool_message = messages[-1].get("role") == "tool"
|
|
@@ -1273,11 +1359,23 @@ class RESTfulAPI:
|
|
|
1273
1359
|
async for item in iterator:
|
|
1274
1360
|
yield item
|
|
1275
1361
|
yield "[DONE]"
|
|
1362
|
+
# Note that asyncio.CancelledError does not inherit from Exception.
|
|
1363
|
+
# When the user uses ctrl+c to cancel the streaming chat, asyncio.CancelledError would be triggered.
|
|
1364
|
+
# See https://github.com/sysid/sse-starlette/blob/main/examples/example.py#L48
|
|
1365
|
+
except asyncio.CancelledError:
|
|
1366
|
+
logger.info(
|
|
1367
|
+
f"Disconnected from client (via refresh/close) {request.client} during chat."
|
|
1368
|
+
)
|
|
1369
|
+
# See https://github.com/sysid/sse-starlette/blob/main/examples/error_handling.py#L13
|
|
1370
|
+
# Use return to stop the generator from continuing.
|
|
1371
|
+
# TODO: Cannot yield here. Yield here would leads to error for the next streaming request.
|
|
1372
|
+
return
|
|
1276
1373
|
except Exception as ex:
|
|
1277
1374
|
logger.exception("Chat completion stream got an error: %s", ex)
|
|
1278
1375
|
await self._report_error_event(model_uid, str(ex))
|
|
1279
1376
|
# https://github.com/openai/openai-python/blob/e0aafc6c1a45334ac889fe3e54957d309c3af93f/src/openai/_streaming.py#L107
|
|
1280
1377
|
yield dict(data=json.dumps({"error": str(ex)}))
|
|
1378
|
+
return
|
|
1281
1379
|
|
|
1282
1380
|
return EventSourceResponse(stream_results())
|
|
1283
1381
|
else:
|
xinference/conftest.py
CHANGED
|
@@ -208,10 +208,11 @@ def setup():
|
|
|
208
208
|
if not api_health_check(endpoint, max_attempts=10, sleep_interval=5):
|
|
209
209
|
raise RuntimeError("Endpoint is not available after multiple attempts")
|
|
210
210
|
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
211
|
+
try:
|
|
212
|
+
yield f"http://localhost:{port}", supervisor_addr
|
|
213
|
+
finally:
|
|
214
|
+
local_cluster_proc.kill()
|
|
215
|
+
restful_api_proc.kill()
|
|
215
216
|
|
|
216
217
|
|
|
217
218
|
@pytest.fixture
|
|
@@ -239,10 +240,11 @@ def setup_with_file_logging():
|
|
|
239
240
|
if not api_health_check(endpoint, max_attempts=3, sleep_interval=5):
|
|
240
241
|
raise RuntimeError("Endpoint is not available after multiple attempts")
|
|
241
242
|
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
243
|
+
try:
|
|
244
|
+
yield f"http://localhost:{port}", supervisor_addr, TEST_LOG_FILE_PATH
|
|
245
|
+
finally:
|
|
246
|
+
local_cluster_proc.kill()
|
|
247
|
+
restful_api_proc.kill()
|
|
246
248
|
|
|
247
249
|
|
|
248
250
|
@pytest.fixture
|
|
@@ -290,11 +292,12 @@ def setup_with_auth():
|
|
|
290
292
|
if not api_health_check(endpoint, max_attempts=10, sleep_interval=5):
|
|
291
293
|
raise RuntimeError("Endpoint is not available after multiple attempts")
|
|
292
294
|
|
|
293
|
-
yield f"http://localhost:{port}", supervisor_addr
|
|
294
|
-
|
|
295
|
-
local_cluster_proc.terminate()
|
|
296
|
-
restful_api_proc.terminate()
|
|
297
295
|
try:
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
296
|
+
yield f"http://localhost:{port}", supervisor_addr
|
|
297
|
+
finally:
|
|
298
|
+
local_cluster_proc.kill()
|
|
299
|
+
restful_api_proc.kill()
|
|
300
|
+
try:
|
|
301
|
+
os.remove(auth_file)
|
|
302
|
+
except:
|
|
303
|
+
pass
|
xinference/constants.py
CHANGED
|
@@ -25,6 +25,7 @@ XINFERENCE_ENV_HEALTH_CHECK_INTERVAL = "XINFERENCE_HEALTH_CHECK_INTERVAL"
|
|
|
25
25
|
XINFERENCE_ENV_HEALTH_CHECK_TIMEOUT = "XINFERENCE_HEALTH_CHECK_TIMEOUT"
|
|
26
26
|
XINFERENCE_ENV_DISABLE_HEALTH_CHECK = "XINFERENCE_DISABLE_HEALTH_CHECK"
|
|
27
27
|
XINFERENCE_ENV_DISABLE_VLLM = "XINFERENCE_DISABLE_VLLM"
|
|
28
|
+
XINFERENCE_ENV_ENABLE_SGLANG = "XINFERENCE_ENABLE_SGLANG"
|
|
28
29
|
|
|
29
30
|
|
|
30
31
|
def get_xinference_home() -> str:
|
|
@@ -64,3 +65,4 @@ XINFERENCE_DISABLE_HEALTH_CHECK = bool(
|
|
|
64
65
|
int(os.environ.get(XINFERENCE_ENV_DISABLE_HEALTH_CHECK, 0))
|
|
65
66
|
)
|
|
66
67
|
XINFERENCE_DISABLE_VLLM = bool(int(os.environ.get(XINFERENCE_ENV_DISABLE_VLLM, 0)))
|
|
68
|
+
XINFERENCE_ENABLE_SGLANG = bool(int(os.environ.get(XINFERENCE_ENV_ENABLE_SGLANG, 0)))
|
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import base64
|
|
16
|
+
import io
|
|
17
|
+
import logging
|
|
18
|
+
import os
|
|
19
|
+
from typing import Dict, List, Optional, Union
|
|
20
|
+
|
|
21
|
+
import gradio as gr
|
|
22
|
+
import PIL.Image
|
|
23
|
+
from gradio import Markdown
|
|
24
|
+
|
|
25
|
+
from ..client.restful.restful_client import RESTfulImageModelHandle
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class ImageInterface:
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
endpoint: str,
|
|
34
|
+
model_uid: str,
|
|
35
|
+
model_family: str,
|
|
36
|
+
model_name: str,
|
|
37
|
+
model_id: str,
|
|
38
|
+
model_revision: str,
|
|
39
|
+
controlnet: Union[None, List[Dict[str, Union[str, None]]]],
|
|
40
|
+
access_token: Optional[str],
|
|
41
|
+
):
|
|
42
|
+
self.endpoint = endpoint
|
|
43
|
+
self.model_uid = model_uid
|
|
44
|
+
self.model_family = model_family
|
|
45
|
+
self.model_name = model_name
|
|
46
|
+
self.model_id = model_id
|
|
47
|
+
self.model_revision = model_revision
|
|
48
|
+
self.controlnet = controlnet
|
|
49
|
+
self.access_token = (
|
|
50
|
+
access_token.replace("Bearer ", "") if access_token is not None else None
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
def build(self) -> gr.Blocks:
|
|
54
|
+
assert "stable_diffusion" in self.model_family
|
|
55
|
+
|
|
56
|
+
interface = self.build_main_interface()
|
|
57
|
+
interface.queue()
|
|
58
|
+
# Gradio initiates the queue during a startup event, but since the app has already been
|
|
59
|
+
# started, that event will not run, so manually invoke the startup events.
|
|
60
|
+
# See: https://github.com/gradio-app/gradio/issues/5228
|
|
61
|
+
interface.startup_events()
|
|
62
|
+
favicon_path = os.path.join(
|
|
63
|
+
os.path.dirname(os.path.abspath(__file__)),
|
|
64
|
+
os.path.pardir,
|
|
65
|
+
"web",
|
|
66
|
+
"ui",
|
|
67
|
+
"public",
|
|
68
|
+
"favicon.svg",
|
|
69
|
+
)
|
|
70
|
+
interface.favicon_path = favicon_path
|
|
71
|
+
return interface
|
|
72
|
+
|
|
73
|
+
def text2image_interface(self) -> "gr.Blocks":
|
|
74
|
+
def text_generate_image(
|
|
75
|
+
prompt: str,
|
|
76
|
+
n: int,
|
|
77
|
+
size_width: int,
|
|
78
|
+
size_height: int,
|
|
79
|
+
negative_prompt: Optional[str] = None,
|
|
80
|
+
) -> PIL.Image.Image:
|
|
81
|
+
from ..client import RESTfulClient
|
|
82
|
+
|
|
83
|
+
client = RESTfulClient(self.endpoint)
|
|
84
|
+
client._set_token(self.access_token)
|
|
85
|
+
model = client.get_model(self.model_uid)
|
|
86
|
+
assert isinstance(model, RESTfulImageModelHandle)
|
|
87
|
+
|
|
88
|
+
size = f"{int(size_width)}*{int(size_height)}"
|
|
89
|
+
|
|
90
|
+
response = model.text_to_image(
|
|
91
|
+
prompt=prompt,
|
|
92
|
+
n=n,
|
|
93
|
+
size=size,
|
|
94
|
+
negative_prompt=negative_prompt,
|
|
95
|
+
response_format="b64_json",
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
images = []
|
|
99
|
+
for image_dict in response["data"]:
|
|
100
|
+
assert image_dict["b64_json"] is not None
|
|
101
|
+
image_data = base64.b64decode(image_dict["b64_json"])
|
|
102
|
+
image = PIL.Image.open(io.BytesIO(image_data))
|
|
103
|
+
images.append(image)
|
|
104
|
+
|
|
105
|
+
return images
|
|
106
|
+
|
|
107
|
+
with gr.Blocks() as text2image_vl_interface:
|
|
108
|
+
with gr.Column():
|
|
109
|
+
with gr.Row():
|
|
110
|
+
with gr.Column(scale=10):
|
|
111
|
+
prompt = gr.Textbox(
|
|
112
|
+
label="Prompt",
|
|
113
|
+
show_label=True,
|
|
114
|
+
placeholder="Enter prompt here...",
|
|
115
|
+
)
|
|
116
|
+
negative_prompt = gr.Textbox(
|
|
117
|
+
label="Negative prompt",
|
|
118
|
+
show_label=True,
|
|
119
|
+
placeholder="Enter negative prompt here...",
|
|
120
|
+
)
|
|
121
|
+
with gr.Column(scale=1):
|
|
122
|
+
generate_button = gr.Button("Generate")
|
|
123
|
+
|
|
124
|
+
with gr.Row():
|
|
125
|
+
n = gr.Number(label="Number of Images", value=1)
|
|
126
|
+
size_width = gr.Number(label="Width", value=1024)
|
|
127
|
+
size_height = gr.Number(label="Height", value=1024)
|
|
128
|
+
|
|
129
|
+
with gr.Column():
|
|
130
|
+
image_output = gr.Gallery()
|
|
131
|
+
|
|
132
|
+
generate_button.click(
|
|
133
|
+
text_generate_image,
|
|
134
|
+
inputs=[prompt, n, size_width, size_height, negative_prompt],
|
|
135
|
+
outputs=image_output,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
return text2image_vl_interface
|
|
139
|
+
|
|
140
|
+
def image2image_interface(self) -> "gr.Blocks":
|
|
141
|
+
def image_generate_image(
|
|
142
|
+
prompt: str,
|
|
143
|
+
negative_prompt: str,
|
|
144
|
+
image: PIL.Image.Image,
|
|
145
|
+
n: int,
|
|
146
|
+
size_width: int,
|
|
147
|
+
size_height: int,
|
|
148
|
+
) -> PIL.Image.Image:
|
|
149
|
+
from ..client import RESTfulClient
|
|
150
|
+
|
|
151
|
+
client = RESTfulClient(self.endpoint)
|
|
152
|
+
client._set_token(self.access_token)
|
|
153
|
+
model = client.get_model(self.model_uid)
|
|
154
|
+
assert isinstance(model, RESTfulImageModelHandle)
|
|
155
|
+
|
|
156
|
+
size = f"{int(size_width)}*{int(size_height)}"
|
|
157
|
+
|
|
158
|
+
bio = io.BytesIO()
|
|
159
|
+
image.save(bio, format="png")
|
|
160
|
+
|
|
161
|
+
response = model.image_to_image(
|
|
162
|
+
prompt=prompt,
|
|
163
|
+
negative_prompt=negative_prompt,
|
|
164
|
+
n=n,
|
|
165
|
+
image=bio.getvalue(),
|
|
166
|
+
size=size,
|
|
167
|
+
response_format="b64_json",
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
images = []
|
|
171
|
+
for image_dict in response["data"]:
|
|
172
|
+
assert image_dict["b64_json"] is not None
|
|
173
|
+
image_data = base64.b64decode(image_dict["b64_json"])
|
|
174
|
+
image = PIL.Image.open(io.BytesIO(image_data))
|
|
175
|
+
images.append(image)
|
|
176
|
+
|
|
177
|
+
return images
|
|
178
|
+
|
|
179
|
+
with gr.Blocks() as image2image_inteface:
|
|
180
|
+
with gr.Column():
|
|
181
|
+
with gr.Row():
|
|
182
|
+
with gr.Column(scale=10):
|
|
183
|
+
prompt = gr.Textbox(
|
|
184
|
+
label="Prompt",
|
|
185
|
+
show_label=True,
|
|
186
|
+
placeholder="Enter prompt here...",
|
|
187
|
+
)
|
|
188
|
+
negative_prompt = gr.Textbox(
|
|
189
|
+
label="Negative Prompt",
|
|
190
|
+
show_label=True,
|
|
191
|
+
placeholder="Enter negative prompt here...",
|
|
192
|
+
)
|
|
193
|
+
with gr.Column(scale=1):
|
|
194
|
+
generate_button = gr.Button("Generate")
|
|
195
|
+
|
|
196
|
+
with gr.Row():
|
|
197
|
+
n = gr.Number(label="Number of image", value=1)
|
|
198
|
+
size_width = gr.Number(label="Width", value=512)
|
|
199
|
+
size_height = gr.Number(label="Height", value=512)
|
|
200
|
+
|
|
201
|
+
with gr.Row():
|
|
202
|
+
with gr.Column(scale=1):
|
|
203
|
+
uploaded_image = gr.Image(type="pil", label="Upload Image")
|
|
204
|
+
with gr.Column(scale=1):
|
|
205
|
+
output_gallery = gr.Gallery()
|
|
206
|
+
|
|
207
|
+
generate_button.click(
|
|
208
|
+
image_generate_image,
|
|
209
|
+
inputs=[
|
|
210
|
+
prompt,
|
|
211
|
+
negative_prompt,
|
|
212
|
+
uploaded_image,
|
|
213
|
+
n,
|
|
214
|
+
size_width,
|
|
215
|
+
size_height,
|
|
216
|
+
],
|
|
217
|
+
outputs=output_gallery,
|
|
218
|
+
)
|
|
219
|
+
return image2image_inteface
|
|
220
|
+
|
|
221
|
+
def build_main_interface(self) -> "gr.Blocks":
|
|
222
|
+
with gr.Blocks(
|
|
223
|
+
title=f"🎨 Xinference Stable Diffusion: {self.model_name} 🎨",
|
|
224
|
+
css="""
|
|
225
|
+
.center{
|
|
226
|
+
display: flex;
|
|
227
|
+
justify-content: center;
|
|
228
|
+
align-items: center;
|
|
229
|
+
padding: 0px;
|
|
230
|
+
color: #9ea4b0 !important;
|
|
231
|
+
}
|
|
232
|
+
""",
|
|
233
|
+
analytics_enabled=False,
|
|
234
|
+
) as app:
|
|
235
|
+
Markdown(
|
|
236
|
+
f"""
|
|
237
|
+
<h1 class="center" style='text-align: center; margin-bottom: 1rem'>🎨 Xinference Stable Diffusion: {self.model_name} 🎨</h1>
|
|
238
|
+
"""
|
|
239
|
+
)
|
|
240
|
+
Markdown(
|
|
241
|
+
f"""
|
|
242
|
+
<div class="center">
|
|
243
|
+
Model ID: {self.model_uid}
|
|
244
|
+
</div>
|
|
245
|
+
"""
|
|
246
|
+
)
|
|
247
|
+
with gr.Tab("Text to Image"):
|
|
248
|
+
self.text2image_interface()
|
|
249
|
+
with gr.Tab("Image to Image"):
|
|
250
|
+
self.image2image_interface()
|
|
251
|
+
|
|
252
|
+
return app
|
xinference/core/supervisor.py
CHANGED
|
@@ -722,17 +722,10 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
722
722
|
if model_uid is None:
|
|
723
723
|
model_uid = self._gen_model_uid(model_name)
|
|
724
724
|
|
|
725
|
+
model_size = str(model_size_in_billions) if model_size_in_billions else ""
|
|
725
726
|
logger.debug(
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
f"model_format: %s, quantization: %s, replica: %s"
|
|
729
|
-
),
|
|
730
|
-
model_uid,
|
|
731
|
-
model_name,
|
|
732
|
-
str(model_size_in_billions) if model_size_in_billions else "",
|
|
733
|
-
model_format,
|
|
734
|
-
quantization,
|
|
735
|
-
replica,
|
|
727
|
+
f"Enter launch_builtin_model, model_uid: {model_uid}, model_name: {model_name}, model_size: {model_size}, "
|
|
728
|
+
f"model_format: {model_format}, quantization: {quantization}, replica: {replica}"
|
|
736
729
|
)
|
|
737
730
|
|
|
738
731
|
async def _launch_one_model(_replica_model_uid):
|