xinference 0.15.3__py3-none-any.whl → 0.15.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (43) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +29 -2
  3. xinference/client/restful/restful_client.py +10 -0
  4. xinference/constants.py +4 -0
  5. xinference/core/image_interface.py +76 -23
  6. xinference/core/model.py +80 -39
  7. xinference/core/progress_tracker.py +187 -0
  8. xinference/core/supervisor.py +11 -0
  9. xinference/core/worker.py +1 -0
  10. xinference/model/audio/chattts.py +2 -1
  11. xinference/model/audio/core.py +0 -2
  12. xinference/model/audio/model_spec.json +8 -0
  13. xinference/model/audio/model_spec_modelscope.json +9 -0
  14. xinference/model/image/core.py +6 -7
  15. xinference/model/image/sdapi.py +35 -4
  16. xinference/model/image/stable_diffusion/core.py +208 -78
  17. xinference/model/llm/llm_family.json +16 -16
  18. xinference/model/llm/llm_family_modelscope.json +16 -12
  19. xinference/model/llm/transformers/cogvlm2.py +2 -1
  20. xinference/model/llm/transformers/cogvlm2_video.py +2 -0
  21. xinference/model/llm/transformers/core.py +6 -2
  22. xinference/model/llm/transformers/deepseek_vl.py +2 -0
  23. xinference/model/llm/transformers/glm4v.py +2 -1
  24. xinference/model/llm/transformers/intern_vl.py +2 -0
  25. xinference/model/llm/transformers/minicpmv25.py +2 -0
  26. xinference/model/llm/transformers/minicpmv26.py +2 -0
  27. xinference/model/llm/transformers/omnilmm.py +2 -0
  28. xinference/model/llm/transformers/qwen2_audio.py +11 -4
  29. xinference/model/llm/transformers/qwen2_vl.py +2 -28
  30. xinference/model/llm/transformers/qwen_vl.py +2 -1
  31. xinference/model/llm/transformers/utils.py +35 -2
  32. xinference/model/llm/transformers/yi_vl.py +2 -0
  33. xinference/model/llm/utils.py +58 -14
  34. xinference/model/llm/vllm/core.py +52 -8
  35. xinference/model/llm/vllm/utils.py +0 -1
  36. xinference/model/utils.py +7 -4
  37. xinference/model/video/core.py +0 -2
  38. {xinference-0.15.3.dist-info → xinference-0.15.4.dist-info}/METADATA +3 -3
  39. {xinference-0.15.3.dist-info → xinference-0.15.4.dist-info}/RECORD +43 -42
  40. {xinference-0.15.3.dist-info → xinference-0.15.4.dist-info}/LICENSE +0 -0
  41. {xinference-0.15.3.dist-info → xinference-0.15.4.dist-info}/WHEEL +0 -0
  42. {xinference-0.15.3.dist-info → xinference-0.15.4.dist-info}/entry_points.txt +0 -0
  43. {xinference-0.15.3.dist-info → xinference-0.15.4.dist-info}/top_level.txt +0 -0
xinference/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2024-09-30T20:17:26+0800",
11
+ "date": "2024-10-12T18:28:41+0800",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "00a9ee15279a60a6d75393c4720d8da5cbbf5796",
15
- "version": "0.15.3"
14
+ "full-revisionid": "c0be11504c70f6c392cbdb67c86cf12153353f70",
15
+ "version": "0.15.4"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -524,6 +524,16 @@ class RESTfulAPI:
524
524
  else None
525
525
  ),
526
526
  )
527
+ self._router.add_api_route(
528
+ "/v1/requests/{request_id}/progress",
529
+ self.get_progress,
530
+ methods=["get"],
531
+ dependencies=(
532
+ [Security(self._auth_service, scopes=["models:read"])]
533
+ if self.is_authenticated()
534
+ else None
535
+ ),
536
+ )
527
537
  self._router.add_api_route(
528
538
  "/v1/images/generations",
529
539
  self.create_images,
@@ -1486,6 +1496,17 @@ class RESTfulAPI:
1486
1496
  await self._report_error_event(model_uid, str(e))
1487
1497
  raise HTTPException(status_code=500, detail=str(e))
1488
1498
 
1499
+ async def get_progress(self, request_id: str) -> JSONResponse:
1500
+ try:
1501
+ supervisor_ref = await self._get_supervisor_ref()
1502
+ result = {"progress": await supervisor_ref.get_progress(request_id)}
1503
+ return JSONResponse(content=result)
1504
+ except KeyError as e:
1505
+ raise HTTPException(status_code=400, detail=str(e))
1506
+ except Exception as e:
1507
+ logger.error(e, exc_info=True)
1508
+ raise HTTPException(status_code=500, detail=str(e))
1509
+
1489
1510
  async def create_images(self, request: Request) -> Response:
1490
1511
  body = TextToImageRequest.parse_obj(await request.json())
1491
1512
  model_uid = body.model
@@ -1853,10 +1874,16 @@ class RESTfulAPI:
1853
1874
  await self._report_error_event(model_uid, str(e))
1854
1875
  raise HTTPException(status_code=500, detail=str(e))
1855
1876
 
1856
- from ..model.llm.utils import GLM4_TOOL_CALL_FAMILY, QWEN_TOOL_CALL_FAMILY
1877
+ from ..model.llm.utils import (
1878
+ GLM4_TOOL_CALL_FAMILY,
1879
+ LLAMA3_TOOL_CALL_FAMILY,
1880
+ QWEN_TOOL_CALL_FAMILY,
1881
+ )
1857
1882
 
1858
1883
  model_family = desc.get("model_family", "")
1859
- function_call_models = QWEN_TOOL_CALL_FAMILY + GLM4_TOOL_CALL_FAMILY
1884
+ function_call_models = (
1885
+ QWEN_TOOL_CALL_FAMILY + GLM4_TOOL_CALL_FAMILY + LLAMA3_TOOL_CALL_FAMILY
1886
+ )
1860
1887
 
1861
1888
  if model_family not in function_call_models:
1862
1889
  if body.tools:
@@ -1385,6 +1385,16 @@ class Client:
1385
1385
  response_json = response.json()
1386
1386
  return response_json
1387
1387
 
1388
+ def get_progress(self, request_id: str):
1389
+ url = f"{self.base_url}/v1/requests/{request_id}/progress"
1390
+ response = requests.get(url, headers=self._headers)
1391
+ if response.status_code != 200:
1392
+ raise RuntimeError(
1393
+ f"Failed to get progress, detail: {_get_error_string(response)}"
1394
+ )
1395
+ response_json = response.json()
1396
+ return response_json
1397
+
1388
1398
  def abort_cluster(self):
1389
1399
  url = f"{self.base_url}/v1/clusters"
1390
1400
  response = requests.delete(url, headers=self._headers)
xinference/constants.py CHANGED
@@ -28,6 +28,7 @@ XINFERENCE_ENV_HEALTH_CHECK_TIMEOUT = "XINFERENCE_HEALTH_CHECK_TIMEOUT"
28
28
  XINFERENCE_ENV_DISABLE_HEALTH_CHECK = "XINFERENCE_DISABLE_HEALTH_CHECK"
29
29
  XINFERENCE_ENV_DISABLE_METRICS = "XINFERENCE_DISABLE_METRICS"
30
30
  XINFERENCE_ENV_TRANSFORMERS_ENABLE_BATCHING = "XINFERENCE_TRANSFORMERS_ENABLE_BATCHING"
31
+ XINFERENCE_ENV_DOWNLOAD_MAX_ATTEMPTS = "XINFERENCE_DOWNLOAD_MAX_ATTEMPTS"
31
32
 
32
33
 
33
34
  def get_xinference_home() -> str:
@@ -82,3 +83,6 @@ XINFERENCE_DISABLE_METRICS = bool(
82
83
  XINFERENCE_TRANSFORMERS_ENABLE_BATCHING = bool(
83
84
  int(os.environ.get(XINFERENCE_ENV_TRANSFORMERS_ENABLE_BATCHING, 0))
84
85
  )
86
+ XINFERENCE_DOWNLOAD_MAX_ATTEMPTS = int(
87
+ os.environ.get(XINFERENCE_ENV_DOWNLOAD_MAX_ATTEMPTS, 3)
88
+ )
@@ -16,6 +16,9 @@ import base64
16
16
  import io
17
17
  import logging
18
18
  import os
19
+ import threading
20
+ import time
21
+ import uuid
19
22
  from typing import Dict, List, Optional, Union
20
23
 
21
24
  import gradio as gr
@@ -84,6 +87,7 @@ class ImageInterface:
84
87
  num_inference_steps: int,
85
88
  negative_prompt: Optional[str] = None,
86
89
  sampler_name: Optional[str] = None,
90
+ progress=gr.Progress(),
87
91
  ) -> PIL.Image.Image:
88
92
  from ..client import RESTfulClient
89
93
 
@@ -99,19 +103,43 @@ class ImageInterface:
99
103
  )
100
104
  sampler_name = None if sampler_name == "default" else sampler_name
101
105
 
102
- response = model.text_to_image(
103
- prompt=prompt,
104
- n=n,
105
- size=size,
106
- num_inference_steps=num_inference_steps,
107
- guidance_scale=guidance_scale,
108
- negative_prompt=negative_prompt,
109
- sampler_name=sampler_name,
110
- response_format="b64_json",
111
- )
106
+ response = None
107
+ exc = None
108
+ request_id = str(uuid.uuid4())
109
+
110
+ def run_in_thread():
111
+ nonlocal exc, response
112
+ try:
113
+ response = model.text_to_image(
114
+ request_id=request_id,
115
+ prompt=prompt,
116
+ n=n,
117
+ size=size,
118
+ num_inference_steps=num_inference_steps,
119
+ guidance_scale=guidance_scale,
120
+ negative_prompt=negative_prompt,
121
+ sampler_name=sampler_name,
122
+ response_format="b64_json",
123
+ )
124
+ except Exception as e:
125
+ exc = e
126
+
127
+ t = threading.Thread(target=run_in_thread)
128
+ t.start()
129
+ while t.is_alive():
130
+ try:
131
+ cur_progress = client.get_progress(request_id)["progress"]
132
+ except (KeyError, RuntimeError):
133
+ cur_progress = 0.0
134
+
135
+ progress(cur_progress, desc="Generating images")
136
+ time.sleep(1)
137
+
138
+ if exc:
139
+ raise exc
112
140
 
113
141
  images = []
114
- for image_dict in response["data"]:
142
+ for image_dict in response["data"]: # type: ignore
115
143
  assert image_dict["b64_json"] is not None
116
144
  image_data = base64.b64decode(image_dict["b64_json"])
117
145
  image = PIL.Image.open(io.BytesIO(image_data))
@@ -184,6 +212,7 @@ class ImageInterface:
184
212
  num_inference_steps: int,
185
213
  padding_image_to_multiple: int,
186
214
  sampler_name: Optional[str] = None,
215
+ progress=gr.Progress(),
187
216
  ) -> PIL.Image.Image:
188
217
  from ..client import RESTfulClient
189
218
 
@@ -205,20 +234,44 @@ class ImageInterface:
205
234
  bio = io.BytesIO()
206
235
  image.save(bio, format="png")
207
236
 
208
- response = model.image_to_image(
209
- prompt=prompt,
210
- negative_prompt=negative_prompt,
211
- n=n,
212
- image=bio.getvalue(),
213
- size=size,
214
- response_format="b64_json",
215
- num_inference_steps=num_inference_steps,
216
- padding_image_to_multiple=padding_image_to_multiple,
217
- sampler_name=sampler_name,
218
- )
237
+ response = None
238
+ exc = None
239
+ request_id = str(uuid.uuid4())
240
+
241
+ def run_in_thread():
242
+ nonlocal exc, response
243
+ try:
244
+ response = model.image_to_image(
245
+ request_id=request_id,
246
+ prompt=prompt,
247
+ negative_prompt=negative_prompt,
248
+ n=n,
249
+ image=bio.getvalue(),
250
+ size=size,
251
+ response_format="b64_json",
252
+ num_inference_steps=num_inference_steps,
253
+ padding_image_to_multiple=padding_image_to_multiple,
254
+ sampler_name=sampler_name,
255
+ )
256
+ except Exception as e:
257
+ exc = e
258
+
259
+ t = threading.Thread(target=run_in_thread)
260
+ t.start()
261
+ while t.is_alive():
262
+ try:
263
+ cur_progress = client.get_progress(request_id)["progress"]
264
+ except (KeyError, RuntimeError):
265
+ cur_progress = 0.0
266
+
267
+ progress(cur_progress, desc="Generating images")
268
+ time.sleep(1)
269
+
270
+ if exc:
271
+ raise exc
219
272
 
220
273
  images = []
221
- for image_dict in response["data"]:
274
+ for image_dict in response["data"]: # type: ignore
222
275
  assert image_dict["b64_json"] is not None
223
276
  image_data = base64.b64decode(image_dict["b64_json"])
224
277
  image = PIL.Image.open(io.BytesIO(image_data))
xinference/core/model.py CHANGED
@@ -44,6 +44,7 @@ import xoscar as xo
44
44
  from ..constants import XINFERENCE_TRANSFORMERS_ENABLE_BATCHING
45
45
 
46
46
  if TYPE_CHECKING:
47
+ from .progress_tracker import ProgressTrackerActor
47
48
  from .worker import WorkerActor
48
49
  from ..model.llm.core import LLM
49
50
  from ..model.core import ModelDescription
@@ -177,6 +178,7 @@ class ModelActor(xo.StatelessActor):
177
178
 
178
179
  def __init__(
179
180
  self,
181
+ supervisor_address: str,
180
182
  worker_address: str,
181
183
  model: "LLM",
182
184
  model_description: Optional["ModelDescription"] = None,
@@ -188,6 +190,7 @@ class ModelActor(xo.StatelessActor):
188
190
  from ..model.llm.transformers.core import PytorchModel
189
191
  from ..model.llm.vllm.core import VLLMModel
190
192
 
193
+ self._supervisor_address = supervisor_address
191
194
  self._worker_address = worker_address
192
195
  self._model = model
193
196
  self._model_description = (
@@ -205,6 +208,7 @@ class ModelActor(xo.StatelessActor):
205
208
  else asyncio.locks.Lock()
206
209
  )
207
210
  self._worker_ref = None
211
+ self._progress_tracker_ref = None
208
212
  self._serve_count = 0
209
213
  self._metrics_labels = {
210
214
  "type": self._model_description.get("model_type", "unknown"),
@@ -275,6 +279,28 @@ class ModelActor(xo.StatelessActor):
275
279
  )
276
280
  return self._worker_ref
277
281
 
282
+ async def _get_progress_tracker_ref(
283
+ self,
284
+ ) -> xo.ActorRefType["ProgressTrackerActor"]:
285
+ from .progress_tracker import ProgressTrackerActor
286
+
287
+ if self._progress_tracker_ref is None:
288
+ self._progress_tracker_ref = await xo.actor_ref(
289
+ address=self._supervisor_address, uid=ProgressTrackerActor.default_uid()
290
+ )
291
+ return self._progress_tracker_ref
292
+
293
+ async def _get_progressor(self, request_id: str):
294
+ from .progress_tracker import Progressor
295
+
296
+ progressor = Progressor(
297
+ request_id,
298
+ await self._get_progress_tracker_ref(),
299
+ asyncio.get_running_loop(),
300
+ )
301
+ await progressor.start()
302
+ return progressor
303
+
278
304
  def is_vllm_backend(self) -> bool:
279
305
  from ..model.llm.vllm.core import VLLMModel
280
306
 
@@ -732,17 +758,20 @@ class ModelActor(xo.StatelessActor):
732
758
  *args,
733
759
  **kwargs,
734
760
  ):
735
- kwargs.pop("request_id", None)
736
761
  if hasattr(self._model, "text_to_image"):
737
- return await self._call_wrapper_json(
738
- self._model.text_to_image,
739
- prompt,
740
- n,
741
- size,
742
- response_format,
743
- *args,
744
- **kwargs,
762
+ progressor = kwargs["progressor"] = await self._get_progressor(
763
+ kwargs.pop("request_id", None)
745
764
  )
765
+ with progressor:
766
+ return await self._call_wrapper_json(
767
+ self._model.text_to_image,
768
+ prompt,
769
+ n,
770
+ size,
771
+ response_format,
772
+ *args,
773
+ **kwargs,
774
+ )
746
775
  raise AttributeError(
747
776
  f"Model {self._model.model_spec} is not for creating image."
748
777
  )
@@ -753,12 +782,15 @@ class ModelActor(xo.StatelessActor):
753
782
  self,
754
783
  **kwargs,
755
784
  ):
756
- kwargs.pop("request_id", None)
757
785
  if hasattr(self._model, "txt2img"):
758
- return await self._call_wrapper_json(
759
- self._model.txt2img,
760
- **kwargs,
786
+ progressor = kwargs["progressor"] = await self._get_progressor(
787
+ kwargs.pop("request_id", None)
761
788
  )
789
+ with progressor:
790
+ return await self._call_wrapper_json(
791
+ self._model.txt2img,
792
+ **kwargs,
793
+ )
762
794
  raise AttributeError(f"Model {self._model.model_spec} is not for txt2img.")
763
795
 
764
796
  @log_async(
@@ -776,19 +808,22 @@ class ModelActor(xo.StatelessActor):
776
808
  *args,
777
809
  **kwargs,
778
810
  ):
779
- kwargs.pop("request_id", None)
780
811
  kwargs["negative_prompt"] = negative_prompt
781
812
  if hasattr(self._model, "image_to_image"):
782
- return await self._call_wrapper_json(
783
- self._model.image_to_image,
784
- image,
785
- prompt,
786
- n,
787
- size,
788
- response_format,
789
- *args,
790
- **kwargs,
813
+ progressor = kwargs["progressor"] = await self._get_progressor(
814
+ kwargs.pop("request_id", None)
791
815
  )
816
+ with progressor:
817
+ return await self._call_wrapper_json(
818
+ self._model.image_to_image,
819
+ image,
820
+ prompt,
821
+ n,
822
+ size,
823
+ response_format,
824
+ *args,
825
+ **kwargs,
826
+ )
792
827
  raise AttributeError(
793
828
  f"Model {self._model.model_spec} is not for creating image."
794
829
  )
@@ -799,12 +834,15 @@ class ModelActor(xo.StatelessActor):
799
834
  self,
800
835
  **kwargs,
801
836
  ):
802
- kwargs.pop("request_id", None)
803
837
  if hasattr(self._model, "img2img"):
804
- return await self._call_wrapper_json(
805
- self._model.img2img,
806
- **kwargs,
838
+ progressor = kwargs["progressor"] = await self._get_progressor(
839
+ kwargs.pop("request_id", None)
807
840
  )
841
+ with progressor:
842
+ return await self._call_wrapper_json(
843
+ self._model.img2img,
844
+ **kwargs,
845
+ )
808
846
  raise AttributeError(f"Model {self._model.model_spec} is not for img2img.")
809
847
 
810
848
  @log_async(
@@ -823,20 +861,23 @@ class ModelActor(xo.StatelessActor):
823
861
  *args,
824
862
  **kwargs,
825
863
  ):
826
- kwargs.pop("request_id", None)
864
+ kwargs["negative_prompt"] = negative_prompt
827
865
  if hasattr(self._model, "inpainting"):
828
- return await self._call_wrapper_json(
829
- self._model.inpainting,
830
- image,
831
- mask_image,
832
- prompt,
833
- negative_prompt,
834
- n,
835
- size,
836
- response_format,
837
- *args,
838
- **kwargs,
866
+ progressor = kwargs["progressor"] = await self._get_progressor(
867
+ kwargs.pop("request_id", None)
839
868
  )
869
+ with progressor:
870
+ return await self._call_wrapper_json(
871
+ self._model.inpainting,
872
+ image,
873
+ mask_image,
874
+ prompt,
875
+ n,
876
+ size,
877
+ response_format,
878
+ *args,
879
+ **kwargs,
880
+ )
840
881
  raise AttributeError(
841
882
  f"Model {self._model.model_spec} is not for creating image."
842
883
  )
@@ -0,0 +1,187 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import asyncio
16
+ import dataclasses
17
+ import logging
18
+ import os
19
+ import time
20
+ from typing import Dict, List, Optional, Tuple
21
+
22
+ import numpy as np
23
+ import xoscar as xo
24
+
25
+ TO_REMOVE_PROGRESS_INTERVAL = float(
26
+ os.getenv("XINFERENCE_REMOVE_PROGRESS_INTERVAL", 5 * 60)
27
+ ) # 5min
28
+ CHECK_PROGRESS_INTERVAL = float(
29
+ os.getenv("XINFERENCE_CHECK_PROGRESS_INTERVAL", 1 * 60)
30
+ ) # 1min
31
+ UPLOAD_PROGRESS_SPAN = float(
32
+ os.getenv("XINFERENCE_UPLOAD_PROGRESS_SPAN", 0.05)
33
+ ) # not upload when change less than 0.1
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+
38
+ @dataclasses.dataclass
39
+ class _ProgressInfo:
40
+ progress: float
41
+ last_updated: float
42
+ info: Optional[str] = None
43
+
44
+
45
+ class ProgressTrackerActor(xo.StatelessActor):
46
+ _request_id_to_progress: Dict[str, _ProgressInfo]
47
+
48
+ @classmethod
49
+ def default_uid(cls) -> str:
50
+ return "progress_tracker"
51
+
52
+ def __init__(
53
+ self,
54
+ to_remove_interval: float = TO_REMOVE_PROGRESS_INTERVAL,
55
+ check_interval: float = CHECK_PROGRESS_INTERVAL,
56
+ ):
57
+ super().__init__()
58
+
59
+ self._request_id_to_progress = {}
60
+ self._clear_finished_task = None
61
+ self._to_remove_interval = to_remove_interval
62
+ self._check_interval = check_interval
63
+
64
+ async def __post_create__(self):
65
+ self._clear_finished_task = asyncio.create_task(self._clear_finished())
66
+
67
+ async def __pre_destroy__(self):
68
+ if self._clear_finished_task:
69
+ self._clear_finished_task.cancel()
70
+
71
+ async def _clear_finished(self):
72
+ to_remove_request_ids = []
73
+ while True:
74
+ now = time.time()
75
+ for request_id, progress in self._request_id_to_progress.items():
76
+ if abs(progress.progress - 1.0) > 1e-5:
77
+ continue
78
+
79
+ # finished
80
+ if now - progress.last_updated > self._to_remove_interval:
81
+ to_remove_request_ids.append(request_id)
82
+
83
+ for rid in to_remove_request_ids:
84
+ del self._request_id_to_progress[rid]
85
+
86
+ if to_remove_request_ids:
87
+ logger.debug(
88
+ "Remove requests %s due to it's finished for over %s seconds",
89
+ to_remove_request_ids,
90
+ self._to_remove_interval,
91
+ )
92
+
93
+ await asyncio.sleep(self._check_interval)
94
+
95
+ def start(self, request_id: str):
96
+ self._request_id_to_progress[request_id] = _ProgressInfo(
97
+ progress=0.0, last_updated=time.time()
98
+ )
99
+
100
+ def set_progress(self, request_id: str, progress: float):
101
+ assert progress <= 1.0
102
+ info = self._request_id_to_progress[request_id]
103
+ info.progress = progress
104
+ info.last_updated = time.time()
105
+ logger.debug(
106
+ "Setting progress, request id: %s, progress: %s", request_id, progress
107
+ )
108
+
109
+ def get_progress(self, request_id: str) -> float:
110
+ return self._request_id_to_progress[request_id].progress
111
+
112
+
113
+ class Progressor:
114
+ _sub_progress_stack: List[Tuple[float, float]]
115
+
116
+ def __init__(
117
+ self,
118
+ request_id: str,
119
+ progress_tracker_ref: xo.ActorRefType["ProgressTrackerActor"],
120
+ loop: asyncio.AbstractEventLoop,
121
+ upload_span: float = UPLOAD_PROGRESS_SPAN,
122
+ ):
123
+ self.request_id = request_id
124
+ self.progress_tracker_ref = progress_tracker_ref
125
+ self.loop = loop
126
+ # uploading when progress changes over this span
127
+ # to prevent from frequently uploading
128
+ self._upload_span = upload_span
129
+
130
+ self._last_report_progress = 0.0
131
+ self._current_progress = 0.0
132
+ self._sub_progress_stack = [(0.0, 1.0)]
133
+ self._current_sub_progress_start = 0.0
134
+ self._current_sub_progress_end = 1.0
135
+
136
+ async def start(self):
137
+ if self.request_id:
138
+ await self.progress_tracker_ref.start(self.request_id)
139
+
140
+ def split_stages(self, n_stage: int, stage_weight: Optional[List[float]] = None):
141
+ if self.request_id:
142
+ if stage_weight is not None:
143
+ if len(stage_weight) != n_stage + 1:
144
+ raise ValueError(
145
+ f"stage_weight should have size {n_stage + 1}, got {len(stage_weight)}"
146
+ )
147
+ progresses = stage_weight
148
+ else:
149
+ progresses = np.linspace(
150
+ self._current_sub_progress_start,
151
+ self._current_sub_progress_end,
152
+ n_stage + 1,
153
+ )
154
+ spans = [(progresses[i], progresses[i + 1]) for i in range(n_stage)]
155
+ self._sub_progress_stack.extend(spans[::-1])
156
+
157
+ def __enter__(self):
158
+ if self.request_id:
159
+ (
160
+ self._current_sub_progress_start,
161
+ self._current_sub_progress_end,
162
+ ) = self._sub_progress_stack[-1]
163
+
164
+ def __exit__(self, exc_type, exc_val, exc_tb):
165
+ if self.request_id:
166
+ self._sub_progress_stack.pop()
167
+ # force to set progress to 1.0 for this sub progress
168
+ # nevertheless it is done or not
169
+ self.set_progress(1.0)
170
+ return False
171
+
172
+ def set_progress(self, progress: float):
173
+ if self.request_id:
174
+ self._current_progress = (
175
+ self._current_sub_progress_start
176
+ + (self._current_sub_progress_end - self._current_sub_progress_start)
177
+ * progress
178
+ )
179
+ if (
180
+ self._current_progress - self._last_report_progress >= self._upload_span
181
+ or 1.0 - progress < 1e-5
182
+ ):
183
+ set_progress = self.progress_tracker_ref.set_progress(
184
+ self.request_id, self._current_progress
185
+ )
186
+ asyncio.run_coroutine_threadsafe(set_progress, self.loop) # type: ignore
187
+ self._last_report_progress = self._current_progress
@@ -130,6 +130,7 @@ class SupervisorActor(xo.StatelessActor):
130
130
  )
131
131
  logger.info(f"Xinference supervisor {self.address} started")
132
132
  from .cache_tracker import CacheTrackerActor
133
+ from .progress_tracker import ProgressTrackerActor
133
134
  from .status_guard import StatusGuardActor
134
135
 
135
136
  self._status_guard_ref: xo.ActorRefType[ # type: ignore
@@ -142,6 +143,13 @@ class SupervisorActor(xo.StatelessActor):
142
143
  ] = await xo.create_actor(
143
144
  CacheTrackerActor, address=self.address, uid=CacheTrackerActor.default_uid()
144
145
  )
146
+ self._progress_tracker: xo.ActorRefType[ # type: ignore
147
+ "ProgressTrackerActor"
148
+ ] = await xo.create_actor(
149
+ ProgressTrackerActor,
150
+ address=self.address,
151
+ uid=ProgressTrackerActor.default_uid(),
152
+ )
145
153
 
146
154
  from .event import EventCollectorActor
147
155
 
@@ -1360,3 +1368,6 @@ class SupervisorActor(xo.StatelessActor):
1360
1368
  @staticmethod
1361
1369
  def record_metrics(name, op, kwargs):
1362
1370
  record_metrics(name, op, kwargs)
1371
+
1372
+ async def get_progress(self, request_id: str) -> float:
1373
+ return await self._progress_tracker.get_progress(request_id)
xinference/core/worker.py CHANGED
@@ -885,6 +885,7 @@ class WorkerActor(xo.StatelessActor):
885
885
  ModelActor,
886
886
  address=subpool_address,
887
887
  uid=model_uid,
888
+ supervisor_address=self._supervisor_address,
888
889
  worker_address=self.address,
889
890
  model=model,
890
891
  model_description=model_description,
@@ -53,7 +53,8 @@ class ChatTTSModel:
53
53
  torch._dynamo.config.suppress_errors = True
54
54
  torch.set_float32_matmul_precision("high")
55
55
  self._model = ChatTTS.Chat()
56
- self._model.load(source="custom", custom_path=self._model_path, compile=True)
56
+ logger.info("Load ChatTTS model with kwargs: %s", self._kwargs)
57
+ self._model.load(source="custom", custom_path=self._model_path, **self._kwargs)
57
58
 
58
59
  def speech(
59
60
  self,
@@ -25,8 +25,6 @@ from .fish_speech import FishSpeechModel
25
25
  from .funasr import FunASRModel
26
26
  from .whisper import WhisperModel
27
27
 
28
- MAX_ATTEMPTS = 3
29
-
30
28
  logger = logging.getLogger(__name__)
31
29
 
32
30
  # Used for check whether the model is cached.