xinference 0.15.3__py3-none-any.whl → 0.16.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/__init__.py +0 -4
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +29 -2
- xinference/client/restful/restful_client.py +10 -0
- xinference/constants.py +7 -3
- xinference/core/image_interface.py +76 -23
- xinference/core/model.py +158 -46
- xinference/core/progress_tracker.py +187 -0
- xinference/core/scheduler.py +10 -7
- xinference/core/supervisor.py +11 -0
- xinference/core/utils.py +9 -0
- xinference/core/worker.py +1 -0
- xinference/deploy/supervisor.py +4 -0
- xinference/model/__init__.py +4 -0
- xinference/model/audio/chattts.py +2 -1
- xinference/model/audio/core.py +0 -2
- xinference/model/audio/model_spec.json +8 -0
- xinference/model/audio/model_spec_modelscope.json +9 -0
- xinference/model/image/core.py +6 -7
- xinference/model/image/scheduler/__init__.py +13 -0
- xinference/model/image/scheduler/flux.py +533 -0
- xinference/model/image/sdapi.py +35 -4
- xinference/model/image/stable_diffusion/core.py +215 -110
- xinference/model/image/utils.py +39 -3
- xinference/model/llm/__init__.py +2 -0
- xinference/model/llm/llm_family.json +185 -17
- xinference/model/llm/llm_family_modelscope.json +124 -12
- xinference/model/llm/transformers/chatglm.py +104 -0
- xinference/model/llm/transformers/cogvlm2.py +2 -1
- xinference/model/llm/transformers/cogvlm2_video.py +2 -0
- xinference/model/llm/transformers/core.py +43 -113
- xinference/model/llm/transformers/deepseek_v2.py +0 -226
- xinference/model/llm/transformers/deepseek_vl.py +2 -0
- xinference/model/llm/transformers/glm4v.py +2 -1
- xinference/model/llm/transformers/intern_vl.py +2 -0
- xinference/model/llm/transformers/internlm2.py +3 -95
- xinference/model/llm/transformers/minicpmv25.py +2 -0
- xinference/model/llm/transformers/minicpmv26.py +2 -0
- xinference/model/llm/transformers/omnilmm.py +2 -0
- xinference/model/llm/transformers/opt.py +68 -0
- xinference/model/llm/transformers/qwen2_audio.py +11 -4
- xinference/model/llm/transformers/qwen2_vl.py +2 -28
- xinference/model/llm/transformers/qwen_vl.py +2 -1
- xinference/model/llm/transformers/utils.py +36 -283
- xinference/model/llm/transformers/yi_vl.py +2 -0
- xinference/model/llm/utils.py +60 -16
- xinference/model/llm/vllm/core.py +68 -9
- xinference/model/llm/vllm/utils.py +0 -1
- xinference/model/utils.py +7 -4
- xinference/model/video/core.py +0 -2
- xinference/utils.py +2 -3
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.e51a356d.js → main.f7da0140.js} +3 -3
- xinference/web/ui/build/static/js/main.f7da0140.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/331312668fa8bd3d7401818f4a25fa98135d7f61371cd6bfff78b18cf4fbdd92.json +1 -0
- {xinference-0.15.3.dist-info → xinference-0.16.0.dist-info}/METADATA +38 -6
- {xinference-0.15.3.dist-info → xinference-0.16.0.dist-info}/RECORD +63 -59
- xinference/web/ui/build/static/js/main.e51a356d.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/4385c1095eefbff0a8ec3b2964ba6e5a66a05ab31be721483ca2f43e2a91f6ff.json +0 -1
- /xinference/web/ui/build/static/js/{main.e51a356d.js.LICENSE.txt → main.f7da0140.js.LICENSE.txt} +0 -0
- {xinference-0.15.3.dist-info → xinference-0.16.0.dist-info}/LICENSE +0 -0
- {xinference-0.15.3.dist-info → xinference-0.16.0.dist-info}/WHEEL +0 -0
- {xinference-0.15.3.dist-info → xinference-0.16.0.dist-info}/entry_points.txt +0 -0
- {xinference-0.15.3.dist-info → xinference-0.16.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import asyncio
|
|
16
|
+
import dataclasses
|
|
17
|
+
import logging
|
|
18
|
+
import os
|
|
19
|
+
import time
|
|
20
|
+
from typing import Dict, List, Optional, Tuple
|
|
21
|
+
|
|
22
|
+
import numpy as np
|
|
23
|
+
import xoscar as xo
|
|
24
|
+
|
|
25
|
+
TO_REMOVE_PROGRESS_INTERVAL = float(
|
|
26
|
+
os.getenv("XINFERENCE_REMOVE_PROGRESS_INTERVAL", 5 * 60)
|
|
27
|
+
) # 5min
|
|
28
|
+
CHECK_PROGRESS_INTERVAL = float(
|
|
29
|
+
os.getenv("XINFERENCE_CHECK_PROGRESS_INTERVAL", 1 * 60)
|
|
30
|
+
) # 1min
|
|
31
|
+
UPLOAD_PROGRESS_SPAN = float(
|
|
32
|
+
os.getenv("XINFERENCE_UPLOAD_PROGRESS_SPAN", 0.05)
|
|
33
|
+
) # not upload when change less than 0.1
|
|
34
|
+
|
|
35
|
+
logger = logging.getLogger(__name__)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@dataclasses.dataclass
|
|
39
|
+
class _ProgressInfo:
|
|
40
|
+
progress: float
|
|
41
|
+
last_updated: float
|
|
42
|
+
info: Optional[str] = None
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class ProgressTrackerActor(xo.StatelessActor):
|
|
46
|
+
_request_id_to_progress: Dict[str, _ProgressInfo]
|
|
47
|
+
|
|
48
|
+
@classmethod
|
|
49
|
+
def default_uid(cls) -> str:
|
|
50
|
+
return "progress_tracker"
|
|
51
|
+
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
to_remove_interval: float = TO_REMOVE_PROGRESS_INTERVAL,
|
|
55
|
+
check_interval: float = CHECK_PROGRESS_INTERVAL,
|
|
56
|
+
):
|
|
57
|
+
super().__init__()
|
|
58
|
+
|
|
59
|
+
self._request_id_to_progress = {}
|
|
60
|
+
self._clear_finished_task = None
|
|
61
|
+
self._to_remove_interval = to_remove_interval
|
|
62
|
+
self._check_interval = check_interval
|
|
63
|
+
|
|
64
|
+
async def __post_create__(self):
|
|
65
|
+
self._clear_finished_task = asyncio.create_task(self._clear_finished())
|
|
66
|
+
|
|
67
|
+
async def __pre_destroy__(self):
|
|
68
|
+
if self._clear_finished_task:
|
|
69
|
+
self._clear_finished_task.cancel()
|
|
70
|
+
|
|
71
|
+
async def _clear_finished(self):
|
|
72
|
+
to_remove_request_ids = []
|
|
73
|
+
while True:
|
|
74
|
+
now = time.time()
|
|
75
|
+
for request_id, progress in self._request_id_to_progress.items():
|
|
76
|
+
if abs(progress.progress - 1.0) > 1e-5:
|
|
77
|
+
continue
|
|
78
|
+
|
|
79
|
+
# finished
|
|
80
|
+
if now - progress.last_updated > self._to_remove_interval:
|
|
81
|
+
to_remove_request_ids.append(request_id)
|
|
82
|
+
|
|
83
|
+
for rid in to_remove_request_ids:
|
|
84
|
+
del self._request_id_to_progress[rid]
|
|
85
|
+
|
|
86
|
+
if to_remove_request_ids:
|
|
87
|
+
logger.debug(
|
|
88
|
+
"Remove requests %s due to it's finished for over %s seconds",
|
|
89
|
+
to_remove_request_ids,
|
|
90
|
+
self._to_remove_interval,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
await asyncio.sleep(self._check_interval)
|
|
94
|
+
|
|
95
|
+
def start(self, request_id: str):
|
|
96
|
+
self._request_id_to_progress[request_id] = _ProgressInfo(
|
|
97
|
+
progress=0.0, last_updated=time.time()
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
def set_progress(self, request_id: str, progress: float):
|
|
101
|
+
assert progress <= 1.0
|
|
102
|
+
info = self._request_id_to_progress[request_id]
|
|
103
|
+
info.progress = progress
|
|
104
|
+
info.last_updated = time.time()
|
|
105
|
+
logger.debug(
|
|
106
|
+
"Setting progress, request id: %s, progress: %s", request_id, progress
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
def get_progress(self, request_id: str) -> float:
|
|
110
|
+
return self._request_id_to_progress[request_id].progress
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class Progressor:
|
|
114
|
+
_sub_progress_stack: List[Tuple[float, float]]
|
|
115
|
+
|
|
116
|
+
def __init__(
|
|
117
|
+
self,
|
|
118
|
+
request_id: str,
|
|
119
|
+
progress_tracker_ref: xo.ActorRefType["ProgressTrackerActor"],
|
|
120
|
+
loop: asyncio.AbstractEventLoop,
|
|
121
|
+
upload_span: float = UPLOAD_PROGRESS_SPAN,
|
|
122
|
+
):
|
|
123
|
+
self.request_id = request_id
|
|
124
|
+
self.progress_tracker_ref = progress_tracker_ref
|
|
125
|
+
self.loop = loop
|
|
126
|
+
# uploading when progress changes over this span
|
|
127
|
+
# to prevent from frequently uploading
|
|
128
|
+
self._upload_span = upload_span
|
|
129
|
+
|
|
130
|
+
self._last_report_progress = 0.0
|
|
131
|
+
self._current_progress = 0.0
|
|
132
|
+
self._sub_progress_stack = [(0.0, 1.0)]
|
|
133
|
+
self._current_sub_progress_start = 0.0
|
|
134
|
+
self._current_sub_progress_end = 1.0
|
|
135
|
+
|
|
136
|
+
async def start(self):
|
|
137
|
+
if self.request_id:
|
|
138
|
+
await self.progress_tracker_ref.start(self.request_id)
|
|
139
|
+
|
|
140
|
+
def split_stages(self, n_stage: int, stage_weight: Optional[List[float]] = None):
|
|
141
|
+
if self.request_id:
|
|
142
|
+
if stage_weight is not None:
|
|
143
|
+
if len(stage_weight) != n_stage + 1:
|
|
144
|
+
raise ValueError(
|
|
145
|
+
f"stage_weight should have size {n_stage + 1}, got {len(stage_weight)}"
|
|
146
|
+
)
|
|
147
|
+
progresses = stage_weight
|
|
148
|
+
else:
|
|
149
|
+
progresses = np.linspace(
|
|
150
|
+
self._current_sub_progress_start,
|
|
151
|
+
self._current_sub_progress_end,
|
|
152
|
+
n_stage + 1,
|
|
153
|
+
)
|
|
154
|
+
spans = [(progresses[i], progresses[i + 1]) for i in range(n_stage)]
|
|
155
|
+
self._sub_progress_stack.extend(spans[::-1])
|
|
156
|
+
|
|
157
|
+
def __enter__(self):
|
|
158
|
+
if self.request_id:
|
|
159
|
+
(
|
|
160
|
+
self._current_sub_progress_start,
|
|
161
|
+
self._current_sub_progress_end,
|
|
162
|
+
) = self._sub_progress_stack[-1]
|
|
163
|
+
|
|
164
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
165
|
+
if self.request_id:
|
|
166
|
+
self._sub_progress_stack.pop()
|
|
167
|
+
# force to set progress to 1.0 for this sub progress
|
|
168
|
+
# nevertheless it is done or not
|
|
169
|
+
self.set_progress(1.0)
|
|
170
|
+
return False
|
|
171
|
+
|
|
172
|
+
def set_progress(self, progress: float):
|
|
173
|
+
if self.request_id:
|
|
174
|
+
self._current_progress = (
|
|
175
|
+
self._current_sub_progress_start
|
|
176
|
+
+ (self._current_sub_progress_end - self._current_sub_progress_start)
|
|
177
|
+
* progress
|
|
178
|
+
)
|
|
179
|
+
if (
|
|
180
|
+
self._current_progress - self._last_report_progress >= self._upload_span
|
|
181
|
+
or 1.0 - progress < 1e-5
|
|
182
|
+
):
|
|
183
|
+
set_progress = self.progress_tracker_ref.set_progress(
|
|
184
|
+
self.request_id, self._current_progress
|
|
185
|
+
)
|
|
186
|
+
asyncio.run_coroutine_threadsafe(set_progress, self.loop) # type: ignore
|
|
187
|
+
self._last_report_progress = self._current_progress
|
xinference/core/scheduler.py
CHANGED
|
@@ -17,11 +17,12 @@ import functools
|
|
|
17
17
|
import logging
|
|
18
18
|
import uuid
|
|
19
19
|
from collections import deque
|
|
20
|
-
from enum import Enum
|
|
21
20
|
from typing import Dict, List, Optional, Set, Tuple, Union
|
|
22
21
|
|
|
23
22
|
import xoscar as xo
|
|
24
23
|
|
|
24
|
+
from .utils import AbortRequestMessage
|
|
25
|
+
|
|
25
26
|
logger = logging.getLogger(__name__)
|
|
26
27
|
|
|
27
28
|
XINFERENCE_STREAMING_DONE_FLAG = "<XINFERENCE_STREAMING_DONE>"
|
|
@@ -30,12 +31,6 @@ XINFERENCE_STREAMING_ABORT_FLAG = "<XINFERENCE_STREAMING_ABORT>"
|
|
|
30
31
|
XINFERENCE_NON_STREAMING_ABORT_FLAG = "<XINFERENCE_NON_STREAMING_ABORT>"
|
|
31
32
|
|
|
32
33
|
|
|
33
|
-
class AbortRequestMessage(Enum):
|
|
34
|
-
NOT_FOUND = 1
|
|
35
|
-
DONE = 2
|
|
36
|
-
NO_OP = 3
|
|
37
|
-
|
|
38
|
-
|
|
39
34
|
class InferenceRequest:
|
|
40
35
|
def __init__(
|
|
41
36
|
self,
|
|
@@ -81,6 +76,10 @@ class InferenceRequest:
|
|
|
81
76
|
self.padding_len = 0
|
|
82
77
|
# Use in stream mode
|
|
83
78
|
self.last_output_length = 0
|
|
79
|
+
# For tool call
|
|
80
|
+
self.tools = None
|
|
81
|
+
# Currently, for storing tool call streaming results.
|
|
82
|
+
self.outputs: List[str] = []
|
|
84
83
|
# inference results,
|
|
85
84
|
# it is a list type because when stream=True,
|
|
86
85
|
# self.completion contains all the results in a decode round.
|
|
@@ -112,6 +111,10 @@ class InferenceRequest:
|
|
|
112
111
|
"""
|
|
113
112
|
return self._prompt
|
|
114
113
|
|
|
114
|
+
@prompt.setter
|
|
115
|
+
def prompt(self, value: str):
|
|
116
|
+
self._prompt = value
|
|
117
|
+
|
|
115
118
|
@property
|
|
116
119
|
def call_ability(self):
|
|
117
120
|
return self._call_ability
|
xinference/core/supervisor.py
CHANGED
|
@@ -130,6 +130,7 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
130
130
|
)
|
|
131
131
|
logger.info(f"Xinference supervisor {self.address} started")
|
|
132
132
|
from .cache_tracker import CacheTrackerActor
|
|
133
|
+
from .progress_tracker import ProgressTrackerActor
|
|
133
134
|
from .status_guard import StatusGuardActor
|
|
134
135
|
|
|
135
136
|
self._status_guard_ref: xo.ActorRefType[ # type: ignore
|
|
@@ -142,6 +143,13 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
142
143
|
] = await xo.create_actor(
|
|
143
144
|
CacheTrackerActor, address=self.address, uid=CacheTrackerActor.default_uid()
|
|
144
145
|
)
|
|
146
|
+
self._progress_tracker: xo.ActorRefType[ # type: ignore
|
|
147
|
+
"ProgressTrackerActor"
|
|
148
|
+
] = await xo.create_actor(
|
|
149
|
+
ProgressTrackerActor,
|
|
150
|
+
address=self.address,
|
|
151
|
+
uid=ProgressTrackerActor.default_uid(),
|
|
152
|
+
)
|
|
145
153
|
|
|
146
154
|
from .event import EventCollectorActor
|
|
147
155
|
|
|
@@ -1360,3 +1368,6 @@ class SupervisorActor(xo.StatelessActor):
|
|
|
1360
1368
|
@staticmethod
|
|
1361
1369
|
def record_metrics(name, op, kwargs):
|
|
1362
1370
|
record_metrics(name, op, kwargs)
|
|
1371
|
+
|
|
1372
|
+
async def get_progress(self, request_id: str) -> float:
|
|
1373
|
+
return await self._progress_tracker.get_progress(request_id)
|
xinference/core/utils.py
CHANGED
|
@@ -16,6 +16,7 @@ import os
|
|
|
16
16
|
import random
|
|
17
17
|
import string
|
|
18
18
|
import uuid
|
|
19
|
+
from enum import Enum
|
|
19
20
|
from typing import Dict, Generator, List, Optional, Tuple, Union
|
|
20
21
|
|
|
21
22
|
import orjson
|
|
@@ -27,6 +28,12 @@ from ..constants import XINFERENCE_LOG_ARG_MAX_LENGTH
|
|
|
27
28
|
logger = logging.getLogger(__name__)
|
|
28
29
|
|
|
29
30
|
|
|
31
|
+
class AbortRequestMessage(Enum):
|
|
32
|
+
NOT_FOUND = 1
|
|
33
|
+
DONE = 2
|
|
34
|
+
NO_OP = 3
|
|
35
|
+
|
|
36
|
+
|
|
30
37
|
def truncate_log_arg(arg) -> str:
|
|
31
38
|
s = str(arg)
|
|
32
39
|
if len(s) > XINFERENCE_LOG_ARG_MAX_LENGTH:
|
|
@@ -51,6 +58,8 @@ def log_async(
|
|
|
51
58
|
request_id_str = kwargs.get("request_id", "")
|
|
52
59
|
if not request_id_str:
|
|
53
60
|
request_id_str = uuid.uuid1()
|
|
61
|
+
if func_name == "text_to_image":
|
|
62
|
+
kwargs["request_id"] = request_id_str
|
|
54
63
|
request_id_str = f"[request {request_id_str}]"
|
|
55
64
|
formatted_args = ",".join(map(truncate_log_arg, args))
|
|
56
65
|
formatted_kwargs = ",".join(
|
xinference/core/worker.py
CHANGED
xinference/deploy/supervisor.py
CHANGED
|
@@ -31,6 +31,10 @@ from .utils import health_check
|
|
|
31
31
|
|
|
32
32
|
logger = logging.getLogger(__name__)
|
|
33
33
|
|
|
34
|
+
from ..model import _install as install_model
|
|
35
|
+
|
|
36
|
+
install_model()
|
|
37
|
+
|
|
34
38
|
|
|
35
39
|
async def _start_supervisor(address: str, logging_conf: Optional[Dict] = None):
|
|
36
40
|
logging.config.dictConfig(logging_conf) # type: ignore
|
xinference/model/__init__.py
CHANGED
|
@@ -53,7 +53,8 @@ class ChatTTSModel:
|
|
|
53
53
|
torch._dynamo.config.suppress_errors = True
|
|
54
54
|
torch.set_float32_matmul_precision("high")
|
|
55
55
|
self._model = ChatTTS.Chat()
|
|
56
|
-
|
|
56
|
+
logger.info("Load ChatTTS model with kwargs: %s", self._kwargs)
|
|
57
|
+
self._model.load(source="custom", custom_path=self._model_path, **self._kwargs)
|
|
57
58
|
|
|
58
59
|
def speech(
|
|
59
60
|
self,
|
xinference/model/audio/core.py
CHANGED
|
@@ -71,6 +71,14 @@
|
|
|
71
71
|
"model_ability": "audio-to-text",
|
|
72
72
|
"multilingual": true
|
|
73
73
|
},
|
|
74
|
+
{
|
|
75
|
+
"model_name": "whisper-large-v3-turbo",
|
|
76
|
+
"model_family": "whisper",
|
|
77
|
+
"model_id": "openai/whisper-large-v3-turbo",
|
|
78
|
+
"model_revision": "41f01f3fe87f28c78e2fbf8b568835947dd65ed9",
|
|
79
|
+
"model_ability": "audio-to-text",
|
|
80
|
+
"multilingual": true
|
|
81
|
+
},
|
|
74
82
|
{
|
|
75
83
|
"model_name": "Belle-distilwhisper-large-v2-zh",
|
|
76
84
|
"model_family": "whisper",
|
|
@@ -8,6 +8,15 @@
|
|
|
8
8
|
"model_ability": "audio-to-text",
|
|
9
9
|
"multilingual": true
|
|
10
10
|
},
|
|
11
|
+
{
|
|
12
|
+
"model_name": "whisper-large-v3-turbo",
|
|
13
|
+
"model_family": "whisper",
|
|
14
|
+
"model_hub": "modelscope",
|
|
15
|
+
"model_id": "AI-ModelScope/whisper-large-v3-turbo",
|
|
16
|
+
"model_revision": "master",
|
|
17
|
+
"model_ability": "audio-to-text",
|
|
18
|
+
"multilingual": true
|
|
19
|
+
},
|
|
11
20
|
{
|
|
12
21
|
"model_name": "SenseVoiceSmall",
|
|
13
22
|
"model_family": "funasr",
|
xinference/model/image/core.py
CHANGED
|
@@ -23,8 +23,6 @@ from ..core import CacheableModelSpec, ModelDescription
|
|
|
23
23
|
from ..utils import valid_model_revision
|
|
24
24
|
from .stable_diffusion.core import DiffusionModel
|
|
25
25
|
|
|
26
|
-
MAX_ATTEMPTS = 3
|
|
27
|
-
|
|
28
26
|
logger = logging.getLogger(__name__)
|
|
29
27
|
|
|
30
28
|
MODEL_NAME_TO_REVISION: Dict[str, List[str]] = defaultdict(list)
|
|
@@ -210,18 +208,19 @@ def create_image_model_instance(
|
|
|
210
208
|
for name in controlnet:
|
|
211
209
|
for cn_model_spec in model_spec.controlnet:
|
|
212
210
|
if cn_model_spec.model_name == name:
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
controlnet_model_paths.append(model_path)
|
|
211
|
+
controlnet_model_path = cache(cn_model_spec)
|
|
212
|
+
controlnet_model_paths.append(controlnet_model_path)
|
|
216
213
|
break
|
|
217
214
|
else:
|
|
218
215
|
raise ValueError(
|
|
219
216
|
f"controlnet `{name}` is not supported for model `{model_name}`."
|
|
220
217
|
)
|
|
221
218
|
if len(controlnet_model_paths) == 1:
|
|
222
|
-
kwargs["controlnet"] = controlnet_model_paths[0]
|
|
219
|
+
kwargs["controlnet"] = (controlnet[0], controlnet_model_paths[0])
|
|
223
220
|
else:
|
|
224
|
-
kwargs["controlnet"] =
|
|
221
|
+
kwargs["controlnet"] = [
|
|
222
|
+
(n, path) for n, path in zip(controlnet, controlnet_model_paths)
|
|
223
|
+
]
|
|
225
224
|
if not model_path:
|
|
226
225
|
model_path = cache(model_spec)
|
|
227
226
|
if peft_model_config is not None:
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2022-2024 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|