xinference 0.14.0.post1__py3-none-any.whl → 0.14.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (50) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +54 -0
  3. xinference/client/handlers.py +0 -3
  4. xinference/client/restful/restful_client.py +51 -134
  5. xinference/constants.py +1 -0
  6. xinference/core/chat_interface.py +1 -4
  7. xinference/core/image_interface.py +33 -5
  8. xinference/core/model.py +28 -2
  9. xinference/core/supervisor.py +37 -0
  10. xinference/core/worker.py +128 -84
  11. xinference/deploy/cmdline.py +1 -4
  12. xinference/model/audio/core.py +11 -3
  13. xinference/model/audio/funasr.py +114 -0
  14. xinference/model/audio/model_spec.json +20 -0
  15. xinference/model/audio/model_spec_modelscope.json +21 -0
  16. xinference/model/audio/whisper.py +1 -1
  17. xinference/model/core.py +12 -0
  18. xinference/model/image/core.py +3 -4
  19. xinference/model/image/model_spec.json +41 -13
  20. xinference/model/image/model_spec_modelscope.json +30 -10
  21. xinference/model/image/stable_diffusion/core.py +53 -2
  22. xinference/model/llm/__init__.py +2 -0
  23. xinference/model/llm/llm_family.json +83 -1
  24. xinference/model/llm/llm_family_modelscope.json +85 -1
  25. xinference/model/llm/pytorch/core.py +1 -0
  26. xinference/model/llm/pytorch/minicpmv26.py +247 -0
  27. xinference/model/llm/sglang/core.py +72 -34
  28. xinference/model/llm/vllm/core.py +38 -0
  29. xinference/model/video/__init__.py +62 -0
  30. xinference/model/video/core.py +178 -0
  31. xinference/model/video/diffusers.py +180 -0
  32. xinference/model/video/model_spec.json +11 -0
  33. xinference/model/video/model_spec_modelscope.json +12 -0
  34. xinference/types.py +10 -24
  35. xinference/web/ui/build/asset-manifest.json +3 -3
  36. xinference/web/ui/build/index.html +1 -1
  37. xinference/web/ui/build/static/js/{main.ef2a203a.js → main.17ca0398.js} +3 -3
  38. xinference/web/ui/build/static/js/main.17ca0398.js.map +1 -0
  39. xinference/web/ui/node_modules/.cache/babel-loader/71684495d995c7e266eecc6a0ad8ea0284cc785f80abddf863789c57a6134969.json +1 -0
  40. xinference/web/ui/node_modules/.cache/babel-loader/80acd1edf31542ab1dcccfad02cb4b38f3325cff847a781fcce97500cfd6f878.json +1 -0
  41. {xinference-0.14.0.post1.dist-info → xinference-0.14.1.dist-info}/METADATA +14 -8
  42. {xinference-0.14.0.post1.dist-info → xinference-0.14.1.dist-info}/RECORD +47 -40
  43. xinference/web/ui/build/static/js/main.ef2a203a.js.map +0 -1
  44. xinference/web/ui/node_modules/.cache/babel-loader/2c63090c842376cdd368c3ded88a333ef40d94785747651343040a6f7872a223.json +0 -1
  45. xinference/web/ui/node_modules/.cache/babel-loader/70fa8c07463a5fe57c68bf92502910105a8f647371836fe8c3a7408246ca7ba0.json +0 -1
  46. /xinference/web/ui/build/static/js/{main.ef2a203a.js.LICENSE.txt → main.17ca0398.js.LICENSE.txt} +0 -0
  47. {xinference-0.14.0.post1.dist-info → xinference-0.14.1.dist-info}/LICENSE +0 -0
  48. {xinference-0.14.0.post1.dist-info → xinference-0.14.1.dist-info}/WHEEL +0 -0
  49. {xinference-0.14.0.post1.dist-info → xinference-0.14.1.dist-info}/entry_points.txt +0 -0
  50. {xinference-0.14.0.post1.dist-info → xinference-0.14.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,247 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import base64
15
+ import json
16
+ import logging
17
+ import time
18
+ import uuid
19
+ from concurrent.futures import ThreadPoolExecutor
20
+ from io import BytesIO
21
+ from typing import Dict, Iterator, List, Optional, Union
22
+
23
+ import requests
24
+ import torch
25
+ from PIL import Image
26
+
27
+ from ....types import (
28
+ ChatCompletion,
29
+ ChatCompletionChunk,
30
+ ChatCompletionMessage,
31
+ Completion,
32
+ CompletionChoice,
33
+ CompletionChunk,
34
+ CompletionUsage,
35
+ )
36
+ from ...utils import select_device
37
+ from ..llm_family import LLMFamilyV1, LLMSpecV1
38
+ from .core import PytorchChatModel, PytorchGenerateConfig
39
+
40
+ logger = logging.getLogger(__name__)
41
+
42
+
43
+ class MiniCPMV26Model(PytorchChatModel):
44
+ def __init__(self, *args, **kwargs):
45
+ super().__init__(*args, **kwargs)
46
+ self._device = None
47
+ self._tokenizer = None
48
+ self._model = None
49
+
50
+ @classmethod
51
+ def match(
52
+ cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
53
+ ) -> bool:
54
+ family = model_family.model_family or model_family.model_name
55
+ if "MiniCPM-V-2.6".lower() in family.lower():
56
+ return True
57
+ return False
58
+
59
+ def _get_model_class(self):
60
+ from transformers import AutoModel
61
+
62
+ return AutoModel
63
+
64
+ def load(self, **kwargs):
65
+ from transformers import AutoModel, AutoTokenizer
66
+ from transformers.generation import GenerationConfig
67
+
68
+ device = self._pytorch_model_config.get("device", "auto")
69
+ self._device = select_device(device)
70
+ self._device = (
71
+ "auto"
72
+ if self._device == "cuda" and self.quantization is None
73
+ else self._device
74
+ )
75
+
76
+ if "int4" in self.model_path and device == "mps":
77
+ logger.error(
78
+ "Error: running int4 model with bitsandbytes on Mac is not supported right now."
79
+ )
80
+ exit()
81
+
82
+ if self._check_tensorizer_integrity():
83
+ self._model, self._tokenizer = self._load_tensorizer()
84
+ return
85
+
86
+ if "int4" in self.model_path:
87
+ model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True)
88
+ else:
89
+ model = AutoModel.from_pretrained(
90
+ self.model_path,
91
+ trust_remote_code=True,
92
+ torch_dtype=torch.float16,
93
+ device_map=self._device,
94
+ )
95
+ tokenizer = AutoTokenizer.from_pretrained(
96
+ self.model_path, trust_remote_code=True
97
+ )
98
+ self._model = model.eval()
99
+ self._tokenizer = tokenizer
100
+
101
+ # Specify hyperparameters for generation
102
+ self._model.generation_config = GenerationConfig.from_pretrained(
103
+ self.model_path,
104
+ trust_remote_code=True,
105
+ )
106
+ self._save_tensorizer()
107
+
108
+ def _message_content_to_chat(self, content):
109
+ def _load_image(_url):
110
+ if _url.startswith("data:"):
111
+ logging.info("Parse url by base64 decoder.")
112
+ # https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images
113
+ # e.g. f"data:image/jpeg;base64,{base64_image}"
114
+ _type, data = _url.split(";")
115
+ _, ext = _type.split("/")
116
+ data = data[len("base64,") :]
117
+ data = base64.b64decode(data.encode("utf-8"))
118
+ return Image.open(BytesIO(data)).convert("RGB")
119
+ else:
120
+ try:
121
+ response = requests.get(_url)
122
+ except requests.exceptions.MissingSchema:
123
+ return Image.open(_url).convert("RGB")
124
+ else:
125
+ return Image.open(BytesIO(response.content)).convert("RGB")
126
+
127
+ if not isinstance(content, str):
128
+ texts = []
129
+ image_urls = []
130
+ for c in content:
131
+ c_type = c.get("type")
132
+ if c_type == "text":
133
+ texts.append(c["text"])
134
+ elif c_type == "image_url":
135
+ image_urls.append(c["image_url"]["url"])
136
+ image_futures = []
137
+ with ThreadPoolExecutor() as executor:
138
+ for image_url in image_urls:
139
+ fut = executor.submit(_load_image, image_url)
140
+ image_futures.append(fut)
141
+ images = [fut.result() for fut in image_futures]
142
+ text = " ".join(texts)
143
+ if len(images) == 0:
144
+ return text, []
145
+ elif len(images) == 1:
146
+ return text, images
147
+ else:
148
+ raise RuntimeError("Only one image per message is supported")
149
+ return content, []
150
+
151
+ def chat(
152
+ self,
153
+ prompt: Union[str, List[Dict]],
154
+ system_prompt: Optional[str] = None,
155
+ chat_history: Optional[List[ChatCompletionMessage]] = None,
156
+ generate_config: Optional[PytorchGenerateConfig] = None,
157
+ ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
158
+ stream = generate_config.get("stream", False) if generate_config else False
159
+ content, images_chat = self._message_content_to_chat(prompt)
160
+
161
+ msgs = []
162
+ query_to_response: List[Dict] = []
163
+ images_history = []
164
+ for h in chat_history or []:
165
+ role = h["role"]
166
+ content_h, images_tmp = self._message_content_to_chat(h["content"])
167
+ if images_tmp != []:
168
+ images_history = images_tmp
169
+ if len(query_to_response) == 0 and role == "user":
170
+ query_to_response.append({"role": "user", "content": content_h})
171
+ if len(query_to_response) == 1 and role == "assistant":
172
+ query_to_response.append({"role": "assistant", "content": content_h})
173
+ if len(query_to_response) == 2:
174
+ msgs.extend(query_to_response)
175
+ query_to_response = []
176
+ image = None
177
+ if len(images_chat) > 0:
178
+ image = images_chat[0]
179
+ elif len(images_history) > 0:
180
+ image = images_history[0]
181
+ msgs.append({"role": "user", "content": content})
182
+
183
+ chat = self._model.chat(
184
+ image=image,
185
+ msgs=json.dumps(msgs, ensure_ascii=True),
186
+ tokenizer=self._tokenizer,
187
+ sampling=True,
188
+ **generate_config
189
+ )
190
+ if stream:
191
+ it = self.chat_stream(chat)
192
+ return self._to_chat_completion_chunks(it)
193
+ else:
194
+ c = Completion(
195
+ id=str(uuid.uuid1()),
196
+ object="text_completion",
197
+ created=int(time.time()),
198
+ model=self.model_uid,
199
+ choices=[
200
+ CompletionChoice(
201
+ index=0, text=chat, finish_reason="stop", logprobs=None
202
+ )
203
+ ],
204
+ usage=CompletionUsage(
205
+ prompt_tokens=-1, completion_tokens=-1, total_tokens=-1
206
+ ),
207
+ )
208
+ return self._to_chat_completion(c)
209
+
210
+ def chat_stream(self, chat) -> Iterator[CompletionChunk]:
211
+ completion_id = str(uuid.uuid1())
212
+ for new_text in chat:
213
+ completion_choice = CompletionChoice(
214
+ text=new_text, index=0, logprobs=None, finish_reason=None
215
+ )
216
+ chunk = CompletionChunk(
217
+ id=completion_id,
218
+ object="text_completion",
219
+ created=int(time.time()),
220
+ model=self.model_uid,
221
+ choices=[completion_choice],
222
+ )
223
+ completion_usage = CompletionUsage(
224
+ prompt_tokens=-1,
225
+ completion_tokens=-1,
226
+ total_tokens=-1,
227
+ )
228
+ chunk["usage"] = completion_usage
229
+ yield chunk
230
+
231
+ completion_choice = CompletionChoice(
232
+ text="", index=0, logprobs=None, finish_reason="stop"
233
+ )
234
+ chunk = CompletionChunk(
235
+ id=completion_id,
236
+ object="text_completion",
237
+ created=int(time.time()),
238
+ model=self.model_uid,
239
+ choices=[completion_choice],
240
+ )
241
+ completion_usage = CompletionUsage(
242
+ prompt_tokens=-1,
243
+ completion_tokens=-1,
244
+ total_tokens=-1,
245
+ )
246
+ chunk["usage"] = completion_usage
247
+ yield chunk
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ import json
15
16
  import logging
16
17
  import time
17
18
  import uuid
@@ -122,6 +123,10 @@ class SGLANGModel(LLM):
122
123
  **self._model_config,
123
124
  )
124
125
 
126
+ def stop(self):
127
+ logger.info("Stopping SGLang engine")
128
+ self._engine.shutdown()
129
+
125
130
  def _sanitize_model_config(
126
131
  self, model_config: Optional[SGLANGModelConfig]
127
132
  ) -> SGLANGModelConfig:
@@ -132,18 +137,20 @@ class SGLANGModel(LLM):
132
137
  model_config.setdefault("tokenizer_mode", "auto")
133
138
  model_config.setdefault("trust_remote_code", True)
134
139
  model_config.setdefault("tp_size", cuda_count)
135
- # See https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/server_args.py#L37
136
- mem_fraction_static = model_config.pop("mem_fraction_static", None)
140
+ # See https://github.com/sgl-project/sglang/blob/00023d622a6d484e67ef4a0e444f708b8fc861c8/python/sglang/srt/server_args.py#L100-L109
141
+ mem_fraction_static = model_config.get("mem_fraction_static")
137
142
  if mem_fraction_static is None:
138
143
  tp_size = model_config.get("tp_size", cuda_count)
139
- if tp_size >= 8:
140
- model_config["mem_fraction_static"] = 0.80
144
+ if tp_size >= 16:
145
+ model_config["mem_fraction_static"] = 0.79
146
+ elif tp_size >= 8:
147
+ model_config["mem_fraction_static"] = 0.83
141
148
  elif tp_size >= 4:
142
- model_config["mem_fraction_static"] = 0.82
143
- elif tp_size >= 2:
144
149
  model_config["mem_fraction_static"] = 0.85
150
+ elif tp_size >= 2:
151
+ model_config["mem_fraction_static"] = 0.87
145
152
  else:
146
- model_config["mem_fraction_static"] = 0.90
153
+ model_config["mem_fraction_static"] = 0.88
147
154
  model_config.setdefault("log_level", "info")
148
155
  model_config.setdefault("attention_reduce_in_fp32", False)
149
156
 
@@ -249,28 +256,64 @@ class SGLANGModel(LLM):
249
256
  usage=usage,
250
257
  )
251
258
 
259
+ @classmethod
260
+ def _filter_sampling_params(cls, sampling_params: dict):
261
+ if not sampling_params.get("lora_name"):
262
+ sampling_params.pop("lora_name", None)
263
+ return sampling_params
264
+
265
+ async def _stream_generate(self, prompt: str, **sampling_params):
266
+ import aiohttp
267
+
268
+ sampling_params = self._filter_sampling_params(sampling_params)
269
+ json_data = {
270
+ "text": prompt,
271
+ "sampling_params": sampling_params,
272
+ "stream": True,
273
+ }
274
+ pos = 0
275
+
276
+ timeout = aiohttp.ClientTimeout(total=3 * 3600)
277
+ async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
278
+ async with session.post(
279
+ self._engine.generate_url, json=json_data # type: ignore
280
+ ) as response:
281
+ async for chunk, _ in response.content.iter_chunks():
282
+ chunk = chunk.decode("utf-8")
283
+ if chunk and chunk.startswith("data:"):
284
+ stop = "data: [DONE]\n\n"
285
+ need_stop = False
286
+ if chunk.endswith(stop):
287
+ chunk = chunk[: -len(stop)]
288
+ need_stop = True
289
+ if chunk:
290
+ data = json.loads(chunk[5:].strip("\n"))
291
+ cur = data["text"][pos:]
292
+ if cur:
293
+ yield data["meta_info"], cur
294
+ pos += len(cur)
295
+ if need_stop:
296
+ break
297
+
298
+ async def _non_stream_generate(self, prompt: str, **sampling_params) -> dict:
299
+ import aiohttp
300
+
301
+ sampling_params = self._filter_sampling_params(sampling_params)
302
+ json_data = {
303
+ "text": prompt,
304
+ "sampling_params": sampling_params,
305
+ }
306
+ async with aiohttp.ClientSession(trust_env=True) as session:
307
+ async with session.post(
308
+ self._engine.generate_url, json=json_data # type: ignore
309
+ ) as response:
310
+ return await response.json()
311
+
252
312
  async def async_generate(
253
313
  self,
254
314
  prompt: str,
255
315
  generate_config: Optional[SGLANGGenerateConfig] = None,
256
316
  ) -> Union[Completion, AsyncGenerator[CompletionChunk, None]]:
257
- try:
258
- import sglang as sgl
259
- from sglang import assistant, gen, user
260
- except ImportError:
261
- error_message = "Failed to import module 'sglang'"
262
- installation_guide = [
263
- "Please make sure 'sglang' is installed. ",
264
- "You can install it by `pip install sglang[all]`\n",
265
- ]
266
-
267
- raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
268
-
269
- @sgl.function
270
- def pipeline(s, question):
271
- s += user(question)
272
- s += assistant(gen("answer"))
273
-
274
317
  sanitized_generate_config = self._sanitize_generate_config(generate_config)
275
318
  logger.debug(
276
319
  "Enter generate, prompt: %s, generate config: %s", prompt, generate_config
@@ -285,25 +328,20 @@ class SGLANGModel(LLM):
285
328
  )
286
329
 
287
330
  request_id = str(uuid.uuid1())
288
- state = pipeline.run(
289
- question=prompt,
290
- backend=self._engine,
291
- stream=stream,
292
- **sanitized_generate_config,
293
- )
294
331
  if not stream:
332
+ state = await self._non_stream_generate(prompt, **sanitized_generate_config)
295
333
  return self._convert_state_to_completion(
296
334
  request_id,
297
335
  model=self.model_uid,
298
- output_text=state["answer"],
299
- meta_info=state.get_meta_info(name="answer"),
336
+ output_text=state["text"],
337
+ meta_info=state["meta_info"],
300
338
  )
301
339
  else:
302
340
 
303
341
  async def stream_results() -> AsyncGenerator[CompletionChunk, None]:
304
342
  prompt_tokens, completion_tokens, total_tokens = 0, 0, 0
305
- async for out, meta_info in state.text_async_iter(
306
- var_name="answer", return_meta_data=True
343
+ async for meta_info, out in self._stream_generate(
344
+ prompt, **sanitized_generate_config
307
345
  ):
308
346
  chunk = self._convert_state_to_completion_chunk(
309
347
  request_id, self.model_uid, output_text=out
@@ -12,9 +12,11 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ import asyncio
15
16
  import json
16
17
  import logging
17
18
  import multiprocessing
19
+ import os
18
20
  import time
19
21
  import uuid
20
22
  from typing import (
@@ -240,6 +242,42 @@ class VLLMModel(LLM):
240
242
  )
241
243
  self._engine = AsyncLLMEngine.from_engine_args(engine_args)
242
244
 
245
+ self._check_health_task = None
246
+ if hasattr(self._engine, "check_health"):
247
+ # vLLM introduced `check_health` since v0.4.1
248
+ self._check_health_task = asyncio.create_task(self._check_healthy())
249
+
250
+ def stop(self):
251
+ # though the vLLM engine will shutdown when deleted,
252
+ # but some issue e.g. GH#1682 reported
253
+ # when deleting, the engine exists still
254
+ logger.info("Stopping vLLM engine")
255
+ if self._check_health_task:
256
+ self._check_health_task.cancel()
257
+ if model_executor := getattr(self._engine.engine, "model_executor", None):
258
+ model_executor.shutdown()
259
+ self._engine = None
260
+
261
+ async def _check_healthy(self, interval: int = 30):
262
+ from vllm.engine.async_llm_engine import AsyncEngineDeadError
263
+
264
+ logger.debug("Begin to check health of vLLM")
265
+
266
+ while self._engine is not None:
267
+ try:
268
+ await self._engine.check_health()
269
+ except (AsyncEngineDeadError, RuntimeError):
270
+ logger.info("Detecting vLLM is not health, prepare to quit the process")
271
+ try:
272
+ self.stop()
273
+ except:
274
+ # ignore error when stop
275
+ pass
276
+ # Just kill the process and let xinference auto-recover the model
277
+ os._exit(1)
278
+ else:
279
+ await asyncio.sleep(interval)
280
+
243
281
  def _sanitize_model_config(
244
282
  self, model_config: Optional[VLLMModelConfig]
245
283
  ) -> VLLMModelConfig:
@@ -0,0 +1,62 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import codecs
16
+ import json
17
+ import os
18
+ from itertools import chain
19
+
20
+ from .core import (
21
+ BUILTIN_VIDEO_MODELS,
22
+ MODEL_NAME_TO_REVISION,
23
+ MODELSCOPE_VIDEO_MODELS,
24
+ VIDEO_MODEL_DESCRIPTIONS,
25
+ VideoModelFamilyV1,
26
+ generate_video_description,
27
+ get_cache_status,
28
+ get_video_model_descriptions,
29
+ )
30
+
31
+ _model_spec_json = os.path.join(os.path.dirname(__file__), "model_spec.json")
32
+ _model_spec_modelscope_json = os.path.join(
33
+ os.path.dirname(__file__), "model_spec_modelscope.json"
34
+ )
35
+ BUILTIN_VIDEO_MODELS.update(
36
+ dict(
37
+ (spec["model_name"], VideoModelFamilyV1(**spec))
38
+ for spec in json.load(codecs.open(_model_spec_json, "r", encoding="utf-8"))
39
+ )
40
+ )
41
+ for model_name, model_spec in BUILTIN_VIDEO_MODELS.items():
42
+ MODEL_NAME_TO_REVISION[model_name].append(model_spec.model_revision)
43
+
44
+ MODELSCOPE_VIDEO_MODELS.update(
45
+ dict(
46
+ (spec["model_name"], VideoModelFamilyV1(**spec))
47
+ for spec in json.load(
48
+ codecs.open(_model_spec_modelscope_json, "r", encoding="utf-8")
49
+ )
50
+ )
51
+ )
52
+ for model_name, model_spec in MODELSCOPE_VIDEO_MODELS.items():
53
+ MODEL_NAME_TO_REVISION[model_name].append(model_spec.model_revision)
54
+
55
+ # register model description
56
+ for model_name, model_spec in chain(
57
+ MODELSCOPE_VIDEO_MODELS.items(), BUILTIN_VIDEO_MODELS.items()
58
+ ):
59
+ VIDEO_MODEL_DESCRIPTIONS.update(generate_video_description(model_spec))
60
+
61
+ del _model_spec_json
62
+ del _model_spec_modelscope_json
@@ -0,0 +1,178 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ import os
16
+ from collections import defaultdict
17
+ from typing import Dict, List, Literal, Optional, Tuple
18
+
19
+ from ...constants import XINFERENCE_CACHE_DIR
20
+ from ..core import CacheableModelSpec, ModelDescription
21
+ from ..utils import valid_model_revision
22
+ from .diffusers import DiffUsersVideoModel
23
+
24
+ MAX_ATTEMPTS = 3
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+ MODEL_NAME_TO_REVISION: Dict[str, List[str]] = defaultdict(list)
29
+ VIDEO_MODEL_DESCRIPTIONS: Dict[str, List[Dict]] = defaultdict(list)
30
+ BUILTIN_VIDEO_MODELS: Dict[str, "VideoModelFamilyV1"] = {}
31
+ MODELSCOPE_VIDEO_MODELS: Dict[str, "VideoModelFamilyV1"] = {}
32
+
33
+
34
+ def get_video_model_descriptions():
35
+ import copy
36
+
37
+ return copy.deepcopy(VIDEO_MODEL_DESCRIPTIONS)
38
+
39
+
40
+ class VideoModelFamilyV1(CacheableModelSpec):
41
+ model_family: str
42
+ model_name: str
43
+ model_id: str
44
+ model_revision: str
45
+ model_hub: str = "huggingface"
46
+ model_ability: Optional[List[str]]
47
+
48
+
49
+ class VideoModelDescription(ModelDescription):
50
+ def __init__(
51
+ self,
52
+ address: Optional[str],
53
+ devices: Optional[List[str]],
54
+ model_spec: VideoModelFamilyV1,
55
+ model_path: Optional[str] = None,
56
+ ):
57
+ super().__init__(address, devices, model_path=model_path)
58
+ self._model_spec = model_spec
59
+
60
+ def to_dict(self):
61
+ return {
62
+ "model_type": "video",
63
+ "address": self.address,
64
+ "accelerators": self.devices,
65
+ "model_name": self._model_spec.model_name,
66
+ "model_family": self._model_spec.model_family,
67
+ "model_revision": self._model_spec.model_revision,
68
+ "model_ability": self._model_spec.model_ability,
69
+ }
70
+
71
+ def to_version_info(self):
72
+ if self._model_path is None:
73
+ is_cached = get_cache_status(self._model_spec)
74
+ file_location = get_cache_dir(self._model_spec)
75
+ else:
76
+ is_cached = True
77
+ file_location = self._model_path
78
+
79
+ return [
80
+ {
81
+ "model_version": self._model_spec.model_name,
82
+ "model_file_location": file_location,
83
+ "cache_status": is_cached,
84
+ }
85
+ ]
86
+
87
+
88
+ def generate_video_description(
89
+ video_model: VideoModelFamilyV1,
90
+ ) -> Dict[str, List[Dict]]:
91
+ res = defaultdict(list)
92
+ res[video_model.model_name].extend(
93
+ VideoModelDescription(None, None, video_model).to_version_info()
94
+ )
95
+ return res
96
+
97
+
98
+ def match_diffusion(
99
+ model_name: str,
100
+ download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
101
+ ) -> VideoModelFamilyV1:
102
+ from ..utils import download_from_modelscope
103
+ from . import BUILTIN_VIDEO_MODELS, MODELSCOPE_VIDEO_MODELS
104
+
105
+ if download_hub == "modelscope" and model_name in MODELSCOPE_VIDEO_MODELS:
106
+ logger.debug(f"Video model {model_name} found in ModelScope.")
107
+ return MODELSCOPE_VIDEO_MODELS[model_name]
108
+ elif download_hub == "huggingface" and model_name in BUILTIN_VIDEO_MODELS:
109
+ logger.debug(f"Video model {model_name} found in Huggingface.")
110
+ return BUILTIN_VIDEO_MODELS[model_name]
111
+ elif download_from_modelscope() and model_name in MODELSCOPE_VIDEO_MODELS:
112
+ logger.debug(f"Video model {model_name} found in ModelScope.")
113
+ return MODELSCOPE_VIDEO_MODELS[model_name]
114
+ elif model_name in BUILTIN_VIDEO_MODELS:
115
+ logger.debug(f"Video model {model_name} found in Huggingface.")
116
+ return BUILTIN_VIDEO_MODELS[model_name]
117
+ else:
118
+ raise ValueError(
119
+ f"Video model {model_name} not found, available"
120
+ f"model list: {BUILTIN_VIDEO_MODELS.keys()}"
121
+ )
122
+
123
+
124
+ def cache(model_spec: VideoModelFamilyV1):
125
+ from ..utils import cache
126
+
127
+ return cache(model_spec, VideoModelDescription)
128
+
129
+
130
+ def get_cache_dir(model_spec: VideoModelFamilyV1):
131
+ return os.path.realpath(os.path.join(XINFERENCE_CACHE_DIR, model_spec.model_name))
132
+
133
+
134
+ def get_cache_status(
135
+ model_spec: VideoModelFamilyV1,
136
+ ) -> bool:
137
+ cache_dir = get_cache_dir(model_spec)
138
+ meta_path = os.path.join(cache_dir, "__valid_download")
139
+
140
+ model_name = model_spec.model_name
141
+ if model_name in BUILTIN_VIDEO_MODELS and model_name in MODELSCOPE_VIDEO_MODELS:
142
+ hf_spec = BUILTIN_VIDEO_MODELS[model_name]
143
+ ms_spec = MODELSCOPE_VIDEO_MODELS[model_name]
144
+
145
+ return any(
146
+ [
147
+ valid_model_revision(meta_path, hf_spec.model_revision),
148
+ valid_model_revision(meta_path, ms_spec.model_revision),
149
+ ]
150
+ )
151
+ else: # Usually for UT
152
+ return valid_model_revision(meta_path, model_spec.model_revision)
153
+
154
+
155
+ def create_video_model_instance(
156
+ subpool_addr: str,
157
+ devices: List[str],
158
+ model_uid: str,
159
+ model_name: str,
160
+ download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
161
+ model_path: Optional[str] = None,
162
+ **kwargs,
163
+ ) -> Tuple[DiffUsersVideoModel, VideoModelDescription]:
164
+ model_spec = match_diffusion(model_name, download_hub)
165
+ if not model_path:
166
+ model_path = cache(model_spec)
167
+ assert model_path is not None
168
+
169
+ model = DiffUsersVideoModel(
170
+ model_uid,
171
+ model_path,
172
+ model_spec,
173
+ **kwargs,
174
+ )
175
+ model_description = VideoModelDescription(
176
+ subpool_addr, devices, model_spec, model_path=model_path
177
+ )
178
+ return model, model_description