xinference 0.15.3__py3-none-any.whl → 0.15.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +29 -2
- xinference/client/restful/restful_client.py +10 -0
- xinference/constants.py +4 -0
- xinference/core/image_interface.py +76 -23
- xinference/core/model.py +80 -39
- xinference/core/progress_tracker.py +187 -0
- xinference/core/supervisor.py +11 -0
- xinference/core/worker.py +1 -0
- xinference/model/audio/chattts.py +2 -1
- xinference/model/audio/core.py +0 -2
- xinference/model/audio/model_spec.json +8 -0
- xinference/model/audio/model_spec_modelscope.json +9 -0
- xinference/model/image/core.py +6 -7
- xinference/model/image/sdapi.py +35 -4
- xinference/model/image/stable_diffusion/core.py +208 -78
- xinference/model/llm/llm_family.json +16 -16
- xinference/model/llm/llm_family_modelscope.json +16 -12
- xinference/model/llm/transformers/cogvlm2.py +2 -1
- xinference/model/llm/transformers/cogvlm2_video.py +2 -0
- xinference/model/llm/transformers/core.py +6 -2
- xinference/model/llm/transformers/deepseek_vl.py +2 -0
- xinference/model/llm/transformers/glm4v.py +2 -1
- xinference/model/llm/transformers/intern_vl.py +2 -0
- xinference/model/llm/transformers/minicpmv25.py +2 -0
- xinference/model/llm/transformers/minicpmv26.py +2 -0
- xinference/model/llm/transformers/omnilmm.py +2 -0
- xinference/model/llm/transformers/qwen2_audio.py +11 -4
- xinference/model/llm/transformers/qwen2_vl.py +2 -28
- xinference/model/llm/transformers/qwen_vl.py +2 -1
- xinference/model/llm/transformers/utils.py +35 -2
- xinference/model/llm/transformers/yi_vl.py +2 -0
- xinference/model/llm/utils.py +58 -14
- xinference/model/llm/vllm/core.py +52 -8
- xinference/model/llm/vllm/utils.py +0 -1
- xinference/model/utils.py +7 -4
- xinference/model/video/core.py +0 -2
- {xinference-0.15.3.dist-info → xinference-0.15.4.dist-info}/METADATA +3 -3
- {xinference-0.15.3.dist-info → xinference-0.15.4.dist-info}/RECORD +43 -42
- {xinference-0.15.3.dist-info → xinference-0.15.4.dist-info}/LICENSE +0 -0
- {xinference-0.15.3.dist-info → xinference-0.15.4.dist-info}/WHEEL +0 -0
- {xinference-0.15.3.dist-info → xinference-0.15.4.dist-info}/entry_points.txt +0 -0
- {xinference-0.15.3.dist-info → xinference-0.15.4.dist-info}/top_level.txt +0 -0
xinference/_version.py
CHANGED
|
@@ -8,11 +8,11 @@ import json
|
|
|
8
8
|
|
|
9
9
|
version_json = '''
|
|
10
10
|
{
|
|
11
|
-
"date": "2024-
|
|
11
|
+
"date": "2024-10-12T18:28:41+0800",
|
|
12
12
|
"dirty": false,
|
|
13
13
|
"error": null,
|
|
14
|
-
"full-revisionid": "
|
|
15
|
-
"version": "0.15.
|
|
14
|
+
"full-revisionid": "c0be11504c70f6c392cbdb67c86cf12153353f70",
|
|
15
|
+
"version": "0.15.4"
|
|
16
16
|
}
|
|
17
17
|
''' # END VERSION_JSON
|
|
18
18
|
|
xinference/api/restful_api.py
CHANGED
|
@@ -524,6 +524,16 @@ class RESTfulAPI:
|
|
|
524
524
|
else None
|
|
525
525
|
),
|
|
526
526
|
)
|
|
527
|
+
self._router.add_api_route(
|
|
528
|
+
"/v1/requests/{request_id}/progress",
|
|
529
|
+
self.get_progress,
|
|
530
|
+
methods=["get"],
|
|
531
|
+
dependencies=(
|
|
532
|
+
[Security(self._auth_service, scopes=["models:read"])]
|
|
533
|
+
if self.is_authenticated()
|
|
534
|
+
else None
|
|
535
|
+
),
|
|
536
|
+
)
|
|
527
537
|
self._router.add_api_route(
|
|
528
538
|
"/v1/images/generations",
|
|
529
539
|
self.create_images,
|
|
@@ -1486,6 +1496,17 @@ class RESTfulAPI:
|
|
|
1486
1496
|
await self._report_error_event(model_uid, str(e))
|
|
1487
1497
|
raise HTTPException(status_code=500, detail=str(e))
|
|
1488
1498
|
|
|
1499
|
+
async def get_progress(self, request_id: str) -> JSONResponse:
|
|
1500
|
+
try:
|
|
1501
|
+
supervisor_ref = await self._get_supervisor_ref()
|
|
1502
|
+
result = {"progress": await supervisor_ref.get_progress(request_id)}
|
|
1503
|
+
return JSONResponse(content=result)
|
|
1504
|
+
except KeyError as e:
|
|
1505
|
+
raise HTTPException(status_code=400, detail=str(e))
|
|
1506
|
+
except Exception as e:
|
|
1507
|
+
logger.error(e, exc_info=True)
|
|
1508
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
1509
|
+
|
|
1489
1510
|
async def create_images(self, request: Request) -> Response:
|
|
1490
1511
|
body = TextToImageRequest.parse_obj(await request.json())
|
|
1491
1512
|
model_uid = body.model
|
|
@@ -1853,10 +1874,16 @@ class RESTfulAPI:
|
|
|
1853
1874
|
await self._report_error_event(model_uid, str(e))
|
|
1854
1875
|
raise HTTPException(status_code=500, detail=str(e))
|
|
1855
1876
|
|
|
1856
|
-
from ..model.llm.utils import
|
|
1877
|
+
from ..model.llm.utils import (
|
|
1878
|
+
GLM4_TOOL_CALL_FAMILY,
|
|
1879
|
+
LLAMA3_TOOL_CALL_FAMILY,
|
|
1880
|
+
QWEN_TOOL_CALL_FAMILY,
|
|
1881
|
+
)
|
|
1857
1882
|
|
|
1858
1883
|
model_family = desc.get("model_family", "")
|
|
1859
|
-
function_call_models =
|
|
1884
|
+
function_call_models = (
|
|
1885
|
+
QWEN_TOOL_CALL_FAMILY + GLM4_TOOL_CALL_FAMILY + LLAMA3_TOOL_CALL_FAMILY
|
|
1886
|
+
)
|
|
1860
1887
|
|
|
1861
1888
|
if model_family not in function_call_models:
|
|
1862
1889
|
if body.tools:
|
|
@@ -1385,6 +1385,16 @@ class Client:
|
|
|
1385
1385
|
response_json = response.json()
|
|
1386
1386
|
return response_json
|
|
1387
1387
|
|
|
1388
|
+
def get_progress(self, request_id: str):
|
|
1389
|
+
url = f"{self.base_url}/v1/requests/{request_id}/progress"
|
|
1390
|
+
response = requests.get(url, headers=self._headers)
|
|
1391
|
+
if response.status_code != 200:
|
|
1392
|
+
raise RuntimeError(
|
|
1393
|
+
f"Failed to get progress, detail: {_get_error_string(response)}"
|
|
1394
|
+
)
|
|
1395
|
+
response_json = response.json()
|
|
1396
|
+
return response_json
|
|
1397
|
+
|
|
1388
1398
|
def abort_cluster(self):
|
|
1389
1399
|
url = f"{self.base_url}/v1/clusters"
|
|
1390
1400
|
response = requests.delete(url, headers=self._headers)
|
xinference/constants.py
CHANGED
|
@@ -28,6 +28,7 @@ XINFERENCE_ENV_HEALTH_CHECK_TIMEOUT = "XINFERENCE_HEALTH_CHECK_TIMEOUT"
|
|
|
28
28
|
XINFERENCE_ENV_DISABLE_HEALTH_CHECK = "XINFERENCE_DISABLE_HEALTH_CHECK"
|
|
29
29
|
XINFERENCE_ENV_DISABLE_METRICS = "XINFERENCE_DISABLE_METRICS"
|
|
30
30
|
XINFERENCE_ENV_TRANSFORMERS_ENABLE_BATCHING = "XINFERENCE_TRANSFORMERS_ENABLE_BATCHING"
|
|
31
|
+
XINFERENCE_ENV_DOWNLOAD_MAX_ATTEMPTS = "XINFERENCE_DOWNLOAD_MAX_ATTEMPTS"
|
|
31
32
|
|
|
32
33
|
|
|
33
34
|
def get_xinference_home() -> str:
|
|
@@ -82,3 +83,6 @@ XINFERENCE_DISABLE_METRICS = bool(
|
|
|
82
83
|
XINFERENCE_TRANSFORMERS_ENABLE_BATCHING = bool(
|
|
83
84
|
int(os.environ.get(XINFERENCE_ENV_TRANSFORMERS_ENABLE_BATCHING, 0))
|
|
84
85
|
)
|
|
86
|
+
XINFERENCE_DOWNLOAD_MAX_ATTEMPTS = int(
|
|
87
|
+
os.environ.get(XINFERENCE_ENV_DOWNLOAD_MAX_ATTEMPTS, 3)
|
|
88
|
+
)
|
|
@@ -16,6 +16,9 @@ import base64
|
|
|
16
16
|
import io
|
|
17
17
|
import logging
|
|
18
18
|
import os
|
|
19
|
+
import threading
|
|
20
|
+
import time
|
|
21
|
+
import uuid
|
|
19
22
|
from typing import Dict, List, Optional, Union
|
|
20
23
|
|
|
21
24
|
import gradio as gr
|
|
@@ -84,6 +87,7 @@ class ImageInterface:
|
|
|
84
87
|
num_inference_steps: int,
|
|
85
88
|
negative_prompt: Optional[str] = None,
|
|
86
89
|
sampler_name: Optional[str] = None,
|
|
90
|
+
progress=gr.Progress(),
|
|
87
91
|
) -> PIL.Image.Image:
|
|
88
92
|
from ..client import RESTfulClient
|
|
89
93
|
|
|
@@ -99,19 +103,43 @@ class ImageInterface:
|
|
|
99
103
|
)
|
|
100
104
|
sampler_name = None if sampler_name == "default" else sampler_name
|
|
101
105
|
|
|
102
|
-
response =
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
106
|
+
response = None
|
|
107
|
+
exc = None
|
|
108
|
+
request_id = str(uuid.uuid4())
|
|
109
|
+
|
|
110
|
+
def run_in_thread():
|
|
111
|
+
nonlocal exc, response
|
|
112
|
+
try:
|
|
113
|
+
response = model.text_to_image(
|
|
114
|
+
request_id=request_id,
|
|
115
|
+
prompt=prompt,
|
|
116
|
+
n=n,
|
|
117
|
+
size=size,
|
|
118
|
+
num_inference_steps=num_inference_steps,
|
|
119
|
+
guidance_scale=guidance_scale,
|
|
120
|
+
negative_prompt=negative_prompt,
|
|
121
|
+
sampler_name=sampler_name,
|
|
122
|
+
response_format="b64_json",
|
|
123
|
+
)
|
|
124
|
+
except Exception as e:
|
|
125
|
+
exc = e
|
|
126
|
+
|
|
127
|
+
t = threading.Thread(target=run_in_thread)
|
|
128
|
+
t.start()
|
|
129
|
+
while t.is_alive():
|
|
130
|
+
try:
|
|
131
|
+
cur_progress = client.get_progress(request_id)["progress"]
|
|
132
|
+
except (KeyError, RuntimeError):
|
|
133
|
+
cur_progress = 0.0
|
|
134
|
+
|
|
135
|
+
progress(cur_progress, desc="Generating images")
|
|
136
|
+
time.sleep(1)
|
|
137
|
+
|
|
138
|
+
if exc:
|
|
139
|
+
raise exc
|
|
112
140
|
|
|
113
141
|
images = []
|
|
114
|
-
for image_dict in response["data"]:
|
|
142
|
+
for image_dict in response["data"]: # type: ignore
|
|
115
143
|
assert image_dict["b64_json"] is not None
|
|
116
144
|
image_data = base64.b64decode(image_dict["b64_json"])
|
|
117
145
|
image = PIL.Image.open(io.BytesIO(image_data))
|
|
@@ -184,6 +212,7 @@ class ImageInterface:
|
|
|
184
212
|
num_inference_steps: int,
|
|
185
213
|
padding_image_to_multiple: int,
|
|
186
214
|
sampler_name: Optional[str] = None,
|
|
215
|
+
progress=gr.Progress(),
|
|
187
216
|
) -> PIL.Image.Image:
|
|
188
217
|
from ..client import RESTfulClient
|
|
189
218
|
|
|
@@ -205,20 +234,44 @@ class ImageInterface:
|
|
|
205
234
|
bio = io.BytesIO()
|
|
206
235
|
image.save(bio, format="png")
|
|
207
236
|
|
|
208
|
-
response =
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
237
|
+
response = None
|
|
238
|
+
exc = None
|
|
239
|
+
request_id = str(uuid.uuid4())
|
|
240
|
+
|
|
241
|
+
def run_in_thread():
|
|
242
|
+
nonlocal exc, response
|
|
243
|
+
try:
|
|
244
|
+
response = model.image_to_image(
|
|
245
|
+
request_id=request_id,
|
|
246
|
+
prompt=prompt,
|
|
247
|
+
negative_prompt=negative_prompt,
|
|
248
|
+
n=n,
|
|
249
|
+
image=bio.getvalue(),
|
|
250
|
+
size=size,
|
|
251
|
+
response_format="b64_json",
|
|
252
|
+
num_inference_steps=num_inference_steps,
|
|
253
|
+
padding_image_to_multiple=padding_image_to_multiple,
|
|
254
|
+
sampler_name=sampler_name,
|
|
255
|
+
)
|
|
256
|
+
except Exception as e:
|
|
257
|
+
exc = e
|
|
258
|
+
|
|
259
|
+
t = threading.Thread(target=run_in_thread)
|
|
260
|
+
t.start()
|
|
261
|
+
while t.is_alive():
|
|
262
|
+
try:
|
|
263
|
+
cur_progress = client.get_progress(request_id)["progress"]
|
|
264
|
+
except (KeyError, RuntimeError):
|
|
265
|
+
cur_progress = 0.0
|
|
266
|
+
|
|
267
|
+
progress(cur_progress, desc="Generating images")
|
|
268
|
+
time.sleep(1)
|
|
269
|
+
|
|
270
|
+
if exc:
|
|
271
|
+
raise exc
|
|
219
272
|
|
|
220
273
|
images = []
|
|
221
|
-
for image_dict in response["data"]:
|
|
274
|
+
for image_dict in response["data"]: # type: ignore
|
|
222
275
|
assert image_dict["b64_json"] is not None
|
|
223
276
|
image_data = base64.b64decode(image_dict["b64_json"])
|
|
224
277
|
image = PIL.Image.open(io.BytesIO(image_data))
|
xinference/core/model.py
CHANGED
|
@@ -44,6 +44,7 @@ import xoscar as xo
|
|
|
44
44
|
from ..constants import XINFERENCE_TRANSFORMERS_ENABLE_BATCHING
|
|
45
45
|
|
|
46
46
|
if TYPE_CHECKING:
|
|
47
|
+
from .progress_tracker import ProgressTrackerActor
|
|
47
48
|
from .worker import WorkerActor
|
|
48
49
|
from ..model.llm.core import LLM
|
|
49
50
|
from ..model.core import ModelDescription
|
|
@@ -177,6 +178,7 @@ class ModelActor(xo.StatelessActor):
|
|
|
177
178
|
|
|
178
179
|
def __init__(
|
|
179
180
|
self,
|
|
181
|
+
supervisor_address: str,
|
|
180
182
|
worker_address: str,
|
|
181
183
|
model: "LLM",
|
|
182
184
|
model_description: Optional["ModelDescription"] = None,
|
|
@@ -188,6 +190,7 @@ class ModelActor(xo.StatelessActor):
|
|
|
188
190
|
from ..model.llm.transformers.core import PytorchModel
|
|
189
191
|
from ..model.llm.vllm.core import VLLMModel
|
|
190
192
|
|
|
193
|
+
self._supervisor_address = supervisor_address
|
|
191
194
|
self._worker_address = worker_address
|
|
192
195
|
self._model = model
|
|
193
196
|
self._model_description = (
|
|
@@ -205,6 +208,7 @@ class ModelActor(xo.StatelessActor):
|
|
|
205
208
|
else asyncio.locks.Lock()
|
|
206
209
|
)
|
|
207
210
|
self._worker_ref = None
|
|
211
|
+
self._progress_tracker_ref = None
|
|
208
212
|
self._serve_count = 0
|
|
209
213
|
self._metrics_labels = {
|
|
210
214
|
"type": self._model_description.get("model_type", "unknown"),
|
|
@@ -275,6 +279,28 @@ class ModelActor(xo.StatelessActor):
|
|
|
275
279
|
)
|
|
276
280
|
return self._worker_ref
|
|
277
281
|
|
|
282
|
+
async def _get_progress_tracker_ref(
|
|
283
|
+
self,
|
|
284
|
+
) -> xo.ActorRefType["ProgressTrackerActor"]:
|
|
285
|
+
from .progress_tracker import ProgressTrackerActor
|
|
286
|
+
|
|
287
|
+
if self._progress_tracker_ref is None:
|
|
288
|
+
self._progress_tracker_ref = await xo.actor_ref(
|
|
289
|
+
address=self._supervisor_address, uid=ProgressTrackerActor.default_uid()
|
|
290
|
+
)
|
|
291
|
+
return self._progress_tracker_ref
|
|
292
|
+
|
|
293
|
+
async def _get_progressor(self, request_id: str):
|
|
294
|
+
from .progress_tracker import Progressor
|
|
295
|
+
|
|
296
|
+
progressor = Progressor(
|
|
297
|
+
request_id,
|
|
298
|
+
await self._get_progress_tracker_ref(),
|
|
299
|
+
asyncio.get_running_loop(),
|
|
300
|
+
)
|
|
301
|
+
await progressor.start()
|
|
302
|
+
return progressor
|
|
303
|
+
|
|
278
304
|
def is_vllm_backend(self) -> bool:
|
|
279
305
|
from ..model.llm.vllm.core import VLLMModel
|
|
280
306
|
|
|
@@ -732,17 +758,20 @@ class ModelActor(xo.StatelessActor):
|
|
|
732
758
|
*args,
|
|
733
759
|
**kwargs,
|
|
734
760
|
):
|
|
735
|
-
kwargs.pop("request_id", None)
|
|
736
761
|
if hasattr(self._model, "text_to_image"):
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
prompt,
|
|
740
|
-
n,
|
|
741
|
-
size,
|
|
742
|
-
response_format,
|
|
743
|
-
*args,
|
|
744
|
-
**kwargs,
|
|
762
|
+
progressor = kwargs["progressor"] = await self._get_progressor(
|
|
763
|
+
kwargs.pop("request_id", None)
|
|
745
764
|
)
|
|
765
|
+
with progressor:
|
|
766
|
+
return await self._call_wrapper_json(
|
|
767
|
+
self._model.text_to_image,
|
|
768
|
+
prompt,
|
|
769
|
+
n,
|
|
770
|
+
size,
|
|
771
|
+
response_format,
|
|
772
|
+
*args,
|
|
773
|
+
**kwargs,
|
|
774
|
+
)
|
|
746
775
|
raise AttributeError(
|
|
747
776
|
f"Model {self._model.model_spec} is not for creating image."
|
|
748
777
|
)
|
|
@@ -753,12 +782,15 @@ class ModelActor(xo.StatelessActor):
|
|
|
753
782
|
self,
|
|
754
783
|
**kwargs,
|
|
755
784
|
):
|
|
756
|
-
kwargs.pop("request_id", None)
|
|
757
785
|
if hasattr(self._model, "txt2img"):
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
**kwargs,
|
|
786
|
+
progressor = kwargs["progressor"] = await self._get_progressor(
|
|
787
|
+
kwargs.pop("request_id", None)
|
|
761
788
|
)
|
|
789
|
+
with progressor:
|
|
790
|
+
return await self._call_wrapper_json(
|
|
791
|
+
self._model.txt2img,
|
|
792
|
+
**kwargs,
|
|
793
|
+
)
|
|
762
794
|
raise AttributeError(f"Model {self._model.model_spec} is not for txt2img.")
|
|
763
795
|
|
|
764
796
|
@log_async(
|
|
@@ -776,19 +808,22 @@ class ModelActor(xo.StatelessActor):
|
|
|
776
808
|
*args,
|
|
777
809
|
**kwargs,
|
|
778
810
|
):
|
|
779
|
-
kwargs.pop("request_id", None)
|
|
780
811
|
kwargs["negative_prompt"] = negative_prompt
|
|
781
812
|
if hasattr(self._model, "image_to_image"):
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
image,
|
|
785
|
-
prompt,
|
|
786
|
-
n,
|
|
787
|
-
size,
|
|
788
|
-
response_format,
|
|
789
|
-
*args,
|
|
790
|
-
**kwargs,
|
|
813
|
+
progressor = kwargs["progressor"] = await self._get_progressor(
|
|
814
|
+
kwargs.pop("request_id", None)
|
|
791
815
|
)
|
|
816
|
+
with progressor:
|
|
817
|
+
return await self._call_wrapper_json(
|
|
818
|
+
self._model.image_to_image,
|
|
819
|
+
image,
|
|
820
|
+
prompt,
|
|
821
|
+
n,
|
|
822
|
+
size,
|
|
823
|
+
response_format,
|
|
824
|
+
*args,
|
|
825
|
+
**kwargs,
|
|
826
|
+
)
|
|
792
827
|
raise AttributeError(
|
|
793
828
|
f"Model {self._model.model_spec} is not for creating image."
|
|
794
829
|
)
|
|
@@ -799,12 +834,15 @@ class ModelActor(xo.StatelessActor):
|
|
|
799
834
|
self,
|
|
800
835
|
**kwargs,
|
|
801
836
|
):
|
|
802
|
-
kwargs.pop("request_id", None)
|
|
803
837
|
if hasattr(self._model, "img2img"):
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
**kwargs,
|
|
838
|
+
progressor = kwargs["progressor"] = await self._get_progressor(
|
|
839
|
+
kwargs.pop("request_id", None)
|
|
807
840
|
)
|
|
841
|
+
with progressor:
|
|
842
|
+
return await self._call_wrapper_json(
|
|
843
|
+
self._model.img2img,
|
|
844
|
+
**kwargs,
|
|
845
|
+
)
|
|
808
846
|
raise AttributeError(f"Model {self._model.model_spec} is not for img2img.")
|
|
809
847
|
|
|
810
848
|
@log_async(
|
|
@@ -823,20 +861,23 @@ class ModelActor(xo.StatelessActor):
|
|
|
823
861
|
*args,
|
|
824
862
|
**kwargs,
|
|
825
863
|
):
|
|
826
|
-
kwargs
|
|
864
|
+
kwargs["negative_prompt"] = negative_prompt
|
|
827
865
|
if hasattr(self._model, "inpainting"):
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
image,
|
|
831
|
-
mask_image,
|
|
832
|
-
prompt,
|
|
833
|
-
negative_prompt,
|
|
834
|
-
n,
|
|
835
|
-
size,
|
|
836
|
-
response_format,
|
|
837
|
-
*args,
|
|
838
|
-
**kwargs,
|
|
866
|
+
progressor = kwargs["progressor"] = await self._get_progressor(
|
|
867
|
+
kwargs.pop("request_id", None)
|
|
839
868
|
)
|
|
869
|
+
with progressor:
|
|
870
|
+
return await self._call_wrapper_json(
|
|
871
|
+
self._model.inpainting,
|
|
872
|
+
image,
|
|
873
|
+
mask_image,
|
|
874
|
+
prompt,
|
|
875
|
+
n,
|
|
876
|
+
size,
|
|
877
|
+
response_format,
|
|
878
|
+
*args,
|
|
879
|
+
**kwargs,
|
|
880
|
+
)
|
|
840
881
|
raise AttributeError(
|
|
841
882
|
f"Model {self._model.model_spec} is not for creating image."
|
|
842
883
|
)
|
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import asyncio
|
|
16
|
+
import dataclasses
|
|
17
|
+
import logging
|
|
18
|
+
import os
|
|
19
|
+
import time
|
|
20
|
+
from typing import Dict, List, Optional, Tuple
|
|
21
|
+
|
|
22
|
+
import numpy as np
|
|
23
|
+
import xoscar as xo
|
|
24
|
+
|
|
25
|
+
TO_REMOVE_PROGRESS_INTERVAL = float(
|
|
26
|
+
os.getenv("XINFERENCE_REMOVE_PROGRESS_INTERVAL", 5 * 60)
|
|
27
|
+
) # 5min
|
|
28
|
+
CHECK_PROGRESS_INTERVAL = float(
|
|
29
|
+
os.getenv("XINFERENCE_CHECK_PROGRESS_INTERVAL", 1 * 60)
|
|
30
|
+
) # 1min
|
|
31
|
+
UPLOAD_PROGRESS_SPAN = float(
|
|
32
|
+
os.getenv("XINFERENCE_UPLOAD_PROGRESS_SPAN", 0.05)
|
|
33
|
+
) # not upload when change less than 0.1
|
|
34
|
+
|
|
35
|
+
logger = logging.getLogger(__name__)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@dataclasses.dataclass
|
|
39
|
+
class _ProgressInfo:
|
|
40
|
+
progress: float
|
|
41
|
+
last_updated: float
|
|
42
|
+
info: Optional[str] = None
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class ProgressTrackerActor(xo.StatelessActor):
|
|
46
|
+
_request_id_to_progress: Dict[str, _ProgressInfo]
|
|
47
|
+
|
|
48
|
+
@classmethod
|
|
49
|
+
def default_uid(cls) -> str:
|
|
50
|
+
return "progress_tracker"
|
|
51
|
+
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
to_remove_interval: float = TO_REMOVE_PROGRESS_INTERVAL,
|
|
55
|
+
check_interval: float = CHECK_PROGRESS_INTERVAL,
|
|
56
|
+
):
|
|
57
|
+
super().__init__()
|
|
58
|
+
|
|
59
|
+
self._request_id_to_progress = {}
|
|
60
|
+
self._clear_finished_task = None
|
|
61
|
+
self._to_remove_interval = to_remove_interval
|
|
62
|
+
self._check_interval = check_interval
|
|
63
|
+
|
|
64
|
+
async def __post_create__(self):
|
|
65
|
+
self._clear_finished_task = asyncio.create_task(self._clear_finished())
|
|
66
|
+
|
|
67
|
+
async def __pre_destroy__(self):
|
|
68
|
+
if self._clear_finished_task:
|
|
69
|
+
self._clear_finished_task.cancel()
|
|
70
|
+
|
|
71
|
+
async def _clear_finished(self):
|
|
72
|
+
to_remove_request_ids = []
|
|
73
|
+
while True:
|
|
74
|
+
now = time.time()
|
|
75
|
+
for request_id, progress in self._request_id_to_progress.items():
|
|
76
|
+
if abs(progress.progress - 1.0) > 1e-5:
|
|
77
|
+
continue
|
|
78
|
+
|
|
79
|
+
# finished
|
|
80
|
+
if now - progress.last_updated > self._to_remove_interval:
|
|
81
|
+
to_remove_request_ids.append(request_id)
|
|
82
|
+
|
|
83
|
+
for rid in to_remove_request_ids:
|
|
84
|
+
del self._request_id_to_progress[rid]
|
|
85
|
+
|
|
86
|
+
if to_remove_request_ids:
|
|
87
|
+
logger.debug(
|
|
88
|
+
"Remove requests %s due to it's finished for over %s seconds",
|
|
89
|
+
to_remove_request_ids,
|
|
90
|
+
self._to_remove_interval,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
await asyncio.sleep(self._check_interval)
|
|
94
|
+
|
|
95
|
+
def start(self, request_id: str):
|
|
96
|
+
self._request_id_to_progress[request_id] = _ProgressInfo(
|
|
97
|
+
progress=0.0, last_updated=time.time()
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
def set_progress(self, request_id: str, progress: float):
|
|
101
|
+
assert progress <= 1.0
|
|
102
|
+
info = self._request_id_to_progress[request_id]
|
|
103
|
+
info.progress = progress
|
|
104
|
+
info.last_updated = time.time()
|
|
105
|
+
logger.debug(
|
|
106
|
+
"Setting progress, request id: %s, progress: %s", request_id, progress
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
def get_progress(self, request_id: str) -> float:
|
|
110
|
+
return self._request_id_to_progress[request_id].progress
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class Progressor:
|
|
114
|
+
_sub_progress_stack: List[Tuple[float, float]]
|
|
115
|
+
|
|
116
|
+
def __init__(
|
|
117
|
+
self,
|
|
118
|
+
request_id: str,
|
|
119
|
+
progress_tracker_ref: xo.ActorRefType["ProgressTrackerActor"],
|
|
120
|
+
loop: asyncio.AbstractEventLoop,
|
|
121
|
+
upload_span: float = UPLOAD_PROGRESS_SPAN,
|
|
122
|
+
):
|
|
123
|
+
self.request_id = request_id
|
|
124
|
+
self.progress_tracker_ref = progress_tracker_ref
|
|
125
|
+
self.loop = loop
|
|
126
|
+
# uploading when progress changes over this span
|
|
127
|
+
# to prevent from frequently uploading
|
|
128
|
+
self._upload_span = upload_span
|
|
129
|
+
|
|
130
|
+
self._last_report_progress = 0.0
|
|
131
|
+
self._current_progress = 0.0
|
|
132
|
+
self._sub_progress_stack = [(0.0, 1.0)]
|
|
133
|
+
self._current_sub_progress_start = 0.0
|
|
134
|
+
self._current_sub_progress_end = 1.0
|
|
135
|
+
|
|
136
|
+
async def start(self):
|
|
137
|
+
if self.request_id:
|
|
138
|
+
await self.progress_tracker_ref.start(self.request_id)
|
|
139
|
+
|
|
140
|
+
def split_stages(self, n_stage: int, stage_weight: Optional[List[float]] = None):
|
|
141
|
+
if self.request_id:
|
|
142
|
+
if stage_weight is not None:
|
|
143
|
+
if len(stage_weight) != n_stage + 1:
|
|
144
|
+
raise ValueError(
|
|
145
|
+
f"stage_weight should have size {n_stage + 1}, got {len(stage_weight)}"
|
|
146
|
+
)
|
|
147
|
+
progresses = stage_weight
|
|
148
|
+
else:
|
|
149
|
+
progresses = np.linspace(
|
|
150
|
+
self._current_sub_progress_start,
|
|
151
|
+
self._current_sub_progress_end,
|
|
152
|
+
n_stage + 1,
|
|
153
|
+
)
|
|
154
|
+
spans = [(progresses[i], progresses[i + 1]) for i in range(n_stage)]
|
|
155
|
+
self._sub_progress_stack.extend(spans[::-1])
|
|
156
|
+
|
|
157
|
+
def __enter__(self):
|
|
158
|
+
if self.request_id:
|
|
159
|
+
(
|
|
160
|
+
self._current_sub_progress_start,
|
|
161
|
+
self._current_sub_progress_end,
|
|
162
|
+
) = self._sub_progress_stack[-1]
|
|
163
|
+
|
|
164
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
165
|
+
if self.request_id:
|
|
166
|
+
self._sub_progress_stack.pop()
|
|
167
|
+
# force to set progress to 1.0 for this sub progress
|
|
168
|
+
# nevertheless it is done or not
|
|
169
|
+
self.set_progress(1.0)
|
|
170
|
+
return False
|
|
171
|
+
|
|
172
|
+
def set_progress(self, progress: float):
|
|
173
|
+
if self.request_id:
|
|
174
|
+
self._current_progress = (
|
|
175
|
+
self._current_sub_progress_start
|
|
176
|
+
+ (self._current_sub_progress_end - self._current_sub_progress_start)
|
|
177
|
+
* progress
|
|
178
|
+
)
|
|
179
|
+
if (
|
|
180
|
+
self._current_progress - self._last_report_progress >= self._upload_span
|
|
181
|
+
or 1.0 - progress < 1e-5
|
|
182
|
+
):
|
|
183
|
+
set_progress = self.progress_tracker_ref.set_progress(
|
|
184
|
+
self.request_id, self._current_progress
|
|
185
|
+
)
|
|
186
|
+
asyncio.run_coroutine_threadsafe(set_progress, self.loop) # type: ignore
|
|
187
|
+
self._last_report_progress = self._current_progress
|
xinference/core/supervisor.py
CHANGED
|
@@ -130,6 +130,7 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
130
130
|
)
|
|
131
131
|
logger.info(f"Xinference supervisor {self.address} started")
|
|
132
132
|
from .cache_tracker import CacheTrackerActor
|
|
133
|
+
from .progress_tracker import ProgressTrackerActor
|
|
133
134
|
from .status_guard import StatusGuardActor
|
|
134
135
|
|
|
135
136
|
self._status_guard_ref: xo.ActorRefType[ # type: ignore
|
|
@@ -142,6 +143,13 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
142
143
|
] = await xo.create_actor(
|
|
143
144
|
CacheTrackerActor, address=self.address, uid=CacheTrackerActor.default_uid()
|
|
144
145
|
)
|
|
146
|
+
self._progress_tracker: xo.ActorRefType[ # type: ignore
|
|
147
|
+
"ProgressTrackerActor"
|
|
148
|
+
] = await xo.create_actor(
|
|
149
|
+
ProgressTrackerActor,
|
|
150
|
+
address=self.address,
|
|
151
|
+
uid=ProgressTrackerActor.default_uid(),
|
|
152
|
+
)
|
|
145
153
|
|
|
146
154
|
from .event import EventCollectorActor
|
|
147
155
|
|
|
@@ -1360,3 +1368,6 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
1360
1368
|
@staticmethod
|
|
1361
1369
|
def record_metrics(name, op, kwargs):
|
|
1362
1370
|
record_metrics(name, op, kwargs)
|
|
1371
|
+
|
|
1372
|
+
async def get_progress(self, request_id: str) -> float:
|
|
1373
|
+
return await self._progress_tracker.get_progress(request_id)
|
xinference/core/worker.py
CHANGED
|
@@ -53,7 +53,8 @@ class ChatTTSModel:
|
|
|
53
53
|
torch._dynamo.config.suppress_errors = True
|
|
54
54
|
torch.set_float32_matmul_precision("high")
|
|
55
55
|
self._model = ChatTTS.Chat()
|
|
56
|
-
|
|
56
|
+
logger.info("Load ChatTTS model with kwargs: %s", self._kwargs)
|
|
57
|
+
self._model.load(source="custom", custom_path=self._model_path, **self._kwargs)
|
|
57
58
|
|
|
58
59
|
def speech(
|
|
59
60
|
self,
|
xinference/model/audio/core.py
CHANGED