xinference 0.16.3__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (54) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +62 -11
  3. xinference/client/restful/restful_client.py +8 -2
  4. xinference/constants.py +1 -0
  5. xinference/core/model.py +10 -3
  6. xinference/core/supervisor.py +8 -2
  7. xinference/core/utils.py +67 -2
  8. xinference/model/audio/model_spec.json +1 -1
  9. xinference/model/image/stable_diffusion/core.py +5 -2
  10. xinference/model/llm/llm_family.json +176 -4
  11. xinference/model/llm/llm_family_modelscope.json +211 -0
  12. xinference/model/llm/mlx/core.py +45 -2
  13. xinference/model/rerank/core.py +11 -4
  14. xinference/thirdparty/fish_speech/fish_speech/conversation.py +254 -0
  15. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +2 -1
  16. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +2 -1
  17. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +2 -2
  18. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ko_KR.json +123 -0
  19. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +2 -1
  20. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +76 -11
  21. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +9 -9
  22. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +1 -1
  23. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +32 -1
  24. xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +2 -1
  25. xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +22 -0
  26. xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +1 -1
  27. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1 -1
  28. xinference/thirdparty/fish_speech/tools/api.py +578 -75
  29. xinference/thirdparty/fish_speech/tools/e2e_webui.py +232 -0
  30. xinference/thirdparty/fish_speech/tools/fish_e2e.py +298 -0
  31. xinference/thirdparty/fish_speech/tools/llama/generate.py +393 -9
  32. xinference/thirdparty/fish_speech/tools/msgpack_api.py +90 -29
  33. xinference/thirdparty/fish_speech/tools/post_api.py +37 -15
  34. xinference/thirdparty/fish_speech/tools/schema.py +187 -0
  35. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +7 -1
  36. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +2 -3
  37. xinference/thirdparty/fish_speech/tools/webui.py +138 -75
  38. {xinference-0.16.3.dist-info → xinference-1.0.0.dist-info}/METADATA +23 -1
  39. {xinference-0.16.3.dist-info → xinference-1.0.0.dist-info}/RECORD +43 -50
  40. {xinference-0.16.3.dist-info → xinference-1.0.0.dist-info}/WHEEL +1 -1
  41. xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
  42. xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
  43. xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
  44. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
  45. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
  46. xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
  47. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
  48. xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
  49. xinference/thirdparty/fish_speech/tools/commons.py +0 -35
  50. xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
  51. xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
  52. {xinference-0.16.3.dist-info → xinference-1.0.0.dist-info}/LICENSE +0 -0
  53. {xinference-0.16.3.dist-info → xinference-1.0.0.dist-info}/entry_points.txt +0 -0
  54. {xinference-0.16.3.dist-info → xinference-1.0.0.dist-info}/top_level.txt +0 -0
xinference/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2024-11-07T16:55:36+0800",
11
+ "date": "2024-11-15T17:33:11+0800",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "85ab86bf1c0967e45fbec995534cd5a0c9a9c439",
15
- "version": "0.16.3"
14
+ "full-revisionid": "4c96475b8f90e354aa1b47856fda4db098b62b65",
15
+ "version": "1.0.0"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -52,10 +52,14 @@ from xoscar.utils import get_next_port
52
52
 
53
53
  from .._compat import BaseModel, Field
54
54
  from .._version import get_versions
55
- from ..constants import XINFERENCE_DEFAULT_ENDPOINT_PORT, XINFERENCE_DISABLE_METRICS
55
+ from ..constants import (
56
+ XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION,
57
+ XINFERENCE_DEFAULT_ENDPOINT_PORT,
58
+ XINFERENCE_DISABLE_METRICS,
59
+ )
56
60
  from ..core.event import Event, EventCollectorActor, EventType
57
61
  from ..core.supervisor import SupervisorActor
58
- from ..core.utils import json_dumps
62
+ from ..core.utils import CancelMixin, json_dumps
59
63
  from ..types import (
60
64
  ChatCompletion,
61
65
  Completion,
@@ -111,6 +115,7 @@ class RerankRequest(BaseModel):
111
115
  return_documents: Optional[bool] = False
112
116
  return_len: Optional[bool] = False
113
117
  max_chunks_per_doc: Optional[int] = None
118
+ kwargs: Optional[str] = None
114
119
 
115
120
 
116
121
  class TextToImageRequest(BaseModel):
@@ -206,7 +211,7 @@ class BuildGradioImageInterfaceRequest(BaseModel):
206
211
  model_ability: List[str]
207
212
 
208
213
 
209
- class RESTfulAPI:
214
+ class RESTfulAPI(CancelMixin):
210
215
  def __init__(
211
216
  self,
212
217
  supervisor_address: str,
@@ -1311,11 +1316,6 @@ class RESTfulAPI:
1311
1316
  payload = await request.json()
1312
1317
  body = RerankRequest.parse_obj(payload)
1313
1318
  model_uid = body.model
1314
- kwargs = {
1315
- key: value
1316
- for key, value in payload.items()
1317
- if key not in RerankRequest.__annotations__.keys()
1318
- }
1319
1319
 
1320
1320
  try:
1321
1321
  model = await (await self._get_supervisor_ref()).get_model(model_uid)
@@ -1329,6 +1329,10 @@ class RESTfulAPI:
1329
1329
  raise HTTPException(status_code=500, detail=str(e))
1330
1330
 
1331
1331
  try:
1332
+ if body.kwargs is not None:
1333
+ parsed_kwargs = json.loads(body.kwargs)
1334
+ else:
1335
+ parsed_kwargs = {}
1332
1336
  scores = await model.rerank(
1333
1337
  body.documents,
1334
1338
  body.query,
@@ -1336,7 +1340,7 @@ class RESTfulAPI:
1336
1340
  max_chunks_per_doc=body.max_chunks_per_doc,
1337
1341
  return_documents=body.return_documents,
1338
1342
  return_len=body.return_len,
1339
- **kwargs,
1343
+ **parsed_kwargs,
1340
1344
  )
1341
1345
  return Response(scores, media_type="application/json")
1342
1346
  except RuntimeError as re:
@@ -1531,8 +1535,11 @@ class RESTfulAPI:
1531
1535
  await self._report_error_event(model_uid, str(e))
1532
1536
  raise HTTPException(status_code=500, detail=str(e))
1533
1537
 
1538
+ request_id = None
1534
1539
  try:
1535
1540
  kwargs = json.loads(body.kwargs) if body.kwargs else {}
1541
+ request_id = kwargs.get("request_id")
1542
+ self._add_running_task(request_id)
1536
1543
  image_list = await model.text_to_image(
1537
1544
  prompt=body.prompt,
1538
1545
  n=body.n,
@@ -1541,6 +1548,11 @@ class RESTfulAPI:
1541
1548
  **kwargs,
1542
1549
  )
1543
1550
  return Response(content=image_list, media_type="application/json")
1551
+ except asyncio.CancelledError:
1552
+ err_str = f"The request has been cancelled: {request_id}"
1553
+ logger.error(err_str)
1554
+ await self._report_error_event(model_uid, err_str)
1555
+ raise HTTPException(status_code=409, detail=err_str)
1544
1556
  except RuntimeError as re:
1545
1557
  logger.error(re, exc_info=True)
1546
1558
  await self._report_error_event(model_uid, str(re))
@@ -1686,11 +1698,14 @@ class RESTfulAPI:
1686
1698
  await self._report_error_event(model_uid, str(e))
1687
1699
  raise HTTPException(status_code=500, detail=str(e))
1688
1700
 
1701
+ request_id = None
1689
1702
  try:
1690
1703
  if kwargs is not None:
1691
1704
  parsed_kwargs = json.loads(kwargs)
1692
1705
  else:
1693
1706
  parsed_kwargs = {}
1707
+ request_id = parsed_kwargs.get("request_id")
1708
+ self._add_running_task(request_id)
1694
1709
  image_list = await model_ref.image_to_image(
1695
1710
  image=Image.open(image.file),
1696
1711
  prompt=prompt,
@@ -1701,6 +1716,11 @@ class RESTfulAPI:
1701
1716
  **parsed_kwargs,
1702
1717
  )
1703
1718
  return Response(content=image_list, media_type="application/json")
1719
+ except asyncio.CancelledError:
1720
+ err_str = f"The request has been cancelled: {request_id}"
1721
+ logger.error(err_str)
1722
+ await self._report_error_event(model_uid, err_str)
1723
+ raise HTTPException(status_code=409, detail=err_str)
1704
1724
  except RuntimeError as re:
1705
1725
  logger.error(re, exc_info=True)
1706
1726
  await self._report_error_event(model_uid, str(re))
@@ -1734,11 +1754,14 @@ class RESTfulAPI:
1734
1754
  await self._report_error_event(model_uid, str(e))
1735
1755
  raise HTTPException(status_code=500, detail=str(e))
1736
1756
 
1757
+ request_id = None
1737
1758
  try:
1738
1759
  if kwargs is not None:
1739
1760
  parsed_kwargs = json.loads(kwargs)
1740
1761
  else:
1741
1762
  parsed_kwargs = {}
1763
+ request_id = parsed_kwargs.get("request_id")
1764
+ self._add_running_task(request_id)
1742
1765
  im = Image.open(image.file)
1743
1766
  mask_im = Image.open(mask_image.file)
1744
1767
  if not size:
@@ -1755,6 +1778,11 @@ class RESTfulAPI:
1755
1778
  **parsed_kwargs,
1756
1779
  )
1757
1780
  return Response(content=image_list, media_type="application/json")
1781
+ except asyncio.CancelledError:
1782
+ err_str = f"The request has been cancelled: {request_id}"
1783
+ logger.error(err_str)
1784
+ await self._report_error_event(model_uid, err_str)
1785
+ raise HTTPException(status_code=409, detail=err_str)
1758
1786
  except RuntimeError as re:
1759
1787
  logger.error(re, exc_info=True)
1760
1788
  await self._report_error_event(model_uid, str(re))
@@ -1782,17 +1810,25 @@ class RESTfulAPI:
1782
1810
  await self._report_error_event(model_uid, str(e))
1783
1811
  raise HTTPException(status_code=500, detail=str(e))
1784
1812
 
1813
+ request_id = None
1785
1814
  try:
1786
1815
  if kwargs is not None:
1787
1816
  parsed_kwargs = json.loads(kwargs)
1788
1817
  else:
1789
1818
  parsed_kwargs = {}
1819
+ request_id = parsed_kwargs.get("request_id")
1820
+ self._add_running_task(request_id)
1790
1821
  im = Image.open(image.file)
1791
1822
  text = await model_ref.ocr(
1792
1823
  image=im,
1793
1824
  **parsed_kwargs,
1794
1825
  )
1795
1826
  return Response(content=text, media_type="text/plain")
1827
+ except asyncio.CancelledError:
1828
+ err_str = f"The request has been cancelled: {request_id}"
1829
+ logger.error(err_str)
1830
+ await self._report_error_event(model_uid, err_str)
1831
+ raise HTTPException(status_code=409, detail=err_str)
1796
1832
  except RuntimeError as re:
1797
1833
  logger.error(re, exc_info=True)
1798
1834
  await self._report_error_event(model_uid, str(re))
@@ -2111,10 +2147,25 @@ class RESTfulAPI:
2111
2147
  logger.error(e, exc_info=True)
2112
2148
  raise HTTPException(status_code=500, detail=str(e))
2113
2149
 
2114
- async def abort_request(self, model_uid: str, request_id: str) -> JSONResponse:
2150
+ async def abort_request(
2151
+ self, request: Request, model_uid: str, request_id: str
2152
+ ) -> JSONResponse:
2115
2153
  try:
2154
+ payload = await request.json()
2155
+ block_duration = payload.get(
2156
+ "block_duration", XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION
2157
+ )
2158
+ logger.info(
2159
+ "Abort request with model uid: %s, request id: %s, block duration: %s",
2160
+ model_uid,
2161
+ request_id,
2162
+ block_duration,
2163
+ )
2116
2164
  supervisor_ref = await self._get_supervisor_ref()
2117
- res = await supervisor_ref.abort_request(model_uid, request_id)
2165
+ res = await supervisor_ref.abort_request(
2166
+ model_uid, request_id, block_duration
2167
+ )
2168
+ self._cancel_running_task(request_id, block_duration)
2118
2169
  return JSONResponse(content=res)
2119
2170
  except Exception as e:
2120
2171
  logger.error(e, exc_info=True)
@@ -174,6 +174,7 @@ class RESTfulRerankModelHandle(RESTfulModelHandle):
174
174
  "max_chunks_per_doc": max_chunks_per_doc,
175
175
  "return_documents": return_documents,
176
176
  "return_len": return_len,
177
+ "kwargs": json.dumps(kwargs),
177
178
  }
178
179
  request_body.update(kwargs)
179
180
  response = requests.post(url, json=request_body, headers=self.auth_headers)
@@ -1357,7 +1358,7 @@ class Client:
1357
1358
  response_data = response.json()
1358
1359
  return response_data
1359
1360
 
1360
- def abort_request(self, model_uid: str, request_id: str):
1361
+ def abort_request(self, model_uid: str, request_id: str, block_duration: int = 30):
1361
1362
  """
1362
1363
  Abort a request.
1363
1364
  Abort a submitted request. If the request is finished or not found, this method will be a no-op.
@@ -1369,13 +1370,18 @@ class Client:
1369
1370
  Model uid.
1370
1371
  request_id: str
1371
1372
  Request id.
1373
+ block_duration: int
1374
+ The duration to make the request id abort. If set to 0, the abort_request will be immediate, which may
1375
+ prevent it from taking effect if it arrives before the request operation.
1372
1376
  Returns
1373
1377
  -------
1374
1378
  Dict
1375
1379
  Return empty dict.
1376
1380
  """
1377
1381
  url = f"{self.base_url}/v1/models/{model_uid}/requests/{request_id}/abort"
1378
- response = requests.post(url, headers=self._headers)
1382
+ response = requests.post(
1383
+ url, headers=self._headers, json={"block_duration": block_duration}
1384
+ )
1379
1385
  if response.status_code != 200:
1380
1386
  raise RuntimeError(
1381
1387
  f"Failed to abort request, detail: {_get_error_string(response)}"
xinference/constants.py CHANGED
@@ -88,3 +88,4 @@ XINFERENCE_TEXT_TO_IMAGE_BATCHING_SIZE = os.environ.get(
88
88
  XINFERENCE_ENV_TEXT_TO_IMAGE_BATCHING_SIZE, None
89
89
  )
90
90
  XINFERENCE_LAUNCH_MODEL_RETRY = 3
91
+ XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION = 30
xinference/core/model.py CHANGED
@@ -41,6 +41,7 @@ import sse_starlette.sse
41
41
  import xoscar as xo
42
42
 
43
43
  from ..constants import (
44
+ XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION,
44
45
  XINFERENCE_LAUNCH_MODEL_RETRY,
45
46
  XINFERENCE_TEXT_TO_IMAGE_BATCHING_SIZE,
46
47
  )
@@ -57,7 +58,7 @@ import logging
57
58
  logger = logging.getLogger(__name__)
58
59
 
59
60
  from ..device_utils import empty_cache
60
- from .utils import json_dumps, log_async
61
+ from .utils import CancelMixin, json_dumps, log_async
61
62
 
62
63
  try:
63
64
  from torch.cuda import OutOfMemoryError
@@ -136,7 +137,7 @@ def oom_check(fn):
136
137
  return _wrapper
137
138
 
138
139
 
139
- class ModelActor(xo.StatelessActor):
140
+ class ModelActor(xo.StatelessActor, CancelMixin):
140
141
  _replica_model_uid: Optional[str]
141
142
 
142
143
  @classmethod
@@ -553,6 +554,7 @@ class ModelActor(xo.StatelessActor):
553
554
 
554
555
  @oom_check
555
556
  async def _call_wrapper(self, output_type: str, fn: Callable, *args, **kwargs):
557
+ self._add_running_task(kwargs.get("request_id"))
556
558
  if self._lock is None:
557
559
  if inspect.iscoroutinefunction(fn):
558
560
  ret = await fn(*args, **kwargs)
@@ -761,9 +763,14 @@ class ModelActor(xo.StatelessActor):
761
763
  prompt_tokens,
762
764
  )
763
765
 
764
- async def abort_request(self, request_id: str) -> str:
766
+ async def abort_request(
767
+ self,
768
+ request_id: str,
769
+ block_duration: int = XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION,
770
+ ) -> str:
765
771
  from .utils import AbortRequestMessage
766
772
 
773
+ self._cancel_running_task(request_id, block_duration)
767
774
  if self.allow_batching():
768
775
  if self._scheduler_ref is None:
769
776
  return AbortRequestMessage.NOT_FOUND.name
@@ -35,6 +35,7 @@ from typing import (
35
35
  import xoscar as xo
36
36
 
37
37
  from ..constants import (
38
+ XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION,
38
39
  XINFERENCE_DISABLE_HEALTH_CHECK,
39
40
  XINFERENCE_HEALTH_CHECK_FAILURE_THRESHOLD,
40
41
  XINFERENCE_HEALTH_CHECK_INTERVAL,
@@ -1213,7 +1214,12 @@ class SupervisorActor(xo.StatelessActor):
1213
1214
  return cached_models
1214
1215
 
1215
1216
  @log_async(logger=logger)
1216
- async def abort_request(self, model_uid: str, request_id: str) -> Dict:
1217
+ async def abort_request(
1218
+ self,
1219
+ model_uid: str,
1220
+ request_id: str,
1221
+ block_duration: int = XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION,
1222
+ ) -> Dict:
1217
1223
  from .scheduler import AbortRequestMessage
1218
1224
 
1219
1225
  res = {"msg": AbortRequestMessage.NO_OP.name}
@@ -1228,7 +1234,7 @@ class SupervisorActor(xo.StatelessActor):
1228
1234
  if worker_ref is None:
1229
1235
  continue
1230
1236
  model_ref = await worker_ref.get_model(model_uid=rep_mid)
1231
- result_info = await model_ref.abort_request(request_id)
1237
+ result_info = await model_ref.abort_request(request_id, block_duration)
1232
1238
  res["msg"] = result_info
1233
1239
  if result_info == AbortRequestMessage.DONE.name:
1234
1240
  break
xinference/core/utils.py CHANGED
@@ -11,11 +11,13 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+ import asyncio
14
15
  import logging
15
16
  import os
16
17
  import random
17
18
  import string
18
19
  import uuid
20
+ import weakref
19
21
  from enum import Enum
20
22
  from typing import Dict, Generator, List, Optional, Tuple, Union
21
23
 
@@ -23,7 +25,10 @@ import orjson
23
25
  from pynvml import nvmlDeviceGetCount, nvmlInit, nvmlShutdown
24
26
 
25
27
  from .._compat import BaseModel
26
- from ..constants import XINFERENCE_LOG_ARG_MAX_LENGTH
28
+ from ..constants import (
29
+ XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION,
30
+ XINFERENCE_LOG_ARG_MAX_LENGTH,
31
+ )
27
32
 
28
33
  logger = logging.getLogger(__name__)
29
34
 
@@ -49,13 +54,20 @@ def log_async(
49
54
  ):
50
55
  import time
51
56
  from functools import wraps
57
+ from inspect import signature
52
58
 
53
59
  def decorator(func):
54
60
  func_name = func.__name__
61
+ sig = signature(func)
55
62
 
56
63
  @wraps(func)
57
64
  async def wrapped(*args, **kwargs):
58
- request_id_str = kwargs.get("request_id", "")
65
+ try:
66
+ bound_args = sig.bind_partial(*args, **kwargs)
67
+ arguments = bound_args.arguments
68
+ except TypeError:
69
+ arguments = {}
70
+ request_id_str = arguments.get("request_id", "")
59
71
  if not request_id_str:
60
72
  request_id_str = uuid.uuid1()
61
73
  if func_name == "text_to_image":
@@ -269,3 +281,56 @@ def assign_replica_gpu(
269
281
  if isinstance(gpu_idx, list) and gpu_idx:
270
282
  return gpu_idx[rep_id::replica]
271
283
  return gpu_idx
284
+
285
+
286
+ class CancelMixin:
287
+ _CANCEL_TASK_NAME = "abort_block"
288
+
289
+ def __init__(self):
290
+ self._running_tasks: weakref.WeakValueDictionary[
291
+ str, asyncio.Task
292
+ ] = weakref.WeakValueDictionary()
293
+
294
+ def _add_running_task(self, request_id: Optional[str]):
295
+ """Add current asyncio task to the running task.
296
+ :param request_id: The corresponding request id.
297
+ """
298
+ if request_id is None:
299
+ return
300
+ running_task = self._running_tasks.get(request_id)
301
+ if running_task is not None:
302
+ if running_task.get_name() == self._CANCEL_TASK_NAME:
303
+ raise Exception(f"The request has been aborted: {request_id}")
304
+ raise Exception(f"Duplicate request id: {request_id}")
305
+ current_task = asyncio.current_task()
306
+ assert current_task is not None
307
+ self._running_tasks[request_id] = current_task
308
+
309
+ def _cancel_running_task(
310
+ self,
311
+ request_id: Optional[str],
312
+ block_duration: int = XINFERENCE_DEFAULT_CANCEL_BLOCK_DURATION,
313
+ ):
314
+ """Cancel the running asyncio task.
315
+ :param request_id: The request id to cancel.
316
+ :param block_duration: The duration seconds to ensure the request can't be executed.
317
+ """
318
+ if request_id is None:
319
+ return
320
+ running_task = self._running_tasks.pop(request_id, None)
321
+ if running_task is not None:
322
+ running_task.cancel()
323
+
324
+ async def block_task():
325
+ """This task is for blocking the request for a duration."""
326
+ try:
327
+ await asyncio.sleep(block_duration)
328
+ logger.info("Abort block end for request: %s", request_id)
329
+ except asyncio.CancelledError:
330
+ logger.info("Abort block is cancelled for request: %s", request_id)
331
+
332
+ if block_duration > 0:
333
+ logger.info("Abort block start for request: %s", request_id)
334
+ self._running_tasks[request_id] = asyncio.create_task(
335
+ block_task(), name=self._CANCEL_TASK_NAME
336
+ )
@@ -159,7 +159,7 @@
159
159
  "model_name": "FishSpeech-1.4",
160
160
  "model_family": "FishAudio",
161
161
  "model_id": "fishaudio/fish-speech-1.4",
162
- "model_revision": "3c49651b8e583b6b13f55e375432e0d57e1aa84d",
162
+ "model_revision": "069c573759936b35191d3380deb89183c0656f59",
163
163
  "model_ability": "text-to-audio",
164
164
  "multilingual": true
165
165
  }
@@ -17,9 +17,11 @@ import gc
17
17
  import inspect
18
18
  import itertools
19
19
  import logging
20
+ import os
20
21
  import re
21
22
  import sys
22
23
  import warnings
24
+ from glob import glob
23
25
  from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
24
26
 
25
27
  import PIL.Image
@@ -194,8 +196,9 @@ class DiffusionModel(SDAPIDiffusionModelMixin):
194
196
  if sys.platform != "darwin" and torch_dtype is None:
195
197
  # The following params crashes on Mac M2
196
198
  self._torch_dtype = self._kwargs["torch_dtype"] = torch.float16
197
- self._kwargs["variant"] = "fp16"
198
- self._kwargs["use_safetensors"] = True
199
+ self._kwargs["use_safetensors"] = any(
200
+ glob(os.path.join(self._model_path, "*/*.safetensors"))
201
+ )
199
202
  if isinstance(torch_dtype, str):
200
203
  self._kwargs["torch_dtype"] = getattr(torch, torch_dtype)
201
204