xinference 0.15.4__py3-none-any.whl → 0.16.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (67) hide show
  1. xinference/__init__.py +0 -4
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +48 -0
  4. xinference/client/restful/restful_client.py +19 -0
  5. xinference/constants.py +4 -4
  6. xinference/core/chat_interface.py +5 -1
  7. xinference/core/image_interface.py +5 -1
  8. xinference/core/model.py +195 -34
  9. xinference/core/scheduler.py +10 -7
  10. xinference/core/utils.py +9 -0
  11. xinference/model/__init__.py +4 -0
  12. xinference/model/audio/chattts.py +25 -14
  13. xinference/model/audio/model_spec.json +1 -1
  14. xinference/model/audio/model_spec_modelscope.json +1 -1
  15. xinference/model/embedding/model_spec.json +1 -1
  16. xinference/model/image/core.py +59 -4
  17. xinference/model/image/model_spec.json +24 -3
  18. xinference/model/image/model_spec_modelscope.json +25 -3
  19. xinference/model/image/ocr/__init__.py +13 -0
  20. xinference/model/image/ocr/got_ocr2.py +76 -0
  21. xinference/model/image/scheduler/__init__.py +13 -0
  22. xinference/model/image/scheduler/flux.py +533 -0
  23. xinference/model/image/stable_diffusion/core.py +8 -34
  24. xinference/model/image/stable_diffusion/mlx.py +221 -0
  25. xinference/model/image/utils.py +39 -3
  26. xinference/model/llm/__init__.py +2 -0
  27. xinference/model/llm/llm_family.json +178 -1
  28. xinference/model/llm/llm_family_modelscope.json +119 -0
  29. xinference/model/llm/transformers/chatglm.py +104 -0
  30. xinference/model/llm/transformers/core.py +37 -111
  31. xinference/model/llm/transformers/deepseek_v2.py +0 -226
  32. xinference/model/llm/transformers/internlm2.py +3 -95
  33. xinference/model/llm/transformers/opt.py +68 -0
  34. xinference/model/llm/transformers/utils.py +4 -284
  35. xinference/model/llm/utils.py +2 -2
  36. xinference/model/llm/vllm/core.py +16 -1
  37. xinference/thirdparty/mlx/__init__.py +13 -0
  38. xinference/thirdparty/mlx/flux/__init__.py +15 -0
  39. xinference/thirdparty/mlx/flux/autoencoder.py +357 -0
  40. xinference/thirdparty/mlx/flux/clip.py +154 -0
  41. xinference/thirdparty/mlx/flux/datasets.py +75 -0
  42. xinference/thirdparty/mlx/flux/flux.py +247 -0
  43. xinference/thirdparty/mlx/flux/layers.py +302 -0
  44. xinference/thirdparty/mlx/flux/lora.py +76 -0
  45. xinference/thirdparty/mlx/flux/model.py +134 -0
  46. xinference/thirdparty/mlx/flux/sampler.py +56 -0
  47. xinference/thirdparty/mlx/flux/t5.py +244 -0
  48. xinference/thirdparty/mlx/flux/tokenizers.py +185 -0
  49. xinference/thirdparty/mlx/flux/trainer.py +98 -0
  50. xinference/thirdparty/mlx/flux/utils.py +179 -0
  51. xinference/utils.py +2 -3
  52. xinference/web/ui/build/asset-manifest.json +3 -3
  53. xinference/web/ui/build/index.html +1 -1
  54. xinference/web/ui/build/static/js/{main.e51a356d.js → main.b76aeeb7.js} +3 -3
  55. xinference/web/ui/build/static/js/main.b76aeeb7.js.map +1 -0
  56. xinference/web/ui/node_modules/.cache/babel-loader/32ea2c04cf0bba2761b4883d2c40cc259952c94d2d6bb774e510963ca37aac0a.json +1 -0
  57. xinference/web/ui/node_modules/.cache/babel-loader/331312668fa8bd3d7401818f4a25fa98135d7f61371cd6bfff78b18cf4fbdd92.json +1 -0
  58. {xinference-0.15.4.dist-info → xinference-0.16.1.dist-info}/METADATA +49 -10
  59. {xinference-0.15.4.dist-info → xinference-0.16.1.dist-info}/RECORD +64 -44
  60. xinference/web/ui/build/static/js/main.e51a356d.js.map +0 -1
  61. xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +0 -1
  62. xinference/web/ui/node_modules/.cache/babel-loader/4385c1095eefbff0a8ec3b2964ba6e5a66a05ab31be721483ca2f43e2a91f6ff.json +0 -1
  63. /xinference/web/ui/build/static/js/{main.e51a356d.js.LICENSE.txt → main.b76aeeb7.js.LICENSE.txt} +0 -0
  64. {xinference-0.15.4.dist-info → xinference-0.16.1.dist-info}/LICENSE +0 -0
  65. {xinference-0.15.4.dist-info → xinference-0.16.1.dist-info}/WHEEL +0 -0
  66. {xinference-0.15.4.dist-info → xinference-0.16.1.dist-info}/entry_points.txt +0 -0
  67. {xinference-0.15.4.dist-info → xinference-0.16.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,68 @@
1
+ # Copyright 2022-2024 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from builtins import classmethod
15
+ from typing import List, Optional
16
+
17
+ from ....core.scheduler import InferenceRequest
18
+ from ....types import LoRA
19
+ from ..llm_family import LLMFamilyV1, LLMSpecV1
20
+ from .core import PytorchModel, PytorchModelConfig
21
+
22
+
23
+ class OptPytorchModel(PytorchModel):
24
+ def __init__(
25
+ self,
26
+ model_uid: str,
27
+ model_family: "LLMFamilyV1",
28
+ model_spec: "LLMSpecV1",
29
+ quantization: str,
30
+ model_path: str,
31
+ pytorch_model_config: Optional[PytorchModelConfig] = None,
32
+ peft_model: Optional[List[LoRA]] = None,
33
+ ):
34
+ super().__init__(
35
+ model_uid,
36
+ model_family,
37
+ model_spec,
38
+ quantization,
39
+ model_path,
40
+ pytorch_model_config=pytorch_model_config,
41
+ peft_model=peft_model,
42
+ )
43
+
44
+ @classmethod
45
+ def match(
46
+ cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
47
+ ) -> bool:
48
+ if llm_spec.model_format != "pytorch":
49
+ return False
50
+ model_family = llm_family.model_family or llm_family.model_name
51
+ if model_family != "opt":
52
+ return False
53
+ return True
54
+
55
+ def build_prefill_position_ids(
56
+ self, batch_size: int, seq_length: int, reqs: List[InferenceRequest]
57
+ ):
58
+ """
59
+ Mainly for UT.
60
+ Transformers code in `main` branch supports `position_ids` parameter (https://github.com/huggingface/transformers/blob/main/src/transformers/models/opt/modeling_opt.py#L1076),
61
+ while in release branch, it doesn't (https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/opt/modeling_opt.py#L886).
62
+ """
63
+ return None
64
+
65
+ def build_decode_position_ids(
66
+ self, batch_size: int, seq_length: int, reqs: List[InferenceRequest]
67
+ ):
68
+ return None
@@ -11,14 +11,13 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+
14
15
  import asyncio
15
16
  import functools
16
- import gc
17
17
  import logging
18
18
  import os
19
19
  import time
20
- import uuid
21
- from typing import TYPE_CHECKING, Dict, Iterable, Iterator, List, Optional, Tuple
20
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
22
21
 
23
22
  import torch
24
23
  from transformers.cache_utils import DynamicCache
@@ -46,20 +45,6 @@ if TYPE_CHECKING:
46
45
  logger = logging.getLogger(__name__)
47
46
 
48
47
 
49
- def is_sentence_complete(output: str):
50
- """Check whether the output is a complete sentence."""
51
- end_symbols = (".", "?", "!", "...", "。", "?", "!", "…", '"', "'", "”")
52
- return output.endswith(end_symbols)
53
-
54
-
55
- def is_partial_stop(output: str, stop_str: str):
56
- """Check whether the output contains a partial stop str."""
57
- for i in range(0, min(len(output), len(stop_str))):
58
- if stop_str.startswith(output[-i:]):
59
- return True
60
- return False
61
-
62
-
63
48
  def get_context_length(config) -> int:
64
49
  """Get the context length of a model from a huggingface model config."""
65
50
  if (
@@ -99,273 +84,6 @@ def prepare_logits_processor(
99
84
  return processor_list
100
85
 
101
86
 
102
- @torch.inference_mode()
103
- def generate_stream(
104
- model_uid,
105
- model,
106
- tokenizer,
107
- prompt,
108
- device,
109
- generate_config,
110
- judge_sent_end=False,
111
- ) -> Iterator[Tuple[CompletionChunk, CompletionUsage]]:
112
- context_len = get_context_length(model.config)
113
- stream_interval = generate_config.get("stream_interval", 2)
114
- stream = generate_config.get("stream", False)
115
- stream_options = generate_config.pop("stream_options", None)
116
- include_usage = (
117
- stream_options["include_usage"] if isinstance(stream_options, dict) else False
118
- )
119
-
120
- len_prompt = len(prompt)
121
-
122
- temperature = float(generate_config.get("temperature", 1.0))
123
- repetition_penalty = float(generate_config.get("repetition_penalty", 1.0))
124
- top_p = float(generate_config.get("top_p", 1.0))
125
- top_k = int(generate_config.get("top_k", -1)) # -1 means disable
126
- max_new_tokens = int(generate_config.get("max_tokens", max_tokens_field.default))
127
- echo = bool(generate_config.get("echo", False))
128
- stop_str = generate_config.get("stop", None)
129
- stop_token_ids = generate_config.get("stop_token_ids", None) or []
130
- if tokenizer.eos_token_id not in stop_token_ids:
131
- stop_token_ids.append(tokenizer.eos_token_id)
132
- chunk_id = str(uuid.uuid4())
133
-
134
- logits_processor = prepare_logits_processor(
135
- temperature, repetition_penalty, top_p, top_k
136
- )
137
-
138
- if ".modeling_qwen." in str(type(model)).lower():
139
- # TODO: hacky
140
- input_ids = tokenizer(prompt, allowed_special="all").input_ids
141
- else:
142
- input_ids = tokenizer(prompt).input_ids
143
- output_ids = list(input_ids)
144
-
145
- if model.config.is_encoder_decoder:
146
- max_src_len = context_len
147
- else:
148
- max_src_len = context_len - max_new_tokens - 8
149
- if max_src_len < 0:
150
- raise ValueError("Max tokens exceeds model's max length")
151
-
152
- input_ids = input_ids[-max_src_len:]
153
- input_echo_len = len(input_ids)
154
-
155
- if model.config.is_encoder_decoder:
156
- encoder_output = model.encoder(
157
- input_ids=torch.as_tensor([input_ids], device=device)
158
- )[0]
159
- start_ids = torch.as_tensor(
160
- [[model.generation_config.decoder_start_token_id]],
161
- dtype=torch.int64,
162
- device=device,
163
- )
164
-
165
- start = time.time()
166
- past_key_values = out = None
167
- sent_interrupt = False
168
- token = None
169
- last_output_length = 0
170
- for i in range(max_new_tokens):
171
- if i == 0:
172
- if model.config.is_encoder_decoder:
173
- out = model.decoder(
174
- input_ids=start_ids,
175
- encoder_hidden_states=encoder_output,
176
- use_cache=True,
177
- )
178
- logits = model.lm_head(out[0])
179
- else:
180
- out = model(torch.as_tensor([input_ids], device=device), use_cache=True)
181
- logits = out.logits
182
- past_key_values = out.past_key_values
183
- else:
184
- if model.config.is_encoder_decoder:
185
- out = model.decoder(
186
- input_ids=torch.as_tensor(
187
- [[token] if not sent_interrupt else output_ids], device=device
188
- ),
189
- encoder_hidden_states=encoder_output,
190
- use_cache=True,
191
- past_key_values=past_key_values if not sent_interrupt else None,
192
- )
193
- sent_interrupt = False
194
-
195
- logits = model.lm_head(out[0])
196
- else:
197
- out = model(
198
- input_ids=torch.as_tensor(
199
- [[token] if not sent_interrupt else output_ids], device=device
200
- ),
201
- use_cache=True,
202
- past_key_values=past_key_values if not sent_interrupt else None,
203
- )
204
- sent_interrupt = False
205
- logits = out.logits
206
- past_key_values = out.past_key_values
207
-
208
- if logits_processor:
209
- if repetition_penalty > 1.0:
210
- tmp_output_ids = torch.as_tensor([output_ids], device=logits.device)
211
- else:
212
- tmp_output_ids = None
213
- last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0]
214
- else:
215
- last_token_logits = logits[0, -1, :]
216
-
217
- if device == "mps":
218
- # Switch to CPU by avoiding some bugs in mps backend.
219
- last_token_logits = last_token_logits.float().to("cpu")
220
-
221
- if temperature < 1e-5 or top_p < 1e-8: # greedy
222
- _, indices = torch.topk(last_token_logits, 2)
223
- tokens = [int(index) for index in indices.tolist()]
224
- else:
225
- probs = torch.softmax(last_token_logits, dim=-1)
226
- indices = torch.multinomial(probs, num_samples=2)
227
- tokens = [int(token) for token in indices.tolist()]
228
- token = tokens[0]
229
- output_ids.append(token)
230
-
231
- if token in stop_token_ids:
232
- stopped = True
233
- else:
234
- stopped = False
235
-
236
- if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped:
237
- if echo:
238
- tmp_output_ids = output_ids
239
- rfind_start = len_prompt
240
- else:
241
- tmp_output_ids = output_ids[input_echo_len:]
242
- rfind_start = 0
243
-
244
- output = tokenizer.decode(
245
- tmp_output_ids,
246
- skip_special_tokens=True,
247
- spaces_between_special_tokens=False,
248
- clean_up_tokenization_spaces=True,
249
- )
250
-
251
- # TODO: For the issue of incomplete sentences interrupting output, apply a patch and others can also modify it to a more elegant way
252
- if judge_sent_end and stopped and not is_sentence_complete(output):
253
- if len(tokens) > 1:
254
- token = tokens[1]
255
- output_ids[-1] = token
256
- else:
257
- output_ids.pop()
258
- stopped = False
259
- sent_interrupt = True
260
-
261
- partially_stopped = False
262
- if stop_str:
263
- if isinstance(stop_str, str):
264
- pos = output.rfind(stop_str, rfind_start)
265
- if pos != -1:
266
- output = output[:pos]
267
- stopped = True
268
- else:
269
- partially_stopped = is_partial_stop(output, stop_str)
270
- elif isinstance(stop_str, Iterable):
271
- for each_stop in stop_str:
272
- pos = output.rfind(each_stop, rfind_start)
273
- if pos != -1:
274
- output = output[:pos]
275
- stopped = True
276
- break
277
- else:
278
- partially_stopped = is_partial_stop(output, each_stop)
279
- if partially_stopped:
280
- break
281
- else:
282
- raise ValueError("Invalid stop field type.")
283
-
284
- if stream:
285
- output = output.strip("�")
286
- tmp_output_length = len(output)
287
- output = output[last_output_length:]
288
- last_output_length = tmp_output_length
289
-
290
- # prevent yielding partial stop sequence
291
- if not partially_stopped:
292
- completion_choice = CompletionChoice(
293
- text=output, index=0, logprobs=None, finish_reason=None
294
- )
295
- completion_chunk = CompletionChunk(
296
- id=chunk_id,
297
- object="text_completion",
298
- created=int(time.time()),
299
- model=model_uid,
300
- choices=[completion_choice],
301
- )
302
- completion_usage = CompletionUsage(
303
- prompt_tokens=input_echo_len,
304
- completion_tokens=i,
305
- total_tokens=(input_echo_len + i),
306
- )
307
-
308
- yield completion_chunk, completion_usage
309
-
310
- if stopped:
311
- break
312
-
313
- elapsed_time = time.time() - start
314
- logger.info(f"Average generation speed: {i / elapsed_time:.2f} tokens/s.")
315
-
316
- # finish stream event, which contains finish reason
317
- if stopped:
318
- finish_reason = "stop"
319
- elif i == max_new_tokens - 1:
320
- finish_reason = "length"
321
- else:
322
- finish_reason = None
323
-
324
- if stream:
325
- completion_choice = CompletionChoice(
326
- text=output, index=0, logprobs=None, finish_reason=finish_reason
327
- )
328
- else:
329
- completion_choice = CompletionChoice(
330
- text=output, index=0, logprobs=None, finish_reason=finish_reason
331
- )
332
-
333
- completion_chunk = CompletionChunk(
334
- id=chunk_id,
335
- object="text_completion",
336
- created=int(time.time()),
337
- model=model_uid,
338
- choices=[completion_choice],
339
- )
340
- completion_usage = CompletionUsage(
341
- prompt_tokens=input_echo_len,
342
- completion_tokens=i,
343
- total_tokens=(input_echo_len + i),
344
- )
345
-
346
- yield completion_chunk, completion_usage
347
-
348
- if include_usage:
349
- completion_chunk = CompletionChunk(
350
- id=chunk_id,
351
- object="text_completion",
352
- created=int(time.time()),
353
- model=model_uid,
354
- choices=[],
355
- )
356
- completion_usage = CompletionUsage(
357
- prompt_tokens=input_echo_len,
358
- completion_tokens=i,
359
- total_tokens=(input_echo_len + i),
360
- )
361
- yield completion_chunk, completion_usage
362
-
363
- # clean
364
- del past_key_values, out
365
- gc.collect()
366
- empty_cache()
367
-
368
-
369
87
  def _get_token_from_logits(
370
88
  req: InferenceRequest, i: int, logits, temperature, repetition_penalty, top_p, top_k
371
89
  ):
@@ -680,6 +398,7 @@ def _batch_inference_one_step_internal(
680
398
  output = output.strip("�")
681
399
  output = output[r.last_output_length :]
682
400
  r.last_output_length += len(output)
401
+ r.outputs.append(output)
683
402
 
684
403
  completion_chunk = generate_completion_chunk(
685
404
  chunk_text=output,
@@ -704,6 +423,7 @@ def _batch_inference_one_step_internal(
704
423
  )
705
424
  r.completion.append(completion_chunk)
706
425
  r.completion.append(eos_flag)
426
+ r.outputs.append(eos_flag)
707
427
 
708
428
  # last round, handle stream result
709
429
  # append usage information when enable `include_usage` for OPENAI API compatibility
@@ -386,8 +386,8 @@ class ChatModelMixin:
386
386
  return result
387
387
 
388
388
  @classmethod
389
- def _tool_calls_completion_chunk(cls, model_family, model_uid, c):
390
- _id = str(uuid.uuid4())
389
+ def _tool_calls_completion_chunk(cls, model_family, model_uid, c, chunk_id=None):
390
+ _id = chunk_id if chunk_id is not None else str(uuid.uuid4())
391
391
  tool_result = cls._eval_tool_arguments(model_family, c)
392
392
  tool_calls = []
393
393
  failed_contents = []
@@ -717,11 +717,26 @@ class VLLMVisionModel(VLLMModel, ChatModelMixin):
717
717
  def match(
718
718
  cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
719
719
  ) -> bool:
720
- if llm_spec.model_format != "pytorch":
720
+ if not cls._has_cuda_device():
721
+ return False
722
+ if not cls._is_linux():
723
+ return False
724
+ if llm_spec.model_format not in ["pytorch", "gptq", "awq", "fp8"]:
721
725
  return False
722
726
  if llm_spec.model_format == "pytorch":
723
727
  if quantization != "none" and not (quantization is None):
724
728
  return False
729
+ if llm_spec.model_format == "awq":
730
+ # Currently, only 4-bit weight quantization is supported for AWQ, but got 8 bits.
731
+ if "4" not in quantization:
732
+ return False
733
+ if llm_spec.model_format == "gptq":
734
+ if VLLM_INSTALLED and vllm.__version__ >= "0.3.3":
735
+ if not any(q in quantization for q in ("3", "4", "8")):
736
+ return False
737
+ else:
738
+ if "4" not in quantization:
739
+ return False
725
740
  if isinstance(llm_family, CustomLLMFamilyV1):
726
741
  if llm_family.model_family not in VLLM_SUPPORTED_VISION_MODEL_LIST:
727
742
  return False
@@ -0,0 +1,13 @@
1
+ # Copyright 2022-2023 XProbe Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,15 @@
1
+ # Copyright © 2024 Apple Inc.
2
+
3
+ from .datasets import Dataset, load_dataset
4
+ from .flux import FluxPipeline
5
+ from .lora import LoRALinear
6
+ from .sampler import FluxSampler
7
+ from .trainer import Trainer
8
+ from .utils import (
9
+ load_ae,
10
+ load_clip,
11
+ load_clip_tokenizer,
12
+ load_flow_model,
13
+ load_t5,
14
+ load_t5_tokenizer,
15
+ )