xinference 0.14.4.post1__py3-none-any.whl → 0.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (149) hide show
  1. xinference/_compat.py +51 -0
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +5 -39
  4. xinference/client/restful/restful_client.py +3 -24
  5. xinference/conftest.py +1 -1
  6. xinference/constants.py +5 -0
  7. xinference/core/cache_tracker.py +1 -1
  8. xinference/core/chat_interface.py +8 -14
  9. xinference/core/event.py +1 -1
  10. xinference/core/model.py +82 -31
  11. xinference/core/scheduler.py +37 -37
  12. xinference/core/status_guard.py +1 -1
  13. xinference/core/supervisor.py +11 -10
  14. xinference/core/utils.py +80 -22
  15. xinference/core/worker.py +17 -16
  16. xinference/deploy/cmdline.py +8 -16
  17. xinference/deploy/local.py +1 -1
  18. xinference/deploy/supervisor.py +1 -1
  19. xinference/deploy/utils.py +1 -1
  20. xinference/deploy/worker.py +1 -1
  21. xinference/model/audio/cosyvoice.py +86 -41
  22. xinference/model/embedding/core.py +52 -31
  23. xinference/model/image/stable_diffusion/core.py +18 -1
  24. xinference/model/llm/__init__.py +21 -11
  25. xinference/model/llm/llama_cpp/core.py +16 -33
  26. xinference/model/llm/llm_family.json +619 -1297
  27. xinference/model/llm/llm_family.py +31 -52
  28. xinference/model/llm/llm_family_csghub.json +18 -35
  29. xinference/model/llm/llm_family_modelscope.json +573 -1119
  30. xinference/model/llm/lmdeploy/core.py +56 -88
  31. xinference/model/llm/mlx/core.py +46 -69
  32. xinference/model/llm/sglang/core.py +33 -18
  33. xinference/model/llm/transformers/chatglm.py +167 -305
  34. xinference/model/llm/transformers/cogvlm2.py +36 -63
  35. xinference/model/llm/transformers/cogvlm2_video.py +33 -223
  36. xinference/model/llm/transformers/core.py +49 -50
  37. xinference/model/llm/transformers/deepseek_vl.py +53 -96
  38. xinference/model/llm/transformers/glm4v.py +55 -111
  39. xinference/model/llm/transformers/intern_vl.py +39 -70
  40. xinference/model/llm/transformers/internlm2.py +32 -54
  41. xinference/model/llm/transformers/minicpmv25.py +22 -55
  42. xinference/model/llm/transformers/minicpmv26.py +158 -68
  43. xinference/model/llm/transformers/omnilmm.py +5 -28
  44. xinference/model/llm/transformers/qwen2_vl.py +208 -0
  45. xinference/model/llm/transformers/qwen_vl.py +34 -86
  46. xinference/model/llm/transformers/utils.py +32 -38
  47. xinference/model/llm/transformers/yi_vl.py +32 -72
  48. xinference/model/llm/utils.py +195 -489
  49. xinference/model/llm/vllm/core.py +153 -100
  50. xinference/model/rerank/core.py +41 -8
  51. xinference/model/rerank/model_spec.json +7 -0
  52. xinference/model/rerank/model_spec_modelscope.json +7 -1
  53. xinference/model/utils.py +1 -31
  54. xinference/thirdparty/cosyvoice/bin/export_jit.py +64 -0
  55. xinference/thirdparty/cosyvoice/bin/export_trt.py +8 -0
  56. xinference/thirdparty/cosyvoice/bin/inference.py +5 -2
  57. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +38 -22
  58. xinference/thirdparty/cosyvoice/cli/model.py +139 -26
  59. xinference/thirdparty/cosyvoice/flow/flow.py +15 -9
  60. xinference/thirdparty/cosyvoice/flow/length_regulator.py +20 -1
  61. xinference/thirdparty/cosyvoice/hifigan/generator.py +8 -4
  62. xinference/thirdparty/cosyvoice/llm/llm.py +14 -13
  63. xinference/thirdparty/cosyvoice/transformer/attention.py +7 -3
  64. xinference/thirdparty/cosyvoice/transformer/decoder.py +1 -1
  65. xinference/thirdparty/cosyvoice/transformer/embedding.py +4 -3
  66. xinference/thirdparty/cosyvoice/transformer/encoder.py +4 -2
  67. xinference/thirdparty/cosyvoice/utils/common.py +36 -0
  68. xinference/thirdparty/cosyvoice/utils/file_utils.py +16 -0
  69. xinference/thirdparty/deepseek_vl/serve/assets/Kelpy-Codos.js +100 -0
  70. xinference/thirdparty/deepseek_vl/serve/assets/avatar.png +0 -0
  71. xinference/thirdparty/deepseek_vl/serve/assets/custom.css +355 -0
  72. xinference/thirdparty/deepseek_vl/serve/assets/custom.js +22 -0
  73. xinference/thirdparty/deepseek_vl/serve/assets/favicon.ico +0 -0
  74. xinference/thirdparty/deepseek_vl/serve/examples/app.png +0 -0
  75. xinference/thirdparty/deepseek_vl/serve/examples/chart.png +0 -0
  76. xinference/thirdparty/deepseek_vl/serve/examples/mirror.png +0 -0
  77. xinference/thirdparty/deepseek_vl/serve/examples/pipeline.png +0 -0
  78. xinference/thirdparty/deepseek_vl/serve/examples/puzzle.png +0 -0
  79. xinference/thirdparty/deepseek_vl/serve/examples/rap.jpeg +0 -0
  80. xinference/thirdparty/fish_speech/fish_speech/configs/base.yaml +87 -0
  81. xinference/thirdparty/fish_speech/fish_speech/configs/firefly_gan_vq.yaml +34 -0
  82. xinference/thirdparty/fish_speech/fish_speech/configs/lora/r_8_alpha_16.yaml +4 -0
  83. xinference/thirdparty/fish_speech/fish_speech/configs/text2semantic_finetune.yaml +83 -0
  84. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text-data.proto +24 -0
  85. xinference/thirdparty/fish_speech/fish_speech/i18n/README.md +27 -0
  86. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/.gitignore +114 -0
  87. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/README.md +36 -0
  88. xinference/thirdparty/fish_speech/fish_speech/webui/css/style.css +161 -0
  89. xinference/thirdparty/fish_speech/fish_speech/webui/html/footer.html +11 -0
  90. xinference/thirdparty/fish_speech/fish_speech/webui/js/animate.js +69 -0
  91. xinference/thirdparty/fish_speech/tools/sensevoice/README.md +59 -0
  92. xinference/thirdparty/matcha/VERSION +1 -0
  93. xinference/thirdparty/matcha/hifigan/LICENSE +21 -0
  94. xinference/thirdparty/matcha/hifigan/README.md +101 -0
  95. xinference/thirdparty/omnilmm/LICENSE +201 -0
  96. xinference/thirdparty/whisper/__init__.py +156 -0
  97. xinference/thirdparty/whisper/__main__.py +3 -0
  98. xinference/thirdparty/whisper/assets/gpt2.tiktoken +50256 -0
  99. xinference/thirdparty/whisper/assets/mel_filters.npz +0 -0
  100. xinference/thirdparty/whisper/assets/multilingual.tiktoken +50257 -0
  101. xinference/thirdparty/whisper/audio.py +157 -0
  102. xinference/thirdparty/whisper/decoding.py +826 -0
  103. xinference/thirdparty/whisper/model.py +314 -0
  104. xinference/thirdparty/whisper/normalizers/__init__.py +2 -0
  105. xinference/thirdparty/whisper/normalizers/basic.py +76 -0
  106. xinference/thirdparty/whisper/normalizers/english.json +1741 -0
  107. xinference/thirdparty/whisper/normalizers/english.py +550 -0
  108. xinference/thirdparty/whisper/timing.py +386 -0
  109. xinference/thirdparty/whisper/tokenizer.py +395 -0
  110. xinference/thirdparty/whisper/transcribe.py +605 -0
  111. xinference/thirdparty/whisper/triton_ops.py +109 -0
  112. xinference/thirdparty/whisper/utils.py +316 -0
  113. xinference/thirdparty/whisper/version.py +1 -0
  114. xinference/types.py +7 -49
  115. xinference/web/ui/build/asset-manifest.json +6 -6
  116. xinference/web/ui/build/index.html +1 -1
  117. xinference/web/ui/build/static/css/{main.4bafd904.css → main.632e9148.css} +2 -2
  118. xinference/web/ui/build/static/css/main.632e9148.css.map +1 -0
  119. xinference/web/ui/build/static/js/main.9cfafbd6.js +3 -0
  120. xinference/web/ui/build/static/js/{main.eb13fe95.js.LICENSE.txt → main.9cfafbd6.js.LICENSE.txt} +2 -0
  121. xinference/web/ui/build/static/js/main.9cfafbd6.js.map +1 -0
  122. xinference/web/ui/node_modules/.cache/babel-loader/01d6d198156bacbd436c51435edbd4b2cacd47a79db929105eba30f74b67d48d.json +1 -0
  123. xinference/web/ui/node_modules/.cache/babel-loader/10c69dc7a296779fcffedeff9393d832dfcb0013c36824adf623d3c518b801ff.json +1 -0
  124. xinference/web/ui/node_modules/.cache/babel-loader/59eb25f514afcc4fefd1b309d192b2455f1e0aec68a9de598ca4b2333fe2c774.json +1 -0
  125. xinference/web/ui/node_modules/.cache/babel-loader/68bede6d95bb5ef0b35bbb3ec5b8c937eaf6862c6cdbddb5ef222a7776aaf336.json +1 -0
  126. xinference/web/ui/node_modules/.cache/babel-loader/77d50223f3e734d4485cca538cb098a8c3a7a0a1a9f01f58cdda3af42fe1adf5.json +1 -0
  127. xinference/web/ui/node_modules/.cache/babel-loader/a56d5a642409a84988891089c98ca28ad0546432dfbae8aaa51bc5a280e1cdd2.json +1 -0
  128. xinference/web/ui/node_modules/.cache/babel-loader/d9ff696a3e3471f01b46c63d18af32e491eb5dc0e43cb30202c96871466df57f.json +1 -0
  129. xinference/web/ui/node_modules/.cache/babel-loader/f5039ddbeb815c51491a1989532006b96fc3ae49c6c60e3c097f875b4ae915ae.json +1 -0
  130. xinference/web/ui/node_modules/.package-lock.json +37 -0
  131. xinference/web/ui/node_modules/a-sync-waterfall/package.json +21 -0
  132. xinference/web/ui/node_modules/nunjucks/node_modules/commander/package.json +48 -0
  133. xinference/web/ui/node_modules/nunjucks/package.json +112 -0
  134. xinference/web/ui/package-lock.json +38 -0
  135. xinference/web/ui/package.json +1 -0
  136. {xinference-0.14.4.post1.dist-info → xinference-0.15.0.dist-info}/METADATA +8 -8
  137. {xinference-0.14.4.post1.dist-info → xinference-0.15.0.dist-info}/RECORD +141 -87
  138. xinference/model/llm/transformers/llama_2.py +0 -108
  139. xinference/web/ui/build/static/css/main.4bafd904.css.map +0 -1
  140. xinference/web/ui/build/static/js/main.eb13fe95.js +0 -3
  141. xinference/web/ui/build/static/js/main.eb13fe95.js.map +0 -1
  142. xinference/web/ui/node_modules/.cache/babel-loader/0b11a5339468c13b2d31ac085e7effe4303259b2071abd46a0a8eb8529233a5e.json +0 -1
  143. xinference/web/ui/node_modules/.cache/babel-loader/213b5913e164773c2b0567455377765715f5f07225fbac77ad8e1e9dc9648a47.json +0 -1
  144. xinference/web/ui/node_modules/.cache/babel-loader/5c26a23b5eacf5b752a08531577ae3840bb247745ef9a39583dc2d05ba93a82a.json +0 -1
  145. xinference/web/ui/node_modules/.cache/babel-loader/978b57d1a04a701bc3fcfebc511f5f274eed6ed7eade67f6fb76c27d5fd9ecc8.json +0 -1
  146. {xinference-0.14.4.post1.dist-info → xinference-0.15.0.dist-info}/LICENSE +0 -0
  147. {xinference-0.14.4.post1.dist-info → xinference-0.15.0.dist-info}/WHEEL +0 -0
  148. {xinference-0.14.4.post1.dist-info → xinference-0.15.0.dist-info}/entry_points.txt +0 -0
  149. {xinference-0.14.4.post1.dist-info → xinference-0.15.0.dist-info}/top_level.txt +0 -0
@@ -53,7 +53,82 @@ class CosyVoiceModel:
53
53
 
54
54
  from cosyvoice.cli.cosyvoice import CosyVoice
55
55
 
56
- self._model = CosyVoice(self._model_path)
56
+ self._model = CosyVoice(
57
+ self._model_path, load_jit=self._kwargs.get("load_jit", False)
58
+ )
59
+
60
+ def _speech_handle(
61
+ self,
62
+ stream,
63
+ input,
64
+ instruct_text,
65
+ prompt_speech,
66
+ prompt_text,
67
+ voice,
68
+ response_format,
69
+ ):
70
+ if prompt_speech:
71
+ from cosyvoice.utils.file_utils import load_wav
72
+
73
+ with io.BytesIO(prompt_speech) as prompt_speech_io:
74
+ prompt_speech_16k = load_wav(prompt_speech_io, 16000)
75
+
76
+ if prompt_text:
77
+ logger.info("CosyVoice inference_zero_shot")
78
+ output = self._model.inference_zero_shot(
79
+ input, prompt_text, prompt_speech_16k, stream=stream
80
+ )
81
+ else:
82
+ logger.info("CosyVoice inference_cross_lingual")
83
+ output = self._model.inference_cross_lingual(
84
+ input, prompt_speech_16k, stream=stream
85
+ )
86
+ else:
87
+ available_speakers = self._model.list_avaliable_spks()
88
+ if not voice:
89
+ voice = available_speakers[0]
90
+ else:
91
+ assert (
92
+ voice in available_speakers
93
+ ), f"Invalid voice {voice}, CosyVoice available speakers: {available_speakers}"
94
+ if instruct_text:
95
+ logger.info("CosyVoice inference_instruct")
96
+ output = self._model.inference_instruct(
97
+ input, voice, instruct_text=instruct_text, stream=stream
98
+ )
99
+ else:
100
+ logger.info("CosyVoice inference_sft")
101
+ output = self._model.inference_sft(input, voice, stream=stream)
102
+
103
+ import torch
104
+ import torchaudio
105
+
106
+ def _generator_stream():
107
+ with BytesIO() as out:
108
+ writer = torchaudio.io.StreamWriter(out, format=response_format)
109
+ writer.add_audio_stream(sample_rate=22050, num_channels=1)
110
+ i = 0
111
+ last_pos = 0
112
+ with writer.open():
113
+ for chunk in output:
114
+ chunk = chunk["tts_speech"]
115
+ trans_chunk = torch.transpose(chunk, 0, 1)
116
+ writer.write_audio_chunk(i, trans_chunk)
117
+ new_last_pos = out.tell()
118
+ if new_last_pos != last_pos:
119
+ out.seek(last_pos)
120
+ encoded_bytes = out.read()
121
+ yield encoded_bytes
122
+ last_pos = new_last_pos
123
+
124
+ def _generator_block():
125
+ chunk = next(output)
126
+ assert isinstance(chunk, dict), "Expected data to be of type dict"
127
+ with BytesIO() as out:
128
+ torchaudio.save(out, chunk["tts_speech"], 22050, format=response_format)
129
+ return out.getvalue()
130
+
131
+ return _generator_stream() if stream else _generator_block()
57
132
 
58
133
  def speech(
59
134
  self,
@@ -64,12 +139,6 @@ class CosyVoiceModel:
64
139
  stream: bool = False,
65
140
  **kwargs,
66
141
  ):
67
- if stream:
68
- raise Exception("CosyVoiceModel does not support stream.")
69
-
70
- import torchaudio
71
- from cosyvoice.utils.file_utils import load_wav
72
-
73
142
  prompt_speech: Optional[bytes] = kwargs.pop("prompt_speech", None)
74
143
  prompt_text: Optional[str] = kwargs.pop("prompt_text", None)
75
144
  instruct_text: Optional[str] = kwargs.pop("instruct_text", None)
@@ -103,39 +172,15 @@ class CosyVoiceModel:
103
172
  ), "CosyVoice model does not support instruct_text"
104
173
 
105
174
  assert self._model is not None
175
+
106
176
  set_all_random_seed(seed)
107
- if prompt_speech:
108
- assert not voice, "voice can't be set with prompt speech."
109
- with io.BytesIO(prompt_speech) as prompt_speech_io:
110
- prompt_speech_16k = load_wav(prompt_speech_io, 16000)
111
- if prompt_text:
112
- logger.info("CosyVoice inference_zero_shot")
113
- output = self._model.inference_zero_shot(
114
- input, prompt_text, prompt_speech_16k
115
- )
116
- else:
117
- logger.info("CosyVoice inference_cross_lingual")
118
- output = self._model.inference_cross_lingual(
119
- input, prompt_speech_16k
120
- )
121
- else:
122
- available_speakers = self._model.list_avaliable_spks()
123
- if not voice:
124
- voice = available_speakers[0]
125
- else:
126
- assert (
127
- voice in available_speakers
128
- ), f"Invalid voice {voice}, CosyVoice available speakers: {available_speakers}"
129
- if instruct_text:
130
- logger.info("CosyVoice inference_instruct")
131
- output = self._model.inference_instruct(
132
- input, voice, instruct_text=instruct_text
133
- )
134
- else:
135
- logger.info("CosyVoice inference_sft")
136
- output = self._model.inference_sft(input, voice)
137
177
 
138
- # Save the generated audio
139
- with BytesIO() as out:
140
- torchaudio.save(out, output["tts_speech"], 22050, format=response_format)
141
- return out.getvalue()
178
+ return self._speech_handle(
179
+ stream,
180
+ input,
181
+ instruct_text,
182
+ prompt_speech,
183
+ prompt_text,
184
+ voice,
185
+ response_format,
186
+ )
@@ -19,6 +19,7 @@ from collections import defaultdict
19
19
  from typing import Dict, List, Literal, Optional, Tuple, Union, no_type_check
20
20
 
21
21
  import numpy as np
22
+ import torch
22
23
 
23
24
  from ...device_utils import empty_cache
24
25
  from ...types import Embedding, EmbeddingData, EmbeddingUsage
@@ -34,7 +35,11 @@ EMBEDDING_MODEL_DESCRIPTIONS: Dict[str, List[Dict]] = defaultdict(list)
34
35
  EMBEDDING_EMPTY_CACHE_COUNT = int(
35
36
  os.getenv("XINFERENCE_EMBEDDING_EMPTY_CACHE_COUNT", "10")
36
37
  )
38
+ EMBEDDING_EMPTY_CACHE_TOKENS = int(
39
+ os.getenv("XINFERENCE_EMBEDDING_EMPTY_CACHE_TOKENS", "8192")
40
+ )
37
41
  assert EMBEDDING_EMPTY_CACHE_COUNT > 0
42
+ assert EMBEDDING_EMPTY_CACHE_TOKENS > 0
38
43
 
39
44
 
40
45
  def get_embedding_model_descriptions():
@@ -149,6 +154,25 @@ class EmbeddingModel:
149
154
  def to(self, *args, **kwargs):
150
155
  pass
151
156
 
157
+ torch_dtype = None
158
+ if torch_dtype_str := self._kwargs.get("torch_dtype"):
159
+ try:
160
+ torch_dtype = getattr(torch, torch_dtype_str)
161
+ if torch_dtype not in [
162
+ torch.float16,
163
+ torch.float32,
164
+ torch.bfloat16,
165
+ ]:
166
+ logger.warning(
167
+ f"Load embedding model with unsupported torch dtype : {torch_dtype_str}. Using default torch dtype: fp32."
168
+ )
169
+ torch_dtype = torch.float32
170
+ except AttributeError:
171
+ logger.warning(
172
+ f"Load embedding model with unknown torch dtype '{torch_dtype_str}'. Using default torch dtype: fp32."
173
+ )
174
+ torch_dtype = torch.float32
175
+
152
176
  from ..utils import patch_trust_remote_code
153
177
 
154
178
  patch_trust_remote_code()
@@ -156,42 +180,21 @@ class EmbeddingModel:
156
180
  "gte" in self._model_spec.model_name.lower()
157
181
  and "qwen2" in self._model_spec.model_name.lower()
158
182
  ):
159
- import torch
160
-
161
- torch_dtype_str = self._kwargs.get("torch_dtype")
162
- if torch_dtype_str is not None:
163
- try:
164
- torch_dtype = getattr(torch, torch_dtype_str)
165
- if torch_dtype not in [
166
- torch.float16,
167
- torch.float32,
168
- torch.bfloat16,
169
- ]:
170
- logger.warning(
171
- f"Load embedding model with unsupported torch dtype : {torch_dtype_str}. Using default torch dtype: fp32."
172
- )
173
- torch_dtype = torch.float32
174
- except AttributeError:
175
- logger.warning(
176
- f"Load embedding model with unknown torch dtype '{torch_dtype_str}'. Using default torch dtype: fp32."
177
- )
178
- torch_dtype = torch.float32
179
- else:
180
- torch_dtype = "auto"
183
+ model_kwargs = {"device_map": "auto"}
184
+ if torch_dtype:
185
+ model_kwargs["torch_dtype"] = torch_dtype
181
186
  self._model = XSentenceTransformer(
182
187
  self._model_path,
183
188
  device=self._device,
184
- model_kwargs={"device_map": "auto", "torch_dtype": torch_dtype},
189
+ model_kwargs=model_kwargs,
185
190
  )
186
191
  else:
187
- self._model = SentenceTransformer(self._model_path, device=self._device)
192
+ model_kwargs = {"torch_dtype": torch_dtype} if torch_dtype else None
193
+ self._model = SentenceTransformer(
194
+ self._model_path, device=self._device, model_kwargs=model_kwargs
195
+ )
188
196
 
189
197
  def create_embedding(self, sentences: Union[str, List[str]], **kwargs):
190
- self._counter += 1
191
- if self._counter % EMBEDDING_EMPTY_CACHE_COUNT == 0:
192
- logger.debug("Empty embedding cache.")
193
- gc.collect()
194
- empty_cache()
195
198
  from sentence_transformers import SentenceTransformer
196
199
 
197
200
  kwargs.setdefault("normalize_embeddings", True)
@@ -309,7 +312,9 @@ class EmbeddingModel:
309
312
  features = model.tokenize(sentences_batch)
310
313
  features = batch_to_device(features, device)
311
314
  features.update(extra_features)
312
- all_token_nums += sum([len(f) for f in features])
315
+ # when batching, the attention mask 1 means there is a token
316
+ # thus we just sum up it to get the total number of tokens
317
+ all_token_nums += features["attention_mask"].sum().item()
313
318
 
314
319
  with torch.no_grad():
315
320
  out_features = model.forward(features)
@@ -393,13 +398,29 @@ class EmbeddingModel:
393
398
  usage = EmbeddingUsage(
394
399
  prompt_tokens=all_token_nums, total_tokens=all_token_nums
395
400
  )
396
- return Embedding(
401
+ result = Embedding(
397
402
  object="list",
398
403
  model=self._model_uid,
399
404
  data=embedding_list,
400
405
  usage=usage,
401
406
  )
402
407
 
408
+ # clean cache if possible
409
+ self._counter += 1
410
+ if (
411
+ self._counter % EMBEDDING_EMPTY_CACHE_COUNT == 0
412
+ or all_token_nums >= EMBEDDING_EMPTY_CACHE_TOKENS
413
+ ):
414
+ logger.debug(
415
+ "Empty embedding cache, calling count %s, all_token_nums %s",
416
+ self._counter,
417
+ all_token_nums,
418
+ )
419
+ gc.collect()
420
+ empty_cache()
421
+
422
+ return result
423
+
403
424
 
404
425
  def match_embedding(
405
426
  model_name: str,
@@ -172,10 +172,21 @@ class DiffusionModel:
172
172
  "stable diffusion args: %s",
173
173
  kwargs,
174
174
  )
175
+ is_padded = kwargs.pop("is_padded", None)
176
+ origin_size = kwargs.pop("origin_size", None)
177
+
175
178
  model = model if model is not None else self._model
176
179
  assert callable(model)
177
180
  images = model(**kwargs).images
178
181
 
182
+ # revert padding if padded
183
+ if is_padded and origin_size:
184
+ new_images = []
185
+ x, y = origin_size
186
+ for img in images:
187
+ new_images.append(img.crop((0, 0, x, y)))
188
+ images = new_images
189
+
179
190
  # clean cache
180
191
  gc.collect()
181
192
  empty_cache()
@@ -198,7 +209,7 @@ class DiffusionModel:
198
209
 
199
210
  with ThreadPoolExecutor() as executor:
200
211
  results = list(map(partial(executor.submit, _gen_base64_image), images)) # type: ignore
201
- image_list = [Image(url=None, b64_json=s.result()) for s in results]
212
+ image_list = [Image(url=None, b64_json=s.result()) for s in results] # type: ignore
202
213
  return ImageList(created=int(time.time()), data=image_list)
203
214
  else:
204
215
  raise ValueError(f"Unsupported response format: {response_format}")
@@ -265,6 +276,9 @@ class DiffusionModel:
265
276
  if padding_image_to_multiple := kwargs.pop("padding_image_to_multiple", None):
266
277
  # Model like SD3 image to image requires image's height and width is times of 16
267
278
  # padding the image if specified
279
+ origin_x, origin_y = image.size
280
+ kwargs["origin_size"] = (origin_x, origin_y)
281
+ kwargs["is_padded"] = True
268
282
  image = self.pad_to_multiple(image, multiple=int(padding_image_to_multiple))
269
283
 
270
284
  if size:
@@ -318,6 +332,9 @@ class DiffusionModel:
318
332
  if padding_image_to_multiple := kwargs.pop("padding_image_to_multiple", None):
319
333
  # Model like SD3 inpainting requires image's height and width is times of 16
320
334
  # padding the image if specified
335
+ origin_x, origin_y = image.size
336
+ kwargs["origin_size"] = (origin_x, origin_y)
337
+ kwargs["is_padded"] = True
321
338
  image = self.pad_to_multiple(image, multiple=int(padding_image_to_multiple))
322
339
  mask_image = self.pad_to_multiple(
323
340
  mask_image, multiple=int(padding_image_to_multiple)
@@ -45,7 +45,6 @@ from .llm_family import (
45
45
  LLMFamilyV1,
46
46
  LLMSpecV1,
47
47
  MLXLLMSpecV1,
48
- PromptStyleV1,
49
48
  PytorchLLMSpecV1,
50
49
  get_cache_status,
51
50
  get_user_defined_llm_families,
@@ -141,9 +140,9 @@ def _install():
141
140
  from .transformers.glm4v import Glm4VModel
142
141
  from .transformers.intern_vl import InternVLChatModel
143
142
  from .transformers.internlm2 import Internlm2PytorchChatModel
144
- from .transformers.llama_2 import LlamaPytorchChatModel, LlamaPytorchModel
145
143
  from .transformers.minicpmv25 import MiniCPMV25Model
146
144
  from .transformers.minicpmv26 import MiniCPMV26Model
145
+ from .transformers.qwen2_vl import Qwen2VLChatModel
147
146
  from .transformers.qwen_vl import QwenVLChatModel
148
147
  from .transformers.yi_vl import YiVLChatModel
149
148
  from .vllm.core import VLLMChatModel, VLLMModel, VLLMVisionModel
@@ -170,11 +169,10 @@ def _install():
170
169
  TRANSFORMERS_CLASSES.extend(
171
170
  [
172
171
  ChatglmPytorchChatModel,
173
- LlamaPytorchModel,
174
- LlamaPytorchChatModel,
175
172
  PytorchChatModel,
176
173
  Internlm2PytorchChatModel,
177
174
  QwenVLChatModel,
175
+ Qwen2VLChatModel,
178
176
  YiVLChatModel,
179
177
  DeepSeekVLChatModel,
180
178
  InternVLChatModel,
@@ -204,13 +202,17 @@ def _install():
204
202
  model_spec = LLMFamilyV1.parse_obj(json_obj)
205
203
  BUILTIN_LLM_FAMILIES.append(model_spec)
206
204
 
207
- # register prompt style
205
+ # register chat_template
208
206
  if "chat" in model_spec.model_ability and isinstance(
209
- model_spec.prompt_style, PromptStyleV1
207
+ model_spec.chat_template, str
210
208
  ):
211
209
  # note that the key is the model name,
212
210
  # since there are multiple representations of the same prompt style name in json.
213
- BUILTIN_LLM_PROMPT_STYLE[model_spec.model_name] = model_spec.prompt_style
211
+ BUILTIN_LLM_PROMPT_STYLE[model_spec.model_name] = {
212
+ "chat_template": model_spec.chat_template,
213
+ "stop_token_ids": model_spec.stop_token_ids,
214
+ "stop": model_spec.stop,
215
+ }
214
216
  # register model family
215
217
  if "chat" in model_spec.model_ability:
216
218
  BUILTIN_LLM_MODEL_CHAT_FAMILIES.add(model_spec.model_name)
@@ -230,10 +232,14 @@ def _install():
230
232
  # if duplicated with huggingface json, keep it as the huggingface style
231
233
  if (
232
234
  "chat" in model_spec.model_ability
233
- and isinstance(model_spec.prompt_style, PromptStyleV1)
235
+ and isinstance(model_spec.chat_template, str)
234
236
  and model_spec.model_name not in BUILTIN_LLM_PROMPT_STYLE
235
237
  ):
236
- BUILTIN_LLM_PROMPT_STYLE[model_spec.model_name] = model_spec.prompt_style
238
+ BUILTIN_LLM_PROMPT_STYLE[model_spec.model_name] = {
239
+ "chat_template": model_spec.chat_template,
240
+ "stop_token_ids": model_spec.stop_token_ids,
241
+ "stop": model_spec.stop,
242
+ }
237
243
  # register model family
238
244
  if "chat" in model_spec.model_ability:
239
245
  BUILTIN_LLM_MODEL_CHAT_FAMILIES.add(model_spec.model_name)
@@ -253,10 +259,14 @@ def _install():
253
259
  # if duplicated with huggingface json, keep it as the huggingface style
254
260
  if (
255
261
  "chat" in model_spec.model_ability
256
- and isinstance(model_spec.prompt_style, PromptStyleV1)
262
+ and isinstance(model_spec.chat_template, str)
257
263
  and model_spec.model_name not in BUILTIN_LLM_PROMPT_STYLE
258
264
  ):
259
- BUILTIN_LLM_PROMPT_STYLE[model_spec.model_name] = model_spec.prompt_style
265
+ BUILTIN_LLM_PROMPT_STYLE[model_spec.model_name] = {
266
+ "chat_template": model_spec.chat_template,
267
+ "stop_token_ids": model_spec.stop_token_ids,
268
+ "stop": model_spec.stop,
269
+ }
260
270
  # register model family
261
271
  if "chat" in model_spec.model_ability:
262
272
  BUILTIN_LLM_MODEL_CHAT_FAMILIES.add(model_spec.model_name)
@@ -14,12 +14,11 @@
14
14
  import logging
15
15
  import os
16
16
  import time
17
- from typing import Iterable, Iterator, List, Optional, Union
17
+ from typing import Dict, Iterator, List, Optional, Union
18
18
 
19
19
  from ....types import (
20
20
  ChatCompletion,
21
21
  ChatCompletionChunk,
22
- ChatCompletionMessage,
23
22
  Completion,
24
23
  CompletionChunk,
25
24
  CompletionUsage,
@@ -181,13 +180,12 @@ class LlamaCppModel(LLM):
181
180
  for index, _completion_chunk in enumerate(
182
181
  self._llm(prompt=_prompt, **_generate_config)
183
182
  ):
183
+ _completion_chunk["model"] = self.model_uid
184
184
  request_id = _completion_chunk["id"]
185
- choice = _completion_chunk["choices"][0]
186
- if choice["finish_reason"] is not None:
187
- completion_tokens = index
185
+ completion_tokens = index + 1
188
186
  total_tokens = prompt_tokens + completion_tokens
189
187
  _completion_chunk["usage"] = CompletionUsage(
190
- prompt_tokens=total_tokens,
188
+ prompt_tokens=prompt_tokens,
191
189
  completion_tokens=completion_tokens,
192
190
  total_tokens=total_tokens,
193
191
  )
@@ -262,39 +260,26 @@ class LlamaCppChatModel(LlamaCppModel, ChatModelMixin):
262
260
  self, generate_config: Optional[LlamaCppGenerateConfig]
263
261
  ) -> LlamaCppGenerateConfig:
264
262
  generate_config = super()._sanitize_generate_config(generate_config)
265
- if self.model_family.prompt_style and self.model_family.prompt_style.stop:
266
- generate_config["stop"] = self.model_family.prompt_style.stop
263
+ if self.model_family.stop and self.model_family.stop:
264
+ generate_config["stop"] = self.model_family.stop.copy()
267
265
  return generate_config
268
266
 
269
267
  def chat(
270
268
  self,
271
- prompt: str,
272
- system_prompt: Optional[str] = None,
273
- chat_history: Optional[List[ChatCompletionMessage]] = None,
269
+ messages: List[Dict],
274
270
  generate_config: Optional[LlamaCppGenerateConfig] = None,
275
271
  ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
276
- assert self.model_family.prompt_style is not None
277
- prompt_style = self.model_family.prompt_style.copy()
278
- if system_prompt:
279
- prompt_style.system_prompt = system_prompt
280
-
281
- chat_history = chat_history or []
282
- assert prompt_style is not None
272
+ model_family = self.model_family.model_family or self.model_family.model_name
283
273
  tools = generate_config.pop("tools", []) if generate_config else None
284
- full_prompt = self.get_prompt(prompt, chat_history, prompt_style, tools=tools)
274
+ full_context_kwargs = {}
275
+ if tools and model_family in QWEN_TOOL_CALL_FAMILY:
276
+ full_context_kwargs["tools"] = tools
277
+ assert self.model_family.chat_template is not None
278
+ full_prompt = self.get_full_context(
279
+ messages, self.model_family.chat_template, **full_context_kwargs
280
+ )
285
281
 
286
282
  generate_config = self._sanitize_generate_config(generate_config)
287
- # TODO(codingl2k1): qwen hacky to set stop for function call.
288
- model_family = self.model_family.model_family or self.model_family.model_name
289
- if tools and model_family in QWEN_TOOL_CALL_FAMILY:
290
- stop = generate_config.get("stop")
291
- if isinstance(stop, str):
292
- generate_config["stop"] = [stop, "Observation:"]
293
- elif isinstance(stop, Iterable):
294
- assert not isinstance(stop, str)
295
- generate_config["stop"] = stop + ["Observation:"] # type: ignore
296
- else:
297
- generate_config["stop"] = "Observation:"
298
283
 
299
284
  stream = generate_config.get("stream", False)
300
285
  if stream:
@@ -305,7 +290,5 @@ class LlamaCppChatModel(LlamaCppModel, ChatModelMixin):
305
290
  c = self.generate(full_prompt, generate_config)
306
291
  assert not isinstance(c, Iterator)
307
292
  if tools:
308
- return self._tool_calls_completion(
309
- self.model_family, self.model_uid, c, tools
310
- )
293
+ return self._tool_calls_completion(self.model_family, self.model_uid, c)
311
294
  return self._to_chat_completion(c)