xinference 0.14.3__py3-none-any.whl → 0.14.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (70) hide show
  1. xinference/_version.py +3 -3
  2. xinference/core/worker.py +18 -9
  3. xinference/model/audio/chattts.py +4 -3
  4. xinference/model/audio/cosyvoice.py +4 -3
  5. xinference/model/audio/custom.py +4 -5
  6. xinference/model/embedding/core.py +2 -0
  7. xinference/model/embedding/custom.py +4 -5
  8. xinference/model/flexible/core.py +5 -1
  9. xinference/model/image/custom.py +4 -5
  10. xinference/model/image/stable_diffusion/core.py +21 -6
  11. xinference/model/llm/llm_family.py +5 -6
  12. xinference/model/llm/sglang/core.py +7 -1
  13. xinference/model/llm/transformers/core.py +2 -0
  14. xinference/model/llm/utils.py +3 -0
  15. xinference/model/llm/vllm/core.py +0 -33
  16. xinference/model/rerank/custom.py +4 -5
  17. xinference/model/utils.py +41 -1
  18. xinference/model/video/core.py +3 -1
  19. xinference/model/video/diffusers.py +41 -38
  20. xinference/model/video/model_spec.json +24 -1
  21. xinference/model/video/model_spec_modelscope.json +25 -1
  22. xinference/thirdparty/fish_speech/tools/api.py +1 -1
  23. xinference/thirdparty/matcha/__init__.py +0 -0
  24. xinference/thirdparty/matcha/app.py +357 -0
  25. xinference/thirdparty/matcha/cli.py +419 -0
  26. xinference/thirdparty/matcha/data/__init__.py +0 -0
  27. xinference/thirdparty/matcha/data/components/__init__.py +0 -0
  28. xinference/thirdparty/matcha/data/text_mel_datamodule.py +274 -0
  29. xinference/thirdparty/matcha/hifigan/__init__.py +0 -0
  30. xinference/thirdparty/matcha/hifigan/config.py +28 -0
  31. xinference/thirdparty/matcha/hifigan/denoiser.py +64 -0
  32. xinference/thirdparty/matcha/hifigan/env.py +17 -0
  33. xinference/thirdparty/matcha/hifigan/meldataset.py +217 -0
  34. xinference/thirdparty/matcha/hifigan/models.py +368 -0
  35. xinference/thirdparty/matcha/hifigan/xutils.py +60 -0
  36. xinference/thirdparty/matcha/models/__init__.py +0 -0
  37. xinference/thirdparty/matcha/models/baselightningmodule.py +210 -0
  38. xinference/thirdparty/matcha/models/components/__init__.py +0 -0
  39. xinference/thirdparty/matcha/models/components/decoder.py +443 -0
  40. xinference/thirdparty/matcha/models/components/flow_matching.py +132 -0
  41. xinference/thirdparty/matcha/models/components/text_encoder.py +410 -0
  42. xinference/thirdparty/matcha/models/components/transformer.py +316 -0
  43. xinference/thirdparty/matcha/models/matcha_tts.py +244 -0
  44. xinference/thirdparty/matcha/onnx/__init__.py +0 -0
  45. xinference/thirdparty/matcha/onnx/export.py +181 -0
  46. xinference/thirdparty/matcha/onnx/infer.py +168 -0
  47. xinference/thirdparty/matcha/text/__init__.py +53 -0
  48. xinference/thirdparty/matcha/text/cleaners.py +121 -0
  49. xinference/thirdparty/matcha/text/numbers.py +71 -0
  50. xinference/thirdparty/matcha/text/symbols.py +17 -0
  51. xinference/thirdparty/matcha/train.py +122 -0
  52. xinference/thirdparty/matcha/utils/__init__.py +5 -0
  53. xinference/thirdparty/matcha/utils/audio.py +82 -0
  54. xinference/thirdparty/matcha/utils/generate_data_statistics.py +112 -0
  55. xinference/thirdparty/matcha/utils/get_durations_from_trained_model.py +195 -0
  56. xinference/thirdparty/matcha/utils/instantiators.py +56 -0
  57. xinference/thirdparty/matcha/utils/logging_utils.py +53 -0
  58. xinference/thirdparty/matcha/utils/model.py +90 -0
  59. xinference/thirdparty/matcha/utils/monotonic_align/__init__.py +22 -0
  60. xinference/thirdparty/matcha/utils/monotonic_align/core.pyx +47 -0
  61. xinference/thirdparty/matcha/utils/monotonic_align/setup.py +7 -0
  62. xinference/thirdparty/matcha/utils/pylogger.py +21 -0
  63. xinference/thirdparty/matcha/utils/rich_utils.py +101 -0
  64. xinference/thirdparty/matcha/utils/utils.py +259 -0
  65. {xinference-0.14.3.dist-info → xinference-0.14.4.dist-info}/METADATA +20 -12
  66. {xinference-0.14.3.dist-info → xinference-0.14.4.dist-info}/RECORD +70 -28
  67. {xinference-0.14.3.dist-info → xinference-0.14.4.dist-info}/LICENSE +0 -0
  68. {xinference-0.14.3.dist-info → xinference-0.14.4.dist-info}/WHEEL +0 -0
  69. {xinference-0.14.3.dist-info → xinference-0.14.4.dist-info}/entry_points.txt +0 -0
  70. {xinference-0.14.3.dist-info → xinference-0.14.4.dist-info}/top_level.txt +0 -0
xinference/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2024-08-23T18:14:53+0800",
11
+ "date": "2024-08-30T18:54:16+0800",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "b5002242e04634bca7e75cac9df0cdc6c0bf407a",
15
- "version": "0.14.3"
14
+ "full-revisionid": "f3d510eceffbbbc41ce065919fd2c48411017573",
15
+ "version": "0.14.4"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
xinference/core/worker.py CHANGED
@@ -73,15 +73,15 @@ class WorkerActor(xo.StatelessActor):
73
73
  self._supervisor_ref: Optional[xo.ActorRefType] = None
74
74
  self._main_pool = main_pool
75
75
  self._main_pool.recover_sub_pool = self.recover_sub_pool
76
- self._status_guard_ref: xo.ActorRefType[ # type: ignore
77
- "StatusGuardActor"
78
- ] = None
76
+ self._status_guard_ref: xo.ActorRefType["StatusGuardActor"] = ( # type: ignore
77
+ None
78
+ )
79
79
  self._event_collector_ref: xo.ActorRefType[ # type: ignore
80
80
  EventCollectorActor
81
81
  ] = None
82
- self._cache_tracker_ref: xo.ActorRefType[ # type: ignore
83
- CacheTrackerActor
84
- ] = None
82
+ self._cache_tracker_ref: xo.ActorRefType[CacheTrackerActor] = ( # type: ignore
83
+ None
84
+ )
85
85
 
86
86
  # internal states.
87
87
  # temporary placeholder during model launch process:
@@ -146,7 +146,7 @@ class WorkerActor(xo.StatelessActor):
146
146
  else:
147
147
  recover_count = self._model_uid_to_recover_count.get(model_uid)
148
148
  try:
149
- await self.terminate_model(model_uid)
149
+ await self.terminate_model(model_uid, is_model_die=True)
150
150
  except Exception:
151
151
  pass
152
152
  if recover_count is not None:
@@ -664,6 +664,8 @@ class WorkerActor(xo.StatelessActor):
664
664
 
665
665
  ret.sort(key=sort_helper)
666
666
  return ret
667
+ elif model_type == "video":
668
+ return []
667
669
  elif model_type == "rerank":
668
670
  from ..model.rerank.custom import get_user_defined_reranks
669
671
 
@@ -703,6 +705,8 @@ class WorkerActor(xo.StatelessActor):
703
705
  for f in get_user_defined_audios():
704
706
  if f.model_name == model_name:
705
707
  return f
708
+ elif model_type == "video":
709
+ return None
706
710
  elif model_type == "rerank":
707
711
  from ..model.rerank.custom import get_user_defined_reranks
708
712
 
@@ -914,7 +918,7 @@ class WorkerActor(xo.StatelessActor):
914
918
  )
915
919
 
916
920
  @log_async(logger=logger)
917
- async def terminate_model(self, model_uid: str):
921
+ async def terminate_model(self, model_uid: str, is_model_die=False):
918
922
  # Terminate model while its launching is not allow
919
923
  if model_uid in self._model_uid_launching_guard:
920
924
  raise ValueError(f"{model_uid} is launching")
@@ -963,11 +967,16 @@ class WorkerActor(xo.StatelessActor):
963
967
  self._model_uid_to_recover_count.pop(model_uid, None)
964
968
  self._model_uid_to_launch_args.pop(model_uid, None)
965
969
 
970
+ if is_model_die:
971
+ status = LaunchStatus.ERROR.name
972
+ else:
973
+ status = LaunchStatus.TERMINATED.name
974
+
966
975
  if self._status_guard_ref is None:
967
976
  _ = await self.get_supervisor_ref()
968
977
  assert self._status_guard_ref is not None
969
978
  await self._status_guard_ref.update_instance_info(
970
- origin_uid, {"status": LaunchStatus.TERMINATED.name}
979
+ origin_uid, {"status": status}
971
980
  )
972
981
 
973
982
  # Provide an interface for future version of supervisor to call
@@ -11,11 +11,14 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+
14
15
  import base64
15
16
  import logging
16
17
  from io import BytesIO
17
18
  from typing import TYPE_CHECKING, Optional
18
19
 
20
+ from ..utils import set_all_random_seed
21
+
19
22
  if TYPE_CHECKING:
20
23
  from .core import AudioModelFamilyV1
21
24
 
@@ -78,9 +81,7 @@ class ChatTTSModel:
78
81
  if rnd_spk_emb is None:
79
82
  seed = xxhash.xxh32_intdigest(voice)
80
83
 
81
- torch.manual_seed(seed)
82
- np.random.seed(seed)
83
- torch.cuda.manual_seed(seed)
84
+ set_all_random_seed(seed)
84
85
  torch.backends.cudnn.deterministic = True
85
86
  torch.backends.cudnn.benchmark = False
86
87
 
@@ -16,6 +16,8 @@ import logging
16
16
  from io import BytesIO
17
17
  from typing import TYPE_CHECKING, Optional
18
18
 
19
+ from ..utils import set_all_random_seed
20
+
19
21
  if TYPE_CHECKING:
20
22
  from .core import AudioModelFamilyV1
21
23
 
@@ -67,6 +69,7 @@ class CosyVoiceModel:
67
69
  prompt_speech: Optional[bytes] = kwargs.pop("prompt_speech", None)
68
70
  prompt_text: Optional[str] = kwargs.pop("prompt_text", None)
69
71
  instruct_text: Optional[str] = kwargs.pop("instruct_text", None)
72
+ seed: Optional[int] = kwargs.pop("seed", 0)
70
73
 
71
74
  if "SFT" in self._model_spec.model_name:
72
75
  # inference_sft
@@ -87,9 +90,6 @@ class CosyVoiceModel:
87
90
  assert (
88
91
  prompt_text is None
89
92
  ), "CosyVoice Instruct model does not support prompt_text"
90
- assert (
91
- instruct_text is not None
92
- ), "CosyVoice Instruct model expect a instruct_text"
93
93
  else:
94
94
  # inference_zero_shot
95
95
  # inference_cross_lingual
@@ -99,6 +99,7 @@ class CosyVoiceModel:
99
99
  ), "CosyVoice model does not support instruct_text"
100
100
 
101
101
  assert self._model is not None
102
+ set_all_random_seed(seed)
102
103
  if prompt_speech:
103
104
  assert not voice, "voice can't be set with prompt speech."
104
105
  with io.BytesIO(prompt_speech) as prompt_speech_io:
@@ -88,6 +88,10 @@ def register_audio(model_spec: CustomAudioModelFamilyV1, persist: bool):
88
88
  if not is_valid_model_name(model_spec.model_name):
89
89
  raise ValueError(f"Invalid model name {model_spec.model_name}.")
90
90
 
91
+ model_uri = model_spec.model_uri
92
+ if model_uri and not is_valid_model_uri(model_uri):
93
+ raise ValueError(f"Invalid model URI {model_uri}.")
94
+
91
95
  with UD_AUDIO_LOCK:
92
96
  for model_name in (
93
97
  list(BUILTIN_AUDIO_MODELS.keys())
@@ -102,11 +106,6 @@ def register_audio(model_spec: CustomAudioModelFamilyV1, persist: bool):
102
106
  UD_AUDIOS.append(model_spec)
103
107
 
104
108
  if persist:
105
- # We only validate model URL when persist is True.
106
- model_uri = model_spec.model_uri
107
- if model_uri and not is_valid_model_uri(model_uri):
108
- raise ValueError(f"Invalid model URI {model_uri}.")
109
-
110
109
  persist_path = os.path.join(
111
110
  XINFERENCE_MODEL_DIR, "audio", f"{model_spec.model_name}.json"
112
111
  )
@@ -124,6 +124,7 @@ class EmbeddingModel:
124
124
  model_path: str,
125
125
  model_spec: EmbeddingModelSpec,
126
126
  device: Optional[str] = None,
127
+ **kwargs,
127
128
  ):
128
129
  self._model_uid = model_uid
129
130
  self._model_path = model_path
@@ -131,6 +132,7 @@ class EmbeddingModel:
131
132
  self._model = None
132
133
  self._counter = 0
133
134
  self._model_spec = model_spec
135
+ self._kwargs = kwargs
134
136
 
135
137
  def load(self):
136
138
  try:
@@ -47,6 +47,10 @@ def register_embedding(model_spec: CustomEmbeddingModelSpec, persist: bool):
47
47
  if not is_valid_model_name(model_spec.model_name):
48
48
  raise ValueError(f"Invalid model name {model_spec.model_name}.")
49
49
 
50
+ model_uri = model_spec.model_uri
51
+ if model_uri and not is_valid_model_uri(model_uri):
52
+ raise ValueError(f"Invalid model URI {model_uri}.")
53
+
50
54
  with UD_EMBEDDING_LOCK:
51
55
  for model_name in (
52
56
  list(BUILTIN_EMBEDDING_MODELS.keys())
@@ -61,11 +65,6 @@ def register_embedding(model_spec: CustomEmbeddingModelSpec, persist: bool):
61
65
  UD_EMBEDDINGS.append(model_spec)
62
66
 
63
67
  if persist:
64
- # We only validate model URL when persist is True.
65
- model_uri = model_spec.model_uri
66
- if model_uri and not is_valid_model_uri(model_uri):
67
- raise ValueError(f"Invalid model URI {model_uri}.")
68
-
69
68
  persist_path = os.path.join(
70
69
  XINFERENCE_MODEL_DIR, "embedding", f"{model_spec.model_name}.json"
71
70
  )
@@ -99,11 +99,15 @@ def get_flexible_model_descriptions():
99
99
 
100
100
 
101
101
  def register_flexible_model(model_spec: FlexibleModelSpec, persist: bool):
102
- from ..utils import is_valid_model_name
102
+ from ..utils import is_valid_model_name, is_valid_model_uri
103
103
 
104
104
  if not is_valid_model_name(model_spec.model_name):
105
105
  raise ValueError(f"Invalid model name {model_spec.model_name}.")
106
106
 
107
+ model_uri = model_spec.model_uri
108
+ if model_uri and not is_valid_model_uri(model_uri):
109
+ raise ValueError(f"Invalid model URI {model_uri}.")
110
+
107
111
  if model_spec.launcher_args:
108
112
  try:
109
113
  model_spec.parser_args()
@@ -47,6 +47,10 @@ def register_image(model_spec: CustomImageModelFamilyV1, persist: bool):
47
47
  if not is_valid_model_name(model_spec.model_name):
48
48
  raise ValueError(f"Invalid model name {model_spec.model_name}.")
49
49
 
50
+ model_uri = model_spec.model_uri
51
+ if model_uri and not is_valid_model_uri(model_uri):
52
+ raise ValueError(f"Invalid model URI {model_uri}")
53
+
50
54
  with UD_IMAGE_LOCK:
51
55
  for model_name in (
52
56
  list(BUILTIN_IMAGE_MODELS.keys())
@@ -60,11 +64,6 @@ def register_image(model_spec: CustomImageModelFamilyV1, persist: bool):
60
64
  UD_IMAGES.append(model_spec)
61
65
 
62
66
  if persist:
63
- # We only validate model URL when persist is True.
64
- model_uri = model_spec.model_uri
65
- if model_uri and not is_valid_model_uri(model_uri):
66
- raise ValueError(f"Invalid model URI {model_uri}")
67
-
68
67
  persist_path = os.path.join(
69
68
  XINFERENCE_MODEL_DIR, "image", f"{model_spec.model_name}.json"
70
69
  )
@@ -257,15 +257,19 @@ class DiffusionModel:
257
257
  self._i2i_model = model = AutoPipelineForImage2Image.from_pipe(
258
258
  self._model
259
259
  )
260
- if size:
261
- width, height = map(int, re.split(r"[^\d]+", size))
262
- kwargs["width"] = width
263
- kwargs["height"] = height
260
+
264
261
  if padding_image_to_multiple := kwargs.pop("padding_image_to_multiple", None):
265
262
  # Model like SD3 image to image requires image's height and width is times of 16
266
263
  # padding the image if specified
267
264
  image = self.pad_to_multiple(image, multiple=int(padding_image_to_multiple))
268
265
 
266
+ if size:
267
+ width, height = map(int, re.split(r"[^\d]+", size))
268
+ if padding_image_to_multiple:
269
+ width, height = image.size
270
+ kwargs["width"] = width
271
+ kwargs["height"] = height
272
+
269
273
  self._filter_kwargs(kwargs)
270
274
  return self._call_model(
271
275
  image=image,
@@ -279,8 +283,8 @@ class DiffusionModel:
279
283
 
280
284
  def inpainting(
281
285
  self,
282
- image: bytes,
283
- mask_image: bytes,
286
+ image: PIL.Image,
287
+ mask_image: PIL.Image,
284
288
  prompt: Optional[Union[str, List[str]]] = None,
285
289
  negative_prompt: Optional[Union[str, List[str]]] = None,
286
290
  n: int = 1,
@@ -306,6 +310,17 @@ class DiffusionModel:
306
310
  model = self._model
307
311
 
308
312
  width, height = map(int, re.split(r"[^\d]+", size))
313
+
314
+ if padding_image_to_multiple := kwargs.pop("padding_image_to_multiple", None):
315
+ # Model like SD3 inpainting requires image's height and width is times of 16
316
+ # padding the image if specified
317
+ image = self.pad_to_multiple(image, multiple=int(padding_image_to_multiple))
318
+ mask_image = self.pad_to_multiple(
319
+ mask_image, multiple=int(padding_image_to_multiple)
320
+ )
321
+ # calculate actual image size after padding
322
+ width, height = image.size
323
+
309
324
  return self._call_model(
310
325
  image=image,
311
326
  mask_image=mask_image,
@@ -1004,6 +1004,11 @@ def register_llm(llm_family: LLMFamilyV1, persist: bool):
1004
1004
  if not is_valid_model_name(llm_family.model_name):
1005
1005
  raise ValueError(f"Invalid model name {llm_family.model_name}.")
1006
1006
 
1007
+ for spec in llm_family.model_specs:
1008
+ model_uri = spec.model_uri
1009
+ if model_uri and not is_valid_model_uri(model_uri):
1010
+ raise ValueError(f"Invalid model URI {model_uri}.")
1011
+
1007
1012
  with UD_LLM_FAMILIES_LOCK:
1008
1013
  for family in BUILTIN_LLM_FAMILIES + UD_LLM_FAMILIES:
1009
1014
  if llm_family.model_name == family.model_name:
@@ -1015,12 +1020,6 @@ def register_llm(llm_family: LLMFamilyV1, persist: bool):
1015
1020
  generate_engine_config_by_model_family(llm_family)
1016
1021
 
1017
1022
  if persist:
1018
- # We only validate model URL when persist is True.
1019
- for spec in llm_family.model_specs:
1020
- model_uri = spec.model_uri
1021
- if model_uri and not is_valid_model_uri(model_uri):
1022
- raise ValueError(f"Invalid model URI {model_uri}.")
1023
-
1024
1023
  persist_path = os.path.join(
1025
1024
  XINFERENCE_MODEL_DIR, "llm", f"{llm_family.model_name}.json"
1026
1025
  )
@@ -113,6 +113,13 @@ class SGLANGModel(LLM):
113
113
  raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
114
114
 
115
115
  self._model_config = self._sanitize_model_config(self._model_config)
116
+
117
+ # Fix: GH#2169
118
+ if sgl.__version__ >= "0.2.14":
119
+ self._model_config.setdefault("triton_attention_reduce_in_fp32", False)
120
+ else:
121
+ self._model_config.setdefault("attention_reduce_in_fp32", False)
122
+
116
123
  logger.info(
117
124
  f"Loading {self.model_uid} with following model config: {self._model_config}"
118
125
  )
@@ -152,7 +159,6 @@ class SGLANGModel(LLM):
152
159
  else:
153
160
  model_config["mem_fraction_static"] = 0.88
154
161
  model_config.setdefault("log_level", "info")
155
- model_config.setdefault("attention_reduce_in_fp32", False)
156
162
 
157
163
  return model_config
158
164
 
@@ -319,6 +319,8 @@ class PytorchModel(LLM):
319
319
  else:
320
320
  self._model, self._tokenizer = self._load_model(**kwargs)
321
321
 
322
+ self._apply_lora()
323
+
322
324
  if not is_device_map_auto:
323
325
  self._model.to(self._device)
324
326
 
@@ -32,6 +32,7 @@ from ...types import (
32
32
  Completion,
33
33
  CompletionChunk,
34
34
  )
35
+ from ..utils import ensure_cache_cleared
35
36
  from .llm_family import (
36
37
  LlamaCppLLMSpecV1,
37
38
  LLMFamilyV1,
@@ -576,6 +577,7 @@ Begin!"""
576
577
  return cast(ChatCompletionChunk, chat_chunk)
577
578
 
578
579
  @classmethod
580
+ @ensure_cache_cleared
579
581
  def _to_chat_completion_chunks(
580
582
  cls,
581
583
  chunks: Iterator[CompletionChunk],
@@ -608,6 +610,7 @@ Begin!"""
608
610
  i += 1
609
611
 
610
612
  @staticmethod
613
+ @ensure_cache_cleared
611
614
  def _to_chat_completion(completion: Completion) -> ChatCompletion:
612
615
  return {
613
616
  "id": "chat" + completion["id"],
@@ -643,39 +643,6 @@ class VLLMChatModel(VLLMModel, ChatModelMixin):
643
643
 
644
644
 
645
645
  class VLLMVisionModel(VLLMModel, ChatModelMixin):
646
- def load(self):
647
- try:
648
- import vllm
649
- from vllm.engine.arg_utils import AsyncEngineArgs
650
- from vllm.engine.async_llm_engine import AsyncLLMEngine
651
- except ImportError:
652
- error_message = "Failed to import module 'vllm'"
653
- installation_guide = [
654
- "Please make sure 'vllm' is installed. ",
655
- "You can install it by `pip install vllm`\n",
656
- ]
657
- raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
658
-
659
- if vllm.__version__ >= "0.3.1":
660
- # from vllm v0.3.1, it uses cupy as NCCL backend
661
- # in which cupy will fork a process
662
- # only for xoscar >= 0.3.0, new process is allowed in subpool
663
- # besides, xinference set start method as forkserver for unix
664
- # we need to set it to fork to make cupy NCCL work
665
- multiprocessing.set_start_method("fork", force=True)
666
-
667
- self._model_config = self._sanitize_model_config(self._model_config)
668
-
669
- logger.info(
670
- f"Loading {self.model_uid} with following model config: {self._model_config}"
671
- )
672
-
673
- engine_args = AsyncEngineArgs(
674
- model=self.model_path,
675
- **self._model_config,
676
- )
677
- self._engine = AsyncLLMEngine.from_engine_args(engine_args)
678
-
679
646
  @classmethod
680
647
  def match(
681
648
  cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
@@ -48,6 +48,10 @@ def register_rerank(model_spec: CustomRerankModelSpec, persist: bool):
48
48
  if not is_valid_model_name(model_spec.model_name):
49
49
  raise ValueError(f"Invalid model name {model_spec.model_name}.")
50
50
 
51
+ model_uri = model_spec.model_uri
52
+ if model_uri and not is_valid_model_uri(model_uri):
53
+ raise ValueError(f"Invalid model URI {model_uri}.")
54
+
51
55
  with UD_RERANK_LOCK:
52
56
  for model_name in (
53
57
  list(BUILTIN_RERANK_MODELS.keys())
@@ -62,11 +66,6 @@ def register_rerank(model_spec: CustomRerankModelSpec, persist: bool):
62
66
  UD_RERANKS.append(model_spec)
63
67
 
64
68
  if persist:
65
- # We only validate model URL when persist is True.
66
- model_uri = model_spec.model_uri
67
- if model_uri and not is_valid_model_uri(model_uri):
68
- raise ValueError(f"Invalid model URI {model_uri}.")
69
-
70
69
  persist_path = os.path.join(
71
70
  XINFERENCE_MODEL_DIR, "rerank", f"{model_spec.model_name}.json"
72
71
  )
xinference/model/utils.py CHANGED
@@ -11,17 +11,24 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+
15
+ import functools
16
+ import gc
17
+ import inspect
14
18
  import json
15
19
  import logging
16
20
  import os
21
+ import random
17
22
  from json import JSONDecodeError
18
23
  from pathlib import Path
19
24
  from typing import Any, Callable, Dict, Optional, Tuple, Union
20
25
 
21
26
  import huggingface_hub
27
+ import numpy as np
28
+ import torch
22
29
 
23
30
  from ..constants import XINFERENCE_CACHE_DIR, XINFERENCE_ENV_MODEL_SRC
24
- from ..device_utils import get_available_device, is_device_available
31
+ from ..device_utils import empty_cache, get_available_device, is_device_available
25
32
  from .core import CacheableModelSpec
26
33
 
27
34
  logger = logging.getLogger(__name__)
@@ -348,3 +355,36 @@ def convert_float_to_int_or_str(model_size: float) -> Union[int, str]:
348
355
  return int(model_size)
349
356
  else:
350
357
  return str(model_size)
358
+
359
+
360
+ def ensure_cache_cleared(func: Callable):
361
+ assert not inspect.iscoroutinefunction(func) and not inspect.isasyncgenfunction(
362
+ func
363
+ )
364
+ if inspect.isgeneratorfunction(func):
365
+
366
+ @functools.wraps(func)
367
+ def inner(*args, **kwargs):
368
+ for obj in func(*args, **kwargs):
369
+ yield obj
370
+ gc.collect()
371
+ empty_cache()
372
+
373
+ else:
374
+
375
+ @functools.wraps(func)
376
+ def inner(*args, **kwargs):
377
+ try:
378
+ return func(*args, **kwargs)
379
+ finally:
380
+ gc.collect()
381
+ empty_cache()
382
+
383
+ return inner
384
+
385
+
386
+ def set_all_random_seed(seed: int):
387
+ random.seed(seed)
388
+ np.random.seed(seed)
389
+ torch.manual_seed(seed)
390
+ torch.cuda.manual_seed_all(seed)
@@ -14,7 +14,7 @@
14
14
  import logging
15
15
  import os
16
16
  from collections import defaultdict
17
- from typing import Dict, List, Literal, Optional, Tuple
17
+ from typing import Any, Dict, List, Literal, Optional, Tuple
18
18
 
19
19
  from ...constants import XINFERENCE_CACHE_DIR
20
20
  from ..core import CacheableModelSpec, ModelDescription
@@ -44,6 +44,8 @@ class VideoModelFamilyV1(CacheableModelSpec):
44
44
  model_revision: str
45
45
  model_hub: str = "huggingface"
46
46
  model_ability: Optional[List[str]]
47
+ default_model_config: Optional[Dict[str, Any]]
48
+ default_generate_config: Optional[Dict[str, Any]]
47
49
 
48
50
 
49
51
  class VideoModelDescription(ModelDescription):
@@ -15,7 +15,6 @@
15
15
  import base64
16
16
  import logging
17
17
  import os
18
- import sys
19
18
  import time
20
19
  import uuid
21
20
  from concurrent.futures import ThreadPoolExecutor
@@ -24,10 +23,9 @@ from typing import TYPE_CHECKING, List, Union
24
23
 
25
24
  import numpy as np
26
25
  import PIL.Image
27
- import torch
28
26
 
29
27
  from ...constants import XINFERENCE_VIDEO_DIR
30
- from ...device_utils import move_model_to_available_device
28
+ from ...device_utils import gpu_count, move_model_to_available_device
31
29
  from ...types import Video, VideoList
32
30
 
33
31
  if TYPE_CHECKING:
@@ -76,41 +74,58 @@ class DiffUsersVideoModel:
76
74
  def load(self):
77
75
  import torch
78
76
 
79
- torch_dtype = self._kwargs.get("torch_dtype")
80
- if sys.platform != "darwin" and torch_dtype is None:
81
- # The following params crashes on Mac M2
82
- self._kwargs["torch_dtype"] = torch.float16
83
- self._kwargs["variant"] = "fp16"
84
- self._kwargs["use_safetensors"] = True
77
+ kwargs = self._model_spec.default_model_config.copy()
78
+ kwargs.update(self._kwargs)
79
+
80
+ scheduler_cls_name = kwargs.pop("scheduler", None)
81
+
82
+ torch_dtype = kwargs.get("torch_dtype")
85
83
  if isinstance(torch_dtype, str):
86
- self._kwargs["torch_dtype"] = getattr(torch, torch_dtype)
84
+ kwargs["torch_dtype"] = getattr(torch, torch_dtype)
85
+ logger.debug("Loading video model with kwargs: %s", kwargs)
87
86
 
88
87
  if self._model_spec.model_family == "CogVideoX":
88
+ import diffusers
89
89
  from diffusers import CogVideoXPipeline
90
90
 
91
- self._model = CogVideoXPipeline.from_pretrained(
92
- self._model_path, **self._kwargs
91
+ pipeline = self._model = CogVideoXPipeline.from_pretrained(
92
+ self._model_path, **kwargs
93
93
  )
94
94
  else:
95
95
  raise Exception(
96
96
  f"Unsupported model family: {self._model_spec.model_family}"
97
97
  )
98
98
 
99
- if self._kwargs.get("cpu_offload", False):
99
+ if scheduler_cls_name:
100
+ logger.debug("Using scheduler: %s", scheduler_cls_name)
101
+ pipeline.scheduler = getattr(diffusers, scheduler_cls_name).from_config(
102
+ pipeline.scheduler.config, timestep_spacing="trailing"
103
+ )
104
+ if kwargs.get("compile_graph", False):
105
+ pipeline.transformer = torch.compile(
106
+ pipeline.transformer, mode="max-autotune", fullgraph=True
107
+ )
108
+ if kwargs.get("cpu_offload", False):
100
109
  logger.debug("CPU offloading model")
101
- self._model.enable_model_cpu_offload()
102
- elif not self._kwargs.get("device_map"):
110
+ pipeline.enable_model_cpu_offload()
111
+ if kwargs.get("sequential_cpu_offload", True):
112
+ pipeline.enable_sequential_cpu_offload()
113
+ pipeline.vae.enable_slicing()
114
+ pipeline.vae.enable_tiling()
115
+ elif not kwargs.get("device_map"):
103
116
  logger.debug("Loading model to available device")
104
- self._model = move_model_to_available_device(self._model)
117
+ if gpu_count() > 1:
118
+ kwargs["device_map"] = "balanced"
119
+ else:
120
+ pipeline = move_model_to_available_device(self._model)
105
121
  # Recommended if your computer has < 64 GB of RAM
106
- self._model.enable_attention_slicing()
122
+ pipeline.enable_attention_slicing()
107
123
 
108
124
  def text_to_video(
109
125
  self,
110
126
  prompt: str,
111
127
  n: int = 1,
112
128
  num_inference_steps: int = 50,
113
- guidance_scale: int = 6,
114
129
  response_format: str = "b64_json",
115
130
  **kwargs,
116
131
  ) -> VideoList:
@@ -121,31 +136,19 @@ class DiffUsersVideoModel:
121
136
  # from diffusers.utils import export_to_video
122
137
  from ...device_utils import empty_cache
123
138
 
139
+ assert self._model is not None
140
+ assert callable(self._model)
141
+ generate_kwargs = self._model_spec.default_generate_config.copy()
142
+ generate_kwargs.update(kwargs)
143
+ generate_kwargs["num_videos_per_prompt"] = n
124
144
  logger.debug(
125
145
  "diffusers text_to_video args: %s",
126
- kwargs,
146
+ generate_kwargs,
127
147
  )
128
- assert self._model is not None
129
- if self._kwargs.get("cpu_offload"):
130
- # if enabled cpu offload,
131
- # the model.device would be CPU
132
- device = "cuda"
133
- else:
134
- device = self._model.device
135
- prompt_embeds, _ = self._model.encode_prompt(
136
- prompt=prompt,
137
- do_classifier_free_guidance=True,
138
- num_videos_per_prompt=n,
139
- max_sequence_length=226,
140
- device=device,
141
- dtype=torch.float16,
142
- )
143
- assert callable(self._model)
144
148
  output = self._model(
149
+ prompt=prompt,
145
150
  num_inference_steps=num_inference_steps,
146
- guidance_scale=guidance_scale,
147
- prompt_embeds=prompt_embeds,
148
- **kwargs,
151
+ **generate_kwargs,
149
152
  )
150
153
 
151
154
  # clean cache