xinference 0.14.1.post1__py3-none-any.whl → 0.14.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (194) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +15 -34
  3. xinference/client/restful/restful_client.py +2 -2
  4. xinference/core/chat_interface.py +45 -10
  5. xinference/core/image_interface.py +9 -0
  6. xinference/core/model.py +8 -5
  7. xinference/core/scheduler.py +1 -2
  8. xinference/core/worker.py +49 -42
  9. xinference/deploy/cmdline.py +2 -2
  10. xinference/deploy/test/test_cmdline.py +7 -7
  11. xinference/model/audio/chattts.py +24 -9
  12. xinference/model/audio/core.py +8 -2
  13. xinference/model/audio/fish_speech.py +228 -0
  14. xinference/model/audio/model_spec.json +8 -0
  15. xinference/model/embedding/core.py +23 -1
  16. xinference/model/image/model_spec.json +2 -1
  17. xinference/model/image/model_spec_modelscope.json +2 -1
  18. xinference/model/image/stable_diffusion/core.py +49 -1
  19. xinference/model/llm/__init__.py +26 -27
  20. xinference/model/llm/{ggml/llamacpp.py → llama_cpp/core.py} +2 -35
  21. xinference/model/llm/llm_family.json +606 -1266
  22. xinference/model/llm/llm_family.py +16 -139
  23. xinference/model/llm/llm_family_modelscope.json +276 -313
  24. xinference/model/llm/lmdeploy/__init__.py +0 -0
  25. xinference/model/llm/lmdeploy/core.py +557 -0
  26. xinference/model/llm/memory.py +9 -9
  27. xinference/model/llm/sglang/core.py +2 -2
  28. xinference/model/llm/{pytorch → transformers}/chatglm.py +6 -13
  29. xinference/model/llm/{pytorch → transformers}/cogvlm2.py +4 -45
  30. xinference/model/llm/transformers/cogvlm2_video.py +524 -0
  31. xinference/model/llm/{pytorch → transformers}/core.py +3 -10
  32. xinference/model/llm/{pytorch → transformers}/glm4v.py +2 -23
  33. xinference/model/llm/transformers/intern_vl.py +540 -0
  34. xinference/model/llm/{pytorch → transformers}/internlm2.py +4 -8
  35. xinference/model/llm/{pytorch → transformers}/minicpmv25.py +2 -23
  36. xinference/model/llm/{pytorch → transformers}/minicpmv26.py +66 -41
  37. xinference/model/llm/{pytorch → transformers}/utils.py +1 -2
  38. xinference/model/llm/{pytorch → transformers}/yi_vl.py +2 -24
  39. xinference/model/llm/utils.py +85 -70
  40. xinference/model/llm/vllm/core.py +110 -11
  41. xinference/model/utils.py +1 -95
  42. xinference/thirdparty/fish_speech/__init__.py +0 -0
  43. xinference/thirdparty/fish_speech/fish_speech/__init__.py +0 -0
  44. xinference/thirdparty/fish_speech/fish_speech/callbacks/__init__.py +3 -0
  45. xinference/thirdparty/fish_speech/fish_speech/callbacks/grad_norm.py +113 -0
  46. xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
  47. xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
  48. xinference/thirdparty/fish_speech/fish_speech/conversation.py +2 -0
  49. xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
  50. xinference/thirdparty/fish_speech/fish_speech/datasets/concat_repeat.py +53 -0
  51. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
  52. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_pb2.py +33 -0
  53. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_stream.py +36 -0
  54. xinference/thirdparty/fish_speech/fish_speech/datasets/semantic.py +496 -0
  55. xinference/thirdparty/fish_speech/fish_speech/datasets/vqgan.py +147 -0
  56. xinference/thirdparty/fish_speech/fish_speech/i18n/__init__.py +3 -0
  57. xinference/thirdparty/fish_speech/fish_speech/i18n/core.py +40 -0
  58. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
  59. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +122 -0
  60. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +122 -0
  61. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +123 -0
  62. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +133 -0
  63. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +122 -0
  64. xinference/thirdparty/fish_speech/fish_speech/i18n/scan.py +122 -0
  65. xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
  66. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/__init__.py +0 -0
  67. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lit_module.py +202 -0
  68. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +779 -0
  69. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lora.py +92 -0
  70. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +3 -0
  71. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +442 -0
  72. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
  73. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +44 -0
  74. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +625 -0
  75. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +139 -0
  76. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +115 -0
  77. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +225 -0
  78. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/utils.py +94 -0
  79. xinference/thirdparty/fish_speech/fish_speech/scheduler.py +40 -0
  80. xinference/thirdparty/fish_speech/fish_speech/text/__init__.py +4 -0
  81. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/__init__.py +0 -0
  82. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_class.py +172 -0
  83. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_constant.py +30 -0
  84. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_util.py +342 -0
  85. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/cardinal.py +32 -0
  86. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/date.py +75 -0
  87. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/digit.py +32 -0
  88. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/fraction.py +35 -0
  89. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/money.py +43 -0
  90. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/percentage.py +33 -0
  91. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/telephone.py +51 -0
  92. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +177 -0
  93. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +69 -0
  94. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +130 -0
  95. xinference/thirdparty/fish_speech/fish_speech/train.py +139 -0
  96. xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +23 -0
  97. xinference/thirdparty/fish_speech/fish_speech/utils/braceexpand.py +217 -0
  98. xinference/thirdparty/fish_speech/fish_speech/utils/context.py +13 -0
  99. xinference/thirdparty/fish_speech/fish_speech/utils/file.py +16 -0
  100. xinference/thirdparty/fish_speech/fish_speech/utils/instantiators.py +50 -0
  101. xinference/thirdparty/fish_speech/fish_speech/utils/logger.py +55 -0
  102. xinference/thirdparty/fish_speech/fish_speech/utils/logging_utils.py +48 -0
  103. xinference/thirdparty/fish_speech/fish_speech/utils/rich_utils.py +100 -0
  104. xinference/thirdparty/fish_speech/fish_speech/utils/spectrogram.py +122 -0
  105. xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +114 -0
  106. xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
  107. xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +120 -0
  108. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1237 -0
  109. xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
  110. xinference/thirdparty/fish_speech/tools/api.py +495 -0
  111. xinference/thirdparty/fish_speech/tools/auto_rerank.py +159 -0
  112. xinference/thirdparty/fish_speech/tools/download_models.py +55 -0
  113. xinference/thirdparty/fish_speech/tools/extract_model.py +21 -0
  114. xinference/thirdparty/fish_speech/tools/file.py +108 -0
  115. xinference/thirdparty/fish_speech/tools/gen_ref.py +36 -0
  116. xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
  117. xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +169 -0
  118. xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +171 -0
  119. xinference/thirdparty/fish_speech/tools/llama/generate.py +698 -0
  120. xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +95 -0
  121. xinference/thirdparty/fish_speech/tools/llama/quantize.py +497 -0
  122. xinference/thirdparty/fish_speech/tools/llama/rebuild_tokenizer.py +57 -0
  123. xinference/thirdparty/fish_speech/tools/merge_asr_files.py +55 -0
  124. xinference/thirdparty/fish_speech/tools/post_api.py +164 -0
  125. xinference/thirdparty/fish_speech/tools/sensevoice/__init__.py +0 -0
  126. xinference/thirdparty/fish_speech/tools/sensevoice/auto_model.py +573 -0
  127. xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +332 -0
  128. xinference/thirdparty/fish_speech/tools/sensevoice/vad_utils.py +61 -0
  129. xinference/thirdparty/fish_speech/tools/smart_pad.py +47 -0
  130. xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
  131. xinference/thirdparty/fish_speech/tools/vqgan/create_train_split.py +83 -0
  132. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +227 -0
  133. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +120 -0
  134. xinference/thirdparty/fish_speech/tools/webui.py +619 -0
  135. xinference/thirdparty/fish_speech/tools/whisper_asr.py +176 -0
  136. xinference/thirdparty/internvl/__init__.py +0 -0
  137. xinference/thirdparty/internvl/conversation.py +393 -0
  138. xinference/thirdparty/omnilmm/model/utils.py +16 -1
  139. xinference/web/ui/build/asset-manifest.json +3 -3
  140. xinference/web/ui/build/index.html +1 -1
  141. xinference/web/ui/build/static/js/main.661c7b0a.js +3 -0
  142. xinference/web/ui/build/static/js/{main.17ca0398.js.map → main.661c7b0a.js.map} +1 -1
  143. xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +1 -0
  144. xinference/web/ui/node_modules/.cache/babel-loader/213b5913e164773c2b0567455377765715f5f07225fbac77ad8e1e9dc9648a47.json +1 -0
  145. xinference/web/ui/node_modules/.cache/babel-loader/4de9a6942c5f1749d6cbfdd54279699975f16016b182848bc253886f52ec2ec3.json +1 -0
  146. xinference/web/ui/node_modules/.cache/babel-loader/5391543180fead1eeef5364300301498d58a7d91d62de3841a32768b67f4552f.json +1 -0
  147. xinference/web/ui/node_modules/.cache/babel-loader/5c26a23b5eacf5b752a08531577ae3840bb247745ef9a39583dc2d05ba93a82a.json +1 -0
  148. xinference/web/ui/node_modules/.cache/babel-loader/714c37ce0ec5b5c591033f02be2f3f491fdd70da3ef568ee4a4f94689a3d5ca2.json +1 -0
  149. xinference/web/ui/node_modules/.cache/babel-loader/822586ed1077201b64b954f12f25e3f9b45678c1acbabe53d8af3ca82ca71f33.json +1 -0
  150. xinference/web/ui/node_modules/.cache/babel-loader/978b57d1a04a701bc3fcfebc511f5f274eed6ed7eade67f6fb76c27d5fd9ecc8.json +1 -0
  151. xinference/web/ui/node_modules/.cache/babel-loader/a797831de0dc74897f4b50b3426555d748f328b4c2cc391de709eadaf6a5f3e3.json +1 -0
  152. xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +1 -0
  153. xinference/web/ui/node_modules/.cache/babel-loader/e64b7e8cedcf43d4c95deba60ec1341855c887705805bb62431693118b870c69.json +1 -0
  154. xinference/web/ui/node_modules/.cache/babel-loader/e91938976f229ce986b2907e51e1f00540b584ced0a315d498c172d13220739d.json +1 -0
  155. xinference/web/ui/node_modules/.cache/babel-loader/f72f011744c4649fabddca6f7a9327861ac0a315a89b1a2e62a39774e7863845.json +1 -0
  156. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/METADATA +22 -13
  157. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/RECORD +170 -79
  158. xinference/locale/utils.py +0 -39
  159. xinference/locale/zh_CN.json +0 -26
  160. xinference/model/llm/ggml/tools/__init__.py +0 -15
  161. xinference/model/llm/ggml/tools/convert_ggml_to_gguf.py +0 -498
  162. xinference/model/llm/ggml/tools/gguf.py +0 -884
  163. xinference/model/llm/pytorch/__init__.py +0 -13
  164. xinference/model/llm/pytorch/baichuan.py +0 -81
  165. xinference/model/llm/pytorch/falcon.py +0 -138
  166. xinference/model/llm/pytorch/intern_vl.py +0 -352
  167. xinference/model/llm/pytorch/vicuna.py +0 -69
  168. xinference/web/ui/build/static/js/main.17ca0398.js +0 -3
  169. xinference/web/ui/node_modules/.cache/babel-loader/1444c41a4d04494f1cbc2d8c1537df107b451cb569cb2c1fbf5159f3a4841a5f.json +0 -1
  170. xinference/web/ui/node_modules/.cache/babel-loader/2f40209b32e7e46a2eab6b8c8a355eb42c3caa8bc3228dd929f32fd2b3940294.json +0 -1
  171. xinference/web/ui/node_modules/.cache/babel-loader/44774c783428f952d8e2e4ad0998a9c5bc16a57cd9c68b7c5ff18aaa5a41d65c.json +0 -1
  172. xinference/web/ui/node_modules/.cache/babel-loader/5262556baf9207738bf6a8ba141ec6599d0a636345c245d61fdf88d3171998cb.json +0 -1
  173. xinference/web/ui/node_modules/.cache/babel-loader/6450605fac003812485f6251b9f0caafbf2e5bfc3bbe2f000050d9e2fdb8dcd3.json +0 -1
  174. xinference/web/ui/node_modules/.cache/babel-loader/71684495d995c7e266eecc6a0ad8ea0284cc785f80abddf863789c57a6134969.json +0 -1
  175. xinference/web/ui/node_modules/.cache/babel-loader/80acd1edf31542ab1dcccfad02cb4b38f3325cff847a781fcce97500cfd6f878.json +0 -1
  176. xinference/web/ui/node_modules/.cache/babel-loader/8a9742ddd8ba8546ef42dc14caca443f2b4524fabed7bf269e0eff3b7b64ee7d.json +0 -1
  177. xinference/web/ui/node_modules/.cache/babel-loader/d06a96a3c9c32e42689094aa3aaad41c8125894e956b8f84a70fadce6e3f65b3.json +0 -1
  178. xinference/web/ui/node_modules/.cache/babel-loader/d93730e2b5d7e8c957b4d0965d2ed1dac9045a649adbd47c220d11f255d4b1e0.json +0 -1
  179. xinference/web/ui/node_modules/.cache/babel-loader/e656dc00b4d8b387f0a81ba8fc558767df1601c66369e2eb86a5ef27cf080572.json +0 -1
  180. xinference/web/ui/node_modules/.cache/babel-loader/f28b83886159d83b84f099b05d607a822dca4dd7f2d8aa6d56fe08bab0b5b086.json +0 -1
  181. xinference/web/ui/node_modules/.cache/babel-loader/f3e02274cb1964e99b1fe69cbb6db233d3d8d7dd05d50ebcdb8e66d50b224b7b.json +0 -1
  182. /xinference/{locale → model/llm/llama_cpp}/__init__.py +0 -0
  183. /xinference/model/llm/{ggml → transformers}/__init__.py +0 -0
  184. /xinference/model/llm/{pytorch → transformers}/compression.py +0 -0
  185. /xinference/model/llm/{pytorch → transformers}/deepseek_vl.py +0 -0
  186. /xinference/model/llm/{pytorch → transformers}/llama_2.py +0 -0
  187. /xinference/model/llm/{pytorch → transformers}/omnilmm.py +0 -0
  188. /xinference/model/llm/{pytorch → transformers}/qwen_vl.py +0 -0
  189. /xinference/model/llm/{pytorch → transformers}/tensorizer_utils.py +0 -0
  190. /xinference/web/ui/build/static/js/{main.17ca0398.js.LICENSE.txt → main.661c7b0a.js.LICENSE.txt} +0 -0
  191. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/LICENSE +0 -0
  192. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/WHEEL +0 -0
  193. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/entry_points.txt +0 -0
  194. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,228 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import gc
15
+ import logging
16
+ import os.path
17
+ import queue
18
+ import sys
19
+ from io import BytesIO
20
+ from typing import TYPE_CHECKING, Optional
21
+
22
+ import numpy as np
23
+ import torch
24
+
25
+ from ...device_utils import get_available_device, is_device_available
26
+
27
+ if TYPE_CHECKING:
28
+ from .core import AudioModelFamilyV1
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
34
+ import wave
35
+
36
+ buffer = BytesIO()
37
+
38
+ with wave.open(buffer, "wb") as wav_file:
39
+ wav_file.setnchannels(channels)
40
+ wav_file.setsampwidth(bit_depth // 8)
41
+ wav_file.setframerate(sample_rate)
42
+
43
+ wav_header_bytes = buffer.getvalue()
44
+ buffer.close()
45
+ return wav_header_bytes
46
+
47
+
48
+ class FishSpeechModel:
49
+ def __init__(
50
+ self,
51
+ model_uid: str,
52
+ model_path: str,
53
+ model_spec: "AudioModelFamilyV1",
54
+ device: Optional[str] = None,
55
+ **kwargs,
56
+ ):
57
+ self._model_uid = model_uid
58
+ self._model_path = model_path
59
+ self._model_spec = model_spec
60
+ self._device = device
61
+ self._llama_queue = None
62
+ self._model = None
63
+ self._kwargs = kwargs
64
+
65
+ def load(self):
66
+ # There are too many imports from fish_speech.
67
+ sys.path.insert(
68
+ 0, os.path.join(os.path.dirname(__file__), "../../thirdparty/fish_speech")
69
+ )
70
+
71
+ from tools.llama.generate import launch_thread_safe_queue
72
+ from tools.vqgan.inference import load_model as load_decoder_model
73
+
74
+ if self._device is None:
75
+ self._device = get_available_device()
76
+ else:
77
+ if not is_device_available(self._device):
78
+ raise ValueError(f"Device {self._device} is not available!")
79
+
80
+ logger.info("Loading Llama model...")
81
+ self._llama_queue = launch_thread_safe_queue(
82
+ checkpoint_path=self._model_path,
83
+ device=self._device,
84
+ precision=torch.bfloat16,
85
+ compile=False,
86
+ )
87
+ logger.info("Llama model loaded, loading VQ-GAN model...")
88
+
89
+ checkpoint_path = os.path.join(
90
+ self._model_path,
91
+ "firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
92
+ )
93
+ self._model = load_decoder_model(
94
+ config_name="firefly_gan_vq",
95
+ checkpoint_path=checkpoint_path,
96
+ device=self._device,
97
+ )
98
+
99
+ @torch.inference_mode()
100
+ def _inference(
101
+ self,
102
+ text,
103
+ enable_reference_audio,
104
+ reference_audio,
105
+ reference_text,
106
+ max_new_tokens,
107
+ chunk_length,
108
+ top_p,
109
+ repetition_penalty,
110
+ temperature,
111
+ streaming=False,
112
+ ):
113
+ from fish_speech.utils import autocast_exclude_mps
114
+ from tools.api import decode_vq_tokens, encode_reference
115
+ from tools.llama.generate import (
116
+ GenerateRequest,
117
+ GenerateResponse,
118
+ WrappedGenerateResponse,
119
+ )
120
+
121
+ # Parse reference audio aka prompt
122
+ prompt_tokens = encode_reference(
123
+ decoder_model=self._model,
124
+ reference_audio=reference_audio,
125
+ enable_reference_audio=enable_reference_audio,
126
+ )
127
+
128
+ # LLAMA Inference
129
+ request = dict(
130
+ device=self._model.device,
131
+ max_new_tokens=max_new_tokens,
132
+ text=text,
133
+ top_p=top_p,
134
+ repetition_penalty=repetition_penalty,
135
+ temperature=temperature,
136
+ compile=False,
137
+ iterative_prompt=chunk_length > 0,
138
+ chunk_length=chunk_length,
139
+ max_length=2048,
140
+ prompt_tokens=prompt_tokens if enable_reference_audio else None,
141
+ prompt_text=reference_text if enable_reference_audio else None,
142
+ )
143
+
144
+ response_queue = queue.Queue()
145
+ self._llama_queue.put(
146
+ GenerateRequest(
147
+ request=request,
148
+ response_queue=response_queue,
149
+ )
150
+ )
151
+
152
+ if streaming:
153
+ yield wav_chunk_header(), None, None
154
+
155
+ segments = []
156
+
157
+ while True:
158
+ result: WrappedGenerateResponse = response_queue.get()
159
+ if result.status == "error":
160
+ raise Exception(str(result.response))
161
+
162
+ result: GenerateResponse = result.response
163
+ if result.action == "next":
164
+ break
165
+
166
+ with autocast_exclude_mps(
167
+ device_type=self._model.device.type, dtype=torch.bfloat16
168
+ ):
169
+ fake_audios = decode_vq_tokens(
170
+ decoder_model=self._model,
171
+ codes=result.codes,
172
+ )
173
+
174
+ fake_audios = fake_audios.float().cpu().numpy()
175
+ segments.append(fake_audios)
176
+
177
+ if streaming:
178
+ yield (fake_audios * 32768).astype(np.int16).tobytes(), None, None
179
+
180
+ if len(segments) == 0:
181
+ raise Exception("No audio generated, please check the input text.")
182
+
183
+ # No matter streaming or not, we need to return the final audio
184
+ audio = np.concatenate(segments, axis=0)
185
+ yield None, (self._model.spec_transform.sample_rate, audio), None
186
+
187
+ if torch.cuda.is_available():
188
+ torch.cuda.empty_cache()
189
+ gc.collect()
190
+
191
+ def speech(
192
+ self,
193
+ input: str,
194
+ voice: str,
195
+ response_format: str = "mp3",
196
+ speed: float = 1.0,
197
+ stream: bool = False,
198
+ **kwargs,
199
+ ):
200
+ logger.warning("Fish speech does not support setting voice: %s.", voice)
201
+ if speed != 1.0:
202
+ logger.warning("Fish speech does not support setting speed: %s.", speed)
203
+ if stream is True:
204
+ logger.warning("stream mode is not implemented.")
205
+ import torchaudio
206
+
207
+ result = list(
208
+ self._inference(
209
+ text=input,
210
+ enable_reference_audio=False,
211
+ reference_audio=None,
212
+ reference_text="",
213
+ max_new_tokens=0,
214
+ chunk_length=100,
215
+ top_p=0.7,
216
+ repetition_penalty=1.2,
217
+ temperature=0.7,
218
+ )
219
+ )
220
+ sample_rate, audio = result[0][1]
221
+ audio = np.array([audio])
222
+
223
+ # Save the generated audio
224
+ with BytesIO() as out:
225
+ torchaudio.save(
226
+ out, torch.from_numpy(audio), sample_rate, format=response_format
227
+ )
228
+ return out.getvalue()
@@ -146,5 +146,13 @@
146
146
  "model_revision": "fb5f676733139f35670bed9b59a77d476b1aa898",
147
147
  "ability": "text-to-audio",
148
148
  "multilingual": true
149
+ },
150
+ {
151
+ "model_name": "FishSpeech-1.2-SFT",
152
+ "model_family": "FishAudio",
153
+ "model_id": "fishaudio/fish-speech-1.2-sft",
154
+ "model_revision": "180288e21ec5c50cfc564023a22f789e4b88a0e0",
155
+ "ability": "text-to-audio",
156
+ "multilingual": true
149
157
  }
150
158
  ]
@@ -154,10 +154,32 @@ class EmbeddingModel:
154
154
  "gte" in self._model_spec.model_name.lower()
155
155
  and "qwen2" in self._model_spec.model_name.lower()
156
156
  ):
157
+ import torch
158
+
159
+ torch_dtype_str = self._kwargs.get("torch_dtype")
160
+ if torch_dtype_str is not None:
161
+ try:
162
+ torch_dtype = getattr(torch, torch_dtype_str)
163
+ if torch_dtype not in [
164
+ torch.float16,
165
+ torch.float32,
166
+ torch.bfloat16,
167
+ ]:
168
+ logger.warning(
169
+ f"Load embedding model with unsupported torch dtype : {torch_dtype_str}. Using default torch dtype: fp32."
170
+ )
171
+ torch_dtype = torch.float32
172
+ except AttributeError:
173
+ logger.warning(
174
+ f"Load embedding model with unknown torch dtype '{torch_dtype_str}'. Using default torch dtype: fp32."
175
+ )
176
+ torch_dtype = torch.float32
177
+ else:
178
+ torch_dtype = "auto"
157
179
  self._model = XSentenceTransformer(
158
180
  self._model_path,
159
181
  device=self._device,
160
- model_kwargs={"device_map": "auto"},
182
+ model_kwargs={"device_map": "auto", "torch_dtype": torch_dtype},
161
183
  )
162
184
  else:
163
185
  self._model = SentenceTransformer(self._model_path, device=self._device)
@@ -24,7 +24,8 @@
24
24
  "model_revision": "ea42f8cef0f178587cf766dc8129abd379c90671",
25
25
  "model_ability": [
26
26
  "text2image",
27
- "image2image"
27
+ "image2image",
28
+ "inpainting"
28
29
  ]
29
30
  },
30
31
  {
@@ -27,7 +27,8 @@
27
27
  "model_revision": "master",
28
28
  "model_ability": [
29
29
  "text2image",
30
- "image2image"
30
+ "image2image",
31
+ "inpainting"
31
32
  ]
32
33
  },
33
34
  {
@@ -24,6 +24,9 @@ from functools import partial
24
24
  from io import BytesIO
25
25
  from typing import Dict, List, Optional, Union
26
26
 
27
+ import PIL.Image
28
+ from PIL import ImageOps
29
+
27
30
  from ....constants import XINFERENCE_IMAGE_DIR
28
31
  from ....device_utils import move_model_to_available_device
29
32
  from ....types import Image, ImageList, LoRA
@@ -46,8 +49,13 @@ class DiffusionModel:
46
49
  self._model_uid = model_uid
47
50
  self._model_path = model_path
48
51
  self._device = device
52
+ # when a model has text2image ability,
53
+ # it will be loaded as AutoPipelineForText2Image
54
+ # for image2image and inpainting,
55
+ # we convert to the corresponding model
49
56
  self._model = None
50
57
  self._i2i_model = None # image to image model
58
+ self._inpainting_model = None # inpainting model
51
59
  self._lora_model = lora_model
52
60
  self._lora_load_kwargs = lora_load_kwargs or {}
53
61
  self._lora_fuse_kwargs = lora_fuse_kwargs or {}
@@ -152,6 +160,10 @@ class DiffusionModel:
152
160
  model=None,
153
161
  **kwargs,
154
162
  ):
163
+ import gc
164
+
165
+ from ....device_utils import empty_cache
166
+
155
167
  logger.debug(
156
168
  "stable diffusion args: %s",
157
169
  kwargs,
@@ -159,6 +171,11 @@ class DiffusionModel:
159
171
  model = model if model is not None else self._model
160
172
  assert callable(model)
161
173
  images = model(**kwargs).images
174
+
175
+ # clean cache
176
+ gc.collect()
177
+ empty_cache()
178
+
162
179
  if response_format == "url":
163
180
  os.makedirs(XINFERENCE_IMAGE_DIR, exist_ok=True)
164
181
  image_list = []
@@ -209,9 +226,17 @@ class DiffusionModel:
209
226
  **kwargs,
210
227
  )
211
228
 
229
+ @staticmethod
230
+ def pad_to_multiple(image, multiple=8):
231
+ x, y = image.size
232
+ padding_x = (multiple - x % multiple) % multiple
233
+ padding_y = (multiple - y % multiple) % multiple
234
+ padding = (0, 0, padding_x, padding_y)
235
+ return ImageOps.expand(image, padding)
236
+
212
237
  def image_to_image(
213
238
  self,
214
- image: bytes,
239
+ image: PIL.Image,
215
240
  prompt: Optional[Union[str, List[str]]] = None,
216
241
  negative_prompt: Optional[Union[str, List[str]]] = None,
217
242
  n: int = 1,
@@ -236,6 +261,11 @@ class DiffusionModel:
236
261
  width, height = map(int, re.split(r"[^\d]+", size))
237
262
  kwargs["width"] = width
238
263
  kwargs["height"] = height
264
+ if padding_image_to_multiple := kwargs.pop("padding_image_to_multiple", None):
265
+ # Model like SD3 image to image requires image's height and width is times of 16
266
+ # padding the image if specified
267
+ image = self.pad_to_multiple(image, multiple=int(padding_image_to_multiple))
268
+
239
269
  self._filter_kwargs(kwargs)
240
270
  return self._call_model(
241
271
  image=image,
@@ -258,6 +288,23 @@ class DiffusionModel:
258
288
  response_format: str = "url",
259
289
  **kwargs,
260
290
  ):
291
+ if "inpainting" not in self._abilities:
292
+ raise RuntimeError(f"{self._model_uid} does not support inpainting")
293
+
294
+ if (
295
+ "text2image" in self._abilities or "image2image" in self._abilities
296
+ ) and self._model is not None:
297
+ from diffusers import AutoPipelineForInpainting
298
+
299
+ if self._inpainting_model is not None:
300
+ model = self._inpainting_model
301
+ else:
302
+ model = self._inpainting_model = AutoPipelineForInpainting.from_pipe(
303
+ self._model
304
+ )
305
+ else:
306
+ model = self._model
307
+
261
308
  width, height = map(int, re.split(r"[^\d]+", size))
262
309
  return self._call_model(
263
310
  image=image,
@@ -268,5 +315,6 @@ class DiffusionModel:
268
315
  width=width,
269
316
  num_images_per_prompt=n,
270
317
  response_format=response_format,
318
+ model=model,
271
319
  **kwargs,
272
320
  )
@@ -34,13 +34,14 @@ from .llm_family import (
34
34
  BUILTIN_MODELSCOPE_LLM_FAMILIES,
35
35
  LLAMA_CLASSES,
36
36
  LLM_ENGINES,
37
+ LMDEPLOY_CLASSES,
37
38
  MLX_CLASSES,
38
39
  SGLANG_CLASSES,
39
40
  SUPPORTED_ENGINES,
40
41
  TRANSFORMERS_CLASSES,
41
42
  VLLM_CLASSES,
42
43
  CustomLLMFamilyV1,
43
- GgmlLLMSpecV1,
44
+ LlamaCppLLMSpecV1,
44
45
  LLMFamilyV1,
45
46
  LLMSpecV1,
46
47
  MLXLLMSpecV1,
@@ -55,10 +56,10 @@ from .llm_family import (
55
56
 
56
57
 
57
58
  def check_format_with_engine(model_format, engine):
58
- # only llama-cpp-python support and only support ggufv2 and ggmlv3
59
- if model_format in ["ggufv2", "ggmlv3"] and engine != "llama.cpp":
59
+ # only llama-cpp-python support and only support ggufv2
60
+ if model_format in ["ggufv2"] and engine != "llama.cpp":
60
61
  return False
61
- if model_format not in ["ggufv2", "ggmlv3"] and engine == "llama.cpp":
62
+ if model_format not in ["ggufv2"] and engine == "llama.cpp":
62
63
  return False
63
64
  return True
64
65
 
@@ -112,28 +113,27 @@ def generate_engine_config_by_model_family(model_family):
112
113
 
113
114
 
114
115
  def _install():
115
- from .ggml.llamacpp import LlamaCppChatModel, LlamaCppModel
116
+ from .llama_cpp.core import LlamaCppChatModel, LlamaCppModel
117
+ from .lmdeploy.core import LMDeployChatModel, LMDeployModel
116
118
  from .mlx.core import MLXChatModel, MLXModel
117
- from .pytorch.baichuan import BaichuanPytorchChatModel
118
- from .pytorch.chatglm import ChatglmPytorchChatModel
119
- from .pytorch.cogvlm2 import CogVLM2Model
120
- from .pytorch.core import PytorchChatModel, PytorchModel
121
- from .pytorch.deepseek_vl import DeepSeekVLChatModel
122
- from .pytorch.falcon import FalconPytorchChatModel, FalconPytorchModel
123
- from .pytorch.glm4v import Glm4VModel
124
- from .pytorch.intern_vl import InternVLChatModel
125
- from .pytorch.internlm2 import Internlm2PytorchChatModel
126
- from .pytorch.llama_2 import LlamaPytorchChatModel, LlamaPytorchModel
127
- from .pytorch.minicpmv25 import MiniCPMV25Model
128
- from .pytorch.minicpmv26 import MiniCPMV26Model
129
- from .pytorch.qwen_vl import QwenVLChatModel
130
- from .pytorch.vicuna import VicunaPytorchChatModel
131
- from .pytorch.yi_vl import YiVLChatModel
132
119
  from .sglang.core import SGLANGChatModel, SGLANGModel
133
- from .vllm.core import VLLMChatModel, VLLMModel
120
+ from .transformers.chatglm import ChatglmPytorchChatModel
121
+ from .transformers.cogvlm2 import CogVLM2Model
122
+ from .transformers.cogvlm2_video import CogVLM2VideoModel
123
+ from .transformers.core import PytorchChatModel, PytorchModel
124
+ from .transformers.deepseek_vl import DeepSeekVLChatModel
125
+ from .transformers.glm4v import Glm4VModel
126
+ from .transformers.intern_vl import InternVLChatModel
127
+ from .transformers.internlm2 import Internlm2PytorchChatModel
128
+ from .transformers.llama_2 import LlamaPytorchChatModel, LlamaPytorchModel
129
+ from .transformers.minicpmv25 import MiniCPMV25Model
130
+ from .transformers.minicpmv26 import MiniCPMV26Model
131
+ from .transformers.qwen_vl import QwenVLChatModel
132
+ from .transformers.yi_vl import YiVLChatModel
133
+ from .vllm.core import VLLMChatModel, VLLMModel, VLLMVisionModel
134
134
 
135
135
  try:
136
- from .pytorch.omnilmm import OmniLMMModel
136
+ from .transformers.omnilmm import OmniLMMModel
137
137
  except ImportError as e:
138
138
  # For quite old transformers version,
139
139
  # import will generate error
@@ -148,18 +148,15 @@ def _install():
148
148
  ]
149
149
  )
150
150
  SGLANG_CLASSES.extend([SGLANGModel, SGLANGChatModel])
151
- VLLM_CLASSES.extend([VLLMModel, VLLMChatModel])
151
+ VLLM_CLASSES.extend([VLLMModel, VLLMChatModel, VLLMVisionModel])
152
152
  MLX_CLASSES.extend([MLXModel, MLXChatModel])
153
+ LMDEPLOY_CLASSES.extend([LMDeployModel, LMDeployChatModel])
153
154
  TRANSFORMERS_CLASSES.extend(
154
155
  [
155
- BaichuanPytorchChatModel,
156
- VicunaPytorchChatModel,
157
- FalconPytorchChatModel,
158
156
  ChatglmPytorchChatModel,
159
157
  LlamaPytorchModel,
160
158
  LlamaPytorchChatModel,
161
159
  PytorchChatModel,
162
- FalconPytorchModel,
163
160
  Internlm2PytorchChatModel,
164
161
  QwenVLChatModel,
165
162
  YiVLChatModel,
@@ -167,6 +164,7 @@ def _install():
167
164
  InternVLChatModel,
168
165
  PytorchModel,
169
166
  CogVLM2Model,
167
+ CogVLM2VideoModel,
170
168
  MiniCPMV25Model,
171
169
  MiniCPMV26Model,
172
170
  Glm4VModel,
@@ -181,6 +179,7 @@ def _install():
181
179
  SUPPORTED_ENGINES["Transformers"] = TRANSFORMERS_CLASSES
182
180
  SUPPORTED_ENGINES["llama.cpp"] = LLAMA_CLASSES
183
181
  SUPPORTED_ENGINES["MLX"] = MLX_CLASSES
182
+ SUPPORTED_ENGINES["LMDEPLOY"] = LMDEPLOY_CLASSES
184
183
 
185
184
  json_path = os.path.join(
186
185
  os.path.dirname(os.path.abspath(__file__)), "llm_family.json"
@@ -11,7 +11,6 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- import datetime
15
14
  import logging
16
15
  import os
17
16
  import time
@@ -104,35 +103,6 @@ class LlamaCppModel(LLM):
104
103
  generate_config.pop("lora_name", None) # type: ignore
105
104
  return generate_config
106
105
 
107
- def _convert_ggml_to_gguf(self, model_path: str) -> str:
108
- from .tools import convert
109
-
110
- root_dir = os.path.dirname(os.path.dirname(model_path))
111
- gguf_dir = os.path.join(
112
- root_dir,
113
- "{}-ggufv2-{}b".format(
114
- self.model_family.model_name, self.model_spec.model_size_in_billions
115
- ),
116
- )
117
- os.makedirs(gguf_dir, exist_ok=True)
118
- gguf_path = os.path.join(
119
- gguf_dir,
120
- "{}.{}.ggufv2".format(self.model_family.model_name, self.quantization),
121
- )
122
- # trick for validation, use a mark file to make sure the gguf file is converted
123
- mark_file = os.path.join(gguf_dir, f"__valid_{self.quantization}")
124
- if os.path.exists(mark_file):
125
- return gguf_path
126
- else:
127
- logger.warning(
128
- "You are using a model with ggmlv3, "
129
- "and it will take some time to convert to ggufv2"
130
- )
131
- convert(model_path, gguf_path)
132
- with open(mark_file, "w") as f:
133
- f.write(str(datetime.datetime.now()))
134
- return gguf_path
135
-
136
106
  def load(self):
137
107
  try:
138
108
  import llama_cpp
@@ -167,9 +137,6 @@ class LlamaCppModel(LLM):
167
137
  if os.path.exists(legacy_model_file_path):
168
138
  model_path = legacy_model_file_path
169
139
 
170
- if self.model_spec.model_format == "ggmlv3":
171
- model_path = self._convert_ggml_to_gguf(model_path)
172
-
173
140
  try:
174
141
  self._llm = Llama(
175
142
  model_path=model_path,
@@ -183,7 +150,7 @@ class LlamaCppModel(LLM):
183
150
  def match(
184
151
  cls, llm_family: LLMFamilyV1, llm_spec: LLMSpecV1, quantization: str
185
152
  ) -> bool:
186
- if llm_spec.model_format not in ["ggmlv3", "ggufv2"]:
153
+ if llm_spec.model_format not in ["ggufv2"]:
187
154
  return False
188
155
  if "qwen" in llm_family.model_name:
189
156
  return False
@@ -285,7 +252,7 @@ class LlamaCppChatModel(LlamaCppModel, ChatModelMixin):
285
252
  def match(
286
253
  cls, llm_family: LLMFamilyV1, llm_spec: LLMSpecV1, quantization: str
287
254
  ) -> bool:
288
- if llm_spec.model_format not in ["ggmlv3", "ggufv2"]:
255
+ if llm_spec.model_format not in ["ggufv2"]:
289
256
  return False
290
257
  if "chat" not in llm_family.model_ability:
291
258
  return False