xinference 0.13.2__py3-none-any.whl → 0.13.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/__init__.py +0 -1
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +26 -4
- xinference/client/restful/restful_client.py +16 -1
- xinference/core/chat_interface.py +2 -2
- xinference/core/model.py +8 -3
- xinference/core/scheduler.py +4 -4
- xinference/model/audio/core.py +5 -2
- xinference/model/audio/cosyvoice.py +136 -0
- xinference/model/audio/model_spec.json +24 -0
- xinference/model/audio/model_spec_modelscope.json +27 -0
- xinference/model/flexible/launchers/__init__.py +1 -0
- xinference/model/flexible/launchers/image_process_launcher.py +70 -0
- xinference/model/image/model_spec.json +7 -0
- xinference/model/image/stable_diffusion/core.py +6 -1
- xinference/model/llm/llm_family.json +802 -82
- xinference/model/llm/llm_family_csghub.json +39 -0
- xinference/model/llm/llm_family_modelscope.json +295 -47
- xinference/model/llm/pytorch/chatglm.py +243 -5
- xinference/model/llm/pytorch/cogvlm2.py +1 -1
- xinference/model/llm/utils.py +78 -1
- xinference/model/llm/vllm/core.py +8 -0
- xinference/thirdparty/cosyvoice/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/bin/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/bin/inference.py +114 -0
- xinference/thirdparty/cosyvoice/bin/train.py +136 -0
- xinference/thirdparty/cosyvoice/cli/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/cli/cosyvoice.py +83 -0
- xinference/thirdparty/cosyvoice/cli/frontend.py +168 -0
- xinference/thirdparty/cosyvoice/cli/model.py +60 -0
- xinference/thirdparty/cosyvoice/dataset/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/dataset/dataset.py +160 -0
- xinference/thirdparty/cosyvoice/dataset/processor.py +369 -0
- xinference/thirdparty/cosyvoice/flow/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/flow/decoder.py +222 -0
- xinference/thirdparty/cosyvoice/flow/flow.py +135 -0
- xinference/thirdparty/cosyvoice/flow/flow_matching.py +138 -0
- xinference/thirdparty/cosyvoice/flow/length_regulator.py +49 -0
- xinference/thirdparty/cosyvoice/hifigan/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/hifigan/f0_predictor.py +55 -0
- xinference/thirdparty/cosyvoice/hifigan/generator.py +391 -0
- xinference/thirdparty/cosyvoice/llm/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/llm/llm.py +206 -0
- xinference/thirdparty/cosyvoice/transformer/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/transformer/activation.py +84 -0
- xinference/thirdparty/cosyvoice/transformer/attention.py +326 -0
- xinference/thirdparty/cosyvoice/transformer/convolution.py +145 -0
- xinference/thirdparty/cosyvoice/transformer/decoder.py +396 -0
- xinference/thirdparty/cosyvoice/transformer/decoder_layer.py +132 -0
- xinference/thirdparty/cosyvoice/transformer/embedding.py +293 -0
- xinference/thirdparty/cosyvoice/transformer/encoder.py +472 -0
- xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +236 -0
- xinference/thirdparty/cosyvoice/transformer/label_smoothing_loss.py +96 -0
- xinference/thirdparty/cosyvoice/transformer/positionwise_feed_forward.py +115 -0
- xinference/thirdparty/cosyvoice/transformer/subsampling.py +383 -0
- xinference/thirdparty/cosyvoice/utils/__init__.py +0 -0
- xinference/thirdparty/cosyvoice/utils/class_utils.py +70 -0
- xinference/thirdparty/cosyvoice/utils/common.py +103 -0
- xinference/thirdparty/cosyvoice/utils/executor.py +110 -0
- xinference/thirdparty/cosyvoice/utils/file_utils.py +41 -0
- xinference/thirdparty/cosyvoice/utils/frontend_utils.py +125 -0
- xinference/thirdparty/cosyvoice/utils/mask.py +227 -0
- xinference/thirdparty/cosyvoice/utils/scheduler.py +739 -0
- xinference/thirdparty/cosyvoice/utils/train_utils.py +289 -0
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.95c1d652.js → main.2ef0cfaf.js} +3 -3
- xinference/web/ui/build/static/js/main.2ef0cfaf.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/b6807ecc0c231fea699533518a0eb2a2bf68a081ce00d452be40600dbffa17a7.json +1 -0
- {xinference-0.13.2.dist-info → xinference-0.13.3.dist-info}/METADATA +16 -8
- {xinference-0.13.2.dist-info → xinference-0.13.3.dist-info}/RECORD +76 -32
- xinference/web/ui/build/static/js/main.95c1d652.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/709711edada3f1596b309d571285fd31f1c364d66f4425bc28723d0088cc351a.json +0 -1
- /xinference/web/ui/build/static/js/{main.95c1d652.js.LICENSE.txt → main.2ef0cfaf.js.LICENSE.txt} +0 -0
- {xinference-0.13.2.dist-info → xinference-0.13.3.dist-info}/LICENSE +0 -0
- {xinference-0.13.2.dist-info → xinference-0.13.3.dist-info}/WHEEL +0 -0
- {xinference-0.13.2.dist-info → xinference-0.13.3.dist-info}/entry_points.txt +0 -0
- {xinference-0.13.2.dist-info → xinference-0.13.3.dist-info}/top_level.txt +0 -0
|
@@ -11,10 +11,17 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
+
import copy
|
|
15
|
+
import json
|
|
16
|
+
import threading
|
|
14
17
|
import time
|
|
15
18
|
import uuid
|
|
16
19
|
from typing import Any, Dict, Iterator, List, Optional, Union
|
|
17
20
|
|
|
21
|
+
import torch
|
|
22
|
+
from transformers.generation.logits_process import LogitsProcessor
|
|
23
|
+
from transformers.generation.utils import LogitsProcessorList
|
|
24
|
+
|
|
18
25
|
from ....core.scheduler import InferenceRequest
|
|
19
26
|
from ....types import (
|
|
20
27
|
SPECIAL_TOOL_PROMPT,
|
|
@@ -33,6 +40,16 @@ from ..utils import GLM4_TOOL_CALL_FAMILY
|
|
|
33
40
|
from .core import PytorchChatModel, PytorchModelConfig
|
|
34
41
|
|
|
35
42
|
|
|
43
|
+
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
|
44
|
+
def __call__(
|
|
45
|
+
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
|
|
46
|
+
) -> torch.FloatTensor:
|
|
47
|
+
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
|
48
|
+
scores.zero_()
|
|
49
|
+
scores[..., 198] = 5e4
|
|
50
|
+
return scores
|
|
51
|
+
|
|
52
|
+
|
|
36
53
|
class ChatglmPytorchChatModel(PytorchChatModel):
|
|
37
54
|
def __init__(
|
|
38
55
|
self,
|
|
@@ -103,9 +120,11 @@ class ChatglmPytorchChatModel(PytorchChatModel):
|
|
|
103
120
|
tools = generate_config.pop("tools", None)
|
|
104
121
|
if tools is None:
|
|
105
122
|
return False
|
|
123
|
+
# Convert a iterable to a list
|
|
124
|
+
tools = list(tools)
|
|
106
125
|
tool_choice = generate_config.pop("tool_choice", "none")
|
|
107
126
|
if self.model_family.model_name in GLM4_TOOL_CALL_FAMILY:
|
|
108
|
-
chat_history[:] = self.
|
|
127
|
+
chat_history[:] = self._process_messages(
|
|
109
128
|
chat_history, tools=tools, tool_choice=tool_choice
|
|
110
129
|
)
|
|
111
130
|
return True
|
|
@@ -124,7 +143,7 @@ class ChatglmPytorchChatModel(PytorchChatModel):
|
|
|
124
143
|
return True
|
|
125
144
|
|
|
126
145
|
@staticmethod
|
|
127
|
-
def
|
|
146
|
+
def _process_messages(messages, tools=None, tool_choice="none"):
|
|
128
147
|
# This method is adapted from https://github.com/THUDM/GLM-4/blob/main/basic_demo/openai_api_server.py
|
|
129
148
|
_messages = messages
|
|
130
149
|
processed_messages = []
|
|
@@ -210,6 +229,209 @@ class ChatglmPytorchChatModel(PytorchChatModel):
|
|
|
210
229
|
break
|
|
211
230
|
return processed_messages
|
|
212
231
|
|
|
232
|
+
@staticmethod
|
|
233
|
+
def _process_response(output, history, tools, end=False):
|
|
234
|
+
# Copy from https://huggingface.co/THUDM/glm-4-9b-chat/blob/main/modeling_chatglm.py
|
|
235
|
+
content = ""
|
|
236
|
+
history = copy.deepcopy(history)
|
|
237
|
+
if not tools and end:
|
|
238
|
+
return None, None
|
|
239
|
+
for response in output.split("<|assistant|>"):
|
|
240
|
+
if "\n" in response:
|
|
241
|
+
metadata, content = response.split("\n", maxsplit=1)
|
|
242
|
+
else:
|
|
243
|
+
metadata, content = "", response
|
|
244
|
+
if not metadata.strip():
|
|
245
|
+
if tools and any(t.startswith(response) for t in tools) and not end:
|
|
246
|
+
# Waiting for tool call complete.
|
|
247
|
+
return None, None
|
|
248
|
+
content = content.strip()
|
|
249
|
+
history.append(
|
|
250
|
+
{"role": "assistant", "metadata": metadata, "content": content}
|
|
251
|
+
)
|
|
252
|
+
content = content.replace("[[训练时间]]", "2023年")
|
|
253
|
+
else:
|
|
254
|
+
if tools and metadata in tools and not end:
|
|
255
|
+
return None, None
|
|
256
|
+
history.append(
|
|
257
|
+
{"role": "assistant", "metadata": metadata, "content": content}
|
|
258
|
+
)
|
|
259
|
+
metadata = metadata.strip()
|
|
260
|
+
if tools and metadata in tools and end:
|
|
261
|
+
try:
|
|
262
|
+
parameters = json.loads(content)
|
|
263
|
+
content = {"name": metadata.strip(), "parameters": parameters}
|
|
264
|
+
except json.JSONDecodeError:
|
|
265
|
+
content = {"name": metadata.strip(), "content": content}
|
|
266
|
+
else:
|
|
267
|
+
content = {"name": metadata.strip(), "content": content}
|
|
268
|
+
return content, history
|
|
269
|
+
|
|
270
|
+
def _get_generate_args(
|
|
271
|
+
self,
|
|
272
|
+
tokenizer,
|
|
273
|
+
query: str,
|
|
274
|
+
history: Optional[List[Dict]] = None,
|
|
275
|
+
role: str = "user",
|
|
276
|
+
past_key_values=None,
|
|
277
|
+
max_length: int = 8192,
|
|
278
|
+
do_sample=True,
|
|
279
|
+
top_p=0.8,
|
|
280
|
+
temperature=0.8,
|
|
281
|
+
logits_processor=None,
|
|
282
|
+
**kwargs,
|
|
283
|
+
):
|
|
284
|
+
# Copy from https://huggingface.co/THUDM/glm-4-9b-chat/blob/main/modeling_chatglm.py
|
|
285
|
+
if history is None:
|
|
286
|
+
history = []
|
|
287
|
+
if logits_processor is None:
|
|
288
|
+
logits_processor = LogitsProcessorList()
|
|
289
|
+
logits_processor.append(InvalidScoreLogitsProcessor())
|
|
290
|
+
eos_token_id = [
|
|
291
|
+
tokenizer.eos_token_id,
|
|
292
|
+
tokenizer.convert_tokens_to_ids("<|user|>"),
|
|
293
|
+
tokenizer.convert_tokens_to_ids("<|observation|>"),
|
|
294
|
+
]
|
|
295
|
+
gen_kwargs = {
|
|
296
|
+
"max_length": max_length,
|
|
297
|
+
"do_sample": do_sample,
|
|
298
|
+
"top_p": top_p,
|
|
299
|
+
"temperature": temperature,
|
|
300
|
+
"logits_processor": logits_processor,
|
|
301
|
+
**kwargs,
|
|
302
|
+
}
|
|
303
|
+
if past_key_values is None:
|
|
304
|
+
inputs = tokenizer.apply_chat_template(
|
|
305
|
+
history + [{"role": role, "content": query}],
|
|
306
|
+
add_generation_prompt=True,
|
|
307
|
+
tokenize=True,
|
|
308
|
+
return_tensors="pt",
|
|
309
|
+
return_dict=True,
|
|
310
|
+
)
|
|
311
|
+
else:
|
|
312
|
+
inputs = tokenizer.apply_chat_template(
|
|
313
|
+
[{"role": role, "content": query}],
|
|
314
|
+
add_special_tokens=False,
|
|
315
|
+
add_generation_prompt=True,
|
|
316
|
+
tokenize=True,
|
|
317
|
+
return_tensors="pt",
|
|
318
|
+
return_dict=True,
|
|
319
|
+
)
|
|
320
|
+
inputs = inputs.to(self._model.device)
|
|
321
|
+
if past_key_values is not None:
|
|
322
|
+
past_length = past_key_values[0][0].shape[2]
|
|
323
|
+
inputs.position_ids += past_length
|
|
324
|
+
attention_mask = inputs.attention_mask
|
|
325
|
+
attention_mask = torch.cat(
|
|
326
|
+
(attention_mask.new_ones(1, past_length), attention_mask), dim=1
|
|
327
|
+
)
|
|
328
|
+
inputs["attention_mask"] = attention_mask
|
|
329
|
+
history.append({"role": role, "content": query})
|
|
330
|
+
tools = history[0]["role"] == "system" and history[0].get("tools")
|
|
331
|
+
tools = (
|
|
332
|
+
[
|
|
333
|
+
t.get("function", {}).get("name", "")
|
|
334
|
+
for t in tools
|
|
335
|
+
if isinstance(t, dict)
|
|
336
|
+
]
|
|
337
|
+
if tools
|
|
338
|
+
else []
|
|
339
|
+
)
|
|
340
|
+
kwargs = dict(inputs)
|
|
341
|
+
kwargs["past_key_values"] = past_key_values
|
|
342
|
+
kwargs["eos_token_id"] = eos_token_id
|
|
343
|
+
kwargs.update(gen_kwargs)
|
|
344
|
+
return kwargs, tools
|
|
345
|
+
|
|
346
|
+
@torch.inference_mode()
|
|
347
|
+
def stream_chat(
|
|
348
|
+
self,
|
|
349
|
+
tokenizer,
|
|
350
|
+
query: str,
|
|
351
|
+
history: Optional[List[Dict]] = None,
|
|
352
|
+
role: str = "user",
|
|
353
|
+
past_key_values=None,
|
|
354
|
+
max_length: int = 8192,
|
|
355
|
+
do_sample=True,
|
|
356
|
+
top_p=0.8,
|
|
357
|
+
temperature=0.8,
|
|
358
|
+
logits_processor=None,
|
|
359
|
+
**kwargs,
|
|
360
|
+
):
|
|
361
|
+
from transformers import TextIteratorStreamer
|
|
362
|
+
|
|
363
|
+
kwargs, tools = self._get_generate_args(
|
|
364
|
+
tokenizer=tokenizer,
|
|
365
|
+
query=query,
|
|
366
|
+
history=history,
|
|
367
|
+
role=role,
|
|
368
|
+
past_key_values=past_key_values,
|
|
369
|
+
max_length=max_length,
|
|
370
|
+
do_sample=do_sample,
|
|
371
|
+
top_p=top_p,
|
|
372
|
+
temperature=temperature,
|
|
373
|
+
logits_processor=logits_processor,
|
|
374
|
+
**kwargs,
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
streamer = TextIteratorStreamer(
|
|
378
|
+
tokenizer, skip_prompt=True, skip_special_tokens=True
|
|
379
|
+
)
|
|
380
|
+
kwargs["streamer"] = streamer
|
|
381
|
+
thread = threading.Thread(target=self._model.generate, kwargs=kwargs)
|
|
382
|
+
thread.start()
|
|
383
|
+
|
|
384
|
+
response = ""
|
|
385
|
+
for token in streamer:
|
|
386
|
+
response += token
|
|
387
|
+
if response and response[-1] != "�":
|
|
388
|
+
new_response, new_history = self._process_response(
|
|
389
|
+
response, history, tools, end=False
|
|
390
|
+
)
|
|
391
|
+
if new_response is None:
|
|
392
|
+
continue
|
|
393
|
+
yield new_response, new_history
|
|
394
|
+
if tools:
|
|
395
|
+
new_response, new_history = self._process_response(
|
|
396
|
+
response, history, tools, end=True
|
|
397
|
+
)
|
|
398
|
+
if new_response:
|
|
399
|
+
yield new_response, new_history
|
|
400
|
+
|
|
401
|
+
@torch.inference_mode()
|
|
402
|
+
def non_stream_chat(
|
|
403
|
+
self,
|
|
404
|
+
tokenizer,
|
|
405
|
+
query: str,
|
|
406
|
+
history: Optional[List[Dict]] = None,
|
|
407
|
+
role: str = "user",
|
|
408
|
+
past_key_values=None,
|
|
409
|
+
max_length: int = 8192,
|
|
410
|
+
do_sample=True,
|
|
411
|
+
top_p=0.8,
|
|
412
|
+
temperature=0.8,
|
|
413
|
+
logits_processor=None,
|
|
414
|
+
**kwargs,
|
|
415
|
+
):
|
|
416
|
+
kwargs, tools = self._get_generate_args(
|
|
417
|
+
tokenizer=tokenizer,
|
|
418
|
+
query=query,
|
|
419
|
+
history=history,
|
|
420
|
+
role=role,
|
|
421
|
+
past_key_values=past_key_values,
|
|
422
|
+
max_length=max_length,
|
|
423
|
+
do_sample=do_sample,
|
|
424
|
+
top_p=top_p,
|
|
425
|
+
temperature=temperature,
|
|
426
|
+
logits_processor=logits_processor,
|
|
427
|
+
**kwargs,
|
|
428
|
+
)
|
|
429
|
+
|
|
430
|
+
outputs = self._model.generate(**kwargs)
|
|
431
|
+
outputs = outputs[:, kwargs["input_ids"].shape[1] :]
|
|
432
|
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
433
|
+
return self._process_response(response, history, tools, end=True)
|
|
434
|
+
|
|
213
435
|
def chat(
|
|
214
436
|
self,
|
|
215
437
|
prompt: str,
|
|
@@ -247,7 +469,13 @@ class ChatglmPytorchChatModel(PytorchChatModel):
|
|
|
247
469
|
if isinstance(stream_options, dict)
|
|
248
470
|
else False
|
|
249
471
|
)
|
|
250
|
-
if stream and
|
|
472
|
+
if stream and (
|
|
473
|
+
not tools or self.model_family.model_name in GLM4_TOOL_CALL_FAMILY
|
|
474
|
+
):
|
|
475
|
+
if self.model_family.model_name in GLM4_TOOL_CALL_FAMILY:
|
|
476
|
+
stream_chat = self.stream_chat
|
|
477
|
+
else:
|
|
478
|
+
stream_chat = self._model.stream_chat
|
|
251
479
|
|
|
252
480
|
def _stream_generator():
|
|
253
481
|
last_chunk_text_length = 0
|
|
@@ -256,9 +484,14 @@ class ChatglmPytorchChatModel(PytorchChatModel):
|
|
|
256
484
|
inputs = self._tokenizer([prompt], return_tensors="pt")
|
|
257
485
|
inputs = inputs.to(self._model.device)
|
|
258
486
|
prompt_tokens = len(inputs["input_ids"][0])
|
|
259
|
-
for chunk_text, _ in
|
|
487
|
+
for chunk_text, _ in stream_chat(
|
|
260
488
|
self._tokenizer, prompt, chat_history, **kwargs
|
|
261
489
|
):
|
|
490
|
+
if tools and isinstance(chunk_text, dict):
|
|
491
|
+
yield self._tool_calls_completion_chunk(
|
|
492
|
+
self.model_family, self.model_uid, [chunk_text, _], tools
|
|
493
|
+
)
|
|
494
|
+
return
|
|
262
495
|
completion_tokens = completion_tokens + 1
|
|
263
496
|
total_tokens = prompt_tokens + completion_tokens
|
|
264
497
|
chunk_text = chunk_text[last_chunk_text_length:]
|
|
@@ -312,7 +545,12 @@ class ChatglmPytorchChatModel(PytorchChatModel):
|
|
|
312
545
|
|
|
313
546
|
return self._to_chat_completion_chunks(_stream_generator())
|
|
314
547
|
else:
|
|
315
|
-
|
|
548
|
+
if self.model_family.model_name in GLM4_TOOL_CALL_FAMILY:
|
|
549
|
+
chat = self.non_stream_chat
|
|
550
|
+
else:
|
|
551
|
+
chat = self._model.chat
|
|
552
|
+
|
|
553
|
+
response = chat(self._tokenizer, prompt, chat_history, **kwargs)
|
|
316
554
|
if tools:
|
|
317
555
|
return self._tool_calls_completion(
|
|
318
556
|
self.model_family, self.model_uid, response, tools
|
|
@@ -387,7 +387,7 @@ class CogVLM2Model(PytorchChatModel):
|
|
|
387
387
|
prompt, system_prompt=system_prompt, chat_history=chat_history
|
|
388
388
|
)
|
|
389
389
|
|
|
390
|
-
input_by_model: dict = self._model.build_conversation_input_ids(
|
|
390
|
+
input_by_model: dict = self._model.build_conversation_input_ids( # type: ignore
|
|
391
391
|
self._tokenizer,
|
|
392
392
|
query=query,
|
|
393
393
|
history=history,
|
xinference/model/llm/utils.py
CHANGED
|
@@ -483,11 +483,40 @@ Begin!"""
|
|
|
483
483
|
else:
|
|
484
484
|
ret += role
|
|
485
485
|
return ret
|
|
486
|
+
elif prompt_style.style_name == "mistral-nemo":
|
|
487
|
+
seps = [prompt_style.intra_message_sep, prompt_style.inter_message_sep]
|
|
488
|
+
ret = "<s>"
|
|
489
|
+
for i, message in enumerate(chat_history):
|
|
490
|
+
role = get_role(message["role"])
|
|
491
|
+
content = message["content"]
|
|
492
|
+
if content:
|
|
493
|
+
if i == len(chat_history) - 2 and prompt_style.system_prompt:
|
|
494
|
+
ret += (
|
|
495
|
+
role
|
|
496
|
+
+ " "
|
|
497
|
+
+ prompt_style.system_prompt
|
|
498
|
+
+ "\n\n"
|
|
499
|
+
+ content
|
|
500
|
+
+ seps[i % 2]
|
|
501
|
+
)
|
|
502
|
+
else:
|
|
503
|
+
ret += role + " " + content + seps[i % 2]
|
|
504
|
+
else:
|
|
505
|
+
ret += role
|
|
506
|
+
return ret
|
|
486
507
|
else:
|
|
487
508
|
raise ValueError(f"Invalid prompt style: {prompt_style.style_name}")
|
|
488
509
|
|
|
489
510
|
@classmethod
|
|
490
511
|
def _to_chat_completion_chunk(cls, chunk: CompletionChunk) -> ChatCompletionChunk:
|
|
512
|
+
choices = chunk.get("choices")
|
|
513
|
+
if (
|
|
514
|
+
chunk.get("object") == "chat.completion.chunk"
|
|
515
|
+
and choices
|
|
516
|
+
and "delta" in choices[0]
|
|
517
|
+
):
|
|
518
|
+
# Already a ChatCompletionChunk, we don't need to convert chunk.
|
|
519
|
+
return cast(ChatCompletionChunk, chunk)
|
|
491
520
|
chat_chunk = {
|
|
492
521
|
"id": "chat" + chunk["id"],
|
|
493
522
|
"model": chunk["model"],
|
|
@@ -497,7 +526,7 @@ Begin!"""
|
|
|
497
526
|
{
|
|
498
527
|
"index": i,
|
|
499
528
|
"delta": {
|
|
500
|
-
"content": choice
|
|
529
|
+
"content": choice.get("text"),
|
|
501
530
|
**(
|
|
502
531
|
{"tool_calls": choice["tool_calls"]}
|
|
503
532
|
if "tool_calls" in choice
|
|
@@ -718,6 +747,54 @@ Begin!"""
|
|
|
718
747
|
else:
|
|
719
748
|
return lambda tokens, delta: delta
|
|
720
749
|
|
|
750
|
+
@classmethod
|
|
751
|
+
def _tool_calls_completion_chunk(cls, model_family, model_uid, c, tools):
|
|
752
|
+
_id = str(uuid.uuid4())
|
|
753
|
+
content, func, args = cls._eval_tool_arguments(model_family, c, tools)
|
|
754
|
+
if func:
|
|
755
|
+
d = {
|
|
756
|
+
"role": "assistant",
|
|
757
|
+
"content": content,
|
|
758
|
+
"tool_calls": [
|
|
759
|
+
{
|
|
760
|
+
"id": f"call_{_id}",
|
|
761
|
+
"type": "function",
|
|
762
|
+
"function": {
|
|
763
|
+
"name": func,
|
|
764
|
+
"arguments": json.dumps(args),
|
|
765
|
+
},
|
|
766
|
+
}
|
|
767
|
+
],
|
|
768
|
+
}
|
|
769
|
+
finish_reason = "tool_calls"
|
|
770
|
+
else:
|
|
771
|
+
d = {"role": "assistant", "content": content, "tool_calls": []}
|
|
772
|
+
finish_reason = "stop"
|
|
773
|
+
try:
|
|
774
|
+
usage = c.get("usage")
|
|
775
|
+
assert "prompt_tokens" in usage
|
|
776
|
+
except Exception:
|
|
777
|
+
usage = {
|
|
778
|
+
"prompt_tokens": -1,
|
|
779
|
+
"completion_tokens": -1,
|
|
780
|
+
"total_tokens": -1,
|
|
781
|
+
}
|
|
782
|
+
return {
|
|
783
|
+
"id": "chat" + f"cmpl-{_id}",
|
|
784
|
+
"model": model_uid,
|
|
785
|
+
"object": "chat.completion.chunk",
|
|
786
|
+
"created": int(time.time()),
|
|
787
|
+
"choices": [
|
|
788
|
+
{
|
|
789
|
+
"index": 0,
|
|
790
|
+
"delta": d,
|
|
791
|
+
"logprobs": None,
|
|
792
|
+
"finish_reason": finish_reason,
|
|
793
|
+
}
|
|
794
|
+
],
|
|
795
|
+
"usage": usage,
|
|
796
|
+
}
|
|
797
|
+
|
|
721
798
|
@classmethod
|
|
722
799
|
def _tool_calls_completion(cls, model_family, model_uid, c, tools):
|
|
723
800
|
_id = str(uuid.uuid4())
|
|
@@ -151,6 +151,14 @@ if VLLM_INSTALLED and vllm.__version__ >= "0.4.0":
|
|
|
151
151
|
VLLM_SUPPORTED_CHAT_MODELS.append("qwen2-moe-instruct")
|
|
152
152
|
VLLM_SUPPORTED_CHAT_MODELS.append("c4ai-command-r-v01")
|
|
153
153
|
|
|
154
|
+
if VLLM_INSTALLED and vllm.__version__ >= "0.5.3":
|
|
155
|
+
VLLM_SUPPORTED_CHAT_MODELS.append("mistral-nemo-instruct")
|
|
156
|
+
VLLM_SUPPORTED_CHAT_MODELS.append("mistral-large-instruct")
|
|
157
|
+
|
|
158
|
+
if VLLM_INSTALLED and vllm.__version__ > "0.5.3":
|
|
159
|
+
VLLM_SUPPORTED_MODELS.append("llama-3.1")
|
|
160
|
+
VLLM_SUPPORTED_CHAT_MODELS.append("llama-3.1-instruct")
|
|
161
|
+
|
|
154
162
|
|
|
155
163
|
class VLLMModel(LLM):
|
|
156
164
|
def __init__(
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from __future__ import print_function
|
|
16
|
+
|
|
17
|
+
import argparse
|
|
18
|
+
import logging
|
|
19
|
+
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
|
20
|
+
import os
|
|
21
|
+
|
|
22
|
+
import torch
|
|
23
|
+
from torch.utils.data import DataLoader
|
|
24
|
+
import torchaudio
|
|
25
|
+
from hyperpyyaml import load_hyperpyyaml
|
|
26
|
+
from tqdm import tqdm
|
|
27
|
+
from cosyvoice.cli.model import CosyVoiceModel
|
|
28
|
+
|
|
29
|
+
from cosyvoice.dataset.dataset import Dataset
|
|
30
|
+
|
|
31
|
+
def get_args():
|
|
32
|
+
parser = argparse.ArgumentParser(description='inference with your model')
|
|
33
|
+
parser.add_argument('--config', required=True, help='config file')
|
|
34
|
+
parser.add_argument('--prompt_data', required=True, help='prompt data file')
|
|
35
|
+
parser.add_argument('--prompt_utt2data', required=True, help='prompt data file')
|
|
36
|
+
parser.add_argument('--tts_text', required=True, help='tts input file')
|
|
37
|
+
parser.add_argument('--llm_model', required=True, help='llm model file')
|
|
38
|
+
parser.add_argument('--flow_model', required=True, help='flow model file')
|
|
39
|
+
parser.add_argument('--hifigan_model', required=True, help='hifigan model file')
|
|
40
|
+
parser.add_argument('--gpu',
|
|
41
|
+
type=int,
|
|
42
|
+
default=-1,
|
|
43
|
+
help='gpu id for this rank, -1 for cpu')
|
|
44
|
+
parser.add_argument('--mode',
|
|
45
|
+
default='sft',
|
|
46
|
+
choices=['sft', 'zero_shot'],
|
|
47
|
+
help='inference mode')
|
|
48
|
+
parser.add_argument('--result_dir', required=True, help='asr result file')
|
|
49
|
+
args = parser.parse_args()
|
|
50
|
+
print(args)
|
|
51
|
+
return args
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def main():
|
|
55
|
+
args = get_args()
|
|
56
|
+
logging.basicConfig(level=logging.DEBUG,
|
|
57
|
+
format='%(asctime)s %(levelname)s %(message)s')
|
|
58
|
+
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
|
|
59
|
+
|
|
60
|
+
# Init cosyvoice models from configs
|
|
61
|
+
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
|
|
62
|
+
device = torch.device('cuda' if use_cuda else 'cpu')
|
|
63
|
+
with open(args.config, 'r') as f:
|
|
64
|
+
configs = load_hyperpyyaml(f)
|
|
65
|
+
|
|
66
|
+
model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
|
|
67
|
+
model.load(args.llm_model, args.flow_model, args.hifigan_model)
|
|
68
|
+
|
|
69
|
+
test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False, tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
|
|
70
|
+
test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
|
|
71
|
+
|
|
72
|
+
del configs
|
|
73
|
+
os.makedirs(args.result_dir, exist_ok=True)
|
|
74
|
+
fn = os.path.join(args.result_dir, 'wav.scp')
|
|
75
|
+
f = open(fn, 'w')
|
|
76
|
+
with torch.no_grad():
|
|
77
|
+
for batch_idx, batch in tqdm(enumerate(test_data_loader)):
|
|
78
|
+
utts = batch["utts"]
|
|
79
|
+
assert len(utts) == 1, "inference mode only support batchsize 1"
|
|
80
|
+
text = batch["text"]
|
|
81
|
+
text_token = batch["text_token"].to(device)
|
|
82
|
+
text_token_len = batch["text_token_len"].to(device)
|
|
83
|
+
tts_text = batch["tts_text"]
|
|
84
|
+
tts_index = batch["tts_index"]
|
|
85
|
+
tts_text_token = batch["tts_text_token"].to(device)
|
|
86
|
+
tts_text_token_len = batch["tts_text_token_len"].to(device)
|
|
87
|
+
speech_token = batch["speech_token"].to(device)
|
|
88
|
+
speech_token_len = batch["speech_token_len"].to(device)
|
|
89
|
+
speech_feat = batch["speech_feat"].to(device)
|
|
90
|
+
speech_feat_len = batch["speech_feat_len"].to(device)
|
|
91
|
+
utt_embedding = batch["utt_embedding"].to(device)
|
|
92
|
+
spk_embedding = batch["spk_embedding"].to(device)
|
|
93
|
+
if args.mode == 'sft':
|
|
94
|
+
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
|
|
95
|
+
'llm_embedding': spk_embedding, 'flow_embedding': spk_embedding}
|
|
96
|
+
else:
|
|
97
|
+
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
|
|
98
|
+
'prompt_text': text_token, 'prompt_text_len': text_token_len,
|
|
99
|
+
'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
|
|
100
|
+
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
|
|
101
|
+
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
|
|
102
|
+
'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
|
|
103
|
+
model_output = model.inference(**model_input)
|
|
104
|
+
tts_key = '{}_{}'.format(utts[0], tts_index[0])
|
|
105
|
+
tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
|
|
106
|
+
torchaudio.save(tts_fn, model_output['tts_speech'], sample_rate=22050)
|
|
107
|
+
f.write('{} {}\n'.format(tts_key, tts_fn))
|
|
108
|
+
f.flush()
|
|
109
|
+
f.close()
|
|
110
|
+
logging.info('Result wav.scp saved in {}'.format(fn))
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
if __name__ == '__main__':
|
|
114
|
+
main()
|