xinference 0.14.4.post1__py3-none-any.whl → 0.15.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (194) hide show
  1. xinference/_compat.py +51 -0
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +209 -40
  4. xinference/client/restful/restful_client.py +7 -26
  5. xinference/conftest.py +1 -1
  6. xinference/constants.py +5 -0
  7. xinference/core/cache_tracker.py +1 -1
  8. xinference/core/chat_interface.py +8 -14
  9. xinference/core/event.py +1 -1
  10. xinference/core/image_interface.py +28 -0
  11. xinference/core/model.py +110 -31
  12. xinference/core/scheduler.py +37 -37
  13. xinference/core/status_guard.py +1 -1
  14. xinference/core/supervisor.py +17 -10
  15. xinference/core/utils.py +80 -22
  16. xinference/core/worker.py +17 -16
  17. xinference/deploy/cmdline.py +8 -16
  18. xinference/deploy/local.py +1 -1
  19. xinference/deploy/supervisor.py +1 -1
  20. xinference/deploy/utils.py +1 -1
  21. xinference/deploy/worker.py +1 -1
  22. xinference/model/audio/cosyvoice.py +86 -41
  23. xinference/model/audio/fish_speech.py +9 -9
  24. xinference/model/audio/model_spec.json +9 -9
  25. xinference/model/audio/whisper.py +4 -1
  26. xinference/model/embedding/core.py +52 -31
  27. xinference/model/image/core.py +2 -1
  28. xinference/model/image/model_spec.json +16 -4
  29. xinference/model/image/model_spec_modelscope.json +16 -4
  30. xinference/model/image/sdapi.py +136 -0
  31. xinference/model/image/stable_diffusion/core.py +164 -19
  32. xinference/model/llm/__init__.py +29 -11
  33. xinference/model/llm/llama_cpp/core.py +16 -33
  34. xinference/model/llm/llm_family.json +1011 -1296
  35. xinference/model/llm/llm_family.py +34 -53
  36. xinference/model/llm/llm_family_csghub.json +18 -35
  37. xinference/model/llm/llm_family_modelscope.json +981 -1122
  38. xinference/model/llm/lmdeploy/core.py +56 -88
  39. xinference/model/llm/mlx/core.py +46 -69
  40. xinference/model/llm/sglang/core.py +36 -18
  41. xinference/model/llm/transformers/chatglm.py +168 -306
  42. xinference/model/llm/transformers/cogvlm2.py +36 -63
  43. xinference/model/llm/transformers/cogvlm2_video.py +33 -223
  44. xinference/model/llm/transformers/core.py +55 -50
  45. xinference/model/llm/transformers/deepseek_v2.py +340 -0
  46. xinference/model/llm/transformers/deepseek_vl.py +53 -96
  47. xinference/model/llm/transformers/glm4v.py +55 -111
  48. xinference/model/llm/transformers/intern_vl.py +39 -70
  49. xinference/model/llm/transformers/internlm2.py +32 -54
  50. xinference/model/llm/transformers/minicpmv25.py +22 -55
  51. xinference/model/llm/transformers/minicpmv26.py +158 -68
  52. xinference/model/llm/transformers/omnilmm.py +5 -28
  53. xinference/model/llm/transformers/qwen2_audio.py +168 -0
  54. xinference/model/llm/transformers/qwen2_vl.py +234 -0
  55. xinference/model/llm/transformers/qwen_vl.py +34 -86
  56. xinference/model/llm/transformers/utils.py +32 -38
  57. xinference/model/llm/transformers/yi_vl.py +32 -72
  58. xinference/model/llm/utils.py +280 -554
  59. xinference/model/llm/vllm/core.py +161 -100
  60. xinference/model/rerank/core.py +41 -8
  61. xinference/model/rerank/model_spec.json +7 -0
  62. xinference/model/rerank/model_spec_modelscope.json +7 -1
  63. xinference/model/utils.py +1 -31
  64. xinference/thirdparty/cosyvoice/bin/export_jit.py +64 -0
  65. xinference/thirdparty/cosyvoice/bin/export_trt.py +8 -0
  66. xinference/thirdparty/cosyvoice/bin/inference.py +5 -2
  67. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +38 -22
  68. xinference/thirdparty/cosyvoice/cli/model.py +139 -26
  69. xinference/thirdparty/cosyvoice/flow/flow.py +15 -9
  70. xinference/thirdparty/cosyvoice/flow/length_regulator.py +20 -1
  71. xinference/thirdparty/cosyvoice/hifigan/generator.py +8 -4
  72. xinference/thirdparty/cosyvoice/llm/llm.py +14 -13
  73. xinference/thirdparty/cosyvoice/transformer/attention.py +7 -3
  74. xinference/thirdparty/cosyvoice/transformer/decoder.py +1 -1
  75. xinference/thirdparty/cosyvoice/transformer/embedding.py +4 -3
  76. xinference/thirdparty/cosyvoice/transformer/encoder.py +4 -2
  77. xinference/thirdparty/cosyvoice/utils/common.py +36 -0
  78. xinference/thirdparty/cosyvoice/utils/file_utils.py +16 -0
  79. xinference/thirdparty/deepseek_vl/serve/assets/Kelpy-Codos.js +100 -0
  80. xinference/thirdparty/deepseek_vl/serve/assets/avatar.png +0 -0
  81. xinference/thirdparty/deepseek_vl/serve/assets/custom.css +355 -0
  82. xinference/thirdparty/deepseek_vl/serve/assets/custom.js +22 -0
  83. xinference/thirdparty/deepseek_vl/serve/assets/favicon.ico +0 -0
  84. xinference/thirdparty/deepseek_vl/serve/examples/app.png +0 -0
  85. xinference/thirdparty/deepseek_vl/serve/examples/chart.png +0 -0
  86. xinference/thirdparty/deepseek_vl/serve/examples/mirror.png +0 -0
  87. xinference/thirdparty/deepseek_vl/serve/examples/pipeline.png +0 -0
  88. xinference/thirdparty/deepseek_vl/serve/examples/puzzle.png +0 -0
  89. xinference/thirdparty/deepseek_vl/serve/examples/rap.jpeg +0 -0
  90. xinference/thirdparty/fish_speech/fish_speech/configs/base.yaml +87 -0
  91. xinference/thirdparty/fish_speech/fish_speech/configs/firefly_gan_vq.yaml +33 -0
  92. xinference/thirdparty/fish_speech/fish_speech/configs/lora/r_8_alpha_16.yaml +4 -0
  93. xinference/thirdparty/fish_speech/fish_speech/configs/text2semantic_finetune.yaml +83 -0
  94. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text-data.proto +24 -0
  95. xinference/thirdparty/fish_speech/fish_speech/i18n/README.md +27 -0
  96. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +1 -1
  97. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +1 -1
  98. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +1 -1
  99. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +1 -1
  100. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +1 -1
  101. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +2 -2
  102. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +0 -3
  103. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +169 -198
  104. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +4 -27
  105. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/.gitignore +114 -0
  106. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/README.md +36 -0
  107. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +9 -47
  108. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +2 -2
  109. xinference/thirdparty/fish_speech/fish_speech/train.py +2 -0
  110. xinference/thirdparty/fish_speech/fish_speech/webui/css/style.css +161 -0
  111. xinference/thirdparty/fish_speech/fish_speech/webui/html/footer.html +11 -0
  112. xinference/thirdparty/fish_speech/fish_speech/webui/js/animate.js +69 -0
  113. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +12 -10
  114. xinference/thirdparty/fish_speech/tools/api.py +79 -134
  115. xinference/thirdparty/fish_speech/tools/commons.py +35 -0
  116. xinference/thirdparty/fish_speech/tools/download_models.py +3 -3
  117. xinference/thirdparty/fish_speech/tools/file.py +17 -0
  118. xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +1 -1
  119. xinference/thirdparty/fish_speech/tools/llama/generate.py +29 -24
  120. xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +1 -1
  121. xinference/thirdparty/fish_speech/tools/llama/quantize.py +2 -2
  122. xinference/thirdparty/fish_speech/tools/msgpack_api.py +34 -0
  123. xinference/thirdparty/fish_speech/tools/post_api.py +85 -44
  124. xinference/thirdparty/fish_speech/tools/sensevoice/README.md +59 -0
  125. xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +1 -1
  126. xinference/thirdparty/fish_speech/tools/smart_pad.py +16 -3
  127. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +2 -2
  128. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +4 -2
  129. xinference/thirdparty/fish_speech/tools/webui.py +12 -146
  130. xinference/thirdparty/matcha/VERSION +1 -0
  131. xinference/thirdparty/matcha/hifigan/LICENSE +21 -0
  132. xinference/thirdparty/matcha/hifigan/README.md +101 -0
  133. xinference/thirdparty/omnilmm/LICENSE +201 -0
  134. xinference/thirdparty/whisper/__init__.py +156 -0
  135. xinference/thirdparty/whisper/__main__.py +3 -0
  136. xinference/thirdparty/whisper/assets/gpt2.tiktoken +50256 -0
  137. xinference/thirdparty/whisper/assets/mel_filters.npz +0 -0
  138. xinference/thirdparty/whisper/assets/multilingual.tiktoken +50257 -0
  139. xinference/thirdparty/whisper/audio.py +157 -0
  140. xinference/thirdparty/whisper/decoding.py +826 -0
  141. xinference/thirdparty/whisper/model.py +314 -0
  142. xinference/thirdparty/whisper/normalizers/__init__.py +2 -0
  143. xinference/thirdparty/whisper/normalizers/basic.py +76 -0
  144. xinference/thirdparty/whisper/normalizers/english.json +1741 -0
  145. xinference/thirdparty/whisper/normalizers/english.py +550 -0
  146. xinference/thirdparty/whisper/timing.py +386 -0
  147. xinference/thirdparty/whisper/tokenizer.py +395 -0
  148. xinference/thirdparty/whisper/transcribe.py +605 -0
  149. xinference/thirdparty/whisper/triton_ops.py +109 -0
  150. xinference/thirdparty/whisper/utils.py +316 -0
  151. xinference/thirdparty/whisper/version.py +1 -0
  152. xinference/types.py +14 -53
  153. xinference/web/ui/build/asset-manifest.json +6 -6
  154. xinference/web/ui/build/index.html +1 -1
  155. xinference/web/ui/build/static/css/{main.4bafd904.css → main.5061c4c3.css} +2 -2
  156. xinference/web/ui/build/static/css/main.5061c4c3.css.map +1 -0
  157. xinference/web/ui/build/static/js/main.754740c0.js +3 -0
  158. xinference/web/ui/build/static/js/{main.eb13fe95.js.LICENSE.txt → main.754740c0.js.LICENSE.txt} +2 -0
  159. xinference/web/ui/build/static/js/main.754740c0.js.map +1 -0
  160. xinference/web/ui/node_modules/.cache/babel-loader/10c69dc7a296779fcffedeff9393d832dfcb0013c36824adf623d3c518b801ff.json +1 -0
  161. xinference/web/ui/node_modules/.cache/babel-loader/68bede6d95bb5ef0b35bbb3ec5b8c937eaf6862c6cdbddb5ef222a7776aaf336.json +1 -0
  162. xinference/web/ui/node_modules/.cache/babel-loader/77d50223f3e734d4485cca538cb098a8c3a7a0a1a9f01f58cdda3af42fe1adf5.json +1 -0
  163. xinference/web/ui/node_modules/.cache/babel-loader/a56d5a642409a84988891089c98ca28ad0546432dfbae8aaa51bc5a280e1cdd2.json +1 -0
  164. xinference/web/ui/node_modules/.cache/babel-loader/cd90b08d177025dfe84209596fc51878f8a86bcaa6a240848a3d2e5fd4c7ff24.json +1 -0
  165. xinference/web/ui/node_modules/.cache/babel-loader/d9ff696a3e3471f01b46c63d18af32e491eb5dc0e43cb30202c96871466df57f.json +1 -0
  166. xinference/web/ui/node_modules/.cache/babel-loader/e42b72d4cc1ea412ebecbb8d040dc6c6bfee462c33903c2f1f3facb602ad742e.json +1 -0
  167. xinference/web/ui/node_modules/.cache/babel-loader/f5039ddbeb815c51491a1989532006b96fc3ae49c6c60e3c097f875b4ae915ae.json +1 -0
  168. xinference/web/ui/node_modules/.package-lock.json +37 -0
  169. xinference/web/ui/node_modules/a-sync-waterfall/package.json +21 -0
  170. xinference/web/ui/node_modules/nunjucks/node_modules/commander/package.json +48 -0
  171. xinference/web/ui/node_modules/nunjucks/package.json +112 -0
  172. xinference/web/ui/package-lock.json +38 -0
  173. xinference/web/ui/package.json +1 -0
  174. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/METADATA +16 -10
  175. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/RECORD +179 -127
  176. xinference/model/llm/transformers/llama_2.py +0 -108
  177. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +0 -442
  178. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +0 -44
  179. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +0 -115
  180. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +0 -225
  181. xinference/thirdparty/fish_speech/tools/auto_rerank.py +0 -159
  182. xinference/thirdparty/fish_speech/tools/gen_ref.py +0 -36
  183. xinference/thirdparty/fish_speech/tools/merge_asr_files.py +0 -55
  184. xinference/web/ui/build/static/css/main.4bafd904.css.map +0 -1
  185. xinference/web/ui/build/static/js/main.eb13fe95.js +0 -3
  186. xinference/web/ui/build/static/js/main.eb13fe95.js.map +0 -1
  187. xinference/web/ui/node_modules/.cache/babel-loader/0b11a5339468c13b2d31ac085e7effe4303259b2071abd46a0a8eb8529233a5e.json +0 -1
  188. xinference/web/ui/node_modules/.cache/babel-loader/213b5913e164773c2b0567455377765715f5f07225fbac77ad8e1e9dc9648a47.json +0 -1
  189. xinference/web/ui/node_modules/.cache/babel-loader/5c26a23b5eacf5b752a08531577ae3840bb247745ef9a39583dc2d05ba93a82a.json +0 -1
  190. xinference/web/ui/node_modules/.cache/babel-loader/978b57d1a04a701bc3fcfebc511f5f274eed6ed7eade67f6fb76c27d5fd9ecc8.json +0 -1
  191. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/LICENSE +0 -0
  192. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/WHEEL +0 -0
  193. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/entry_points.txt +0 -0
  194. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/top_level.txt +0 -0
@@ -73,13 +73,17 @@ class ImageInterface:
73
73
  return interface
74
74
 
75
75
  def text2image_interface(self) -> "gr.Blocks":
76
+ from ..model.image.stable_diffusion.core import SAMPLING_METHODS
77
+
76
78
  def text_generate_image(
77
79
  prompt: str,
78
80
  n: int,
79
81
  size_width: int,
80
82
  size_height: int,
83
+ guidance_scale: int,
81
84
  num_inference_steps: int,
82
85
  negative_prompt: Optional[str] = None,
86
+ sampler_name: Optional[str] = None,
83
87
  ) -> PIL.Image.Image:
84
88
  from ..client import RESTfulClient
85
89
 
@@ -89,16 +93,20 @@ class ImageInterface:
89
93
  assert isinstance(model, RESTfulImageModelHandle)
90
94
 
91
95
  size = f"{int(size_width)}*{int(size_height)}"
96
+ guidance_scale = None if guidance_scale == -1 else guidance_scale # type: ignore
92
97
  num_inference_steps = (
93
98
  None if num_inference_steps == -1 else num_inference_steps # type: ignore
94
99
  )
100
+ sampler_name = None if sampler_name == "default" else sampler_name
95
101
 
96
102
  response = model.text_to_image(
97
103
  prompt=prompt,
98
104
  n=n,
99
105
  size=size,
100
106
  num_inference_steps=num_inference_steps,
107
+ guidance_scale=guidance_scale,
101
108
  negative_prompt=negative_prompt,
109
+ sampler_name=sampler_name,
102
110
  response_format="b64_json",
103
111
  )
104
112
 
@@ -132,9 +140,16 @@ class ImageInterface:
132
140
  n = gr.Number(label="Number of Images", value=1)
133
141
  size_width = gr.Number(label="Width", value=1024)
134
142
  size_height = gr.Number(label="Height", value=1024)
143
+ with gr.Row():
144
+ guidance_scale = gr.Number(label="Guidance scale", value=-1)
135
145
  num_inference_steps = gr.Number(
136
146
  label="Inference Step Number", value=-1
137
147
  )
148
+ sampler_name = gr.Dropdown(
149
+ choices=SAMPLING_METHODS,
150
+ value="default",
151
+ label="Sampling method",
152
+ )
138
153
 
139
154
  with gr.Column():
140
155
  image_output = gr.Gallery()
@@ -146,8 +161,10 @@ class ImageInterface:
146
161
  n,
147
162
  size_width,
148
163
  size_height,
164
+ guidance_scale,
149
165
  num_inference_steps,
150
166
  negative_prompt,
167
+ sampler_name,
151
168
  ],
152
169
  outputs=image_output,
153
170
  )
@@ -155,6 +172,8 @@ class ImageInterface:
155
172
  return text2image_vl_interface
156
173
 
157
174
  def image2image_interface(self) -> "gr.Blocks":
175
+ from ..model.image.stable_diffusion.core import SAMPLING_METHODS
176
+
158
177
  def image_generate_image(
159
178
  prompt: str,
160
179
  negative_prompt: str,
@@ -164,6 +183,7 @@ class ImageInterface:
164
183
  size_height: int,
165
184
  num_inference_steps: int,
166
185
  padding_image_to_multiple: int,
186
+ sampler_name: Optional[str] = None,
167
187
  ) -> PIL.Image.Image:
168
188
  from ..client import RESTfulClient
169
189
 
@@ -180,6 +200,7 @@ class ImageInterface:
180
200
  None if num_inference_steps == -1 else num_inference_steps # type: ignore
181
201
  )
182
202
  padding_image_to_multiple = None if padding_image_to_multiple == -1 else padding_image_to_multiple # type: ignore
203
+ sampler_name = None if sampler_name == "default" else sampler_name
183
204
 
184
205
  bio = io.BytesIO()
185
206
  image.save(bio, format="png")
@@ -193,6 +214,7 @@ class ImageInterface:
193
214
  response_format="b64_json",
194
215
  num_inference_steps=num_inference_steps,
195
216
  padding_image_to_multiple=padding_image_to_multiple,
217
+ sampler_name=sampler_name,
196
218
  )
197
219
 
198
220
  images = []
@@ -233,6 +255,11 @@ class ImageInterface:
233
255
  padding_image_to_multiple = gr.Number(
234
256
  label="Padding image to multiple", value=-1
235
257
  )
258
+ sampler_name = gr.Dropdown(
259
+ choices=SAMPLING_METHODS,
260
+ value="default",
261
+ label="Sampling method",
262
+ )
236
263
 
237
264
  with gr.Row():
238
265
  with gr.Column(scale=1):
@@ -251,6 +278,7 @@ class ImageInterface:
251
278
  size_height,
252
279
  num_inference_steps,
253
280
  padding_image_to_multiple,
281
+ sampler_name,
254
282
  ],
255
283
  outputs=output_gallery,
256
284
  )
xinference/core/model.py CHANGED
@@ -19,6 +19,7 @@ import json
19
19
  import os
20
20
  import time
21
21
  import types
22
+ import uuid
22
23
  import weakref
23
24
  from asyncio.queues import Queue
24
25
  from asyncio.tasks import wait_for
@@ -65,7 +66,12 @@ except ImportError:
65
66
  OutOfMemoryError = _OutOfMemoryError
66
67
 
67
68
 
68
- XINFERENCE_BATCHING_ALLOWED_VISION_MODELS = ["qwen-vl-chat", "cogvlm2", "glm-4v"]
69
+ XINFERENCE_BATCHING_ALLOWED_VISION_MODELS = [
70
+ "qwen-vl-chat",
71
+ "cogvlm2",
72
+ "glm-4v",
73
+ "MiniCPM-V-2.6",
74
+ ]
69
75
 
70
76
 
71
77
  def request_limit(fn):
@@ -265,7 +271,7 @@ class ModelActor(xo.StatelessActor):
265
271
 
266
272
  if self._worker_ref is None:
267
273
  self._worker_ref = await xo.actor_ref(
268
- address=self._worker_address, uid=WorkerActor.uid()
274
+ address=self._worker_address, uid=WorkerActor.default_uid()
269
275
  )
270
276
  return self._worker_ref
271
277
 
@@ -434,23 +440,35 @@ class ModelActor(xo.StatelessActor):
434
440
  assert output_type == "binary", f"Unknown output type '{output_type}'"
435
441
  return ret
436
442
 
437
- @log_async(logger=logger)
438
443
  @request_limit
439
444
  @xo.generator
445
+ @log_async(logger=logger)
440
446
  async def generate(self, prompt: str, *args, **kwargs):
441
447
  if self.allow_batching():
448
+ # not support request_id
449
+ kwargs.pop("request_id", None)
442
450
  return await self.handle_batching_request(
443
451
  prompt, "generate", *args, **kwargs
444
452
  )
445
453
  else:
446
454
  kwargs.pop("raw_params", None)
447
455
  if hasattr(self._model, "generate"):
456
+ # not support request_id
457
+ kwargs.pop("request_id", None)
448
458
  return await self._call_wrapper_json(
449
459
  self._model.generate, prompt, *args, **kwargs
450
460
  )
451
461
  if hasattr(self._model, "async_generate"):
462
+ if "request_id" not in kwargs:
463
+ kwargs["request_id"] = str(uuid.uuid1())
464
+ else:
465
+ # model only accept string
466
+ kwargs["request_id"] = str(kwargs["request_id"])
452
467
  return await self._call_wrapper_json(
453
- self._model.async_generate, prompt, *args, **kwargs
468
+ self._model.async_generate,
469
+ prompt,
470
+ *args,
471
+ **kwargs,
454
472
  )
455
473
  raise AttributeError(f"Model {self._model.model_spec} is not for generate.")
456
474
 
@@ -481,22 +499,27 @@ class ModelActor(xo.StatelessActor):
481
499
  yield res
482
500
 
483
501
  @staticmethod
484
- def _get_stream_from_args(ability: str, *args) -> bool:
485
- if ability == "chat":
486
- assert args[2] is None or isinstance(args[2], dict)
487
- return False if args[2] is None else args[2].get("stream", False)
488
- else:
489
- assert args[0] is None or isinstance(args[0], dict)
490
- return False if args[0] is None else args[0].get("stream", False)
502
+ def _get_stream_from_args(*args) -> bool:
503
+ assert args[0] is None or isinstance(args[0], dict)
504
+ return False if args[0] is None else args[0].get("stream", False)
491
505
 
492
- async def handle_batching_request(self, prompt: str, ability: str, *args, **kwargs):
493
- stream = self._get_stream_from_args(ability, *args)
506
+ async def handle_batching_request(
507
+ self, prompt_or_messages: Union[str, List[Dict]], call_ability, *args, **kwargs
508
+ ):
509
+ """
510
+ The input parameter `prompt_or_messages`:
511
+ - when the model_ability is `generate`, it's `prompt`, which is str type.
512
+ - when the model_ability is `chat`, it's `messages`, which is List[Dict] type.
513
+ """
514
+ stream = self._get_stream_from_args(*args)
494
515
  assert self._scheduler_ref is not None
495
516
  if stream:
496
517
  assert self._scheduler_ref is not None
497
518
  queue: Queue[Any] = Queue()
498
519
  ret = self._queue_consumer(queue)
499
- await self._scheduler_ref.add_request(prompt, queue, *args, **kwargs)
520
+ await self._scheduler_ref.add_request(
521
+ prompt_or_messages, queue, call_ability, *args, **kwargs
522
+ )
500
523
  gen = self._to_async_gen("json", ret)
501
524
  self._current_generator = weakref.ref(gen)
502
525
  return gen
@@ -505,7 +528,9 @@ class ModelActor(xo.StatelessActor):
505
528
 
506
529
  assert self._loop is not None
507
530
  future = ConcurrentFuture()
508
- await self._scheduler_ref.add_request(prompt, future, *args, **kwargs)
531
+ await self._scheduler_ref.add_request(
532
+ prompt_or_messages, future, call_ability, *args, **kwargs
533
+ )
509
534
  fut = asyncio.wrap_future(future, loop=self._loop)
510
535
  result = await fut
511
536
  if result == XINFERENCE_NON_STREAMING_ABORT_FLAG:
@@ -514,27 +539,36 @@ class ModelActor(xo.StatelessActor):
514
539
  )
515
540
  return await asyncio.to_thread(json_dumps, result)
516
541
 
517
- @log_async(logger=logger)
518
542
  @request_limit
519
543
  @xo.generator
520
- async def chat(self, prompt: str, *args, **kwargs):
544
+ @log_async(logger=logger)
545
+ async def chat(self, messages: List[Dict], *args, **kwargs):
521
546
  start_time = time.time()
522
547
  response = None
523
548
  try:
524
549
  if self.allow_batching():
550
+ # not support request_id
551
+ kwargs.pop("request_id", None)
525
552
  return await self.handle_batching_request(
526
- prompt, "chat", *args, **kwargs
553
+ messages, "chat", *args, **kwargs
527
554
  )
528
555
  else:
529
556
  kwargs.pop("raw_params", None)
530
557
  if hasattr(self._model, "chat"):
558
+ # not support request_id
559
+ kwargs.pop("request_id", None)
531
560
  response = await self._call_wrapper_json(
532
- self._model.chat, prompt, *args, **kwargs
561
+ self._model.chat, messages, *args, **kwargs
533
562
  )
534
563
  return response
535
564
  if hasattr(self._model, "async_chat"):
565
+ if "request_id" not in kwargs:
566
+ kwargs["request_id"] = str(uuid.uuid1())
567
+ else:
568
+ # model only accept string
569
+ kwargs["request_id"] = str(kwargs["request_id"])
536
570
  response = await self._call_wrapper_json(
537
- self._model.async_chat, prompt, *args, **kwargs
571
+ self._model.async_chat, messages, *args, **kwargs
538
572
  )
539
573
  return response
540
574
  raise AttributeError(f"Model {self._model.model_spec} is not for chat.")
@@ -565,9 +599,10 @@ class ModelActor(xo.StatelessActor):
565
599
  return await self._scheduler_ref.abort_request(request_id)
566
600
  return AbortRequestMessage.NO_OP.name
567
601
 
568
- @log_async(logger=logger)
569
602
  @request_limit
603
+ @log_async(logger=logger)
570
604
  async def create_embedding(self, input: Union[str, List[str]], *args, **kwargs):
605
+ kwargs.pop("request_id", None)
571
606
  if hasattr(self._model, "create_embedding"):
572
607
  return await self._call_wrapper_json(
573
608
  self._model.create_embedding, input, *args, **kwargs
@@ -577,8 +612,8 @@ class ModelActor(xo.StatelessActor):
577
612
  f"Model {self._model.model_spec} is not for creating embedding."
578
613
  )
579
614
 
580
- @log_async(logger=logger)
581
615
  @request_limit
616
+ @log_async(logger=logger)
582
617
  async def rerank(
583
618
  self,
584
619
  documents: List[str],
@@ -590,6 +625,7 @@ class ModelActor(xo.StatelessActor):
590
625
  *args,
591
626
  **kwargs,
592
627
  ):
628
+ kwargs.pop("request_id", None)
593
629
  if hasattr(self._model, "rerank"):
594
630
  return await self._call_wrapper_json(
595
631
  self._model.rerank,
@@ -604,8 +640,8 @@ class ModelActor(xo.StatelessActor):
604
640
  )
605
641
  raise AttributeError(f"Model {self._model.model_spec} is not for reranking.")
606
642
 
607
- @log_async(logger=logger, args_formatter=lambda _, kwargs: kwargs.pop("audio"))
608
643
  @request_limit
644
+ @log_async(logger=logger, ignore_kwargs=["audio"])
609
645
  async def transcriptions(
610
646
  self,
611
647
  audio: bytes,
@@ -614,7 +650,9 @@ class ModelActor(xo.StatelessActor):
614
650
  response_format: str = "json",
615
651
  temperature: float = 0,
616
652
  timestamp_granularities: Optional[List[str]] = None,
653
+ **kwargs,
617
654
  ):
655
+ kwargs.pop("request_id", None)
618
656
  if hasattr(self._model, "transcriptions"):
619
657
  return await self._call_wrapper_json(
620
658
  self._model.transcriptions,
@@ -629,8 +667,8 @@ class ModelActor(xo.StatelessActor):
629
667
  f"Model {self._model.model_spec} is not for creating transcriptions."
630
668
  )
631
669
 
632
- @log_async(logger=logger, args_formatter=lambda _, kwargs: kwargs.pop("audio"))
633
670
  @request_limit
671
+ @log_async(logger=logger, ignore_kwargs=["audio"])
634
672
  async def translations(
635
673
  self,
636
674
  audio: bytes,
@@ -639,7 +677,9 @@ class ModelActor(xo.StatelessActor):
639
677
  response_format: str = "json",
640
678
  temperature: float = 0,
641
679
  timestamp_granularities: Optional[List[str]] = None,
680
+ **kwargs,
642
681
  ):
682
+ kwargs.pop("request_id", None)
643
683
  if hasattr(self._model, "translations"):
644
684
  return await self._call_wrapper_json(
645
685
  self._model.translations,
@@ -654,12 +694,9 @@ class ModelActor(xo.StatelessActor):
654
694
  f"Model {self._model.model_spec} is not for creating translations."
655
695
  )
656
696
 
657
- @log_async(
658
- logger=logger,
659
- args_formatter=lambda _, kwargs: kwargs.pop("prompt_speech", None),
660
- )
661
697
  @request_limit
662
698
  @xo.generator
699
+ @log_async(logger=logger, ignore_kwargs=["prompt_speech"])
663
700
  async def speech(
664
701
  self,
665
702
  input: str,
@@ -669,6 +706,7 @@ class ModelActor(xo.StatelessActor):
669
706
  stream: bool = False,
670
707
  **kwargs,
671
708
  ):
709
+ kwargs.pop("request_id", None)
672
710
  if hasattr(self._model, "speech"):
673
711
  return await self._call_wrapper_binary(
674
712
  self._model.speech,
@@ -683,8 +721,8 @@ class ModelActor(xo.StatelessActor):
683
721
  f"Model {self._model.model_spec} is not for creating speech."
684
722
  )
685
723
 
686
- @log_async(logger=logger)
687
724
  @request_limit
725
+ @log_async(logger=logger)
688
726
  async def text_to_image(
689
727
  self,
690
728
  prompt: str,
@@ -694,6 +732,7 @@ class ModelActor(xo.StatelessActor):
694
732
  *args,
695
733
  **kwargs,
696
734
  ):
735
+ kwargs.pop("request_id", None)
697
736
  if hasattr(self._model, "text_to_image"):
698
737
  return await self._call_wrapper_json(
699
738
  self._model.text_to_image,
@@ -708,6 +747,24 @@ class ModelActor(xo.StatelessActor):
708
747
  f"Model {self._model.model_spec} is not for creating image."
709
748
  )
710
749
 
750
+ @request_limit
751
+ @log_async(logger=logger)
752
+ async def txt2img(
753
+ self,
754
+ **kwargs,
755
+ ):
756
+ kwargs.pop("request_id", None)
757
+ if hasattr(self._model, "txt2img"):
758
+ return await self._call_wrapper_json(
759
+ self._model.txt2img,
760
+ **kwargs,
761
+ )
762
+ raise AttributeError(f"Model {self._model.model_spec} is not for txt2img.")
763
+
764
+ @log_async(
765
+ logger=logger,
766
+ ignore_kwargs=["image"],
767
+ )
711
768
  async def image_to_image(
712
769
  self,
713
770
  image: "PIL.Image",
@@ -719,6 +776,7 @@ class ModelActor(xo.StatelessActor):
719
776
  *args,
720
777
  **kwargs,
721
778
  ):
779
+ kwargs.pop("request_id", None)
722
780
  if hasattr(self._model, "image_to_image"):
723
781
  return await self._call_wrapper_json(
724
782
  self._model.image_to_image,
@@ -735,6 +793,24 @@ class ModelActor(xo.StatelessActor):
735
793
  f"Model {self._model.model_spec} is not for creating image."
736
794
  )
737
795
 
796
+ @request_limit
797
+ @log_async(logger=logger)
798
+ async def img2img(
799
+ self,
800
+ **kwargs,
801
+ ):
802
+ kwargs.pop("request_id", None)
803
+ if hasattr(self._model, "img2img"):
804
+ return await self._call_wrapper_json(
805
+ self._model.img2img,
806
+ **kwargs,
807
+ )
808
+ raise AttributeError(f"Model {self._model.model_spec} is not for img2img.")
809
+
810
+ @log_async(
811
+ logger=logger,
812
+ ignore_kwargs=["image"],
813
+ )
738
814
  async def inpainting(
739
815
  self,
740
816
  image: "PIL.Image",
@@ -747,6 +823,7 @@ class ModelActor(xo.StatelessActor):
747
823
  *args,
748
824
  **kwargs,
749
825
  ):
826
+ kwargs.pop("request_id", None)
750
827
  if hasattr(self._model, "inpainting"):
751
828
  return await self._call_wrapper_json(
752
829
  self._model.inpainting,
@@ -764,12 +841,13 @@ class ModelActor(xo.StatelessActor):
764
841
  f"Model {self._model.model_spec} is not for creating image."
765
842
  )
766
843
 
767
- @log_async(logger=logger)
768
844
  @request_limit
845
+ @log_async(logger=logger, ignore_kwargs=["image"])
769
846
  async def infer(
770
847
  self,
771
848
  **kwargs,
772
849
  ):
850
+ kwargs.pop("request_id", None)
773
851
  if hasattr(self._model, "infer"):
774
852
  return await self._call_wrapper_json(
775
853
  self._model.infer,
@@ -779,8 +857,8 @@ class ModelActor(xo.StatelessActor):
779
857
  f"Model {self._model.model_spec} is not for flexible infer."
780
858
  )
781
859
 
782
- @log_async(logger=logger)
783
860
  @request_limit
861
+ @log_async(logger=logger)
784
862
  async def text_to_video(
785
863
  self,
786
864
  prompt: str,
@@ -788,6 +866,7 @@ class ModelActor(xo.StatelessActor):
788
866
  *args,
789
867
  **kwargs,
790
868
  ):
869
+ kwargs.pop("request_id", None)
791
870
  if hasattr(self._model, "text_to_video"):
792
871
  return await self._call_wrapper_json(
793
872
  self._model.text_to_video,
@@ -18,7 +18,7 @@ import logging
18
18
  import uuid
19
19
  from collections import deque
20
20
  from enum import Enum
21
- from typing import List, Optional, Set, Tuple
21
+ from typing import Dict, List, Optional, Set, Tuple, Union
22
22
 
23
23
  import xoscar as xo
24
24
 
@@ -37,13 +37,24 @@ class AbortRequestMessage(Enum):
37
37
 
38
38
 
39
39
  class InferenceRequest:
40
- def __init__(self, prompt, future_or_queue, is_prefill, *args, **kwargs):
41
- # original prompt
42
- self._prompt = prompt
40
+ def __init__(
41
+ self,
42
+ prompt_or_messages,
43
+ future_or_queue,
44
+ is_prefill,
45
+ call_ability,
46
+ *args,
47
+ **kwargs,
48
+ ):
49
+ # original prompt, prompt(str) for generate model and messages(List[Dict]) for chat model
50
+ self._prompt = prompt_or_messages
43
51
  # full prompt that contains chat history and applies chat template
44
52
  self._full_prompt = None
45
53
  # whether the current request is in the prefill phase
46
54
  self._is_prefill = is_prefill
55
+ # the ability that the user calls this model for, that is `generate` / `chat` for now,
56
+ # which is for results formatting
57
+ self._call_ability = call_ability
47
58
  # full prompt tokens
48
59
  self._prompt_tokens = None
49
60
  # all new generated tokens during decode phase
@@ -88,38 +99,22 @@ class InferenceRequest:
88
99
  self._check_args()
89
100
 
90
101
  def _check_args(self):
91
- # chat
92
- if len(self._inference_args) == 3:
93
- # system prompt
94
- assert self._inference_args[0] is None or isinstance(
95
- self._inference_args[0], str
96
- )
97
- # chat history
98
- assert self._inference_args[1] is None or isinstance(
99
- self._inference_args[1], list
100
- )
101
- # generate config
102
- assert self._inference_args[2] is None or isinstance(
103
- self._inference_args[2], dict
104
- )
105
- else: # generate
106
- assert len(self._inference_args) == 1
107
- # generate config
108
- assert self._inference_args[0] is None or isinstance(
109
- self._inference_args[0], dict
110
- )
102
+ assert len(self._inference_args) == 1
103
+ # generate config
104
+ assert self._inference_args[0] is None or isinstance(
105
+ self._inference_args[0], dict
106
+ )
111
107
 
112
108
  @property
113
109
  def prompt(self):
110
+ """
111
+ prompt for generate model and messages for chat model
112
+ """
114
113
  return self._prompt
115
114
 
116
115
  @property
117
- def system_prompt(self):
118
- return self._inference_args[0]
119
-
120
- @property
121
- def chat_history(self):
122
- return self._inference_args[1]
116
+ def call_ability(self):
117
+ return self._call_ability
123
118
 
124
119
  @property
125
120
  def full_prompt(self):
@@ -162,11 +157,7 @@ class InferenceRequest:
162
157
 
163
158
  @property
164
159
  def generate_config(self):
165
- return (
166
- self._inference_args[2]
167
- if len(self._inference_args) == 3
168
- else self._inference_args[0]
169
- )
160
+ return self._inference_args[0]
170
161
 
171
162
  @property
172
163
  def sanitized_generate_config(self):
@@ -423,8 +414,17 @@ class SchedulerActor(xo.StatelessActor):
423
414
 
424
415
  self._empty_cache()
425
416
 
426
- async def add_request(self, prompt: str, future_or_queue, *args, **kwargs):
427
- req = InferenceRequest(prompt, future_or_queue, True, *args, **kwargs)
417
+ async def add_request(
418
+ self,
419
+ prompt_or_messages: Union[str, List[Dict]],
420
+ future_or_queue,
421
+ call_ability,
422
+ *args,
423
+ **kwargs,
424
+ ):
425
+ req = InferenceRequest(
426
+ prompt_or_messages, future_or_queue, True, call_ability, *args, **kwargs
427
+ )
428
428
  rid = req.request_id
429
429
  if rid is not None:
430
430
  if rid in self._id_to_req:
@@ -51,7 +51,7 @@ class StatusGuardActor(xo.StatelessActor):
51
51
  self._model_uid_to_info: Dict[str, InstanceInfo] = {} # type: ignore
52
52
 
53
53
  @classmethod
54
- def uid(cls) -> str:
54
+ def default_uid(cls) -> str:
55
55
  return "status_guard"
56
56
 
57
57
  @staticmethod
@@ -105,7 +105,7 @@ class SupervisorActor(xo.StatelessActor):
105
105
  self._lock = asyncio.Lock()
106
106
 
107
107
  @classmethod
108
- def uid(cls) -> str:
108
+ def default_uid(cls) -> str:
109
109
  return "supervisor"
110
110
 
111
111
  def _get_worker_ref_by_ip(
@@ -135,12 +135,12 @@ class SupervisorActor(xo.StatelessActor):
135
135
  self._status_guard_ref: xo.ActorRefType[ # type: ignore
136
136
  "StatusGuardActor"
137
137
  ] = await xo.create_actor(
138
- StatusGuardActor, address=self.address, uid=StatusGuardActor.uid()
138
+ StatusGuardActor, address=self.address, uid=StatusGuardActor.default_uid()
139
139
  )
140
140
  self._cache_tracker_ref: xo.ActorRefType[ # type: ignore
141
141
  "CacheTrackerActor"
142
142
  ] = await xo.create_actor(
143
- CacheTrackerActor, address=self.address, uid=CacheTrackerActor.uid()
143
+ CacheTrackerActor, address=self.address, uid=CacheTrackerActor.default_uid()
144
144
  )
145
145
 
146
146
  from .event import EventCollectorActor
@@ -148,7 +148,9 @@ class SupervisorActor(xo.StatelessActor):
148
148
  self._event_collector_ref: xo.ActorRefType[ # type: ignore
149
149
  EventCollectorActor
150
150
  ] = await xo.create_actor(
151
- EventCollectorActor, address=self.address, uid=EventCollectorActor.uid()
151
+ EventCollectorActor,
152
+ address=self.address,
153
+ uid=EventCollectorActor.default_uid(),
152
154
  )
153
155
 
154
156
  from ..model.audio import (
@@ -308,14 +310,12 @@ class SupervisorActor(xo.StatelessActor):
308
310
  async def get_builtin_prompts() -> Dict[str, Any]:
309
311
  from ..model.llm.llm_family import BUILTIN_LLM_PROMPT_STYLE
310
312
 
311
- data = {}
312
- for k, v in BUILTIN_LLM_PROMPT_STYLE.items():
313
- data[k] = v.dict()
314
- return data
313
+ return {k: v for k, v in BUILTIN_LLM_PROMPT_STYLE.items()}
315
314
 
316
315
  @staticmethod
317
316
  async def get_builtin_families() -> Dict[str, List[str]]:
318
317
  from ..model.llm.llm_family import (
318
+ BUILTIN_LLM_FAMILIES,
319
319
  BUILTIN_LLM_MODEL_CHAT_FAMILIES,
320
320
  BUILTIN_LLM_MODEL_GENERATE_FAMILIES,
321
321
  BUILTIN_LLM_MODEL_TOOL_CALL_FAMILIES,
@@ -325,6 +325,11 @@ class SupervisorActor(xo.StatelessActor):
325
325
  "chat": list(BUILTIN_LLM_MODEL_CHAT_FAMILIES),
326
326
  "generate": list(BUILTIN_LLM_MODEL_GENERATE_FAMILIES),
327
327
  "tools": list(BUILTIN_LLM_MODEL_TOOL_CALL_FAMILIES),
328
+ "vision": [
329
+ family.model_name
330
+ for family in BUILTIN_LLM_FAMILIES
331
+ if "vision" in family.model_ability
332
+ ],
328
333
  }
329
334
 
330
335
  async def get_devices_count(self) -> int:
@@ -1028,7 +1033,7 @@ class SupervisorActor(xo.StatelessActor):
1028
1033
  else:
1029
1034
  task = asyncio.create_task(_launch_model())
1030
1035
  ASYNC_LAUNCH_TASKS[model_uid] = task
1031
- task.add_done_callback(lambda _: callback_for_async_launch(model_uid))
1036
+ task.add_done_callback(lambda _: callback_for_async_launch(model_uid)) # type: ignore
1032
1037
  return model_uid
1033
1038
 
1034
1039
  async def get_instance_info(
@@ -1233,7 +1238,9 @@ class SupervisorActor(xo.StatelessActor):
1233
1238
  worker_address not in self._worker_address_to_worker
1234
1239
  ), f"Worker {worker_address} exists"
1235
1240
 
1236
- worker_ref = await xo.actor_ref(address=worker_address, uid=WorkerActor.uid())
1241
+ worker_ref = await xo.actor_ref(
1242
+ address=worker_address, uid=WorkerActor.default_uid()
1243
+ )
1237
1244
  self._worker_address_to_worker[worker_address] = worker_ref
1238
1245
  logger.debug("Worker %s has been added successfully", worker_address)
1239
1246