xinference 0.14.2__py3-none-any.whl → 0.14.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (137) hide show
  1. xinference/_version.py +3 -3
  2. xinference/core/chat_interface.py +1 -1
  3. xinference/core/image_interface.py +9 -0
  4. xinference/core/model.py +4 -1
  5. xinference/core/worker.py +48 -41
  6. xinference/model/audio/chattts.py +24 -9
  7. xinference/model/audio/core.py +8 -2
  8. xinference/model/audio/fish_speech.py +228 -0
  9. xinference/model/audio/model_spec.json +8 -0
  10. xinference/model/embedding/core.py +23 -1
  11. xinference/model/image/model_spec.json +2 -1
  12. xinference/model/image/model_spec_modelscope.json +2 -1
  13. xinference/model/image/stable_diffusion/core.py +49 -1
  14. xinference/model/llm/__init__.py +6 -0
  15. xinference/model/llm/llm_family.json +54 -9
  16. xinference/model/llm/llm_family.py +2 -0
  17. xinference/model/llm/llm_family_modelscope.json +56 -10
  18. xinference/model/llm/lmdeploy/__init__.py +0 -0
  19. xinference/model/llm/lmdeploy/core.py +557 -0
  20. xinference/model/llm/transformers/cogvlm2.py +4 -45
  21. xinference/model/llm/transformers/cogvlm2_video.py +524 -0
  22. xinference/model/llm/transformers/core.py +1 -0
  23. xinference/model/llm/transformers/glm4v.py +2 -23
  24. xinference/model/llm/transformers/intern_vl.py +94 -11
  25. xinference/model/llm/transformers/minicpmv25.py +2 -23
  26. xinference/model/llm/transformers/minicpmv26.py +2 -22
  27. xinference/model/llm/transformers/yi_vl.py +2 -24
  28. xinference/model/llm/utils.py +10 -1
  29. xinference/model/llm/vllm/core.py +1 -1
  30. xinference/thirdparty/fish_speech/__init__.py +0 -0
  31. xinference/thirdparty/fish_speech/fish_speech/__init__.py +0 -0
  32. xinference/thirdparty/fish_speech/fish_speech/callbacks/__init__.py +3 -0
  33. xinference/thirdparty/fish_speech/fish_speech/callbacks/grad_norm.py +113 -0
  34. xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
  35. xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
  36. xinference/thirdparty/fish_speech/fish_speech/conversation.py +2 -0
  37. xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
  38. xinference/thirdparty/fish_speech/fish_speech/datasets/concat_repeat.py +53 -0
  39. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
  40. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_pb2.py +33 -0
  41. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_stream.py +36 -0
  42. xinference/thirdparty/fish_speech/fish_speech/datasets/semantic.py +496 -0
  43. xinference/thirdparty/fish_speech/fish_speech/datasets/vqgan.py +147 -0
  44. xinference/thirdparty/fish_speech/fish_speech/i18n/__init__.py +3 -0
  45. xinference/thirdparty/fish_speech/fish_speech/i18n/core.py +40 -0
  46. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
  47. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +122 -0
  48. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +122 -0
  49. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +123 -0
  50. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +133 -0
  51. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +122 -0
  52. xinference/thirdparty/fish_speech/fish_speech/i18n/scan.py +122 -0
  53. xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
  54. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/__init__.py +0 -0
  55. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lit_module.py +202 -0
  56. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +779 -0
  57. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lora.py +92 -0
  58. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +3 -0
  59. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +442 -0
  60. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
  61. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +44 -0
  62. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +625 -0
  63. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +139 -0
  64. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +115 -0
  65. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +225 -0
  66. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/utils.py +94 -0
  67. xinference/thirdparty/fish_speech/fish_speech/scheduler.py +40 -0
  68. xinference/thirdparty/fish_speech/fish_speech/text/__init__.py +4 -0
  69. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/__init__.py +0 -0
  70. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_class.py +172 -0
  71. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_constant.py +30 -0
  72. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_util.py +342 -0
  73. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/cardinal.py +32 -0
  74. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/date.py +75 -0
  75. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/digit.py +32 -0
  76. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/fraction.py +35 -0
  77. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/money.py +43 -0
  78. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/percentage.py +33 -0
  79. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/telephone.py +51 -0
  80. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +177 -0
  81. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +69 -0
  82. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +130 -0
  83. xinference/thirdparty/fish_speech/fish_speech/train.py +139 -0
  84. xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +23 -0
  85. xinference/thirdparty/fish_speech/fish_speech/utils/braceexpand.py +217 -0
  86. xinference/thirdparty/fish_speech/fish_speech/utils/context.py +13 -0
  87. xinference/thirdparty/fish_speech/fish_speech/utils/file.py +16 -0
  88. xinference/thirdparty/fish_speech/fish_speech/utils/instantiators.py +50 -0
  89. xinference/thirdparty/fish_speech/fish_speech/utils/logger.py +55 -0
  90. xinference/thirdparty/fish_speech/fish_speech/utils/logging_utils.py +48 -0
  91. xinference/thirdparty/fish_speech/fish_speech/utils/rich_utils.py +100 -0
  92. xinference/thirdparty/fish_speech/fish_speech/utils/spectrogram.py +122 -0
  93. xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +114 -0
  94. xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
  95. xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +120 -0
  96. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1237 -0
  97. xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
  98. xinference/thirdparty/fish_speech/tools/api.py +495 -0
  99. xinference/thirdparty/fish_speech/tools/auto_rerank.py +159 -0
  100. xinference/thirdparty/fish_speech/tools/download_models.py +55 -0
  101. xinference/thirdparty/fish_speech/tools/extract_model.py +21 -0
  102. xinference/thirdparty/fish_speech/tools/file.py +108 -0
  103. xinference/thirdparty/fish_speech/tools/gen_ref.py +36 -0
  104. xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
  105. xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +169 -0
  106. xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +171 -0
  107. xinference/thirdparty/fish_speech/tools/llama/generate.py +698 -0
  108. xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +95 -0
  109. xinference/thirdparty/fish_speech/tools/llama/quantize.py +497 -0
  110. xinference/thirdparty/fish_speech/tools/llama/rebuild_tokenizer.py +57 -0
  111. xinference/thirdparty/fish_speech/tools/merge_asr_files.py +55 -0
  112. xinference/thirdparty/fish_speech/tools/post_api.py +164 -0
  113. xinference/thirdparty/fish_speech/tools/sensevoice/__init__.py +0 -0
  114. xinference/thirdparty/fish_speech/tools/sensevoice/auto_model.py +573 -0
  115. xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +332 -0
  116. xinference/thirdparty/fish_speech/tools/sensevoice/vad_utils.py +61 -0
  117. xinference/thirdparty/fish_speech/tools/smart_pad.py +47 -0
  118. xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
  119. xinference/thirdparty/fish_speech/tools/vqgan/create_train_split.py +83 -0
  120. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +227 -0
  121. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +120 -0
  122. xinference/thirdparty/fish_speech/tools/webui.py +619 -0
  123. xinference/thirdparty/fish_speech/tools/whisper_asr.py +176 -0
  124. xinference/web/ui/build/asset-manifest.json +3 -3
  125. xinference/web/ui/build/index.html +1 -1
  126. xinference/web/ui/build/static/js/{main.ffc26121.js → main.661c7b0a.js} +3 -3
  127. xinference/web/ui/build/static/js/main.661c7b0a.js.map +1 -0
  128. xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +1 -0
  129. {xinference-0.14.2.dist-info → xinference-0.14.3.dist-info}/METADATA +18 -6
  130. {xinference-0.14.2.dist-info → xinference-0.14.3.dist-info}/RECORD +135 -37
  131. xinference/web/ui/build/static/js/main.ffc26121.js.map +0 -1
  132. xinference/web/ui/node_modules/.cache/babel-loader/2f40209b32e7e46a2eab6b8c8a355eb42c3caa8bc3228dd929f32fd2b3940294.json +0 -1
  133. /xinference/web/ui/build/static/js/{main.ffc26121.js.LICENSE.txt → main.661c7b0a.js.LICENSE.txt} +0 -0
  134. {xinference-0.14.2.dist-info → xinference-0.14.3.dist-info}/LICENSE +0 -0
  135. {xinference-0.14.2.dist-info → xinference-0.14.3.dist-info}/WHEEL +0 -0
  136. {xinference-0.14.2.dist-info → xinference-0.14.3.dist-info}/entry_points.txt +0 -0
  137. {xinference-0.14.2.dist-info → xinference-0.14.3.dist-info}/top_level.txt +0 -0
xinference/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2024-08-16T18:10:38+0800",
11
+ "date": "2024-08-23T18:14:53+0800",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "e4d225774dc7a9a9260396bf833e03a1df8e8a92",
15
- "version": "0.14.2"
14
+ "full-revisionid": "b5002242e04634bca7e75cac9df0cdc6c0bf407a",
15
+ "version": "0.14.3"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -340,7 +340,7 @@ class GradioInterface:
340
340
  state = gr.State([])
341
341
  with gr.Row():
342
342
  chatbot = gr.Chatbot(
343
- elem_id="chatbot", label=self.model_name, height=550, scale=7
343
+ elem_id="chatbot", label=self.model_name, height=700, scale=7
344
344
  )
345
345
  with gr.Column(scale=3):
346
346
  imagebox = gr.Image(type="filepath")
@@ -163,6 +163,7 @@ class ImageInterface:
163
163
  size_width: int,
164
164
  size_height: int,
165
165
  num_inference_steps: int,
166
+ padding_image_to_multiple: int,
166
167
  ) -> PIL.Image.Image:
167
168
  from ..client import RESTfulClient
168
169
 
@@ -178,6 +179,7 @@ class ImageInterface:
178
179
  num_inference_steps = (
179
180
  None if num_inference_steps == -1 else num_inference_steps # type: ignore
180
181
  )
182
+ padding_image_to_multiple = None if padding_image_to_multiple == -1 else padding_image_to_multiple # type: ignore
181
183
 
182
184
  bio = io.BytesIO()
183
185
  image.save(bio, format="png")
@@ -190,6 +192,7 @@ class ImageInterface:
190
192
  size=size,
191
193
  response_format="b64_json",
192
194
  num_inference_steps=num_inference_steps,
195
+ padding_image_to_multiple=padding_image_to_multiple,
193
196
  )
194
197
 
195
198
  images = []
@@ -222,9 +225,14 @@ class ImageInterface:
222
225
  n = gr.Number(label="Number of image", value=1)
223
226
  size_width = gr.Number(label="Width", value=-1)
224
227
  size_height = gr.Number(label="Height", value=-1)
228
+
229
+ with gr.Row():
225
230
  num_inference_steps = gr.Number(
226
231
  label="Inference Step Number", value=-1
227
232
  )
233
+ padding_image_to_multiple = gr.Number(
234
+ label="Padding image to multiple", value=-1
235
+ )
228
236
 
229
237
  with gr.Row():
230
238
  with gr.Column(scale=1):
@@ -242,6 +250,7 @@ class ImageInterface:
242
250
  size_width,
243
251
  size_height,
244
252
  num_inference_steps,
253
+ padding_image_to_multiple,
245
254
  ],
246
255
  outputs=output_gallery,
247
256
  )
xinference/core/model.py CHANGED
@@ -177,6 +177,7 @@ class ModelActor(xo.StatelessActor):
177
177
  request_limits: Optional[int] = None,
178
178
  ):
179
179
  super().__init__()
180
+ from ..model.llm.lmdeploy.core import LMDeployModel
180
181
  from ..model.llm.sglang.core import SGLANGModel
181
182
  from ..model.llm.transformers.core import PytorchModel
182
183
  from ..model.llm.vllm.core import VLLMModel
@@ -192,7 +193,9 @@ class ModelActor(xo.StatelessActor):
192
193
  self._current_generator = lambda: None
193
194
  self._lock = (
194
195
  None
195
- if isinstance(self._model, (PytorchModel, VLLMModel, SGLANGModel))
196
+ if isinstance(
197
+ self._model, (PytorchModel, VLLMModel, SGLANGModel, LMDeployModel)
198
+ )
196
199
  else asyncio.locks.Lock()
197
200
  )
198
201
  self._worker_ref = None
xinference/core/worker.py CHANGED
@@ -39,9 +39,11 @@ from ..core.status_guard import LaunchStatus
39
39
  from ..device_utils import get_available_device_env_name, gpu_count
40
40
  from ..model.core import ModelDescription, create_model_instance
41
41
  from ..types import PeftModelConfig
42
+ from .cache_tracker import CacheTrackerActor
42
43
  from .event import Event, EventCollectorActor, EventType
43
44
  from .metrics import launch_metrics_export_server, record_metrics
44
45
  from .resource import gather_node_info
46
+ from .status_guard import StatusGuardActor
45
47
  from .utils import log_async, log_sync, parse_replica_model_uid, purge_dir
46
48
 
47
49
  logger = getLogger(__name__)
@@ -71,6 +73,15 @@ class WorkerActor(xo.StatelessActor):
71
73
  self._supervisor_ref: Optional[xo.ActorRefType] = None
72
74
  self._main_pool = main_pool
73
75
  self._main_pool.recover_sub_pool = self.recover_sub_pool
76
+ self._status_guard_ref: xo.ActorRefType[ # type: ignore
77
+ "StatusGuardActor"
78
+ ] = None
79
+ self._event_collector_ref: xo.ActorRefType[ # type: ignore
80
+ EventCollectorActor
81
+ ] = None
82
+ self._cache_tracker_ref: xo.ActorRefType[ # type: ignore
83
+ CacheTrackerActor
84
+ ] = None
74
85
 
75
86
  # internal states.
76
87
  # temporary placeholder during model launch process:
@@ -308,56 +319,50 @@ class WorkerActor(xo.StatelessActor):
308
319
  Params:
309
320
  add_worker: By default will call supervisor.add_worker after first connect
310
321
  """
311
- from .status_guard import StatusGuardActor
312
322
  from .supervisor import SupervisorActor
313
323
 
314
324
  if self._supervisor_ref is not None:
315
325
  return self._supervisor_ref
316
- self._supervisor_ref: xo.ActorRefType["SupervisorActor"] = await xo.actor_ref( # type: ignore
326
+ supervisor_ref = await xo.actor_ref( # type: ignore
317
327
  address=self._supervisor_address, uid=SupervisorActor.uid()
318
328
  )
329
+ # Prevent concurrent operations leads to double initialization, check again.
330
+ if self._supervisor_ref is not None:
331
+ return self._supervisor_ref
332
+ self._supervisor_ref = supervisor_ref
319
333
  if add_worker and len(self._model_uid_to_model) == 0:
320
334
  # Newly started (or restarted), has no model, notify supervisor
321
335
  await self._supervisor_ref.add_worker(self.address)
322
336
  logger.info("Connected to supervisor as a fresh worker")
323
337
 
324
- self._status_guard_ref: xo.ActorRefType[ # type: ignore
325
- "StatusGuardActor"
326
- ] = await xo.actor_ref(
327
- address=self._supervisor_address, uid=StatusGuardActor.uid()
328
- )
329
-
330
- self._event_collector_ref: xo.ActorRefType[ # type: ignore
331
- EventCollectorActor
332
- ] = await xo.actor_ref(
333
- address=self._supervisor_address, uid=EventCollectorActor.uid()
334
- )
335
- from .cache_tracker import CacheTrackerActor
336
-
337
- self._cache_tracker_ref: xo.ActorRefType[ # type: ignore
338
- "CacheTrackerActor"
339
- ] = await xo.actor_ref(
340
- address=self._supervisor_address, uid=CacheTrackerActor.uid()
341
- )
342
- # cache_tracker is on supervisor
343
- from ..model.audio import get_audio_model_descriptions
344
- from ..model.embedding import get_embedding_model_descriptions
345
- from ..model.flexible import get_flexible_model_descriptions
346
- from ..model.image import get_image_model_descriptions
347
- from ..model.llm import get_llm_model_descriptions
348
- from ..model.rerank import get_rerank_model_descriptions
349
-
350
- # record model version
351
- model_version_infos: Dict[str, List[Dict]] = {} # type: ignore
352
- model_version_infos.update(get_llm_model_descriptions())
353
- model_version_infos.update(get_embedding_model_descriptions())
354
- model_version_infos.update(get_rerank_model_descriptions())
355
- model_version_infos.update(get_image_model_descriptions())
356
- model_version_infos.update(get_audio_model_descriptions())
357
- model_version_infos.update(get_flexible_model_descriptions())
358
- await self._cache_tracker_ref.record_model_version(
359
- model_version_infos, self.address
360
- )
338
+ self._status_guard_ref = await xo.actor_ref(
339
+ address=self._supervisor_address, uid=StatusGuardActor.uid()
340
+ )
341
+ self._event_collector_ref = await xo.actor_ref(
342
+ address=self._supervisor_address, uid=EventCollectorActor.uid()
343
+ )
344
+ self._cache_tracker_ref = await xo.actor_ref(
345
+ address=self._supervisor_address, uid=CacheTrackerActor.uid()
346
+ )
347
+ # cache_tracker is on supervisor
348
+ from ..model.audio import get_audio_model_descriptions
349
+ from ..model.embedding import get_embedding_model_descriptions
350
+ from ..model.flexible import get_flexible_model_descriptions
351
+ from ..model.image import get_image_model_descriptions
352
+ from ..model.llm import get_llm_model_descriptions
353
+ from ..model.rerank import get_rerank_model_descriptions
354
+
355
+ # record model version
356
+ model_version_infos: Dict[str, List[Dict]] = {} # type: ignore
357
+ model_version_infos.update(get_llm_model_descriptions())
358
+ model_version_infos.update(get_embedding_model_descriptions())
359
+ model_version_infos.update(get_rerank_model_descriptions())
360
+ model_version_infos.update(get_image_model_descriptions())
361
+ model_version_infos.update(get_audio_model_descriptions())
362
+ model_version_infos.update(get_flexible_model_descriptions())
363
+ await self._cache_tracker_ref.record_model_version(
364
+ model_version_infos, self.address
365
+ )
361
366
  return self._supervisor_ref
362
367
 
363
368
  @staticmethod
@@ -734,7 +739,7 @@ class WorkerActor(xo.StatelessActor):
734
739
  elif model_type == "image":
735
740
  return ["text_to_image"]
736
741
  elif model_type == "audio":
737
- return ["audio_to_text"]
742
+ return [model._model_spec.ability]
738
743
  elif model_type == "video":
739
744
  return ["text_to_video"]
740
745
  elif model_type == "flexible":
@@ -793,6 +798,7 @@ class WorkerActor(xo.StatelessActor):
793
798
  logger.exception(e)
794
799
  raise
795
800
  try:
801
+ _ = await self.get_supervisor_ref()
796
802
  if self._event_collector_ref is not None:
797
803
  await self._event_collector_ref.report_event(
798
804
  origin_uid,
@@ -914,6 +920,7 @@ class WorkerActor(xo.StatelessActor):
914
920
  raise ValueError(f"{model_uid} is launching")
915
921
  origin_uid, _, __ = parse_replica_model_uid(model_uid)
916
922
  try:
923
+ _ = await self.get_supervisor_ref()
917
924
  if self._event_collector_ref is not None:
918
925
  await self._event_collector_ref.report_event(
919
926
  origin_uid,
@@ -1081,7 +1088,7 @@ class WorkerActor(xo.StatelessActor):
1081
1088
  paths.update([os.path.realpath(path) for path in paths])
1082
1089
 
1083
1090
  # get tensorizer path
1084
- from ..model.llm.pytorch.tensorizer_utils import get_tensorizer_dir
1091
+ from ..model.llm.transformers.tensorizer_utils import get_tensorizer_dir
1085
1092
 
1086
1093
  tensorizer_path = get_tensorizer_dir(path)
1087
1094
  if os.path.isdir(tensorizer_path):
@@ -11,6 +11,7 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+ import base64
14
15
  import logging
15
16
  from io import BytesIO
16
17
  from typing import TYPE_CHECKING, Optional
@@ -61,16 +62,31 @@ class ChatTTSModel:
61
62
  import torchaudio
62
63
  import xxhash
63
64
 
64
- seed = xxhash.xxh32_intdigest(voice)
65
+ rnd_spk_emb = None
65
66
 
66
- torch.manual_seed(seed)
67
- np.random.seed(seed)
68
- torch.cuda.manual_seed(seed)
69
- torch.backends.cudnn.deterministic = True
70
- torch.backends.cudnn.benchmark = False
67
+ if len(voice) > 400:
68
+ try:
69
+ assert self._model is not None
70
+ b = base64.b64decode(voice)
71
+ bio = BytesIO(b)
72
+ tensor = torch.load(bio, map_location="cpu")
73
+ rnd_spk_emb = self._model._encode_spk_emb(tensor)
74
+ logger.info("Speech by input speaker")
75
+ except Exception as e:
76
+ logger.info("Fallback to random speaker due to %s", e)
71
77
 
72
- assert self._model is not None
73
- rnd_spk_emb = self._model.sample_random_speaker()
78
+ if rnd_spk_emb is None:
79
+ seed = xxhash.xxh32_intdigest(voice)
80
+
81
+ torch.manual_seed(seed)
82
+ np.random.seed(seed)
83
+ torch.cuda.manual_seed(seed)
84
+ torch.backends.cudnn.deterministic = True
85
+ torch.backends.cudnn.benchmark = False
86
+
87
+ assert self._model is not None
88
+ rnd_spk_emb = self._model.sample_random_speaker()
89
+ logger.info("Speech by voice %s", voice)
74
90
 
75
91
  default = 5
76
92
  infer_speed = int(default * speed)
@@ -100,7 +116,6 @@ class ChatTTSModel:
100
116
  if new_last_pos != last_pos:
101
117
  out.seek(last_pos)
102
118
  encoded_bytes = out.read()
103
- print(len(encoded_bytes))
104
119
  yield encoded_bytes
105
120
  last_pos = new_last_pos
106
121
 
@@ -21,6 +21,7 @@ from ..core import CacheableModelSpec, ModelDescription
21
21
  from ..utils import valid_model_revision
22
22
  from .chattts import ChatTTSModel
23
23
  from .cosyvoice import CosyVoiceModel
24
+ from .fish_speech import FishSpeechModel
24
25
  from .funasr import FunASRModel
25
26
  from .whisper import WhisperModel
26
27
 
@@ -46,6 +47,7 @@ class AudioModelFamilyV1(CacheableModelSpec):
46
47
  model_id: str
47
48
  model_revision: str
48
49
  multilingual: bool
50
+ ability: str
49
51
  default_model_config: Optional[Dict[str, Any]]
50
52
  default_transcription_config: Optional[Dict[str, Any]]
51
53
 
@@ -156,13 +158,15 @@ def create_audio_model_instance(
156
158
  model_path: Optional[str] = None,
157
159
  **kwargs,
158
160
  ) -> Tuple[
159
- Union[WhisperModel, FunASRModel, ChatTTSModel, CosyVoiceModel],
161
+ Union[WhisperModel, FunASRModel, ChatTTSModel, CosyVoiceModel, FishSpeechModel],
160
162
  AudioModelDescription,
161
163
  ]:
162
164
  model_spec = match_audio(model_name, download_hub)
163
165
  if model_path is None:
164
166
  model_path = cache(model_spec)
165
- model: Union[WhisperModel, FunASRModel, ChatTTSModel, CosyVoiceModel]
167
+ model: Union[
168
+ WhisperModel, FunASRModel, ChatTTSModel, CosyVoiceModel, FishSpeechModel
169
+ ]
166
170
  if model_spec.model_family == "whisper":
167
171
  model = WhisperModel(model_uid, model_path, model_spec, **kwargs)
168
172
  elif model_spec.model_family == "funasr":
@@ -171,6 +175,8 @@ def create_audio_model_instance(
171
175
  model = ChatTTSModel(model_uid, model_path, model_spec, **kwargs)
172
176
  elif model_spec.model_family == "CosyVoice":
173
177
  model = CosyVoiceModel(model_uid, model_path, model_spec, **kwargs)
178
+ elif model_spec.model_family == "FishAudio":
179
+ model = FishSpeechModel(model_uid, model_path, model_spec, **kwargs)
174
180
  else:
175
181
  raise Exception(f"Unsupported audio model family: {model_spec.model_family}")
176
182
  model_description = AudioModelDescription(
@@ -0,0 +1,228 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import gc
15
+ import logging
16
+ import os.path
17
+ import queue
18
+ import sys
19
+ from io import BytesIO
20
+ from typing import TYPE_CHECKING, Optional
21
+
22
+ import numpy as np
23
+ import torch
24
+
25
+ from ...device_utils import get_available_device, is_device_available
26
+
27
+ if TYPE_CHECKING:
28
+ from .core import AudioModelFamilyV1
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
34
+ import wave
35
+
36
+ buffer = BytesIO()
37
+
38
+ with wave.open(buffer, "wb") as wav_file:
39
+ wav_file.setnchannels(channels)
40
+ wav_file.setsampwidth(bit_depth // 8)
41
+ wav_file.setframerate(sample_rate)
42
+
43
+ wav_header_bytes = buffer.getvalue()
44
+ buffer.close()
45
+ return wav_header_bytes
46
+
47
+
48
+ class FishSpeechModel:
49
+ def __init__(
50
+ self,
51
+ model_uid: str,
52
+ model_path: str,
53
+ model_spec: "AudioModelFamilyV1",
54
+ device: Optional[str] = None,
55
+ **kwargs,
56
+ ):
57
+ self._model_uid = model_uid
58
+ self._model_path = model_path
59
+ self._model_spec = model_spec
60
+ self._device = device
61
+ self._llama_queue = None
62
+ self._model = None
63
+ self._kwargs = kwargs
64
+
65
+ def load(self):
66
+ # There are too many imports from fish_speech.
67
+ sys.path.insert(
68
+ 0, os.path.join(os.path.dirname(__file__), "../../thirdparty/fish_speech")
69
+ )
70
+
71
+ from tools.llama.generate import launch_thread_safe_queue
72
+ from tools.vqgan.inference import load_model as load_decoder_model
73
+
74
+ if self._device is None:
75
+ self._device = get_available_device()
76
+ else:
77
+ if not is_device_available(self._device):
78
+ raise ValueError(f"Device {self._device} is not available!")
79
+
80
+ logger.info("Loading Llama model...")
81
+ self._llama_queue = launch_thread_safe_queue(
82
+ checkpoint_path=self._model_path,
83
+ device=self._device,
84
+ precision=torch.bfloat16,
85
+ compile=False,
86
+ )
87
+ logger.info("Llama model loaded, loading VQ-GAN model...")
88
+
89
+ checkpoint_path = os.path.join(
90
+ self._model_path,
91
+ "firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
92
+ )
93
+ self._model = load_decoder_model(
94
+ config_name="firefly_gan_vq",
95
+ checkpoint_path=checkpoint_path,
96
+ device=self._device,
97
+ )
98
+
99
+ @torch.inference_mode()
100
+ def _inference(
101
+ self,
102
+ text,
103
+ enable_reference_audio,
104
+ reference_audio,
105
+ reference_text,
106
+ max_new_tokens,
107
+ chunk_length,
108
+ top_p,
109
+ repetition_penalty,
110
+ temperature,
111
+ streaming=False,
112
+ ):
113
+ from fish_speech.utils import autocast_exclude_mps
114
+ from tools.api import decode_vq_tokens, encode_reference
115
+ from tools.llama.generate import (
116
+ GenerateRequest,
117
+ GenerateResponse,
118
+ WrappedGenerateResponse,
119
+ )
120
+
121
+ # Parse reference audio aka prompt
122
+ prompt_tokens = encode_reference(
123
+ decoder_model=self._model,
124
+ reference_audio=reference_audio,
125
+ enable_reference_audio=enable_reference_audio,
126
+ )
127
+
128
+ # LLAMA Inference
129
+ request = dict(
130
+ device=self._model.device,
131
+ max_new_tokens=max_new_tokens,
132
+ text=text,
133
+ top_p=top_p,
134
+ repetition_penalty=repetition_penalty,
135
+ temperature=temperature,
136
+ compile=False,
137
+ iterative_prompt=chunk_length > 0,
138
+ chunk_length=chunk_length,
139
+ max_length=2048,
140
+ prompt_tokens=prompt_tokens if enable_reference_audio else None,
141
+ prompt_text=reference_text if enable_reference_audio else None,
142
+ )
143
+
144
+ response_queue = queue.Queue()
145
+ self._llama_queue.put(
146
+ GenerateRequest(
147
+ request=request,
148
+ response_queue=response_queue,
149
+ )
150
+ )
151
+
152
+ if streaming:
153
+ yield wav_chunk_header(), None, None
154
+
155
+ segments = []
156
+
157
+ while True:
158
+ result: WrappedGenerateResponse = response_queue.get()
159
+ if result.status == "error":
160
+ raise Exception(str(result.response))
161
+
162
+ result: GenerateResponse = result.response
163
+ if result.action == "next":
164
+ break
165
+
166
+ with autocast_exclude_mps(
167
+ device_type=self._model.device.type, dtype=torch.bfloat16
168
+ ):
169
+ fake_audios = decode_vq_tokens(
170
+ decoder_model=self._model,
171
+ codes=result.codes,
172
+ )
173
+
174
+ fake_audios = fake_audios.float().cpu().numpy()
175
+ segments.append(fake_audios)
176
+
177
+ if streaming:
178
+ yield (fake_audios * 32768).astype(np.int16).tobytes(), None, None
179
+
180
+ if len(segments) == 0:
181
+ raise Exception("No audio generated, please check the input text.")
182
+
183
+ # No matter streaming or not, we need to return the final audio
184
+ audio = np.concatenate(segments, axis=0)
185
+ yield None, (self._model.spec_transform.sample_rate, audio), None
186
+
187
+ if torch.cuda.is_available():
188
+ torch.cuda.empty_cache()
189
+ gc.collect()
190
+
191
+ def speech(
192
+ self,
193
+ input: str,
194
+ voice: str,
195
+ response_format: str = "mp3",
196
+ speed: float = 1.0,
197
+ stream: bool = False,
198
+ **kwargs,
199
+ ):
200
+ logger.warning("Fish speech does not support setting voice: %s.", voice)
201
+ if speed != 1.0:
202
+ logger.warning("Fish speech does not support setting speed: %s.", speed)
203
+ if stream is True:
204
+ logger.warning("stream mode is not implemented.")
205
+ import torchaudio
206
+
207
+ result = list(
208
+ self._inference(
209
+ text=input,
210
+ enable_reference_audio=False,
211
+ reference_audio=None,
212
+ reference_text="",
213
+ max_new_tokens=0,
214
+ chunk_length=100,
215
+ top_p=0.7,
216
+ repetition_penalty=1.2,
217
+ temperature=0.7,
218
+ )
219
+ )
220
+ sample_rate, audio = result[0][1]
221
+ audio = np.array([audio])
222
+
223
+ # Save the generated audio
224
+ with BytesIO() as out:
225
+ torchaudio.save(
226
+ out, torch.from_numpy(audio), sample_rate, format=response_format
227
+ )
228
+ return out.getvalue()
@@ -146,5 +146,13 @@
146
146
  "model_revision": "fb5f676733139f35670bed9b59a77d476b1aa898",
147
147
  "ability": "text-to-audio",
148
148
  "multilingual": true
149
+ },
150
+ {
151
+ "model_name": "FishSpeech-1.2-SFT",
152
+ "model_family": "FishAudio",
153
+ "model_id": "fishaudio/fish-speech-1.2-sft",
154
+ "model_revision": "180288e21ec5c50cfc564023a22f789e4b88a0e0",
155
+ "ability": "text-to-audio",
156
+ "multilingual": true
149
157
  }
150
158
  ]
@@ -154,10 +154,32 @@ class EmbeddingModel:
154
154
  "gte" in self._model_spec.model_name.lower()
155
155
  and "qwen2" in self._model_spec.model_name.lower()
156
156
  ):
157
+ import torch
158
+
159
+ torch_dtype_str = self._kwargs.get("torch_dtype")
160
+ if torch_dtype_str is not None:
161
+ try:
162
+ torch_dtype = getattr(torch, torch_dtype_str)
163
+ if torch_dtype not in [
164
+ torch.float16,
165
+ torch.float32,
166
+ torch.bfloat16,
167
+ ]:
168
+ logger.warning(
169
+ f"Load embedding model with unsupported torch dtype : {torch_dtype_str}. Using default torch dtype: fp32."
170
+ )
171
+ torch_dtype = torch.float32
172
+ except AttributeError:
173
+ logger.warning(
174
+ f"Load embedding model with unknown torch dtype '{torch_dtype_str}'. Using default torch dtype: fp32."
175
+ )
176
+ torch_dtype = torch.float32
177
+ else:
178
+ torch_dtype = "auto"
157
179
  self._model = XSentenceTransformer(
158
180
  self._model_path,
159
181
  device=self._device,
160
- model_kwargs={"device_map": "auto"},
182
+ model_kwargs={"device_map": "auto", "torch_dtype": torch_dtype},
161
183
  )
162
184
  else:
163
185
  self._model = SentenceTransformer(self._model_path, device=self._device)
@@ -24,7 +24,8 @@
24
24
  "model_revision": "ea42f8cef0f178587cf766dc8129abd379c90671",
25
25
  "model_ability": [
26
26
  "text2image",
27
- "image2image"
27
+ "image2image",
28
+ "inpainting"
28
29
  ]
29
30
  },
30
31
  {
@@ -27,7 +27,8 @@
27
27
  "model_revision": "master",
28
28
  "model_ability": [
29
29
  "text2image",
30
- "image2image"
30
+ "image2image",
31
+ "inpainting"
31
32
  ]
32
33
  },
33
34
  {