xinference 0.15.4__py3-none-any.whl → 0.16.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/__init__.py +0 -4
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +48 -0
- xinference/client/restful/restful_client.py +19 -0
- xinference/constants.py +4 -4
- xinference/core/chat_interface.py +5 -1
- xinference/core/image_interface.py +5 -1
- xinference/core/model.py +195 -34
- xinference/core/scheduler.py +10 -7
- xinference/core/utils.py +9 -0
- xinference/model/__init__.py +4 -0
- xinference/model/audio/chattts.py +25 -14
- xinference/model/audio/model_spec.json +1 -1
- xinference/model/audio/model_spec_modelscope.json +1 -1
- xinference/model/embedding/model_spec.json +1 -1
- xinference/model/image/core.py +59 -4
- xinference/model/image/model_spec.json +24 -3
- xinference/model/image/model_spec_modelscope.json +25 -3
- xinference/model/image/ocr/__init__.py +13 -0
- xinference/model/image/ocr/got_ocr2.py +76 -0
- xinference/model/image/scheduler/__init__.py +13 -0
- xinference/model/image/scheduler/flux.py +533 -0
- xinference/model/image/stable_diffusion/core.py +8 -34
- xinference/model/image/stable_diffusion/mlx.py +221 -0
- xinference/model/image/utils.py +39 -3
- xinference/model/llm/__init__.py +2 -0
- xinference/model/llm/llm_family.json +178 -1
- xinference/model/llm/llm_family_modelscope.json +119 -0
- xinference/model/llm/transformers/chatglm.py +104 -0
- xinference/model/llm/transformers/core.py +37 -111
- xinference/model/llm/transformers/deepseek_v2.py +0 -226
- xinference/model/llm/transformers/internlm2.py +3 -95
- xinference/model/llm/transformers/opt.py +68 -0
- xinference/model/llm/transformers/utils.py +4 -284
- xinference/model/llm/utils.py +2 -2
- xinference/model/llm/vllm/core.py +16 -1
- xinference/thirdparty/mlx/__init__.py +13 -0
- xinference/thirdparty/mlx/flux/__init__.py +15 -0
- xinference/thirdparty/mlx/flux/autoencoder.py +357 -0
- xinference/thirdparty/mlx/flux/clip.py +154 -0
- xinference/thirdparty/mlx/flux/datasets.py +75 -0
- xinference/thirdparty/mlx/flux/flux.py +247 -0
- xinference/thirdparty/mlx/flux/layers.py +302 -0
- xinference/thirdparty/mlx/flux/lora.py +76 -0
- xinference/thirdparty/mlx/flux/model.py +134 -0
- xinference/thirdparty/mlx/flux/sampler.py +56 -0
- xinference/thirdparty/mlx/flux/t5.py +244 -0
- xinference/thirdparty/mlx/flux/tokenizers.py +185 -0
- xinference/thirdparty/mlx/flux/trainer.py +98 -0
- xinference/thirdparty/mlx/flux/utils.py +179 -0
- xinference/utils.py +2 -3
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.e51a356d.js → main.b76aeeb7.js} +3 -3
- xinference/web/ui/build/static/js/main.b76aeeb7.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/32ea2c04cf0bba2761b4883d2c40cc259952c94d2d6bb774e510963ca37aac0a.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/331312668fa8bd3d7401818f4a25fa98135d7f61371cd6bfff78b18cf4fbdd92.json +1 -0
- {xinference-0.15.4.dist-info → xinference-0.16.1.dist-info}/METADATA +49 -10
- {xinference-0.15.4.dist-info → xinference-0.16.1.dist-info}/RECORD +64 -44
- xinference/web/ui/build/static/js/main.e51a356d.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/4385c1095eefbff0a8ec3b2964ba6e5a66a05ab31be721483ca2f43e2a91f6ff.json +0 -1
- /xinference/web/ui/build/static/js/{main.e51a356d.js.LICENSE.txt → main.b76aeeb7.js.LICENSE.txt} +0 -0
- {xinference-0.15.4.dist-info → xinference-0.16.1.dist-info}/LICENSE +0 -0
- {xinference-0.15.4.dist-info → xinference-0.16.1.dist-info}/WHEEL +0 -0
- {xinference-0.15.4.dist-info → xinference-0.16.1.dist-info}/entry_points.txt +0 -0
- {xinference-0.15.4.dist-info → xinference-0.16.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
# Copyright 2022-2024 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
from builtins import classmethod
|
|
15
|
+
from typing import List, Optional
|
|
16
|
+
|
|
17
|
+
from ....core.scheduler import InferenceRequest
|
|
18
|
+
from ....types import LoRA
|
|
19
|
+
from ..llm_family import LLMFamilyV1, LLMSpecV1
|
|
20
|
+
from .core import PytorchModel, PytorchModelConfig
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class OptPytorchModel(PytorchModel):
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
model_uid: str,
|
|
27
|
+
model_family: "LLMFamilyV1",
|
|
28
|
+
model_spec: "LLMSpecV1",
|
|
29
|
+
quantization: str,
|
|
30
|
+
model_path: str,
|
|
31
|
+
pytorch_model_config: Optional[PytorchModelConfig] = None,
|
|
32
|
+
peft_model: Optional[List[LoRA]] = None,
|
|
33
|
+
):
|
|
34
|
+
super().__init__(
|
|
35
|
+
model_uid,
|
|
36
|
+
model_family,
|
|
37
|
+
model_spec,
|
|
38
|
+
quantization,
|
|
39
|
+
model_path,
|
|
40
|
+
pytorch_model_config=pytorch_model_config,
|
|
41
|
+
peft_model=peft_model,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
@classmethod
|
|
45
|
+
def match(
|
|
46
|
+
cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
|
|
47
|
+
) -> bool:
|
|
48
|
+
if llm_spec.model_format != "pytorch":
|
|
49
|
+
return False
|
|
50
|
+
model_family = llm_family.model_family or llm_family.model_name
|
|
51
|
+
if model_family != "opt":
|
|
52
|
+
return False
|
|
53
|
+
return True
|
|
54
|
+
|
|
55
|
+
def build_prefill_position_ids(
|
|
56
|
+
self, batch_size: int, seq_length: int, reqs: List[InferenceRequest]
|
|
57
|
+
):
|
|
58
|
+
"""
|
|
59
|
+
Mainly for UT.
|
|
60
|
+
Transformers code in `main` branch supports `position_ids` parameter (https://github.com/huggingface/transformers/blob/main/src/transformers/models/opt/modeling_opt.py#L1076),
|
|
61
|
+
while in release branch, it doesn't (https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/opt/modeling_opt.py#L886).
|
|
62
|
+
"""
|
|
63
|
+
return None
|
|
64
|
+
|
|
65
|
+
def build_decode_position_ids(
|
|
66
|
+
self, batch_size: int, seq_length: int, reqs: List[InferenceRequest]
|
|
67
|
+
):
|
|
68
|
+
return None
|
|
@@ -11,14 +11,13 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
+
|
|
14
15
|
import asyncio
|
|
15
16
|
import functools
|
|
16
|
-
import gc
|
|
17
17
|
import logging
|
|
18
18
|
import os
|
|
19
19
|
import time
|
|
20
|
-
import
|
|
21
|
-
from typing import TYPE_CHECKING, Dict, Iterable, Iterator, List, Optional, Tuple
|
|
20
|
+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
|
22
21
|
|
|
23
22
|
import torch
|
|
24
23
|
from transformers.cache_utils import DynamicCache
|
|
@@ -46,20 +45,6 @@ if TYPE_CHECKING:
|
|
|
46
45
|
logger = logging.getLogger(__name__)
|
|
47
46
|
|
|
48
47
|
|
|
49
|
-
def is_sentence_complete(output: str):
|
|
50
|
-
"""Check whether the output is a complete sentence."""
|
|
51
|
-
end_symbols = (".", "?", "!", "...", "。", "?", "!", "…", '"', "'", "”")
|
|
52
|
-
return output.endswith(end_symbols)
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
def is_partial_stop(output: str, stop_str: str):
|
|
56
|
-
"""Check whether the output contains a partial stop str."""
|
|
57
|
-
for i in range(0, min(len(output), len(stop_str))):
|
|
58
|
-
if stop_str.startswith(output[-i:]):
|
|
59
|
-
return True
|
|
60
|
-
return False
|
|
61
|
-
|
|
62
|
-
|
|
63
48
|
def get_context_length(config) -> int:
|
|
64
49
|
"""Get the context length of a model from a huggingface model config."""
|
|
65
50
|
if (
|
|
@@ -99,273 +84,6 @@ def prepare_logits_processor(
|
|
|
99
84
|
return processor_list
|
|
100
85
|
|
|
101
86
|
|
|
102
|
-
@torch.inference_mode()
|
|
103
|
-
def generate_stream(
|
|
104
|
-
model_uid,
|
|
105
|
-
model,
|
|
106
|
-
tokenizer,
|
|
107
|
-
prompt,
|
|
108
|
-
device,
|
|
109
|
-
generate_config,
|
|
110
|
-
judge_sent_end=False,
|
|
111
|
-
) -> Iterator[Tuple[CompletionChunk, CompletionUsage]]:
|
|
112
|
-
context_len = get_context_length(model.config)
|
|
113
|
-
stream_interval = generate_config.get("stream_interval", 2)
|
|
114
|
-
stream = generate_config.get("stream", False)
|
|
115
|
-
stream_options = generate_config.pop("stream_options", None)
|
|
116
|
-
include_usage = (
|
|
117
|
-
stream_options["include_usage"] if isinstance(stream_options, dict) else False
|
|
118
|
-
)
|
|
119
|
-
|
|
120
|
-
len_prompt = len(prompt)
|
|
121
|
-
|
|
122
|
-
temperature = float(generate_config.get("temperature", 1.0))
|
|
123
|
-
repetition_penalty = float(generate_config.get("repetition_penalty", 1.0))
|
|
124
|
-
top_p = float(generate_config.get("top_p", 1.0))
|
|
125
|
-
top_k = int(generate_config.get("top_k", -1)) # -1 means disable
|
|
126
|
-
max_new_tokens = int(generate_config.get("max_tokens", max_tokens_field.default))
|
|
127
|
-
echo = bool(generate_config.get("echo", False))
|
|
128
|
-
stop_str = generate_config.get("stop", None)
|
|
129
|
-
stop_token_ids = generate_config.get("stop_token_ids", None) or []
|
|
130
|
-
if tokenizer.eos_token_id not in stop_token_ids:
|
|
131
|
-
stop_token_ids.append(tokenizer.eos_token_id)
|
|
132
|
-
chunk_id = str(uuid.uuid4())
|
|
133
|
-
|
|
134
|
-
logits_processor = prepare_logits_processor(
|
|
135
|
-
temperature, repetition_penalty, top_p, top_k
|
|
136
|
-
)
|
|
137
|
-
|
|
138
|
-
if ".modeling_qwen." in str(type(model)).lower():
|
|
139
|
-
# TODO: hacky
|
|
140
|
-
input_ids = tokenizer(prompt, allowed_special="all").input_ids
|
|
141
|
-
else:
|
|
142
|
-
input_ids = tokenizer(prompt).input_ids
|
|
143
|
-
output_ids = list(input_ids)
|
|
144
|
-
|
|
145
|
-
if model.config.is_encoder_decoder:
|
|
146
|
-
max_src_len = context_len
|
|
147
|
-
else:
|
|
148
|
-
max_src_len = context_len - max_new_tokens - 8
|
|
149
|
-
if max_src_len < 0:
|
|
150
|
-
raise ValueError("Max tokens exceeds model's max length")
|
|
151
|
-
|
|
152
|
-
input_ids = input_ids[-max_src_len:]
|
|
153
|
-
input_echo_len = len(input_ids)
|
|
154
|
-
|
|
155
|
-
if model.config.is_encoder_decoder:
|
|
156
|
-
encoder_output = model.encoder(
|
|
157
|
-
input_ids=torch.as_tensor([input_ids], device=device)
|
|
158
|
-
)[0]
|
|
159
|
-
start_ids = torch.as_tensor(
|
|
160
|
-
[[model.generation_config.decoder_start_token_id]],
|
|
161
|
-
dtype=torch.int64,
|
|
162
|
-
device=device,
|
|
163
|
-
)
|
|
164
|
-
|
|
165
|
-
start = time.time()
|
|
166
|
-
past_key_values = out = None
|
|
167
|
-
sent_interrupt = False
|
|
168
|
-
token = None
|
|
169
|
-
last_output_length = 0
|
|
170
|
-
for i in range(max_new_tokens):
|
|
171
|
-
if i == 0:
|
|
172
|
-
if model.config.is_encoder_decoder:
|
|
173
|
-
out = model.decoder(
|
|
174
|
-
input_ids=start_ids,
|
|
175
|
-
encoder_hidden_states=encoder_output,
|
|
176
|
-
use_cache=True,
|
|
177
|
-
)
|
|
178
|
-
logits = model.lm_head(out[0])
|
|
179
|
-
else:
|
|
180
|
-
out = model(torch.as_tensor([input_ids], device=device), use_cache=True)
|
|
181
|
-
logits = out.logits
|
|
182
|
-
past_key_values = out.past_key_values
|
|
183
|
-
else:
|
|
184
|
-
if model.config.is_encoder_decoder:
|
|
185
|
-
out = model.decoder(
|
|
186
|
-
input_ids=torch.as_tensor(
|
|
187
|
-
[[token] if not sent_interrupt else output_ids], device=device
|
|
188
|
-
),
|
|
189
|
-
encoder_hidden_states=encoder_output,
|
|
190
|
-
use_cache=True,
|
|
191
|
-
past_key_values=past_key_values if not sent_interrupt else None,
|
|
192
|
-
)
|
|
193
|
-
sent_interrupt = False
|
|
194
|
-
|
|
195
|
-
logits = model.lm_head(out[0])
|
|
196
|
-
else:
|
|
197
|
-
out = model(
|
|
198
|
-
input_ids=torch.as_tensor(
|
|
199
|
-
[[token] if not sent_interrupt else output_ids], device=device
|
|
200
|
-
),
|
|
201
|
-
use_cache=True,
|
|
202
|
-
past_key_values=past_key_values if not sent_interrupt else None,
|
|
203
|
-
)
|
|
204
|
-
sent_interrupt = False
|
|
205
|
-
logits = out.logits
|
|
206
|
-
past_key_values = out.past_key_values
|
|
207
|
-
|
|
208
|
-
if logits_processor:
|
|
209
|
-
if repetition_penalty > 1.0:
|
|
210
|
-
tmp_output_ids = torch.as_tensor([output_ids], device=logits.device)
|
|
211
|
-
else:
|
|
212
|
-
tmp_output_ids = None
|
|
213
|
-
last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0]
|
|
214
|
-
else:
|
|
215
|
-
last_token_logits = logits[0, -1, :]
|
|
216
|
-
|
|
217
|
-
if device == "mps":
|
|
218
|
-
# Switch to CPU by avoiding some bugs in mps backend.
|
|
219
|
-
last_token_logits = last_token_logits.float().to("cpu")
|
|
220
|
-
|
|
221
|
-
if temperature < 1e-5 or top_p < 1e-8: # greedy
|
|
222
|
-
_, indices = torch.topk(last_token_logits, 2)
|
|
223
|
-
tokens = [int(index) for index in indices.tolist()]
|
|
224
|
-
else:
|
|
225
|
-
probs = torch.softmax(last_token_logits, dim=-1)
|
|
226
|
-
indices = torch.multinomial(probs, num_samples=2)
|
|
227
|
-
tokens = [int(token) for token in indices.tolist()]
|
|
228
|
-
token = tokens[0]
|
|
229
|
-
output_ids.append(token)
|
|
230
|
-
|
|
231
|
-
if token in stop_token_ids:
|
|
232
|
-
stopped = True
|
|
233
|
-
else:
|
|
234
|
-
stopped = False
|
|
235
|
-
|
|
236
|
-
if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped:
|
|
237
|
-
if echo:
|
|
238
|
-
tmp_output_ids = output_ids
|
|
239
|
-
rfind_start = len_prompt
|
|
240
|
-
else:
|
|
241
|
-
tmp_output_ids = output_ids[input_echo_len:]
|
|
242
|
-
rfind_start = 0
|
|
243
|
-
|
|
244
|
-
output = tokenizer.decode(
|
|
245
|
-
tmp_output_ids,
|
|
246
|
-
skip_special_tokens=True,
|
|
247
|
-
spaces_between_special_tokens=False,
|
|
248
|
-
clean_up_tokenization_spaces=True,
|
|
249
|
-
)
|
|
250
|
-
|
|
251
|
-
# TODO: For the issue of incomplete sentences interrupting output, apply a patch and others can also modify it to a more elegant way
|
|
252
|
-
if judge_sent_end and stopped and not is_sentence_complete(output):
|
|
253
|
-
if len(tokens) > 1:
|
|
254
|
-
token = tokens[1]
|
|
255
|
-
output_ids[-1] = token
|
|
256
|
-
else:
|
|
257
|
-
output_ids.pop()
|
|
258
|
-
stopped = False
|
|
259
|
-
sent_interrupt = True
|
|
260
|
-
|
|
261
|
-
partially_stopped = False
|
|
262
|
-
if stop_str:
|
|
263
|
-
if isinstance(stop_str, str):
|
|
264
|
-
pos = output.rfind(stop_str, rfind_start)
|
|
265
|
-
if pos != -1:
|
|
266
|
-
output = output[:pos]
|
|
267
|
-
stopped = True
|
|
268
|
-
else:
|
|
269
|
-
partially_stopped = is_partial_stop(output, stop_str)
|
|
270
|
-
elif isinstance(stop_str, Iterable):
|
|
271
|
-
for each_stop in stop_str:
|
|
272
|
-
pos = output.rfind(each_stop, rfind_start)
|
|
273
|
-
if pos != -1:
|
|
274
|
-
output = output[:pos]
|
|
275
|
-
stopped = True
|
|
276
|
-
break
|
|
277
|
-
else:
|
|
278
|
-
partially_stopped = is_partial_stop(output, each_stop)
|
|
279
|
-
if partially_stopped:
|
|
280
|
-
break
|
|
281
|
-
else:
|
|
282
|
-
raise ValueError("Invalid stop field type.")
|
|
283
|
-
|
|
284
|
-
if stream:
|
|
285
|
-
output = output.strip("�")
|
|
286
|
-
tmp_output_length = len(output)
|
|
287
|
-
output = output[last_output_length:]
|
|
288
|
-
last_output_length = tmp_output_length
|
|
289
|
-
|
|
290
|
-
# prevent yielding partial stop sequence
|
|
291
|
-
if not partially_stopped:
|
|
292
|
-
completion_choice = CompletionChoice(
|
|
293
|
-
text=output, index=0, logprobs=None, finish_reason=None
|
|
294
|
-
)
|
|
295
|
-
completion_chunk = CompletionChunk(
|
|
296
|
-
id=chunk_id,
|
|
297
|
-
object="text_completion",
|
|
298
|
-
created=int(time.time()),
|
|
299
|
-
model=model_uid,
|
|
300
|
-
choices=[completion_choice],
|
|
301
|
-
)
|
|
302
|
-
completion_usage = CompletionUsage(
|
|
303
|
-
prompt_tokens=input_echo_len,
|
|
304
|
-
completion_tokens=i,
|
|
305
|
-
total_tokens=(input_echo_len + i),
|
|
306
|
-
)
|
|
307
|
-
|
|
308
|
-
yield completion_chunk, completion_usage
|
|
309
|
-
|
|
310
|
-
if stopped:
|
|
311
|
-
break
|
|
312
|
-
|
|
313
|
-
elapsed_time = time.time() - start
|
|
314
|
-
logger.info(f"Average generation speed: {i / elapsed_time:.2f} tokens/s.")
|
|
315
|
-
|
|
316
|
-
# finish stream event, which contains finish reason
|
|
317
|
-
if stopped:
|
|
318
|
-
finish_reason = "stop"
|
|
319
|
-
elif i == max_new_tokens - 1:
|
|
320
|
-
finish_reason = "length"
|
|
321
|
-
else:
|
|
322
|
-
finish_reason = None
|
|
323
|
-
|
|
324
|
-
if stream:
|
|
325
|
-
completion_choice = CompletionChoice(
|
|
326
|
-
text=output, index=0, logprobs=None, finish_reason=finish_reason
|
|
327
|
-
)
|
|
328
|
-
else:
|
|
329
|
-
completion_choice = CompletionChoice(
|
|
330
|
-
text=output, index=0, logprobs=None, finish_reason=finish_reason
|
|
331
|
-
)
|
|
332
|
-
|
|
333
|
-
completion_chunk = CompletionChunk(
|
|
334
|
-
id=chunk_id,
|
|
335
|
-
object="text_completion",
|
|
336
|
-
created=int(time.time()),
|
|
337
|
-
model=model_uid,
|
|
338
|
-
choices=[completion_choice],
|
|
339
|
-
)
|
|
340
|
-
completion_usage = CompletionUsage(
|
|
341
|
-
prompt_tokens=input_echo_len,
|
|
342
|
-
completion_tokens=i,
|
|
343
|
-
total_tokens=(input_echo_len + i),
|
|
344
|
-
)
|
|
345
|
-
|
|
346
|
-
yield completion_chunk, completion_usage
|
|
347
|
-
|
|
348
|
-
if include_usage:
|
|
349
|
-
completion_chunk = CompletionChunk(
|
|
350
|
-
id=chunk_id,
|
|
351
|
-
object="text_completion",
|
|
352
|
-
created=int(time.time()),
|
|
353
|
-
model=model_uid,
|
|
354
|
-
choices=[],
|
|
355
|
-
)
|
|
356
|
-
completion_usage = CompletionUsage(
|
|
357
|
-
prompt_tokens=input_echo_len,
|
|
358
|
-
completion_tokens=i,
|
|
359
|
-
total_tokens=(input_echo_len + i),
|
|
360
|
-
)
|
|
361
|
-
yield completion_chunk, completion_usage
|
|
362
|
-
|
|
363
|
-
# clean
|
|
364
|
-
del past_key_values, out
|
|
365
|
-
gc.collect()
|
|
366
|
-
empty_cache()
|
|
367
|
-
|
|
368
|
-
|
|
369
87
|
def _get_token_from_logits(
|
|
370
88
|
req: InferenceRequest, i: int, logits, temperature, repetition_penalty, top_p, top_k
|
|
371
89
|
):
|
|
@@ -680,6 +398,7 @@ def _batch_inference_one_step_internal(
|
|
|
680
398
|
output = output.strip("�")
|
|
681
399
|
output = output[r.last_output_length :]
|
|
682
400
|
r.last_output_length += len(output)
|
|
401
|
+
r.outputs.append(output)
|
|
683
402
|
|
|
684
403
|
completion_chunk = generate_completion_chunk(
|
|
685
404
|
chunk_text=output,
|
|
@@ -704,6 +423,7 @@ def _batch_inference_one_step_internal(
|
|
|
704
423
|
)
|
|
705
424
|
r.completion.append(completion_chunk)
|
|
706
425
|
r.completion.append(eos_flag)
|
|
426
|
+
r.outputs.append(eos_flag)
|
|
707
427
|
|
|
708
428
|
# last round, handle stream result
|
|
709
429
|
# append usage information when enable `include_usage` for OPENAI API compatibility
|
xinference/model/llm/utils.py
CHANGED
|
@@ -386,8 +386,8 @@ class ChatModelMixin:
|
|
|
386
386
|
return result
|
|
387
387
|
|
|
388
388
|
@classmethod
|
|
389
|
-
def _tool_calls_completion_chunk(cls, model_family, model_uid, c):
|
|
390
|
-
_id = str(uuid.uuid4())
|
|
389
|
+
def _tool_calls_completion_chunk(cls, model_family, model_uid, c, chunk_id=None):
|
|
390
|
+
_id = chunk_id if chunk_id is not None else str(uuid.uuid4())
|
|
391
391
|
tool_result = cls._eval_tool_arguments(model_family, c)
|
|
392
392
|
tool_calls = []
|
|
393
393
|
failed_contents = []
|
|
@@ -717,11 +717,26 @@ class VLLMVisionModel(VLLMModel, ChatModelMixin):
|
|
|
717
717
|
def match(
|
|
718
718
|
cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
|
|
719
719
|
) -> bool:
|
|
720
|
-
if
|
|
720
|
+
if not cls._has_cuda_device():
|
|
721
|
+
return False
|
|
722
|
+
if not cls._is_linux():
|
|
723
|
+
return False
|
|
724
|
+
if llm_spec.model_format not in ["pytorch", "gptq", "awq", "fp8"]:
|
|
721
725
|
return False
|
|
722
726
|
if llm_spec.model_format == "pytorch":
|
|
723
727
|
if quantization != "none" and not (quantization is None):
|
|
724
728
|
return False
|
|
729
|
+
if llm_spec.model_format == "awq":
|
|
730
|
+
# Currently, only 4-bit weight quantization is supported for AWQ, but got 8 bits.
|
|
731
|
+
if "4" not in quantization:
|
|
732
|
+
return False
|
|
733
|
+
if llm_spec.model_format == "gptq":
|
|
734
|
+
if VLLM_INSTALLED and vllm.__version__ >= "0.3.3":
|
|
735
|
+
if not any(q in quantization for q in ("3", "4", "8")):
|
|
736
|
+
return False
|
|
737
|
+
else:
|
|
738
|
+
if "4" not in quantization:
|
|
739
|
+
return False
|
|
725
740
|
if isinstance(llm_family, CustomLLMFamilyV1):
|
|
726
741
|
if llm_family.model_family not in VLLM_SUPPORTED_VISION_MODEL_LIST:
|
|
727
742
|
return False
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
from .datasets import Dataset, load_dataset
|
|
4
|
+
from .flux import FluxPipeline
|
|
5
|
+
from .lora import LoRALinear
|
|
6
|
+
from .sampler import FluxSampler
|
|
7
|
+
from .trainer import Trainer
|
|
8
|
+
from .utils import (
|
|
9
|
+
load_ae,
|
|
10
|
+
load_clip,
|
|
11
|
+
load_clip_tokenizer,
|
|
12
|
+
load_flow_model,
|
|
13
|
+
load_t5,
|
|
14
|
+
load_t5_tokenizer,
|
|
15
|
+
)
|