xinference 0.14.4.post1__py3-none-any.whl → 0.15.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (194) hide show
  1. xinference/_compat.py +51 -0
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +209 -40
  4. xinference/client/restful/restful_client.py +7 -26
  5. xinference/conftest.py +1 -1
  6. xinference/constants.py +5 -0
  7. xinference/core/cache_tracker.py +1 -1
  8. xinference/core/chat_interface.py +8 -14
  9. xinference/core/event.py +1 -1
  10. xinference/core/image_interface.py +28 -0
  11. xinference/core/model.py +110 -31
  12. xinference/core/scheduler.py +37 -37
  13. xinference/core/status_guard.py +1 -1
  14. xinference/core/supervisor.py +17 -10
  15. xinference/core/utils.py +80 -22
  16. xinference/core/worker.py +17 -16
  17. xinference/deploy/cmdline.py +8 -16
  18. xinference/deploy/local.py +1 -1
  19. xinference/deploy/supervisor.py +1 -1
  20. xinference/deploy/utils.py +1 -1
  21. xinference/deploy/worker.py +1 -1
  22. xinference/model/audio/cosyvoice.py +86 -41
  23. xinference/model/audio/fish_speech.py +9 -9
  24. xinference/model/audio/model_spec.json +9 -9
  25. xinference/model/audio/whisper.py +4 -1
  26. xinference/model/embedding/core.py +52 -31
  27. xinference/model/image/core.py +2 -1
  28. xinference/model/image/model_spec.json +16 -4
  29. xinference/model/image/model_spec_modelscope.json +16 -4
  30. xinference/model/image/sdapi.py +136 -0
  31. xinference/model/image/stable_diffusion/core.py +164 -19
  32. xinference/model/llm/__init__.py +29 -11
  33. xinference/model/llm/llama_cpp/core.py +16 -33
  34. xinference/model/llm/llm_family.json +1011 -1296
  35. xinference/model/llm/llm_family.py +34 -53
  36. xinference/model/llm/llm_family_csghub.json +18 -35
  37. xinference/model/llm/llm_family_modelscope.json +981 -1122
  38. xinference/model/llm/lmdeploy/core.py +56 -88
  39. xinference/model/llm/mlx/core.py +46 -69
  40. xinference/model/llm/sglang/core.py +36 -18
  41. xinference/model/llm/transformers/chatglm.py +168 -306
  42. xinference/model/llm/transformers/cogvlm2.py +36 -63
  43. xinference/model/llm/transformers/cogvlm2_video.py +33 -223
  44. xinference/model/llm/transformers/core.py +55 -50
  45. xinference/model/llm/transformers/deepseek_v2.py +340 -0
  46. xinference/model/llm/transformers/deepseek_vl.py +53 -96
  47. xinference/model/llm/transformers/glm4v.py +55 -111
  48. xinference/model/llm/transformers/intern_vl.py +39 -70
  49. xinference/model/llm/transformers/internlm2.py +32 -54
  50. xinference/model/llm/transformers/minicpmv25.py +22 -55
  51. xinference/model/llm/transformers/minicpmv26.py +158 -68
  52. xinference/model/llm/transformers/omnilmm.py +5 -28
  53. xinference/model/llm/transformers/qwen2_audio.py +168 -0
  54. xinference/model/llm/transformers/qwen2_vl.py +234 -0
  55. xinference/model/llm/transformers/qwen_vl.py +34 -86
  56. xinference/model/llm/transformers/utils.py +32 -38
  57. xinference/model/llm/transformers/yi_vl.py +32 -72
  58. xinference/model/llm/utils.py +280 -554
  59. xinference/model/llm/vllm/core.py +161 -100
  60. xinference/model/rerank/core.py +41 -8
  61. xinference/model/rerank/model_spec.json +7 -0
  62. xinference/model/rerank/model_spec_modelscope.json +7 -1
  63. xinference/model/utils.py +1 -31
  64. xinference/thirdparty/cosyvoice/bin/export_jit.py +64 -0
  65. xinference/thirdparty/cosyvoice/bin/export_trt.py +8 -0
  66. xinference/thirdparty/cosyvoice/bin/inference.py +5 -2
  67. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +38 -22
  68. xinference/thirdparty/cosyvoice/cli/model.py +139 -26
  69. xinference/thirdparty/cosyvoice/flow/flow.py +15 -9
  70. xinference/thirdparty/cosyvoice/flow/length_regulator.py +20 -1
  71. xinference/thirdparty/cosyvoice/hifigan/generator.py +8 -4
  72. xinference/thirdparty/cosyvoice/llm/llm.py +14 -13
  73. xinference/thirdparty/cosyvoice/transformer/attention.py +7 -3
  74. xinference/thirdparty/cosyvoice/transformer/decoder.py +1 -1
  75. xinference/thirdparty/cosyvoice/transformer/embedding.py +4 -3
  76. xinference/thirdparty/cosyvoice/transformer/encoder.py +4 -2
  77. xinference/thirdparty/cosyvoice/utils/common.py +36 -0
  78. xinference/thirdparty/cosyvoice/utils/file_utils.py +16 -0
  79. xinference/thirdparty/deepseek_vl/serve/assets/Kelpy-Codos.js +100 -0
  80. xinference/thirdparty/deepseek_vl/serve/assets/avatar.png +0 -0
  81. xinference/thirdparty/deepseek_vl/serve/assets/custom.css +355 -0
  82. xinference/thirdparty/deepseek_vl/serve/assets/custom.js +22 -0
  83. xinference/thirdparty/deepseek_vl/serve/assets/favicon.ico +0 -0
  84. xinference/thirdparty/deepseek_vl/serve/examples/app.png +0 -0
  85. xinference/thirdparty/deepseek_vl/serve/examples/chart.png +0 -0
  86. xinference/thirdparty/deepseek_vl/serve/examples/mirror.png +0 -0
  87. xinference/thirdparty/deepseek_vl/serve/examples/pipeline.png +0 -0
  88. xinference/thirdparty/deepseek_vl/serve/examples/puzzle.png +0 -0
  89. xinference/thirdparty/deepseek_vl/serve/examples/rap.jpeg +0 -0
  90. xinference/thirdparty/fish_speech/fish_speech/configs/base.yaml +87 -0
  91. xinference/thirdparty/fish_speech/fish_speech/configs/firefly_gan_vq.yaml +33 -0
  92. xinference/thirdparty/fish_speech/fish_speech/configs/lora/r_8_alpha_16.yaml +4 -0
  93. xinference/thirdparty/fish_speech/fish_speech/configs/text2semantic_finetune.yaml +83 -0
  94. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text-data.proto +24 -0
  95. xinference/thirdparty/fish_speech/fish_speech/i18n/README.md +27 -0
  96. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +1 -1
  97. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +1 -1
  98. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +1 -1
  99. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +1 -1
  100. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +1 -1
  101. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +2 -2
  102. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +0 -3
  103. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +169 -198
  104. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +4 -27
  105. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/.gitignore +114 -0
  106. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/README.md +36 -0
  107. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +9 -47
  108. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +2 -2
  109. xinference/thirdparty/fish_speech/fish_speech/train.py +2 -0
  110. xinference/thirdparty/fish_speech/fish_speech/webui/css/style.css +161 -0
  111. xinference/thirdparty/fish_speech/fish_speech/webui/html/footer.html +11 -0
  112. xinference/thirdparty/fish_speech/fish_speech/webui/js/animate.js +69 -0
  113. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +12 -10
  114. xinference/thirdparty/fish_speech/tools/api.py +79 -134
  115. xinference/thirdparty/fish_speech/tools/commons.py +35 -0
  116. xinference/thirdparty/fish_speech/tools/download_models.py +3 -3
  117. xinference/thirdparty/fish_speech/tools/file.py +17 -0
  118. xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +1 -1
  119. xinference/thirdparty/fish_speech/tools/llama/generate.py +29 -24
  120. xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +1 -1
  121. xinference/thirdparty/fish_speech/tools/llama/quantize.py +2 -2
  122. xinference/thirdparty/fish_speech/tools/msgpack_api.py +34 -0
  123. xinference/thirdparty/fish_speech/tools/post_api.py +85 -44
  124. xinference/thirdparty/fish_speech/tools/sensevoice/README.md +59 -0
  125. xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +1 -1
  126. xinference/thirdparty/fish_speech/tools/smart_pad.py +16 -3
  127. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +2 -2
  128. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +4 -2
  129. xinference/thirdparty/fish_speech/tools/webui.py +12 -146
  130. xinference/thirdparty/matcha/VERSION +1 -0
  131. xinference/thirdparty/matcha/hifigan/LICENSE +21 -0
  132. xinference/thirdparty/matcha/hifigan/README.md +101 -0
  133. xinference/thirdparty/omnilmm/LICENSE +201 -0
  134. xinference/thirdparty/whisper/__init__.py +156 -0
  135. xinference/thirdparty/whisper/__main__.py +3 -0
  136. xinference/thirdparty/whisper/assets/gpt2.tiktoken +50256 -0
  137. xinference/thirdparty/whisper/assets/mel_filters.npz +0 -0
  138. xinference/thirdparty/whisper/assets/multilingual.tiktoken +50257 -0
  139. xinference/thirdparty/whisper/audio.py +157 -0
  140. xinference/thirdparty/whisper/decoding.py +826 -0
  141. xinference/thirdparty/whisper/model.py +314 -0
  142. xinference/thirdparty/whisper/normalizers/__init__.py +2 -0
  143. xinference/thirdparty/whisper/normalizers/basic.py +76 -0
  144. xinference/thirdparty/whisper/normalizers/english.json +1741 -0
  145. xinference/thirdparty/whisper/normalizers/english.py +550 -0
  146. xinference/thirdparty/whisper/timing.py +386 -0
  147. xinference/thirdparty/whisper/tokenizer.py +395 -0
  148. xinference/thirdparty/whisper/transcribe.py +605 -0
  149. xinference/thirdparty/whisper/triton_ops.py +109 -0
  150. xinference/thirdparty/whisper/utils.py +316 -0
  151. xinference/thirdparty/whisper/version.py +1 -0
  152. xinference/types.py +14 -53
  153. xinference/web/ui/build/asset-manifest.json +6 -6
  154. xinference/web/ui/build/index.html +1 -1
  155. xinference/web/ui/build/static/css/{main.4bafd904.css → main.5061c4c3.css} +2 -2
  156. xinference/web/ui/build/static/css/main.5061c4c3.css.map +1 -0
  157. xinference/web/ui/build/static/js/main.754740c0.js +3 -0
  158. xinference/web/ui/build/static/js/{main.eb13fe95.js.LICENSE.txt → main.754740c0.js.LICENSE.txt} +2 -0
  159. xinference/web/ui/build/static/js/main.754740c0.js.map +1 -0
  160. xinference/web/ui/node_modules/.cache/babel-loader/10c69dc7a296779fcffedeff9393d832dfcb0013c36824adf623d3c518b801ff.json +1 -0
  161. xinference/web/ui/node_modules/.cache/babel-loader/68bede6d95bb5ef0b35bbb3ec5b8c937eaf6862c6cdbddb5ef222a7776aaf336.json +1 -0
  162. xinference/web/ui/node_modules/.cache/babel-loader/77d50223f3e734d4485cca538cb098a8c3a7a0a1a9f01f58cdda3af42fe1adf5.json +1 -0
  163. xinference/web/ui/node_modules/.cache/babel-loader/a56d5a642409a84988891089c98ca28ad0546432dfbae8aaa51bc5a280e1cdd2.json +1 -0
  164. xinference/web/ui/node_modules/.cache/babel-loader/cd90b08d177025dfe84209596fc51878f8a86bcaa6a240848a3d2e5fd4c7ff24.json +1 -0
  165. xinference/web/ui/node_modules/.cache/babel-loader/d9ff696a3e3471f01b46c63d18af32e491eb5dc0e43cb30202c96871466df57f.json +1 -0
  166. xinference/web/ui/node_modules/.cache/babel-loader/e42b72d4cc1ea412ebecbb8d040dc6c6bfee462c33903c2f1f3facb602ad742e.json +1 -0
  167. xinference/web/ui/node_modules/.cache/babel-loader/f5039ddbeb815c51491a1989532006b96fc3ae49c6c60e3c097f875b4ae915ae.json +1 -0
  168. xinference/web/ui/node_modules/.package-lock.json +37 -0
  169. xinference/web/ui/node_modules/a-sync-waterfall/package.json +21 -0
  170. xinference/web/ui/node_modules/nunjucks/node_modules/commander/package.json +48 -0
  171. xinference/web/ui/node_modules/nunjucks/package.json +112 -0
  172. xinference/web/ui/package-lock.json +38 -0
  173. xinference/web/ui/package.json +1 -0
  174. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/METADATA +16 -10
  175. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/RECORD +179 -127
  176. xinference/model/llm/transformers/llama_2.py +0 -108
  177. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +0 -442
  178. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +0 -44
  179. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +0 -115
  180. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +0 -225
  181. xinference/thirdparty/fish_speech/tools/auto_rerank.py +0 -159
  182. xinference/thirdparty/fish_speech/tools/gen_ref.py +0 -36
  183. xinference/thirdparty/fish_speech/tools/merge_asr_files.py +0 -55
  184. xinference/web/ui/build/static/css/main.4bafd904.css.map +0 -1
  185. xinference/web/ui/build/static/js/main.eb13fe95.js +0 -3
  186. xinference/web/ui/build/static/js/main.eb13fe95.js.map +0 -1
  187. xinference/web/ui/node_modules/.cache/babel-loader/0b11a5339468c13b2d31ac085e7effe4303259b2071abd46a0a8eb8529233a5e.json +0 -1
  188. xinference/web/ui/node_modules/.cache/babel-loader/213b5913e164773c2b0567455377765715f5f07225fbac77ad8e1e9dc9648a47.json +0 -1
  189. xinference/web/ui/node_modules/.cache/babel-loader/5c26a23b5eacf5b752a08531577ae3840bb247745ef9a39583dc2d05ba93a82a.json +0 -1
  190. xinference/web/ui/node_modules/.cache/babel-loader/978b57d1a04a701bc3fcfebc511f5f274eed6ed7eade67f6fb76c27d5fd9ecc8.json +0 -1
  191. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/LICENSE +0 -0
  192. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/WHEEL +0 -0
  193. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/entry_points.txt +0 -0
  194. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/top_level.txt +0 -0
xinference/core/utils.py CHANGED
@@ -11,62 +11,120 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- import copy
15
14
  import logging
16
15
  import os
17
16
  import random
18
17
  import string
19
- from typing import Dict, Generator, List, Tuple, Union
18
+ import uuid
19
+ from typing import Dict, Generator, List, Optional, Tuple, Union
20
20
 
21
21
  import orjson
22
22
  from pynvml import nvmlDeviceGetCount, nvmlInit, nvmlShutdown
23
23
 
24
24
  from .._compat import BaseModel
25
+ from ..constants import XINFERENCE_LOG_ARG_MAX_LENGTH
25
26
 
26
27
  logger = logging.getLogger(__name__)
27
28
 
28
29
 
29
- def log_async(logger, args_formatter=None):
30
+ def truncate_log_arg(arg) -> str:
31
+ s = str(arg)
32
+ if len(s) > XINFERENCE_LOG_ARG_MAX_LENGTH:
33
+ s = s[0:XINFERENCE_LOG_ARG_MAX_LENGTH] + "..."
34
+ return s
35
+
36
+
37
+ def log_async(
38
+ logger,
39
+ level=logging.DEBUG,
40
+ ignore_kwargs: Optional[List[str]] = None,
41
+ log_exception=True,
42
+ ):
30
43
  import time
31
44
  from functools import wraps
32
45
 
33
46
  def decorator(func):
47
+ func_name = func.__name__
48
+
34
49
  @wraps(func)
35
50
  async def wrapped(*args, **kwargs):
36
- if args_formatter is not None:
37
- formatted_args, formatted_kwargs = copy.copy(args), copy.copy(kwargs)
38
- args_formatter(formatted_args, formatted_kwargs)
39
- else:
40
- formatted_args, formatted_kwargs = args, kwargs
41
- logger.debug(
42
- f"Enter {func.__name__}, args: {formatted_args}, kwargs: {formatted_kwargs}"
51
+ request_id_str = kwargs.get("request_id", "")
52
+ if not request_id_str:
53
+ request_id_str = uuid.uuid1()
54
+ request_id_str = f"[request {request_id_str}]"
55
+ formatted_args = ",".join(map(truncate_log_arg, args))
56
+ formatted_kwargs = ",".join(
57
+ [
58
+ "%s=%s" % (k, truncate_log_arg(v))
59
+ for k, v in kwargs.items()
60
+ if ignore_kwargs is None or k not in ignore_kwargs
61
+ ]
43
62
  )
44
- start = time.time()
45
- ret = await func(*args, **kwargs)
46
- logger.debug(
47
- f"Leave {func.__name__}, elapsed time: {int(time.time() - start)} s"
63
+ logger.log(
64
+ level,
65
+ f"{request_id_str} Enter {func_name}, args: {formatted_args}, kwargs: {formatted_kwargs}",
48
66
  )
49
- return ret
67
+ start = time.time()
68
+ try:
69
+ ret = await func(*args, **kwargs)
70
+ logger.log(
71
+ level,
72
+ f"{request_id_str} Leave {func_name}, elapsed time: {int(time.time() - start)} s",
73
+ )
74
+ return ret
75
+ except Exception as e:
76
+ if log_exception:
77
+ logger.error(
78
+ f"{request_id_str} Leave {func_name}, error: {e}, elapsed time: {int(time.time() - start)} s",
79
+ exc_info=True,
80
+ )
81
+ else:
82
+ logger.log(
83
+ level,
84
+ f"{request_id_str} Leave {func_name}, error: {e}, elapsed time: {int(time.time() - start)} s",
85
+ )
86
+ raise
50
87
 
51
88
  return wrapped
52
89
 
53
90
  return decorator
54
91
 
55
92
 
56
- def log_sync(logger):
93
+ def log_sync(logger, level=logging.DEBUG, log_exception=True):
57
94
  import time
58
95
  from functools import wraps
59
96
 
60
97
  def decorator(func):
61
98
  @wraps(func)
62
99
  def wrapped(*args, **kwargs):
63
- logger.debug(f"Enter {func.__name__}, args: {args}, kwargs: {kwargs}")
64
- start = time.time()
65
- ret = func(*args, **kwargs)
66
- logger.debug(
67
- f"Leave {func.__name__}, elapsed time: {int(time.time() - start)} s"
100
+ formatted_args = ",".join(map(truncate_log_arg, args))
101
+ formatted_kwargs = ",".join(
102
+ map(lambda x: "%s=%s" % (x[0], truncate_log_arg(x[1])), kwargs.items())
68
103
  )
69
- return ret
104
+ logger.log(
105
+ level,
106
+ f"Enter {func.__name__}, args: {formatted_args}, kwargs: {formatted_kwargs}",
107
+ )
108
+ start = time.time()
109
+ try:
110
+ ret = func(*args, **kwargs)
111
+ logger.log(
112
+ level,
113
+ f"Leave {func.__name__}, elapsed time: {int(time.time() - start)} s",
114
+ )
115
+ return ret
116
+ except Exception as e:
117
+ if log_exception:
118
+ logger.error(
119
+ f"Leave {func.__name__}, error: {e}, elapsed time: {int(time.time() - start)} s",
120
+ exc_info=True,
121
+ )
122
+ else:
123
+ logger.log(
124
+ level,
125
+ f"Leave {func.__name__}, error: {e}, elapsed time: {int(time.time() - start)} s",
126
+ )
127
+ raise
70
128
 
71
129
  return wrapped
72
130
 
xinference/core/worker.py CHANGED
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import asyncio
16
+ import logging
16
17
  import os
17
18
  import platform
18
19
  import queue
@@ -73,15 +74,15 @@ class WorkerActor(xo.StatelessActor):
73
74
  self._supervisor_ref: Optional[xo.ActorRefType] = None
74
75
  self._main_pool = main_pool
75
76
  self._main_pool.recover_sub_pool = self.recover_sub_pool
76
- self._status_guard_ref: xo.ActorRefType["StatusGuardActor"] = ( # type: ignore
77
- None
78
- )
77
+ self._status_guard_ref: xo.ActorRefType[
78
+ "StatusGuardActor"
79
+ ] = None # type: ignore
79
80
  self._event_collector_ref: xo.ActorRefType[ # type: ignore
80
81
  EventCollectorActor
81
82
  ] = None
82
- self._cache_tracker_ref: xo.ActorRefType[CacheTrackerActor] = ( # type: ignore
83
- None
84
- )
83
+ self._cache_tracker_ref: xo.ActorRefType[
84
+ CacheTrackerActor
85
+ ] = None # type: ignore
85
86
 
86
87
  # internal states.
87
88
  # temporary placeholder during model launch process:
@@ -185,7 +186,7 @@ class WorkerActor(xo.StatelessActor):
185
186
  break
186
187
 
187
188
  @classmethod
188
- def uid(cls) -> str:
189
+ def default_uid(cls) -> str:
189
190
  return "worker"
190
191
 
191
192
  async def __post_create__(self):
@@ -270,9 +271,9 @@ class WorkerActor(xo.StatelessActor):
270
271
 
271
272
  try:
272
273
  await self.get_supervisor_ref(add_worker=True)
273
- except Exception as e:
274
+ except Exception:
274
275
  # Do not crash the worker if supervisor is down, auto re-connect later
275
- logger.error(f"cannot connect to supervisor {e}")
276
+ logger.error(f"cannot connect to supervisor", exc_info=True)
276
277
 
277
278
  if not XINFERENCE_DISABLE_HEALTH_CHECK:
278
279
  from ..isolation import Isolation
@@ -324,7 +325,7 @@ class WorkerActor(xo.StatelessActor):
324
325
  if self._supervisor_ref is not None:
325
326
  return self._supervisor_ref
326
327
  supervisor_ref = await xo.actor_ref( # type: ignore
327
- address=self._supervisor_address, uid=SupervisorActor.uid()
328
+ address=self._supervisor_address, uid=SupervisorActor.default_uid()
328
329
  )
329
330
  # Prevent concurrent operations leads to double initialization, check again.
330
331
  if self._supervisor_ref is not None:
@@ -336,13 +337,13 @@ class WorkerActor(xo.StatelessActor):
336
337
  logger.info("Connected to supervisor as a fresh worker")
337
338
 
338
339
  self._status_guard_ref = await xo.actor_ref(
339
- address=self._supervisor_address, uid=StatusGuardActor.uid()
340
+ address=self._supervisor_address, uid=StatusGuardActor.default_uid()
340
341
  )
341
342
  self._event_collector_ref = await xo.actor_ref(
342
- address=self._supervisor_address, uid=EventCollectorActor.uid()
343
+ address=self._supervisor_address, uid=EventCollectorActor.default_uid()
343
344
  )
344
345
  self._cache_tracker_ref = await xo.actor_ref(
345
- address=self._supervisor_address, uid=CacheTrackerActor.uid()
346
+ address=self._supervisor_address, uid=CacheTrackerActor.default_uid()
346
347
  )
347
348
  # cache_tracker is on supervisor
348
349
  from ..model.audio import get_audio_model_descriptions
@@ -770,7 +771,7 @@ class WorkerActor(xo.StatelessActor):
770
771
  version_info["model_file_location"],
771
772
  )
772
773
 
773
- @log_async(logger=logger)
774
+ @log_async(logger=logger, level=logging.INFO)
774
775
  async def launch_builtin_model(
775
776
  self,
776
777
  model_uid: str,
@@ -814,7 +815,7 @@ class WorkerActor(xo.StatelessActor):
814
815
  )
815
816
  except Exception as e:
816
817
  # Report callback error can be log and ignore, should not interrupt the Process
817
- logger.error("report_event error: %s" % (e))
818
+ logger.error("report_event error: %s" % (e), exc_info=True)
818
819
 
819
820
  if gpu_idx is not None:
820
821
  logger.info(
@@ -917,7 +918,7 @@ class WorkerActor(xo.StatelessActor):
917
918
  {"model_ability": abilities, "status": LaunchStatus.READY.name},
918
919
  )
919
920
 
920
- @log_async(logger=logger)
921
+ @log_async(logger=logger, level=logging.INFO)
921
922
  async def terminate_model(self, model_uid: str, is_model_die=False):
922
923
  # Terminate model while its launching is not allow
923
924
  if model_uid in self._model_uid_launching_guard:
@@ -17,7 +17,7 @@ import logging
17
17
  import os
18
18
  import sys
19
19
  import warnings
20
- from typing import List, Optional, Sequence, Tuple, Union
20
+ from typing import Dict, List, Optional, Sequence, Tuple, Union
21
21
 
22
22
  import click
23
23
  from xoscar.utils import get_next_port
@@ -38,7 +38,6 @@ from ..constants import (
38
38
  XINFERENCE_LOG_MAX_BYTES,
39
39
  )
40
40
  from ..isolation import Isolation
41
- from ..types import ChatCompletionMessage
42
41
  from .utils import (
43
42
  get_config_dict,
44
43
  get_log_file,
@@ -1210,13 +1209,12 @@ def model_chat(
1210
1209
  stream: bool,
1211
1210
  api_key: Optional[str],
1212
1211
  ):
1213
- # TODO: chat model roles may not be user and assistant.
1214
1212
  endpoint = get_endpoint(endpoint)
1215
1213
  client = RESTfulClient(base_url=endpoint, api_key=api_key)
1216
1214
  if api_key is None:
1217
1215
  client._set_token(get_stored_token(endpoint, client))
1218
1216
 
1219
- chat_history: "List[ChatCompletionMessage]" = []
1217
+ messages: List[Dict] = []
1220
1218
  if stream:
1221
1219
  # TODO: when stream=True, RestfulClient cannot generate words one by one.
1222
1220
  # So use Client in temporary. The implementation needs to be changed to
@@ -1229,10 +1227,10 @@ def model_chat(
1229
1227
  if prompt == "":
1230
1228
  break
1231
1229
  print("Assistant: ", end="", file=sys.stdout)
1230
+ messages.append(dict(role="user", content=prompt))
1232
1231
  response_content = ""
1233
1232
  for chunk in model.chat(
1234
- prompt=prompt,
1235
- chat_history=chat_history,
1233
+ messages,
1236
1234
  generate_config={"stream": stream, "max_tokens": max_tokens},
1237
1235
  ):
1238
1236
  delta = chunk["choices"][0]["delta"]
@@ -1242,10 +1240,7 @@ def model_chat(
1242
1240
  response_content += delta["content"]
1243
1241
  print(delta["content"], end="", flush=True, file=sys.stdout)
1244
1242
  print("", file=sys.stdout)
1245
- chat_history.append(ChatCompletionMessage(role="user", content=prompt))
1246
- chat_history.append(
1247
- ChatCompletionMessage(role="assistant", content=response_content)
1248
- )
1243
+ messages.append(dict(role="assistant", content=response_content))
1249
1244
 
1250
1245
  model = client.get_model(model_uid=model_uid)
1251
1246
 
@@ -1274,20 +1269,17 @@ def model_chat(
1274
1269
  prompt = input("User: ")
1275
1270
  if prompt == "":
1276
1271
  break
1277
- chat_history.append(ChatCompletionMessage(role="user", content=prompt))
1272
+ messages.append({"role": "user", "content": prompt})
1278
1273
  print("Assistant: ", end="", file=sys.stdout)
1279
1274
  response = restful_model.chat(
1280
- prompt=prompt,
1281
- chat_history=chat_history,
1275
+ messages,
1282
1276
  generate_config={"stream": stream, "max_tokens": max_tokens},
1283
1277
  )
1284
1278
  if not isinstance(response, dict):
1285
1279
  raise ValueError("chat result is not valid")
1286
1280
  response_content = response["choices"][0]["message"]["content"]
1287
1281
  print(f"{response_content}\n", file=sys.stdout)
1288
- chat_history.append(
1289
- ChatCompletionMessage(role="assistant", content=response_content)
1290
- )
1282
+ messages.append(dict(role="assistant", content=response_content))
1291
1283
 
1292
1284
 
1293
1285
  @cli.command("vllm-models", help="Query and display models compatible with vLLM.")
@@ -49,7 +49,7 @@ async def _start_local_cluster(
49
49
  address=address, logging_conf=logging_conf
50
50
  )
51
51
  await xo.create_actor(
52
- SupervisorActor, address=address, uid=SupervisorActor.uid()
52
+ SupervisorActor, address=address, uid=SupervisorActor.default_uid()
53
53
  )
54
54
  await start_worker_components(
55
55
  address=address,
@@ -41,7 +41,7 @@ async def _start_supervisor(address: str, logging_conf: Optional[Dict] = None):
41
41
  address=address, n_process=0, logging_conf={"dict": logging_conf}
42
42
  )
43
43
  await xo.create_actor(
44
- SupervisorActor, address=address, uid=SupervisorActor.uid()
44
+ SupervisorActor, address=address, uid=SupervisorActor.default_uid()
45
45
  )
46
46
  await pool.join()
47
47
  except asyncio.exceptions.CancelledError:
@@ -167,7 +167,7 @@ def health_check(address: str, max_attempts: int, sleep_interval: int = 3) -> bo
167
167
  from ..core.supervisor import SupervisorActor
168
168
 
169
169
  supervisor_ref: xo.ActorRefType[SupervisorActor] = await xo.actor_ref( # type: ignore
170
- address=address, uid=SupervisorActor.uid()
170
+ address=address, uid=SupervisorActor.default_uid()
171
171
  )
172
172
 
173
173
  await supervisor_ref.get_status()
@@ -43,7 +43,7 @@ async def start_worker_components(
43
43
  await xo.create_actor(
44
44
  WorkerActor,
45
45
  address=address,
46
- uid=WorkerActor.uid(),
46
+ uid=WorkerActor.default_uid(),
47
47
  supervisor_address=supervisor_address,
48
48
  main_pool=main_pool,
49
49
  gpu_devices=gpu_device_indices,
@@ -53,7 +53,82 @@ class CosyVoiceModel:
53
53
 
54
54
  from cosyvoice.cli.cosyvoice import CosyVoice
55
55
 
56
- self._model = CosyVoice(self._model_path)
56
+ self._model = CosyVoice(
57
+ self._model_path, load_jit=self._kwargs.get("load_jit", False)
58
+ )
59
+
60
+ def _speech_handle(
61
+ self,
62
+ stream,
63
+ input,
64
+ instruct_text,
65
+ prompt_speech,
66
+ prompt_text,
67
+ voice,
68
+ response_format,
69
+ ):
70
+ if prompt_speech:
71
+ from cosyvoice.utils.file_utils import load_wav
72
+
73
+ with io.BytesIO(prompt_speech) as prompt_speech_io:
74
+ prompt_speech_16k = load_wav(prompt_speech_io, 16000)
75
+
76
+ if prompt_text:
77
+ logger.info("CosyVoice inference_zero_shot")
78
+ output = self._model.inference_zero_shot(
79
+ input, prompt_text, prompt_speech_16k, stream=stream
80
+ )
81
+ else:
82
+ logger.info("CosyVoice inference_cross_lingual")
83
+ output = self._model.inference_cross_lingual(
84
+ input, prompt_speech_16k, stream=stream
85
+ )
86
+ else:
87
+ available_speakers = self._model.list_avaliable_spks()
88
+ if not voice:
89
+ voice = available_speakers[0]
90
+ else:
91
+ assert (
92
+ voice in available_speakers
93
+ ), f"Invalid voice {voice}, CosyVoice available speakers: {available_speakers}"
94
+ if instruct_text:
95
+ logger.info("CosyVoice inference_instruct")
96
+ output = self._model.inference_instruct(
97
+ input, voice, instruct_text=instruct_text, stream=stream
98
+ )
99
+ else:
100
+ logger.info("CosyVoice inference_sft")
101
+ output = self._model.inference_sft(input, voice, stream=stream)
102
+
103
+ import torch
104
+ import torchaudio
105
+
106
+ def _generator_stream():
107
+ with BytesIO() as out:
108
+ writer = torchaudio.io.StreamWriter(out, format=response_format)
109
+ writer.add_audio_stream(sample_rate=22050, num_channels=1)
110
+ i = 0
111
+ last_pos = 0
112
+ with writer.open():
113
+ for chunk in output:
114
+ chunk = chunk["tts_speech"]
115
+ trans_chunk = torch.transpose(chunk, 0, 1)
116
+ writer.write_audio_chunk(i, trans_chunk)
117
+ new_last_pos = out.tell()
118
+ if new_last_pos != last_pos:
119
+ out.seek(last_pos)
120
+ encoded_bytes = out.read()
121
+ yield encoded_bytes
122
+ last_pos = new_last_pos
123
+
124
+ def _generator_block():
125
+ chunk = next(output)
126
+ assert isinstance(chunk, dict), "Expected data to be of type dict"
127
+ with BytesIO() as out:
128
+ torchaudio.save(out, chunk["tts_speech"], 22050, format=response_format)
129
+ return out.getvalue()
130
+
131
+ return _generator_stream() if stream else _generator_block()
57
132
 
58
133
  def speech(
59
134
  self,
@@ -64,12 +139,6 @@ class CosyVoiceModel:
64
139
  stream: bool = False,
65
140
  **kwargs,
66
141
  ):
67
- if stream:
68
- raise Exception("CosyVoiceModel does not support stream.")
69
-
70
- import torchaudio
71
- from cosyvoice.utils.file_utils import load_wav
72
-
73
142
  prompt_speech: Optional[bytes] = kwargs.pop("prompt_speech", None)
74
143
  prompt_text: Optional[str] = kwargs.pop("prompt_text", None)
75
144
  instruct_text: Optional[str] = kwargs.pop("instruct_text", None)
@@ -103,39 +172,15 @@ class CosyVoiceModel:
103
172
  ), "CosyVoice model does not support instruct_text"
104
173
 
105
174
  assert self._model is not None
175
+
106
176
  set_all_random_seed(seed)
107
- if prompt_speech:
108
- assert not voice, "voice can't be set with prompt speech."
109
- with io.BytesIO(prompt_speech) as prompt_speech_io:
110
- prompt_speech_16k = load_wav(prompt_speech_io, 16000)
111
- if prompt_text:
112
- logger.info("CosyVoice inference_zero_shot")
113
- output = self._model.inference_zero_shot(
114
- input, prompt_text, prompt_speech_16k
115
- )
116
- else:
117
- logger.info("CosyVoice inference_cross_lingual")
118
- output = self._model.inference_cross_lingual(
119
- input, prompt_speech_16k
120
- )
121
- else:
122
- available_speakers = self._model.list_avaliable_spks()
123
- if not voice:
124
- voice = available_speakers[0]
125
- else:
126
- assert (
127
- voice in available_speakers
128
- ), f"Invalid voice {voice}, CosyVoice available speakers: {available_speakers}"
129
- if instruct_text:
130
- logger.info("CosyVoice inference_instruct")
131
- output = self._model.inference_instruct(
132
- input, voice, instruct_text=instruct_text
133
- )
134
- else:
135
- logger.info("CosyVoice inference_sft")
136
- output = self._model.inference_sft(input, voice)
137
177
 
138
- # Save the generated audio
139
- with BytesIO() as out:
140
- torchaudio.save(out, output["tts_speech"], 22050, format=response_format)
141
- return out.getvalue()
178
+ return self._speech_handle(
179
+ stream,
180
+ input,
181
+ instruct_text,
182
+ prompt_speech,
183
+ prompt_text,
184
+ voice,
185
+ response_format,
186
+ )
@@ -92,7 +92,7 @@ class FishSpeechModel:
92
92
 
93
93
  checkpoint_path = os.path.join(
94
94
  self._model_path,
95
- "firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
95
+ "firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
96
96
  )
97
97
  self._model = load_decoder_model(
98
98
  config_name="firefly_gan_vq",
@@ -159,11 +159,11 @@ class FishSpeechModel:
159
159
  segments = []
160
160
 
161
161
  while True:
162
- result: WrappedGenerateResponse = response_queue.get()
162
+ result: WrappedGenerateResponse = response_queue.get() # type: ignore
163
163
  if result.status == "error":
164
164
  raise Exception(str(result.response))
165
165
 
166
- result: GenerateResponse = result.response
166
+ result: GenerateResponse = result.response # type: ignore
167
167
  if result.action == "next":
168
168
  break
169
169
 
@@ -213,12 +213,12 @@ class FishSpeechModel:
213
213
  text=input,
214
214
  enable_reference_audio=False,
215
215
  reference_audio=None,
216
- reference_text="",
217
- max_new_tokens=0,
218
- chunk_length=100,
219
- top_p=0.7,
220
- repetition_penalty=1.2,
221
- temperature=0.7,
216
+ reference_text=kwargs.get("reference_text", ""),
217
+ max_new_tokens=kwargs.get("max_new_tokens", 1024),
218
+ chunk_length=kwargs.get("chunk_length", 200),
219
+ top_p=kwargs.get("top_p", 0.7),
220
+ repetition_penalty=kwargs.get("repetition_penalty", 1.2),
221
+ temperature=kwargs.get("temperature", 0.7),
222
222
  )
223
223
  )
224
224
  sample_rate, audio = result[0][1]
@@ -126,32 +126,32 @@
126
126
  {
127
127
  "model_name": "CosyVoice-300M",
128
128
  "model_family": "CosyVoice",
129
- "model_id": "model-scope/CosyVoice-300M",
130
- "model_revision": "ca4e036d2db2aa4731cc1747859a68044b6a4694",
129
+ "model_id": "FunAudioLLM/CosyVoice-300M",
130
+ "model_revision": "39c4e13d46bd4dfb840d214547623e5fcd2428e2",
131
131
  "model_ability": "audio-to-audio",
132
132
  "multilingual": true
133
133
  },
134
134
  {
135
135
  "model_name": "CosyVoice-300M-SFT",
136
136
  "model_family": "CosyVoice",
137
- "model_id": "model-scope/CosyVoice-300M-SFT",
138
- "model_revision": "ab918940c6c134b1fc1f069246e67bad6b66abcb",
137
+ "model_id": "FunAudioLLM/CosyVoice-300M-SFT",
138
+ "model_revision": "096a5cff8d497fabb3dec2756a200f3688457a1b",
139
139
  "model_ability": "text-to-audio",
140
140
  "multilingual": true
141
141
  },
142
142
  {
143
143
  "model_name": "CosyVoice-300M-Instruct",
144
144
  "model_family": "CosyVoice",
145
- "model_id": "model-scope/CosyVoice-300M-Instruct",
146
- "model_revision": "fb5f676733139f35670bed9b59a77d476b1aa898",
145
+ "model_id": "FunAudioLLM/CosyVoice-300M-Instruct",
146
+ "model_revision": "ba5265d9a3169c1fedce145122c9dd4bc24e062c",
147
147
  "model_ability": "text-to-audio",
148
148
  "multilingual": true
149
149
  },
150
150
  {
151
- "model_name": "FishSpeech-1.2-SFT",
151
+ "model_name": "FishSpeech-1.4",
152
152
  "model_family": "FishAudio",
153
- "model_id": "fishaudio/fish-speech-1.2-sft",
154
- "model_revision": "180288e21ec5c50cfc564023a22f789e4b88a0e0",
153
+ "model_id": "fishaudio/fish-speech-1.4",
154
+ "model_revision": "3c49651b8e583b6b13f55e375432e0d57e1aa84d",
155
155
  "model_ability": "text-to-audio",
156
156
  "multilingual": true
157
157
  }
@@ -12,6 +12,8 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  import logging
15
+ import os
16
+ from glob import glob
15
17
  from typing import TYPE_CHECKING, Dict, List, Optional, Union
16
18
 
17
19
  from ...device_utils import (
@@ -56,12 +58,13 @@ class WhisperModel:
56
58
  raise ValueError(f"Device {self._device} is not available!")
57
59
 
58
60
  torch_dtype = get_device_preferred_dtype(self._device)
61
+ use_safetensors = any(glob(os.path.join(self._model_path, "*.safetensors")))
59
62
 
60
63
  model = AutoModelForSpeechSeq2Seq.from_pretrained(
61
64
  self._model_path,
62
65
  torch_dtype=torch_dtype,
63
66
  low_cpu_mem_usage=True,
64
- use_safetensors=True,
67
+ use_safetensors=use_safetensors,
65
68
  )
66
69
  model.to(self._device)
67
70