xinference 0.15.0__py3-none-any.whl → 0.15.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (83) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +204 -1
  3. xinference/client/restful/restful_client.py +4 -2
  4. xinference/core/image_interface.py +28 -0
  5. xinference/core/model.py +28 -0
  6. xinference/core/supervisor.py +6 -0
  7. xinference/model/audio/fish_speech.py +9 -9
  8. xinference/model/audio/model_spec.json +9 -9
  9. xinference/model/audio/whisper.py +4 -1
  10. xinference/model/image/core.py +2 -1
  11. xinference/model/image/model_spec.json +16 -4
  12. xinference/model/image/model_spec_modelscope.json +16 -4
  13. xinference/model/image/sdapi.py +136 -0
  14. xinference/model/image/stable_diffusion/core.py +148 -20
  15. xinference/model/llm/__init__.py +8 -0
  16. xinference/model/llm/llm_family.json +393 -0
  17. xinference/model/llm/llm_family.py +3 -1
  18. xinference/model/llm/llm_family_modelscope.json +408 -3
  19. xinference/model/llm/sglang/core.py +3 -0
  20. xinference/model/llm/transformers/chatglm.py +1 -1
  21. xinference/model/llm/transformers/core.py +6 -0
  22. xinference/model/llm/transformers/deepseek_v2.py +340 -0
  23. xinference/model/llm/transformers/qwen2_audio.py +168 -0
  24. xinference/model/llm/transformers/qwen2_vl.py +31 -5
  25. xinference/model/llm/utils.py +104 -84
  26. xinference/model/llm/vllm/core.py +8 -0
  27. xinference/thirdparty/fish_speech/fish_speech/configs/firefly_gan_vq.yaml +2 -3
  28. xinference/thirdparty/fish_speech/fish_speech/configs/text2semantic_finetune.yaml +1 -1
  29. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +1 -1
  30. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +1 -1
  31. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +1 -1
  32. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +1 -1
  33. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +1 -1
  34. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +2 -2
  35. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +0 -3
  36. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +169 -198
  37. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +4 -27
  38. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +9 -47
  39. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +2 -2
  40. xinference/thirdparty/fish_speech/fish_speech/train.py +2 -0
  41. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +12 -10
  42. xinference/thirdparty/fish_speech/tools/api.py +79 -134
  43. xinference/thirdparty/fish_speech/tools/commons.py +35 -0
  44. xinference/thirdparty/fish_speech/tools/download_models.py +3 -3
  45. xinference/thirdparty/fish_speech/tools/file.py +17 -0
  46. xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +1 -1
  47. xinference/thirdparty/fish_speech/tools/llama/generate.py +29 -24
  48. xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +1 -1
  49. xinference/thirdparty/fish_speech/tools/llama/quantize.py +2 -2
  50. xinference/thirdparty/fish_speech/tools/msgpack_api.py +34 -0
  51. xinference/thirdparty/fish_speech/tools/post_api.py +85 -44
  52. xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +1 -1
  53. xinference/thirdparty/fish_speech/tools/smart_pad.py +16 -3
  54. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +2 -2
  55. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +4 -2
  56. xinference/thirdparty/fish_speech/tools/webui.py +12 -146
  57. xinference/types.py +7 -4
  58. xinference/web/ui/build/asset-manifest.json +6 -6
  59. xinference/web/ui/build/index.html +1 -1
  60. xinference/web/ui/build/static/css/{main.632e9148.css → main.5061c4c3.css} +2 -2
  61. xinference/web/ui/build/static/css/main.5061c4c3.css.map +1 -0
  62. xinference/web/ui/build/static/js/{main.9cfafbd6.js → main.754740c0.js} +3 -3
  63. xinference/web/ui/build/static/js/main.754740c0.js.map +1 -0
  64. xinference/web/ui/node_modules/.cache/babel-loader/cd90b08d177025dfe84209596fc51878f8a86bcaa6a240848a3d2e5fd4c7ff24.json +1 -0
  65. xinference/web/ui/node_modules/.cache/babel-loader/e42b72d4cc1ea412ebecbb8d040dc6c6bfee462c33903c2f1f3facb602ad742e.json +1 -0
  66. {xinference-0.15.0.dist-info → xinference-0.15.1.dist-info}/METADATA +9 -3
  67. {xinference-0.15.0.dist-info → xinference-0.15.1.dist-info}/RECORD +72 -74
  68. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +0 -442
  69. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +0 -44
  70. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +0 -115
  71. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +0 -225
  72. xinference/thirdparty/fish_speech/tools/auto_rerank.py +0 -159
  73. xinference/thirdparty/fish_speech/tools/gen_ref.py +0 -36
  74. xinference/thirdparty/fish_speech/tools/merge_asr_files.py +0 -55
  75. xinference/web/ui/build/static/css/main.632e9148.css.map +0 -1
  76. xinference/web/ui/build/static/js/main.9cfafbd6.js.map +0 -1
  77. xinference/web/ui/node_modules/.cache/babel-loader/01d6d198156bacbd436c51435edbd4b2cacd47a79db929105eba30f74b67d48d.json +0 -1
  78. xinference/web/ui/node_modules/.cache/babel-loader/59eb25f514afcc4fefd1b309d192b2455f1e0aec68a9de598ca4b2333fe2c774.json +0 -1
  79. /xinference/web/ui/build/static/js/{main.9cfafbd6.js.LICENSE.txt → main.754740c0.js.LICENSE.txt} +0 -0
  80. {xinference-0.15.0.dist-info → xinference-0.15.1.dist-info}/LICENSE +0 -0
  81. {xinference-0.15.0.dist-info → xinference-0.15.1.dist-info}/WHEEL +0 -0
  82. {xinference-0.15.0.dist-info → xinference-0.15.1.dist-info}/entry_points.txt +0 -0
  83. {xinference-0.15.0.dist-info → xinference-0.15.1.dist-info}/top_level.txt +0 -0
xinference/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2024-09-06T16:29:42+0800",
11
+ "date": "2024-09-14T13:22:13+0800",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "e2618be96293f112709c9ceed639a3443455a0e7",
15
- "version": "0.15.0"
14
+ "full-revisionid": "961d355102007e3cd7963a353105b2422a31d4fd",
15
+ "version": "0.15.1"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -63,6 +63,7 @@ from ..types import (
63
63
  CreateCompletion,
64
64
  ImageList,
65
65
  PeftModelConfig,
66
+ SDAPIResult,
66
67
  VideoList,
67
68
  max_tokens_field,
68
69
  )
@@ -122,6 +123,43 @@ class TextToImageRequest(BaseModel):
122
123
  user: Optional[str] = None
123
124
 
124
125
 
126
+ class SDAPIOptionsRequest(BaseModel):
127
+ sd_model_checkpoint: Optional[str] = None
128
+
129
+
130
+ class SDAPITxt2imgRequst(BaseModel):
131
+ model: Optional[str]
132
+ prompt: Optional[str] = ""
133
+ negative_prompt: Optional[str] = ""
134
+ steps: Optional[int] = None
135
+ seed: Optional[int] = -1
136
+ cfg_scale: Optional[float] = 7.0
137
+ override_settings: Optional[dict] = {}
138
+ width: Optional[int] = 512
139
+ height: Optional[int] = 512
140
+ sampler_name: Optional[str] = None
141
+ denoising_strength: Optional[float] = None
142
+ kwargs: Optional[str] = None
143
+ user: Optional[str] = None
144
+
145
+
146
+ class SDAPIImg2imgRequst(BaseModel):
147
+ model: Optional[str]
148
+ init_images: Optional[list]
149
+ prompt: Optional[str] = ""
150
+ negative_prompt: Optional[str] = ""
151
+ steps: Optional[int] = None
152
+ seed: Optional[int] = -1
153
+ cfg_scale: Optional[float] = 7.0
154
+ override_settings: Optional[dict] = {}
155
+ width: Optional[int] = 512
156
+ height: Optional[int] = 512
157
+ sampler_name: Optional[str] = None
158
+ denoising_strength: Optional[float] = None
159
+ kwargs: Optional[str] = None
160
+ user: Optional[str] = None
161
+
162
+
125
163
  class TextToVideoRequest(BaseModel):
126
164
  model: str
127
165
  prompt: Union[str, List[str]] = Field(description="The input to embed.")
@@ -163,7 +201,7 @@ class BuildGradioImageInterfaceRequest(BaseModel):
163
201
  model_name: str
164
202
  model_family: str
165
203
  model_id: str
166
- controlnet: Union[None, List[Dict[str, Union[str, None]]]]
204
+ controlnet: Union[None, List[Dict[str, Union[str, dict, None]]]]
167
205
  model_revision: str
168
206
  model_ability: List[str]
169
207
 
@@ -519,6 +557,59 @@ class RESTfulAPI:
519
557
  else None
520
558
  ),
521
559
  )
560
+ # SD WebUI API
561
+ self._router.add_api_route(
562
+ "/sdapi/v1/options",
563
+ self.sdapi_options,
564
+ methods=["POST"],
565
+ dependencies=(
566
+ [Security(self._auth_service, scopes=["models:read"])]
567
+ if self.is_authenticated()
568
+ else None
569
+ ),
570
+ )
571
+ self._router.add_api_route(
572
+ "/sdapi/v1/sd-models",
573
+ self.sdapi_sd_models,
574
+ methods=["GET"],
575
+ dependencies=(
576
+ [Security(self._auth_service, scopes=["models:read"])]
577
+ if self.is_authenticated()
578
+ else None
579
+ ),
580
+ )
581
+ self._router.add_api_route(
582
+ "/sdapi/v1/samplers",
583
+ self.sdapi_samplers,
584
+ methods=["GET"],
585
+ dependencies=(
586
+ [Security(self._auth_service, scopes=["models:read"])]
587
+ if self.is_authenticated()
588
+ else None
589
+ ),
590
+ )
591
+ self._router.add_api_route(
592
+ "/sdapi/v1/txt2img",
593
+ self.sdapi_txt2img,
594
+ methods=["POST"],
595
+ response_model=SDAPIResult,
596
+ dependencies=(
597
+ [Security(self._auth_service, scopes=["models:read"])]
598
+ if self.is_authenticated()
599
+ else None
600
+ ),
601
+ )
602
+ self._router.add_api_route(
603
+ "/sdapi/v1/img2img",
604
+ self.sdapi_img2img,
605
+ methods=["POST"],
606
+ response_model=SDAPIResult,
607
+ dependencies=(
608
+ [Security(self._auth_service, scopes=["models:read"])]
609
+ if self.is_authenticated()
610
+ else None
611
+ ),
612
+ )
522
613
  self._router.add_api_route(
523
614
  "/v1/video/generations",
524
615
  self.create_videos,
@@ -1429,6 +1520,118 @@ class RESTfulAPI:
1429
1520
  await self._report_error_event(model_uid, str(e))
1430
1521
  raise HTTPException(status_code=500, detail=str(e))
1431
1522
 
1523
+ async def sdapi_options(self, request: Request) -> Response:
1524
+ body = SDAPIOptionsRequest.parse_obj(await request.json())
1525
+ model_uid = body.sd_model_checkpoint
1526
+
1527
+ try:
1528
+ if not model_uid:
1529
+ raise ValueError("Unknown model")
1530
+ await (await self._get_supervisor_ref()).get_model(model_uid)
1531
+ return Response()
1532
+ except ValueError as ve:
1533
+ logger.error(str(ve), exc_info=True)
1534
+ await self._report_error_event(model_uid, str(ve))
1535
+ raise HTTPException(status_code=400, detail=str(ve))
1536
+ except Exception as e:
1537
+ logger.error(e, exc_info=True)
1538
+ await self._report_error_event(model_uid, str(e))
1539
+ raise HTTPException(status_code=500, detail=str(e))
1540
+
1541
+ async def sdapi_sd_models(self, request: Request) -> Response:
1542
+ try:
1543
+ models = await (await self._get_supervisor_ref()).list_models()
1544
+ sd_models = []
1545
+ for model_name, info in models.items():
1546
+ if info["model_type"] != "image":
1547
+ continue
1548
+ sd_models.append({"model_name": model_name, "config": None})
1549
+ return JSONResponse(content=sd_models)
1550
+ except Exception as e:
1551
+ logger.error(e, exc_info=True)
1552
+ raise HTTPException(status_code=500, detail=str(e))
1553
+
1554
+ async def sdapi_samplers(self, request: Request) -> Response:
1555
+ try:
1556
+ from ..model.image.stable_diffusion.core import SAMPLING_METHODS
1557
+
1558
+ samplers = [
1559
+ {"name": sample_method, "alias": [], "options": {}}
1560
+ for sample_method in SAMPLING_METHODS
1561
+ ]
1562
+ return JSONResponse(content=samplers)
1563
+ except Exception as e:
1564
+ logger.error(e, exc_info=True)
1565
+ raise HTTPException(status_code=500, detail=str(e))
1566
+
1567
+ async def sdapi_txt2img(self, request: Request) -> Response:
1568
+ body = SDAPITxt2imgRequst.parse_obj(await request.json())
1569
+ model_uid = body.model or body.override_settings.get("sd_model_checkpoint")
1570
+
1571
+ try:
1572
+ if not model_uid:
1573
+ raise ValueError("Unknown model")
1574
+ model = await (await self._get_supervisor_ref()).get_model(model_uid)
1575
+ except ValueError as ve:
1576
+ logger.error(str(ve), exc_info=True)
1577
+ await self._report_error_event(model_uid, str(ve))
1578
+ raise HTTPException(status_code=400, detail=str(ve))
1579
+ except Exception as e:
1580
+ logger.error(e, exc_info=True)
1581
+ await self._report_error_event(model_uid, str(e))
1582
+ raise HTTPException(status_code=500, detail=str(e))
1583
+
1584
+ try:
1585
+ kwargs = dict(body)
1586
+ kwargs.update(json.loads(body.kwargs) if body.kwargs else {})
1587
+ image_list = await model.txt2img(
1588
+ **kwargs,
1589
+ )
1590
+ return Response(content=image_list, media_type="application/json")
1591
+ except RuntimeError as re:
1592
+ logger.error(re, exc_info=True)
1593
+ await self._report_error_event(model_uid, str(re))
1594
+ self.handle_request_limit_error(re)
1595
+ raise HTTPException(status_code=400, detail=str(re))
1596
+ except Exception as e:
1597
+ logger.error(e, exc_info=True)
1598
+ await self._report_error_event(model_uid, str(e))
1599
+ raise HTTPException(status_code=500, detail=str(e))
1600
+
1601
+ async def sdapi_img2img(self, request: Request) -> Response:
1602
+ body = SDAPIImg2imgRequst.parse_obj(await request.json())
1603
+ model_uid = body.model or body.override_settings.get("sd_model_checkpoint")
1604
+
1605
+ try:
1606
+ if not model_uid:
1607
+ raise ValueError("Unknown model")
1608
+ model = await (await self._get_supervisor_ref()).get_model(model_uid)
1609
+ except ValueError as ve:
1610
+ logger.error(str(ve), exc_info=True)
1611
+ await self._report_error_event(model_uid, str(ve))
1612
+ raise HTTPException(status_code=400, detail=str(ve))
1613
+ except Exception as e:
1614
+ logger.error(e, exc_info=True)
1615
+ await self._report_error_event(model_uid, str(e))
1616
+ raise HTTPException(status_code=500, detail=str(e))
1617
+
1618
+ try:
1619
+ kwargs = dict(body)
1620
+ kwargs.update(json.loads(body.kwargs) if body.kwargs else {})
1621
+ image_list = await model.img2img(
1622
+ **kwargs,
1623
+ )
1624
+ return Response(content=image_list, media_type="application/json")
1625
+ except RuntimeError as re:
1626
+ logger.error(re, exc_info=True)
1627
+ await self._report_error_event(model_uid, str(re))
1628
+ self.handle_request_limit_error(re)
1629
+ raise HTTPException(status_code=400, detail=str(re))
1630
+ except Exception as e:
1631
+ logger.error(e, exc_info=True)
1632
+ await self._report_error_event(model_uid, str(e))
1633
+ raise HTTPException(status_code=500, detail=str(e))
1634
+
1432
1635
  async def create_variations(
1433
1636
  self,
1434
1637
  model: str = Form(...),
@@ -709,10 +709,12 @@ class RESTfulAudioModelHandle(RESTfulModelHandle):
709
709
  )
710
710
  )
711
711
  response = requests.post(
712
- url, data=params, files=files, headers=self.auth_headers
712
+ url, data=params, files=files, headers=self.auth_headers, stream=stream
713
713
  )
714
714
  else:
715
- response = requests.post(url, json=params, headers=self.auth_headers)
715
+ response = requests.post(
716
+ url, json=params, headers=self.auth_headers, stream=stream
717
+ )
716
718
  if response.status_code != 200:
717
719
  raise RuntimeError(
718
720
  f"Failed to speech the text, detail: {_get_error_string(response)}"
@@ -73,13 +73,17 @@ class ImageInterface:
73
73
  return interface
74
74
 
75
75
  def text2image_interface(self) -> "gr.Blocks":
76
+ from ..model.image.stable_diffusion.core import SAMPLING_METHODS
77
+
76
78
  def text_generate_image(
77
79
  prompt: str,
78
80
  n: int,
79
81
  size_width: int,
80
82
  size_height: int,
83
+ guidance_scale: int,
81
84
  num_inference_steps: int,
82
85
  negative_prompt: Optional[str] = None,
86
+ sampler_name: Optional[str] = None,
83
87
  ) -> PIL.Image.Image:
84
88
  from ..client import RESTfulClient
85
89
 
@@ -89,16 +93,20 @@ class ImageInterface:
89
93
  assert isinstance(model, RESTfulImageModelHandle)
90
94
 
91
95
  size = f"{int(size_width)}*{int(size_height)}"
96
+ guidance_scale = None if guidance_scale == -1 else guidance_scale # type: ignore
92
97
  num_inference_steps = (
93
98
  None if num_inference_steps == -1 else num_inference_steps # type: ignore
94
99
  )
100
+ sampler_name = None if sampler_name == "default" else sampler_name
95
101
 
96
102
  response = model.text_to_image(
97
103
  prompt=prompt,
98
104
  n=n,
99
105
  size=size,
100
106
  num_inference_steps=num_inference_steps,
107
+ guidance_scale=guidance_scale,
101
108
  negative_prompt=negative_prompt,
109
+ sampler_name=sampler_name,
102
110
  response_format="b64_json",
103
111
  )
104
112
 
@@ -132,9 +140,16 @@ class ImageInterface:
132
140
  n = gr.Number(label="Number of Images", value=1)
133
141
  size_width = gr.Number(label="Width", value=1024)
134
142
  size_height = gr.Number(label="Height", value=1024)
143
+ with gr.Row():
144
+ guidance_scale = gr.Number(label="Guidance scale", value=-1)
135
145
  num_inference_steps = gr.Number(
136
146
  label="Inference Step Number", value=-1
137
147
  )
148
+ sampler_name = gr.Dropdown(
149
+ choices=SAMPLING_METHODS,
150
+ value="default",
151
+ label="Sampling method",
152
+ )
138
153
 
139
154
  with gr.Column():
140
155
  image_output = gr.Gallery()
@@ -146,8 +161,10 @@ class ImageInterface:
146
161
  n,
147
162
  size_width,
148
163
  size_height,
164
+ guidance_scale,
149
165
  num_inference_steps,
150
166
  negative_prompt,
167
+ sampler_name,
151
168
  ],
152
169
  outputs=image_output,
153
170
  )
@@ -155,6 +172,8 @@ class ImageInterface:
155
172
  return text2image_vl_interface
156
173
 
157
174
  def image2image_interface(self) -> "gr.Blocks":
175
+ from ..model.image.stable_diffusion.core import SAMPLING_METHODS
176
+
158
177
  def image_generate_image(
159
178
  prompt: str,
160
179
  negative_prompt: str,
@@ -164,6 +183,7 @@ class ImageInterface:
164
183
  size_height: int,
165
184
  num_inference_steps: int,
166
185
  padding_image_to_multiple: int,
186
+ sampler_name: Optional[str] = None,
167
187
  ) -> PIL.Image.Image:
168
188
  from ..client import RESTfulClient
169
189
 
@@ -180,6 +200,7 @@ class ImageInterface:
180
200
  None if num_inference_steps == -1 else num_inference_steps # type: ignore
181
201
  )
182
202
  padding_image_to_multiple = None if padding_image_to_multiple == -1 else padding_image_to_multiple # type: ignore
203
+ sampler_name = None if sampler_name == "default" else sampler_name
183
204
 
184
205
  bio = io.BytesIO()
185
206
  image.save(bio, format="png")
@@ -193,6 +214,7 @@ class ImageInterface:
193
214
  response_format="b64_json",
194
215
  num_inference_steps=num_inference_steps,
195
216
  padding_image_to_multiple=padding_image_to_multiple,
217
+ sampler_name=sampler_name,
196
218
  )
197
219
 
198
220
  images = []
@@ -233,6 +255,11 @@ class ImageInterface:
233
255
  padding_image_to_multiple = gr.Number(
234
256
  label="Padding image to multiple", value=-1
235
257
  )
258
+ sampler_name = gr.Dropdown(
259
+ choices=SAMPLING_METHODS,
260
+ value="default",
261
+ label="Sampling method",
262
+ )
236
263
 
237
264
  with gr.Row():
238
265
  with gr.Column(scale=1):
@@ -251,6 +278,7 @@ class ImageInterface:
251
278
  size_height,
252
279
  num_inference_steps,
253
280
  padding_image_to_multiple,
281
+ sampler_name,
254
282
  ],
255
283
  outputs=output_gallery,
256
284
  )
xinference/core/model.py CHANGED
@@ -747,6 +747,20 @@ class ModelActor(xo.StatelessActor):
747
747
  f"Model {self._model.model_spec} is not for creating image."
748
748
  )
749
749
 
750
+ @request_limit
751
+ @log_async(logger=logger)
752
+ async def txt2img(
753
+ self,
754
+ **kwargs,
755
+ ):
756
+ kwargs.pop("request_id", None)
757
+ if hasattr(self._model, "txt2img"):
758
+ return await self._call_wrapper_json(
759
+ self._model.txt2img,
760
+ **kwargs,
761
+ )
762
+ raise AttributeError(f"Model {self._model.model_spec} is not for txt2img.")
763
+
750
764
  @log_async(
751
765
  logger=logger,
752
766
  ignore_kwargs=["image"],
@@ -779,6 +793,20 @@ class ModelActor(xo.StatelessActor):
779
793
  f"Model {self._model.model_spec} is not for creating image."
780
794
  )
781
795
 
796
+ @request_limit
797
+ @log_async(logger=logger)
798
+ async def img2img(
799
+ self,
800
+ **kwargs,
801
+ ):
802
+ kwargs.pop("request_id", None)
803
+ if hasattr(self._model, "img2img"):
804
+ return await self._call_wrapper_json(
805
+ self._model.img2img,
806
+ **kwargs,
807
+ )
808
+ raise AttributeError(f"Model {self._model.model_spec} is not for img2img.")
809
+
782
810
  @log_async(
783
811
  logger=logger,
784
812
  ignore_kwargs=["image"],
@@ -315,6 +315,7 @@ class SupervisorActor(xo.StatelessActor):
315
315
  @staticmethod
316
316
  async def get_builtin_families() -> Dict[str, List[str]]:
317
317
  from ..model.llm.llm_family import (
318
+ BUILTIN_LLM_FAMILIES,
318
319
  BUILTIN_LLM_MODEL_CHAT_FAMILIES,
319
320
  BUILTIN_LLM_MODEL_GENERATE_FAMILIES,
320
321
  BUILTIN_LLM_MODEL_TOOL_CALL_FAMILIES,
@@ -324,6 +325,11 @@ class SupervisorActor(xo.StatelessActor):
324
325
  "chat": list(BUILTIN_LLM_MODEL_CHAT_FAMILIES),
325
326
  "generate": list(BUILTIN_LLM_MODEL_GENERATE_FAMILIES),
326
327
  "tools": list(BUILTIN_LLM_MODEL_TOOL_CALL_FAMILIES),
328
+ "vision": [
329
+ family.model_name
330
+ for family in BUILTIN_LLM_FAMILIES
331
+ if "vision" in family.model_ability
332
+ ],
327
333
  }
328
334
 
329
335
  async def get_devices_count(self) -> int:
@@ -92,7 +92,7 @@ class FishSpeechModel:
92
92
 
93
93
  checkpoint_path = os.path.join(
94
94
  self._model_path,
95
- "firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
95
+ "firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
96
96
  )
97
97
  self._model = load_decoder_model(
98
98
  config_name="firefly_gan_vq",
@@ -159,11 +159,11 @@ class FishSpeechModel:
159
159
  segments = []
160
160
 
161
161
  while True:
162
- result: WrappedGenerateResponse = response_queue.get()
162
+ result: WrappedGenerateResponse = response_queue.get() # type: ignore
163
163
  if result.status == "error":
164
164
  raise Exception(str(result.response))
165
165
 
166
- result: GenerateResponse = result.response
166
+ result: GenerateResponse = result.response # type: ignore
167
167
  if result.action == "next":
168
168
  break
169
169
 
@@ -213,12 +213,12 @@ class FishSpeechModel:
213
213
  text=input,
214
214
  enable_reference_audio=False,
215
215
  reference_audio=None,
216
- reference_text="",
217
- max_new_tokens=0,
218
- chunk_length=100,
219
- top_p=0.7,
220
- repetition_penalty=1.2,
221
- temperature=0.7,
216
+ reference_text=kwargs.get("reference_text", ""),
217
+ max_new_tokens=kwargs.get("max_new_tokens", 1024),
218
+ chunk_length=kwargs.get("chunk_length", 200),
219
+ top_p=kwargs.get("top_p", 0.7),
220
+ repetition_penalty=kwargs.get("repetition_penalty", 1.2),
221
+ temperature=kwargs.get("temperature", 0.7),
222
222
  )
223
223
  )
224
224
  sample_rate, audio = result[0][1]
@@ -126,32 +126,32 @@
126
126
  {
127
127
  "model_name": "CosyVoice-300M",
128
128
  "model_family": "CosyVoice",
129
- "model_id": "model-scope/CosyVoice-300M",
130
- "model_revision": "ca4e036d2db2aa4731cc1747859a68044b6a4694",
129
+ "model_id": "FunAudioLLM/CosyVoice-300M",
130
+ "model_revision": "39c4e13d46bd4dfb840d214547623e5fcd2428e2",
131
131
  "model_ability": "audio-to-audio",
132
132
  "multilingual": true
133
133
  },
134
134
  {
135
135
  "model_name": "CosyVoice-300M-SFT",
136
136
  "model_family": "CosyVoice",
137
- "model_id": "model-scope/CosyVoice-300M-SFT",
138
- "model_revision": "ab918940c6c134b1fc1f069246e67bad6b66abcb",
137
+ "model_id": "FunAudioLLM/CosyVoice-300M-SFT",
138
+ "model_revision": "096a5cff8d497fabb3dec2756a200f3688457a1b",
139
139
  "model_ability": "text-to-audio",
140
140
  "multilingual": true
141
141
  },
142
142
  {
143
143
  "model_name": "CosyVoice-300M-Instruct",
144
144
  "model_family": "CosyVoice",
145
- "model_id": "model-scope/CosyVoice-300M-Instruct",
146
- "model_revision": "fb5f676733139f35670bed9b59a77d476b1aa898",
145
+ "model_id": "FunAudioLLM/CosyVoice-300M-Instruct",
146
+ "model_revision": "ba5265d9a3169c1fedce145122c9dd4bc24e062c",
147
147
  "model_ability": "text-to-audio",
148
148
  "multilingual": true
149
149
  },
150
150
  {
151
- "model_name": "FishSpeech-1.2-SFT",
151
+ "model_name": "FishSpeech-1.4",
152
152
  "model_family": "FishAudio",
153
- "model_id": "fishaudio/fish-speech-1.2-sft",
154
- "model_revision": "180288e21ec5c50cfc564023a22f789e4b88a0e0",
153
+ "model_id": "fishaudio/fish-speech-1.4",
154
+ "model_revision": "3c49651b8e583b6b13f55e375432e0d57e1aa84d",
155
155
  "model_ability": "text-to-audio",
156
156
  "multilingual": true
157
157
  }
@@ -12,6 +12,8 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  import logging
15
+ import os
16
+ from glob import glob
15
17
  from typing import TYPE_CHECKING, Dict, List, Optional, Union
16
18
 
17
19
  from ...device_utils import (
@@ -56,12 +58,13 @@ class WhisperModel:
56
58
  raise ValueError(f"Device {self._device} is not available!")
57
59
 
58
60
  torch_dtype = get_device_preferred_dtype(self._device)
61
+ use_safetensors = any(glob(os.path.join(self._model_path, "*.safetensors")))
59
62
 
60
63
  model = AutoModelForSpeechSeq2Seq.from_pretrained(
61
64
  self._model_path,
62
65
  torch_dtype=torch_dtype,
63
66
  low_cpu_mem_usage=True,
64
- use_safetensors=True,
67
+ use_safetensors=use_safetensors,
65
68
  )
66
69
  model.to(self._device)
67
70
 
@@ -47,6 +47,7 @@ class ImageModelFamilyV1(CacheableModelSpec):
47
47
  model_hub: str = "huggingface"
48
48
  model_ability: Optional[List[str]]
49
49
  controlnet: Optional[List["ImageModelFamilyV1"]]
50
+ default_generate_config: Optional[dict] = {}
50
51
 
51
52
 
52
53
  class ImageModelDescription(ModelDescription):
@@ -238,7 +239,7 @@ def create_image_model_instance(
238
239
  lora_model_paths=lora_model,
239
240
  lora_load_kwargs=lora_load_kwargs,
240
241
  lora_fuse_kwargs=lora_fuse_kwargs,
241
- abilities=model_spec.model_ability,
242
+ model_spec=model_spec,
242
243
  **kwargs,
243
244
  )
244
245
  model_description = ImageModelDescription(
@@ -5,7 +5,9 @@
5
5
  "model_id": "black-forest-labs/FLUX.1-schnell",
6
6
  "model_revision": "768d12a373ed5cc9ef9a9dea7504dc09fcc14842",
7
7
  "model_ability": [
8
- "text2image"
8
+ "text2image",
9
+ "image2image",
10
+ "inpainting"
9
11
  ]
10
12
  },
11
13
  {
@@ -14,7 +16,9 @@
14
16
  "model_id": "black-forest-labs/FLUX.1-dev",
15
17
  "model_revision": "01aa605f2c300568dd6515476f04565a954fcb59",
16
18
  "model_ability": [
17
- "text2image"
19
+ "text2image",
20
+ "image2image",
21
+ "inpainting"
18
22
  ]
19
23
  },
20
24
  {
@@ -35,7 +39,11 @@
35
39
  "model_revision": "1681ed09e0cff58eeb41e878a49893228b78b94c",
36
40
  "model_ability": [
37
41
  "text2image"
38
- ]
42
+ ],
43
+ "default_generate_config": {
44
+ "guidance_scale": 0.0,
45
+ "num_inference_steps": 1
46
+ }
39
47
  },
40
48
  {
41
49
  "model_name": "sdxl-turbo",
@@ -44,7 +52,11 @@
44
52
  "model_revision": "f4b0486b498f84668e828044de1d0c8ba486e05b",
45
53
  "model_ability": [
46
54
  "text2image"
47
- ]
55
+ ],
56
+ "default_generate_config": {
57
+ "guidance_scale": 0.0,
58
+ "num_inference_steps": 1
59
+ }
48
60
  },
49
61
  {
50
62
  "model_name": "stable-diffusion-v1.5",
@@ -6,7 +6,9 @@
6
6
  "model_id": "AI-ModelScope/FLUX.1-schnell",
7
7
  "model_revision": "master",
8
8
  "model_ability": [
9
- "text2image"
9
+ "text2image",
10
+ "image2image",
11
+ "inpainting"
10
12
  ]
11
13
  },
12
14
  {
@@ -16,7 +18,9 @@
16
18
  "model_id": "AI-ModelScope/FLUX.1-dev",
17
19
  "model_revision": "master",
18
20
  "model_ability": [
19
- "text2image"
21
+ "text2image",
22
+ "image2image",
23
+ "inpainting"
20
24
  ]
21
25
  },
22
26
  {
@@ -39,7 +43,11 @@
39
43
  "model_revision": "master",
40
44
  "model_ability": [
41
45
  "text2image"
42
- ]
46
+ ],
47
+ "default_generate_config": {
48
+ "guidance_scale": 0.0,
49
+ "num_inference_steps": 1
50
+ }
43
51
  },
44
52
  {
45
53
  "model_name": "sdxl-turbo",
@@ -49,7 +57,11 @@
49
57
  "model_revision": "master",
50
58
  "model_ability": [
51
59
  "text2image"
52
- ]
60
+ ],
61
+ "default_generate_config": {
62
+ "guidance_scale": 0.0,
63
+ "num_inference_steps": 1
64
+ }
53
65
  },
54
66
  {
55
67
  "model_name": "stable-diffusion-v1.5",