xinference 0.14.2__py3-none-any.whl → 0.14.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (191) hide show
  1. xinference/_version.py +3 -3
  2. xinference/core/chat_interface.py +1 -1
  3. xinference/core/image_interface.py +9 -0
  4. xinference/core/model.py +4 -1
  5. xinference/core/worker.py +60 -44
  6. xinference/model/audio/chattts.py +25 -9
  7. xinference/model/audio/core.py +8 -2
  8. xinference/model/audio/cosyvoice.py +4 -3
  9. xinference/model/audio/custom.py +4 -5
  10. xinference/model/audio/fish_speech.py +228 -0
  11. xinference/model/audio/model_spec.json +8 -0
  12. xinference/model/embedding/core.py +25 -1
  13. xinference/model/embedding/custom.py +4 -5
  14. xinference/model/flexible/core.py +5 -1
  15. xinference/model/image/custom.py +4 -5
  16. xinference/model/image/model_spec.json +2 -1
  17. xinference/model/image/model_spec_modelscope.json +2 -1
  18. xinference/model/image/stable_diffusion/core.py +66 -3
  19. xinference/model/llm/__init__.py +6 -0
  20. xinference/model/llm/llm_family.json +54 -9
  21. xinference/model/llm/llm_family.py +7 -6
  22. xinference/model/llm/llm_family_modelscope.json +56 -10
  23. xinference/model/llm/lmdeploy/__init__.py +0 -0
  24. xinference/model/llm/lmdeploy/core.py +557 -0
  25. xinference/model/llm/sglang/core.py +7 -1
  26. xinference/model/llm/transformers/cogvlm2.py +4 -45
  27. xinference/model/llm/transformers/cogvlm2_video.py +524 -0
  28. xinference/model/llm/transformers/core.py +3 -0
  29. xinference/model/llm/transformers/glm4v.py +2 -23
  30. xinference/model/llm/transformers/intern_vl.py +94 -11
  31. xinference/model/llm/transformers/minicpmv25.py +2 -23
  32. xinference/model/llm/transformers/minicpmv26.py +2 -22
  33. xinference/model/llm/transformers/yi_vl.py +2 -24
  34. xinference/model/llm/utils.py +13 -1
  35. xinference/model/llm/vllm/core.py +1 -34
  36. xinference/model/rerank/custom.py +4 -5
  37. xinference/model/utils.py +41 -1
  38. xinference/model/video/core.py +3 -1
  39. xinference/model/video/diffusers.py +41 -38
  40. xinference/model/video/model_spec.json +24 -1
  41. xinference/model/video/model_spec_modelscope.json +25 -1
  42. xinference/thirdparty/fish_speech/__init__.py +0 -0
  43. xinference/thirdparty/fish_speech/fish_speech/__init__.py +0 -0
  44. xinference/thirdparty/fish_speech/fish_speech/callbacks/__init__.py +3 -0
  45. xinference/thirdparty/fish_speech/fish_speech/callbacks/grad_norm.py +113 -0
  46. xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
  47. xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
  48. xinference/thirdparty/fish_speech/fish_speech/conversation.py +2 -0
  49. xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
  50. xinference/thirdparty/fish_speech/fish_speech/datasets/concat_repeat.py +53 -0
  51. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
  52. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_pb2.py +33 -0
  53. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_stream.py +36 -0
  54. xinference/thirdparty/fish_speech/fish_speech/datasets/semantic.py +496 -0
  55. xinference/thirdparty/fish_speech/fish_speech/datasets/vqgan.py +147 -0
  56. xinference/thirdparty/fish_speech/fish_speech/i18n/__init__.py +3 -0
  57. xinference/thirdparty/fish_speech/fish_speech/i18n/core.py +40 -0
  58. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
  59. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +122 -0
  60. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +122 -0
  61. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +123 -0
  62. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +133 -0
  63. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +122 -0
  64. xinference/thirdparty/fish_speech/fish_speech/i18n/scan.py +122 -0
  65. xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
  66. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/__init__.py +0 -0
  67. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lit_module.py +202 -0
  68. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +779 -0
  69. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lora.py +92 -0
  70. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +3 -0
  71. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +442 -0
  72. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
  73. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +44 -0
  74. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +625 -0
  75. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +139 -0
  76. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +115 -0
  77. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +225 -0
  78. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/utils.py +94 -0
  79. xinference/thirdparty/fish_speech/fish_speech/scheduler.py +40 -0
  80. xinference/thirdparty/fish_speech/fish_speech/text/__init__.py +4 -0
  81. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/__init__.py +0 -0
  82. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_class.py +172 -0
  83. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_constant.py +30 -0
  84. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_util.py +342 -0
  85. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/cardinal.py +32 -0
  86. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/date.py +75 -0
  87. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/digit.py +32 -0
  88. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/fraction.py +35 -0
  89. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/money.py +43 -0
  90. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/percentage.py +33 -0
  91. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/telephone.py +51 -0
  92. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +177 -0
  93. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +69 -0
  94. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +130 -0
  95. xinference/thirdparty/fish_speech/fish_speech/train.py +139 -0
  96. xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +23 -0
  97. xinference/thirdparty/fish_speech/fish_speech/utils/braceexpand.py +217 -0
  98. xinference/thirdparty/fish_speech/fish_speech/utils/context.py +13 -0
  99. xinference/thirdparty/fish_speech/fish_speech/utils/file.py +16 -0
  100. xinference/thirdparty/fish_speech/fish_speech/utils/instantiators.py +50 -0
  101. xinference/thirdparty/fish_speech/fish_speech/utils/logger.py +55 -0
  102. xinference/thirdparty/fish_speech/fish_speech/utils/logging_utils.py +48 -0
  103. xinference/thirdparty/fish_speech/fish_speech/utils/rich_utils.py +100 -0
  104. xinference/thirdparty/fish_speech/fish_speech/utils/spectrogram.py +122 -0
  105. xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +114 -0
  106. xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
  107. xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +120 -0
  108. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1237 -0
  109. xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
  110. xinference/thirdparty/fish_speech/tools/api.py +495 -0
  111. xinference/thirdparty/fish_speech/tools/auto_rerank.py +159 -0
  112. xinference/thirdparty/fish_speech/tools/download_models.py +55 -0
  113. xinference/thirdparty/fish_speech/tools/extract_model.py +21 -0
  114. xinference/thirdparty/fish_speech/tools/file.py +108 -0
  115. xinference/thirdparty/fish_speech/tools/gen_ref.py +36 -0
  116. xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
  117. xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +169 -0
  118. xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +171 -0
  119. xinference/thirdparty/fish_speech/tools/llama/generate.py +698 -0
  120. xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +95 -0
  121. xinference/thirdparty/fish_speech/tools/llama/quantize.py +497 -0
  122. xinference/thirdparty/fish_speech/tools/llama/rebuild_tokenizer.py +57 -0
  123. xinference/thirdparty/fish_speech/tools/merge_asr_files.py +55 -0
  124. xinference/thirdparty/fish_speech/tools/post_api.py +164 -0
  125. xinference/thirdparty/fish_speech/tools/sensevoice/__init__.py +0 -0
  126. xinference/thirdparty/fish_speech/tools/sensevoice/auto_model.py +573 -0
  127. xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +332 -0
  128. xinference/thirdparty/fish_speech/tools/sensevoice/vad_utils.py +61 -0
  129. xinference/thirdparty/fish_speech/tools/smart_pad.py +47 -0
  130. xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
  131. xinference/thirdparty/fish_speech/tools/vqgan/create_train_split.py +83 -0
  132. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +227 -0
  133. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +120 -0
  134. xinference/thirdparty/fish_speech/tools/webui.py +619 -0
  135. xinference/thirdparty/fish_speech/tools/whisper_asr.py +176 -0
  136. xinference/thirdparty/matcha/__init__.py +0 -0
  137. xinference/thirdparty/matcha/app.py +357 -0
  138. xinference/thirdparty/matcha/cli.py +419 -0
  139. xinference/thirdparty/matcha/data/__init__.py +0 -0
  140. xinference/thirdparty/matcha/data/components/__init__.py +0 -0
  141. xinference/thirdparty/matcha/data/text_mel_datamodule.py +274 -0
  142. xinference/thirdparty/matcha/hifigan/__init__.py +0 -0
  143. xinference/thirdparty/matcha/hifigan/config.py +28 -0
  144. xinference/thirdparty/matcha/hifigan/denoiser.py +64 -0
  145. xinference/thirdparty/matcha/hifigan/env.py +17 -0
  146. xinference/thirdparty/matcha/hifigan/meldataset.py +217 -0
  147. xinference/thirdparty/matcha/hifigan/models.py +368 -0
  148. xinference/thirdparty/matcha/hifigan/xutils.py +60 -0
  149. xinference/thirdparty/matcha/models/__init__.py +0 -0
  150. xinference/thirdparty/matcha/models/baselightningmodule.py +210 -0
  151. xinference/thirdparty/matcha/models/components/__init__.py +0 -0
  152. xinference/thirdparty/matcha/models/components/decoder.py +443 -0
  153. xinference/thirdparty/matcha/models/components/flow_matching.py +132 -0
  154. xinference/thirdparty/matcha/models/components/text_encoder.py +410 -0
  155. xinference/thirdparty/matcha/models/components/transformer.py +316 -0
  156. xinference/thirdparty/matcha/models/matcha_tts.py +244 -0
  157. xinference/thirdparty/matcha/onnx/__init__.py +0 -0
  158. xinference/thirdparty/matcha/onnx/export.py +181 -0
  159. xinference/thirdparty/matcha/onnx/infer.py +168 -0
  160. xinference/thirdparty/matcha/text/__init__.py +53 -0
  161. xinference/thirdparty/matcha/text/cleaners.py +121 -0
  162. xinference/thirdparty/matcha/text/numbers.py +71 -0
  163. xinference/thirdparty/matcha/text/symbols.py +17 -0
  164. xinference/thirdparty/matcha/train.py +122 -0
  165. xinference/thirdparty/matcha/utils/__init__.py +5 -0
  166. xinference/thirdparty/matcha/utils/audio.py +82 -0
  167. xinference/thirdparty/matcha/utils/generate_data_statistics.py +112 -0
  168. xinference/thirdparty/matcha/utils/get_durations_from_trained_model.py +195 -0
  169. xinference/thirdparty/matcha/utils/instantiators.py +56 -0
  170. xinference/thirdparty/matcha/utils/logging_utils.py +53 -0
  171. xinference/thirdparty/matcha/utils/model.py +90 -0
  172. xinference/thirdparty/matcha/utils/monotonic_align/__init__.py +22 -0
  173. xinference/thirdparty/matcha/utils/monotonic_align/core.pyx +47 -0
  174. xinference/thirdparty/matcha/utils/monotonic_align/setup.py +7 -0
  175. xinference/thirdparty/matcha/utils/pylogger.py +21 -0
  176. xinference/thirdparty/matcha/utils/rich_utils.py +101 -0
  177. xinference/thirdparty/matcha/utils/utils.py +259 -0
  178. xinference/web/ui/build/asset-manifest.json +3 -3
  179. xinference/web/ui/build/index.html +1 -1
  180. xinference/web/ui/build/static/js/{main.ffc26121.js → main.661c7b0a.js} +3 -3
  181. xinference/web/ui/build/static/js/main.661c7b0a.js.map +1 -0
  182. xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +1 -0
  183. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/METADATA +31 -11
  184. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/RECORD +189 -49
  185. xinference/web/ui/build/static/js/main.ffc26121.js.map +0 -1
  186. xinference/web/ui/node_modules/.cache/babel-loader/2f40209b32e7e46a2eab6b8c8a355eb42c3caa8bc3228dd929f32fd2b3940294.json +0 -1
  187. /xinference/web/ui/build/static/js/{main.ffc26121.js.LICENSE.txt → main.661c7b0a.js.LICENSE.txt} +0 -0
  188. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/LICENSE +0 -0
  189. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/WHEEL +0 -0
  190. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/entry_points.txt +0 -0
  191. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/top_level.txt +0 -0
xinference/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2024-08-16T18:10:38+0800",
11
+ "date": "2024-08-30T18:54:16+0800",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "e4d225774dc7a9a9260396bf833e03a1df8e8a92",
15
- "version": "0.14.2"
14
+ "full-revisionid": "f3d510eceffbbbc41ce065919fd2c48411017573",
15
+ "version": "0.14.4"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -340,7 +340,7 @@ class GradioInterface:
340
340
  state = gr.State([])
341
341
  with gr.Row():
342
342
  chatbot = gr.Chatbot(
343
- elem_id="chatbot", label=self.model_name, height=550, scale=7
343
+ elem_id="chatbot", label=self.model_name, height=700, scale=7
344
344
  )
345
345
  with gr.Column(scale=3):
346
346
  imagebox = gr.Image(type="filepath")
@@ -163,6 +163,7 @@ class ImageInterface:
163
163
  size_width: int,
164
164
  size_height: int,
165
165
  num_inference_steps: int,
166
+ padding_image_to_multiple: int,
166
167
  ) -> PIL.Image.Image:
167
168
  from ..client import RESTfulClient
168
169
 
@@ -178,6 +179,7 @@ class ImageInterface:
178
179
  num_inference_steps = (
179
180
  None if num_inference_steps == -1 else num_inference_steps # type: ignore
180
181
  )
182
+ padding_image_to_multiple = None if padding_image_to_multiple == -1 else padding_image_to_multiple # type: ignore
181
183
 
182
184
  bio = io.BytesIO()
183
185
  image.save(bio, format="png")
@@ -190,6 +192,7 @@ class ImageInterface:
190
192
  size=size,
191
193
  response_format="b64_json",
192
194
  num_inference_steps=num_inference_steps,
195
+ padding_image_to_multiple=padding_image_to_multiple,
193
196
  )
194
197
 
195
198
  images = []
@@ -222,9 +225,14 @@ class ImageInterface:
222
225
  n = gr.Number(label="Number of image", value=1)
223
226
  size_width = gr.Number(label="Width", value=-1)
224
227
  size_height = gr.Number(label="Height", value=-1)
228
+
229
+ with gr.Row():
225
230
  num_inference_steps = gr.Number(
226
231
  label="Inference Step Number", value=-1
227
232
  )
233
+ padding_image_to_multiple = gr.Number(
234
+ label="Padding image to multiple", value=-1
235
+ )
228
236
 
229
237
  with gr.Row():
230
238
  with gr.Column(scale=1):
@@ -242,6 +250,7 @@ class ImageInterface:
242
250
  size_width,
243
251
  size_height,
244
252
  num_inference_steps,
253
+ padding_image_to_multiple,
245
254
  ],
246
255
  outputs=output_gallery,
247
256
  )
xinference/core/model.py CHANGED
@@ -177,6 +177,7 @@ class ModelActor(xo.StatelessActor):
177
177
  request_limits: Optional[int] = None,
178
178
  ):
179
179
  super().__init__()
180
+ from ..model.llm.lmdeploy.core import LMDeployModel
180
181
  from ..model.llm.sglang.core import SGLANGModel
181
182
  from ..model.llm.transformers.core import PytorchModel
182
183
  from ..model.llm.vllm.core import VLLMModel
@@ -192,7 +193,9 @@ class ModelActor(xo.StatelessActor):
192
193
  self._current_generator = lambda: None
193
194
  self._lock = (
194
195
  None
195
- if isinstance(self._model, (PytorchModel, VLLMModel, SGLANGModel))
196
+ if isinstance(
197
+ self._model, (PytorchModel, VLLMModel, SGLANGModel, LMDeployModel)
198
+ )
196
199
  else asyncio.locks.Lock()
197
200
  )
198
201
  self._worker_ref = None
xinference/core/worker.py CHANGED
@@ -39,9 +39,11 @@ from ..core.status_guard import LaunchStatus
39
39
  from ..device_utils import get_available_device_env_name, gpu_count
40
40
  from ..model.core import ModelDescription, create_model_instance
41
41
  from ..types import PeftModelConfig
42
+ from .cache_tracker import CacheTrackerActor
42
43
  from .event import Event, EventCollectorActor, EventType
43
44
  from .metrics import launch_metrics_export_server, record_metrics
44
45
  from .resource import gather_node_info
46
+ from .status_guard import StatusGuardActor
45
47
  from .utils import log_async, log_sync, parse_replica_model_uid, purge_dir
46
48
 
47
49
  logger = getLogger(__name__)
@@ -71,6 +73,15 @@ class WorkerActor(xo.StatelessActor):
71
73
  self._supervisor_ref: Optional[xo.ActorRefType] = None
72
74
  self._main_pool = main_pool
73
75
  self._main_pool.recover_sub_pool = self.recover_sub_pool
76
+ self._status_guard_ref: xo.ActorRefType["StatusGuardActor"] = ( # type: ignore
77
+ None
78
+ )
79
+ self._event_collector_ref: xo.ActorRefType[ # type: ignore
80
+ EventCollectorActor
81
+ ] = None
82
+ self._cache_tracker_ref: xo.ActorRefType[CacheTrackerActor] = ( # type: ignore
83
+ None
84
+ )
74
85
 
75
86
  # internal states.
76
87
  # temporary placeholder during model launch process:
@@ -135,7 +146,7 @@ class WorkerActor(xo.StatelessActor):
135
146
  else:
136
147
  recover_count = self._model_uid_to_recover_count.get(model_uid)
137
148
  try:
138
- await self.terminate_model(model_uid)
149
+ await self.terminate_model(model_uid, is_model_die=True)
139
150
  except Exception:
140
151
  pass
141
152
  if recover_count is not None:
@@ -308,56 +319,50 @@ class WorkerActor(xo.StatelessActor):
308
319
  Params:
309
320
  add_worker: By default will call supervisor.add_worker after first connect
310
321
  """
311
- from .status_guard import StatusGuardActor
312
322
  from .supervisor import SupervisorActor
313
323
 
314
324
  if self._supervisor_ref is not None:
315
325
  return self._supervisor_ref
316
- self._supervisor_ref: xo.ActorRefType["SupervisorActor"] = await xo.actor_ref( # type: ignore
326
+ supervisor_ref = await xo.actor_ref( # type: ignore
317
327
  address=self._supervisor_address, uid=SupervisorActor.uid()
318
328
  )
329
+ # Prevent concurrent operations leads to double initialization, check again.
330
+ if self._supervisor_ref is not None:
331
+ return self._supervisor_ref
332
+ self._supervisor_ref = supervisor_ref
319
333
  if add_worker and len(self._model_uid_to_model) == 0:
320
334
  # Newly started (or restarted), has no model, notify supervisor
321
335
  await self._supervisor_ref.add_worker(self.address)
322
336
  logger.info("Connected to supervisor as a fresh worker")
323
337
 
324
- self._status_guard_ref: xo.ActorRefType[ # type: ignore
325
- "StatusGuardActor"
326
- ] = await xo.actor_ref(
327
- address=self._supervisor_address, uid=StatusGuardActor.uid()
328
- )
329
-
330
- self._event_collector_ref: xo.ActorRefType[ # type: ignore
331
- EventCollectorActor
332
- ] = await xo.actor_ref(
333
- address=self._supervisor_address, uid=EventCollectorActor.uid()
334
- )
335
- from .cache_tracker import CacheTrackerActor
336
-
337
- self._cache_tracker_ref: xo.ActorRefType[ # type: ignore
338
- "CacheTrackerActor"
339
- ] = await xo.actor_ref(
340
- address=self._supervisor_address, uid=CacheTrackerActor.uid()
341
- )
342
- # cache_tracker is on supervisor
343
- from ..model.audio import get_audio_model_descriptions
344
- from ..model.embedding import get_embedding_model_descriptions
345
- from ..model.flexible import get_flexible_model_descriptions
346
- from ..model.image import get_image_model_descriptions
347
- from ..model.llm import get_llm_model_descriptions
348
- from ..model.rerank import get_rerank_model_descriptions
349
-
350
- # record model version
351
- model_version_infos: Dict[str, List[Dict]] = {} # type: ignore
352
- model_version_infos.update(get_llm_model_descriptions())
353
- model_version_infos.update(get_embedding_model_descriptions())
354
- model_version_infos.update(get_rerank_model_descriptions())
355
- model_version_infos.update(get_image_model_descriptions())
356
- model_version_infos.update(get_audio_model_descriptions())
357
- model_version_infos.update(get_flexible_model_descriptions())
358
- await self._cache_tracker_ref.record_model_version(
359
- model_version_infos, self.address
360
- )
338
+ self._status_guard_ref = await xo.actor_ref(
339
+ address=self._supervisor_address, uid=StatusGuardActor.uid()
340
+ )
341
+ self._event_collector_ref = await xo.actor_ref(
342
+ address=self._supervisor_address, uid=EventCollectorActor.uid()
343
+ )
344
+ self._cache_tracker_ref = await xo.actor_ref(
345
+ address=self._supervisor_address, uid=CacheTrackerActor.uid()
346
+ )
347
+ # cache_tracker is on supervisor
348
+ from ..model.audio import get_audio_model_descriptions
349
+ from ..model.embedding import get_embedding_model_descriptions
350
+ from ..model.flexible import get_flexible_model_descriptions
351
+ from ..model.image import get_image_model_descriptions
352
+ from ..model.llm import get_llm_model_descriptions
353
+ from ..model.rerank import get_rerank_model_descriptions
354
+
355
+ # record model version
356
+ model_version_infos: Dict[str, List[Dict]] = {} # type: ignore
357
+ model_version_infos.update(get_llm_model_descriptions())
358
+ model_version_infos.update(get_embedding_model_descriptions())
359
+ model_version_infos.update(get_rerank_model_descriptions())
360
+ model_version_infos.update(get_image_model_descriptions())
361
+ model_version_infos.update(get_audio_model_descriptions())
362
+ model_version_infos.update(get_flexible_model_descriptions())
363
+ await self._cache_tracker_ref.record_model_version(
364
+ model_version_infos, self.address
365
+ )
361
366
  return self._supervisor_ref
362
367
 
363
368
  @staticmethod
@@ -659,6 +664,8 @@ class WorkerActor(xo.StatelessActor):
659
664
 
660
665
  ret.sort(key=sort_helper)
661
666
  return ret
667
+ elif model_type == "video":
668
+ return []
662
669
  elif model_type == "rerank":
663
670
  from ..model.rerank.custom import get_user_defined_reranks
664
671
 
@@ -698,6 +705,8 @@ class WorkerActor(xo.StatelessActor):
698
705
  for f in get_user_defined_audios():
699
706
  if f.model_name == model_name:
700
707
  return f
708
+ elif model_type == "video":
709
+ return None
701
710
  elif model_type == "rerank":
702
711
  from ..model.rerank.custom import get_user_defined_reranks
703
712
 
@@ -734,7 +743,7 @@ class WorkerActor(xo.StatelessActor):
734
743
  elif model_type == "image":
735
744
  return ["text_to_image"]
736
745
  elif model_type == "audio":
737
- return ["audio_to_text"]
746
+ return [model._model_spec.ability]
738
747
  elif model_type == "video":
739
748
  return ["text_to_video"]
740
749
  elif model_type == "flexible":
@@ -793,6 +802,7 @@ class WorkerActor(xo.StatelessActor):
793
802
  logger.exception(e)
794
803
  raise
795
804
  try:
805
+ _ = await self.get_supervisor_ref()
796
806
  if self._event_collector_ref is not None:
797
807
  await self._event_collector_ref.report_event(
798
808
  origin_uid,
@@ -908,12 +918,13 @@ class WorkerActor(xo.StatelessActor):
908
918
  )
909
919
 
910
920
  @log_async(logger=logger)
911
- async def terminate_model(self, model_uid: str):
921
+ async def terminate_model(self, model_uid: str, is_model_die=False):
912
922
  # Terminate model while its launching is not allow
913
923
  if model_uid in self._model_uid_launching_guard:
914
924
  raise ValueError(f"{model_uid} is launching")
915
925
  origin_uid, _, __ = parse_replica_model_uid(model_uid)
916
926
  try:
927
+ _ = await self.get_supervisor_ref()
917
928
  if self._event_collector_ref is not None:
918
929
  await self._event_collector_ref.report_event(
919
930
  origin_uid,
@@ -956,11 +967,16 @@ class WorkerActor(xo.StatelessActor):
956
967
  self._model_uid_to_recover_count.pop(model_uid, None)
957
968
  self._model_uid_to_launch_args.pop(model_uid, None)
958
969
 
970
+ if is_model_die:
971
+ status = LaunchStatus.ERROR.name
972
+ else:
973
+ status = LaunchStatus.TERMINATED.name
974
+
959
975
  if self._status_guard_ref is None:
960
976
  _ = await self.get_supervisor_ref()
961
977
  assert self._status_guard_ref is not None
962
978
  await self._status_guard_ref.update_instance_info(
963
- origin_uid, {"status": LaunchStatus.TERMINATED.name}
979
+ origin_uid, {"status": status}
964
980
  )
965
981
 
966
982
  # Provide an interface for future version of supervisor to call
@@ -1081,7 +1097,7 @@ class WorkerActor(xo.StatelessActor):
1081
1097
  paths.update([os.path.realpath(path) for path in paths])
1082
1098
 
1083
1099
  # get tensorizer path
1084
- from ..model.llm.pytorch.tensorizer_utils import get_tensorizer_dir
1100
+ from ..model.llm.transformers.tensorizer_utils import get_tensorizer_dir
1085
1101
 
1086
1102
  tensorizer_path = get_tensorizer_dir(path)
1087
1103
  if os.path.isdir(tensorizer_path):
@@ -11,10 +11,14 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+
15
+ import base64
14
16
  import logging
15
17
  from io import BytesIO
16
18
  from typing import TYPE_CHECKING, Optional
17
19
 
20
+ from ..utils import set_all_random_seed
21
+
18
22
  if TYPE_CHECKING:
19
23
  from .core import AudioModelFamilyV1
20
24
 
@@ -61,16 +65,29 @@ class ChatTTSModel:
61
65
  import torchaudio
62
66
  import xxhash
63
67
 
64
- seed = xxhash.xxh32_intdigest(voice)
68
+ rnd_spk_emb = None
65
69
 
66
- torch.manual_seed(seed)
67
- np.random.seed(seed)
68
- torch.cuda.manual_seed(seed)
69
- torch.backends.cudnn.deterministic = True
70
- torch.backends.cudnn.benchmark = False
70
+ if len(voice) > 400:
71
+ try:
72
+ assert self._model is not None
73
+ b = base64.b64decode(voice)
74
+ bio = BytesIO(b)
75
+ tensor = torch.load(bio, map_location="cpu")
76
+ rnd_spk_emb = self._model._encode_spk_emb(tensor)
77
+ logger.info("Speech by input speaker")
78
+ except Exception as e:
79
+ logger.info("Fallback to random speaker due to %s", e)
71
80
 
72
- assert self._model is not None
73
- rnd_spk_emb = self._model.sample_random_speaker()
81
+ if rnd_spk_emb is None:
82
+ seed = xxhash.xxh32_intdigest(voice)
83
+
84
+ set_all_random_seed(seed)
85
+ torch.backends.cudnn.deterministic = True
86
+ torch.backends.cudnn.benchmark = False
87
+
88
+ assert self._model is not None
89
+ rnd_spk_emb = self._model.sample_random_speaker()
90
+ logger.info("Speech by voice %s", voice)
74
91
 
75
92
  default = 5
76
93
  infer_speed = int(default * speed)
@@ -100,7 +117,6 @@ class ChatTTSModel:
100
117
  if new_last_pos != last_pos:
101
118
  out.seek(last_pos)
102
119
  encoded_bytes = out.read()
103
- print(len(encoded_bytes))
104
120
  yield encoded_bytes
105
121
  last_pos = new_last_pos
106
122
 
@@ -21,6 +21,7 @@ from ..core import CacheableModelSpec, ModelDescription
21
21
  from ..utils import valid_model_revision
22
22
  from .chattts import ChatTTSModel
23
23
  from .cosyvoice import CosyVoiceModel
24
+ from .fish_speech import FishSpeechModel
24
25
  from .funasr import FunASRModel
25
26
  from .whisper import WhisperModel
26
27
 
@@ -46,6 +47,7 @@ class AudioModelFamilyV1(CacheableModelSpec):
46
47
  model_id: str
47
48
  model_revision: str
48
49
  multilingual: bool
50
+ ability: str
49
51
  default_model_config: Optional[Dict[str, Any]]
50
52
  default_transcription_config: Optional[Dict[str, Any]]
51
53
 
@@ -156,13 +158,15 @@ def create_audio_model_instance(
156
158
  model_path: Optional[str] = None,
157
159
  **kwargs,
158
160
  ) -> Tuple[
159
- Union[WhisperModel, FunASRModel, ChatTTSModel, CosyVoiceModel],
161
+ Union[WhisperModel, FunASRModel, ChatTTSModel, CosyVoiceModel, FishSpeechModel],
160
162
  AudioModelDescription,
161
163
  ]:
162
164
  model_spec = match_audio(model_name, download_hub)
163
165
  if model_path is None:
164
166
  model_path = cache(model_spec)
165
- model: Union[WhisperModel, FunASRModel, ChatTTSModel, CosyVoiceModel]
167
+ model: Union[
168
+ WhisperModel, FunASRModel, ChatTTSModel, CosyVoiceModel, FishSpeechModel
169
+ ]
166
170
  if model_spec.model_family == "whisper":
167
171
  model = WhisperModel(model_uid, model_path, model_spec, **kwargs)
168
172
  elif model_spec.model_family == "funasr":
@@ -171,6 +175,8 @@ def create_audio_model_instance(
171
175
  model = ChatTTSModel(model_uid, model_path, model_spec, **kwargs)
172
176
  elif model_spec.model_family == "CosyVoice":
173
177
  model = CosyVoiceModel(model_uid, model_path, model_spec, **kwargs)
178
+ elif model_spec.model_family == "FishAudio":
179
+ model = FishSpeechModel(model_uid, model_path, model_spec, **kwargs)
174
180
  else:
175
181
  raise Exception(f"Unsupported audio model family: {model_spec.model_family}")
176
182
  model_description = AudioModelDescription(
@@ -16,6 +16,8 @@ import logging
16
16
  from io import BytesIO
17
17
  from typing import TYPE_CHECKING, Optional
18
18
 
19
+ from ..utils import set_all_random_seed
20
+
19
21
  if TYPE_CHECKING:
20
22
  from .core import AudioModelFamilyV1
21
23
 
@@ -67,6 +69,7 @@ class CosyVoiceModel:
67
69
  prompt_speech: Optional[bytes] = kwargs.pop("prompt_speech", None)
68
70
  prompt_text: Optional[str] = kwargs.pop("prompt_text", None)
69
71
  instruct_text: Optional[str] = kwargs.pop("instruct_text", None)
72
+ seed: Optional[int] = kwargs.pop("seed", 0)
70
73
 
71
74
  if "SFT" in self._model_spec.model_name:
72
75
  # inference_sft
@@ -87,9 +90,6 @@ class CosyVoiceModel:
87
90
  assert (
88
91
  prompt_text is None
89
92
  ), "CosyVoice Instruct model does not support prompt_text"
90
- assert (
91
- instruct_text is not None
92
- ), "CosyVoice Instruct model expect a instruct_text"
93
93
  else:
94
94
  # inference_zero_shot
95
95
  # inference_cross_lingual
@@ -99,6 +99,7 @@ class CosyVoiceModel:
99
99
  ), "CosyVoice model does not support instruct_text"
100
100
 
101
101
  assert self._model is not None
102
+ set_all_random_seed(seed)
102
103
  if prompt_speech:
103
104
  assert not voice, "voice can't be set with prompt speech."
104
105
  with io.BytesIO(prompt_speech) as prompt_speech_io:
@@ -88,6 +88,10 @@ def register_audio(model_spec: CustomAudioModelFamilyV1, persist: bool):
88
88
  if not is_valid_model_name(model_spec.model_name):
89
89
  raise ValueError(f"Invalid model name {model_spec.model_name}.")
90
90
 
91
+ model_uri = model_spec.model_uri
92
+ if model_uri and not is_valid_model_uri(model_uri):
93
+ raise ValueError(f"Invalid model URI {model_uri}.")
94
+
91
95
  with UD_AUDIO_LOCK:
92
96
  for model_name in (
93
97
  list(BUILTIN_AUDIO_MODELS.keys())
@@ -102,11 +106,6 @@ def register_audio(model_spec: CustomAudioModelFamilyV1, persist: bool):
102
106
  UD_AUDIOS.append(model_spec)
103
107
 
104
108
  if persist:
105
- # We only validate model URL when persist is True.
106
- model_uri = model_spec.model_uri
107
- if model_uri and not is_valid_model_uri(model_uri):
108
- raise ValueError(f"Invalid model URI {model_uri}.")
109
-
110
109
  persist_path = os.path.join(
111
110
  XINFERENCE_MODEL_DIR, "audio", f"{model_spec.model_name}.json"
112
111
  )
@@ -0,0 +1,228 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import gc
15
+ import logging
16
+ import os.path
17
+ import queue
18
+ import sys
19
+ from io import BytesIO
20
+ from typing import TYPE_CHECKING, Optional
21
+
22
+ import numpy as np
23
+ import torch
24
+
25
+ from ...device_utils import get_available_device, is_device_available
26
+
27
+ if TYPE_CHECKING:
28
+ from .core import AudioModelFamilyV1
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
34
+ import wave
35
+
36
+ buffer = BytesIO()
37
+
38
+ with wave.open(buffer, "wb") as wav_file:
39
+ wav_file.setnchannels(channels)
40
+ wav_file.setsampwidth(bit_depth // 8)
41
+ wav_file.setframerate(sample_rate)
42
+
43
+ wav_header_bytes = buffer.getvalue()
44
+ buffer.close()
45
+ return wav_header_bytes
46
+
47
+
48
+ class FishSpeechModel:
49
+ def __init__(
50
+ self,
51
+ model_uid: str,
52
+ model_path: str,
53
+ model_spec: "AudioModelFamilyV1",
54
+ device: Optional[str] = None,
55
+ **kwargs,
56
+ ):
57
+ self._model_uid = model_uid
58
+ self._model_path = model_path
59
+ self._model_spec = model_spec
60
+ self._device = device
61
+ self._llama_queue = None
62
+ self._model = None
63
+ self._kwargs = kwargs
64
+
65
+ def load(self):
66
+ # There are too many imports from fish_speech.
67
+ sys.path.insert(
68
+ 0, os.path.join(os.path.dirname(__file__), "../../thirdparty/fish_speech")
69
+ )
70
+
71
+ from tools.llama.generate import launch_thread_safe_queue
72
+ from tools.vqgan.inference import load_model as load_decoder_model
73
+
74
+ if self._device is None:
75
+ self._device = get_available_device()
76
+ else:
77
+ if not is_device_available(self._device):
78
+ raise ValueError(f"Device {self._device} is not available!")
79
+
80
+ logger.info("Loading Llama model...")
81
+ self._llama_queue = launch_thread_safe_queue(
82
+ checkpoint_path=self._model_path,
83
+ device=self._device,
84
+ precision=torch.bfloat16,
85
+ compile=False,
86
+ )
87
+ logger.info("Llama model loaded, loading VQ-GAN model...")
88
+
89
+ checkpoint_path = os.path.join(
90
+ self._model_path,
91
+ "firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
92
+ )
93
+ self._model = load_decoder_model(
94
+ config_name="firefly_gan_vq",
95
+ checkpoint_path=checkpoint_path,
96
+ device=self._device,
97
+ )
98
+
99
+ @torch.inference_mode()
100
+ def _inference(
101
+ self,
102
+ text,
103
+ enable_reference_audio,
104
+ reference_audio,
105
+ reference_text,
106
+ max_new_tokens,
107
+ chunk_length,
108
+ top_p,
109
+ repetition_penalty,
110
+ temperature,
111
+ streaming=False,
112
+ ):
113
+ from fish_speech.utils import autocast_exclude_mps
114
+ from tools.api import decode_vq_tokens, encode_reference
115
+ from tools.llama.generate import (
116
+ GenerateRequest,
117
+ GenerateResponse,
118
+ WrappedGenerateResponse,
119
+ )
120
+
121
+ # Parse reference audio aka prompt
122
+ prompt_tokens = encode_reference(
123
+ decoder_model=self._model,
124
+ reference_audio=reference_audio,
125
+ enable_reference_audio=enable_reference_audio,
126
+ )
127
+
128
+ # LLAMA Inference
129
+ request = dict(
130
+ device=self._model.device,
131
+ max_new_tokens=max_new_tokens,
132
+ text=text,
133
+ top_p=top_p,
134
+ repetition_penalty=repetition_penalty,
135
+ temperature=temperature,
136
+ compile=False,
137
+ iterative_prompt=chunk_length > 0,
138
+ chunk_length=chunk_length,
139
+ max_length=2048,
140
+ prompt_tokens=prompt_tokens if enable_reference_audio else None,
141
+ prompt_text=reference_text if enable_reference_audio else None,
142
+ )
143
+
144
+ response_queue = queue.Queue()
145
+ self._llama_queue.put(
146
+ GenerateRequest(
147
+ request=request,
148
+ response_queue=response_queue,
149
+ )
150
+ )
151
+
152
+ if streaming:
153
+ yield wav_chunk_header(), None, None
154
+
155
+ segments = []
156
+
157
+ while True:
158
+ result: WrappedGenerateResponse = response_queue.get()
159
+ if result.status == "error":
160
+ raise Exception(str(result.response))
161
+
162
+ result: GenerateResponse = result.response
163
+ if result.action == "next":
164
+ break
165
+
166
+ with autocast_exclude_mps(
167
+ device_type=self._model.device.type, dtype=torch.bfloat16
168
+ ):
169
+ fake_audios = decode_vq_tokens(
170
+ decoder_model=self._model,
171
+ codes=result.codes,
172
+ )
173
+
174
+ fake_audios = fake_audios.float().cpu().numpy()
175
+ segments.append(fake_audios)
176
+
177
+ if streaming:
178
+ yield (fake_audios * 32768).astype(np.int16).tobytes(), None, None
179
+
180
+ if len(segments) == 0:
181
+ raise Exception("No audio generated, please check the input text.")
182
+
183
+ # No matter streaming or not, we need to return the final audio
184
+ audio = np.concatenate(segments, axis=0)
185
+ yield None, (self._model.spec_transform.sample_rate, audio), None
186
+
187
+ if torch.cuda.is_available():
188
+ torch.cuda.empty_cache()
189
+ gc.collect()
190
+
191
+ def speech(
192
+ self,
193
+ input: str,
194
+ voice: str,
195
+ response_format: str = "mp3",
196
+ speed: float = 1.0,
197
+ stream: bool = False,
198
+ **kwargs,
199
+ ):
200
+ logger.warning("Fish speech does not support setting voice: %s.", voice)
201
+ if speed != 1.0:
202
+ logger.warning("Fish speech does not support setting speed: %s.", speed)
203
+ if stream is True:
204
+ logger.warning("stream mode is not implemented.")
205
+ import torchaudio
206
+
207
+ result = list(
208
+ self._inference(
209
+ text=input,
210
+ enable_reference_audio=False,
211
+ reference_audio=None,
212
+ reference_text="",
213
+ max_new_tokens=0,
214
+ chunk_length=100,
215
+ top_p=0.7,
216
+ repetition_penalty=1.2,
217
+ temperature=0.7,
218
+ )
219
+ )
220
+ sample_rate, audio = result[0][1]
221
+ audio = np.array([audio])
222
+
223
+ # Save the generated audio
224
+ with BytesIO() as out:
225
+ torchaudio.save(
226
+ out, torch.from_numpy(audio), sample_rate, format=response_format
227
+ )
228
+ return out.getvalue()