xinference 0.13.1__py3-none-any.whl → 0.13.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (82) hide show
  1. xinference/__init__.py +0 -1
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +99 -5
  4. xinference/client/restful/restful_client.py +98 -1
  5. xinference/core/chat_interface.py +2 -2
  6. xinference/core/model.py +85 -26
  7. xinference/core/scheduler.py +4 -4
  8. xinference/model/audio/chattts.py +40 -8
  9. xinference/model/audio/core.py +5 -2
  10. xinference/model/audio/cosyvoice.py +136 -0
  11. xinference/model/audio/model_spec.json +24 -0
  12. xinference/model/audio/model_spec_modelscope.json +27 -0
  13. xinference/model/flexible/launchers/__init__.py +1 -0
  14. xinference/model/flexible/launchers/image_process_launcher.py +70 -0
  15. xinference/model/image/core.py +3 -0
  16. xinference/model/image/model_spec.json +21 -0
  17. xinference/model/image/stable_diffusion/core.py +49 -7
  18. xinference/model/llm/llm_family.json +1065 -106
  19. xinference/model/llm/llm_family.py +26 -6
  20. xinference/model/llm/llm_family_csghub.json +39 -0
  21. xinference/model/llm/llm_family_modelscope.json +460 -47
  22. xinference/model/llm/pytorch/chatglm.py +243 -5
  23. xinference/model/llm/pytorch/cogvlm2.py +1 -1
  24. xinference/model/llm/sglang/core.py +7 -2
  25. xinference/model/llm/utils.py +78 -1
  26. xinference/model/llm/vllm/core.py +11 -0
  27. xinference/thirdparty/cosyvoice/__init__.py +0 -0
  28. xinference/thirdparty/cosyvoice/bin/__init__.py +0 -0
  29. xinference/thirdparty/cosyvoice/bin/inference.py +114 -0
  30. xinference/thirdparty/cosyvoice/bin/train.py +136 -0
  31. xinference/thirdparty/cosyvoice/cli/__init__.py +0 -0
  32. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +83 -0
  33. xinference/thirdparty/cosyvoice/cli/frontend.py +168 -0
  34. xinference/thirdparty/cosyvoice/cli/model.py +60 -0
  35. xinference/thirdparty/cosyvoice/dataset/__init__.py +0 -0
  36. xinference/thirdparty/cosyvoice/dataset/dataset.py +160 -0
  37. xinference/thirdparty/cosyvoice/dataset/processor.py +369 -0
  38. xinference/thirdparty/cosyvoice/flow/__init__.py +0 -0
  39. xinference/thirdparty/cosyvoice/flow/decoder.py +222 -0
  40. xinference/thirdparty/cosyvoice/flow/flow.py +135 -0
  41. xinference/thirdparty/cosyvoice/flow/flow_matching.py +138 -0
  42. xinference/thirdparty/cosyvoice/flow/length_regulator.py +49 -0
  43. xinference/thirdparty/cosyvoice/hifigan/__init__.py +0 -0
  44. xinference/thirdparty/cosyvoice/hifigan/f0_predictor.py +55 -0
  45. xinference/thirdparty/cosyvoice/hifigan/generator.py +391 -0
  46. xinference/thirdparty/cosyvoice/llm/__init__.py +0 -0
  47. xinference/thirdparty/cosyvoice/llm/llm.py +206 -0
  48. xinference/thirdparty/cosyvoice/transformer/__init__.py +0 -0
  49. xinference/thirdparty/cosyvoice/transformer/activation.py +84 -0
  50. xinference/thirdparty/cosyvoice/transformer/attention.py +326 -0
  51. xinference/thirdparty/cosyvoice/transformer/convolution.py +145 -0
  52. xinference/thirdparty/cosyvoice/transformer/decoder.py +396 -0
  53. xinference/thirdparty/cosyvoice/transformer/decoder_layer.py +132 -0
  54. xinference/thirdparty/cosyvoice/transformer/embedding.py +293 -0
  55. xinference/thirdparty/cosyvoice/transformer/encoder.py +472 -0
  56. xinference/thirdparty/cosyvoice/transformer/encoder_layer.py +236 -0
  57. xinference/thirdparty/cosyvoice/transformer/label_smoothing_loss.py +96 -0
  58. xinference/thirdparty/cosyvoice/transformer/positionwise_feed_forward.py +115 -0
  59. xinference/thirdparty/cosyvoice/transformer/subsampling.py +383 -0
  60. xinference/thirdparty/cosyvoice/utils/__init__.py +0 -0
  61. xinference/thirdparty/cosyvoice/utils/class_utils.py +70 -0
  62. xinference/thirdparty/cosyvoice/utils/common.py +103 -0
  63. xinference/thirdparty/cosyvoice/utils/executor.py +110 -0
  64. xinference/thirdparty/cosyvoice/utils/file_utils.py +41 -0
  65. xinference/thirdparty/cosyvoice/utils/frontend_utils.py +125 -0
  66. xinference/thirdparty/cosyvoice/utils/mask.py +227 -0
  67. xinference/thirdparty/cosyvoice/utils/scheduler.py +739 -0
  68. xinference/thirdparty/cosyvoice/utils/train_utils.py +289 -0
  69. xinference/web/ui/build/asset-manifest.json +3 -3
  70. xinference/web/ui/build/index.html +1 -1
  71. xinference/web/ui/build/static/js/{main.95c1d652.js → main.2ef0cfaf.js} +3 -3
  72. xinference/web/ui/build/static/js/main.2ef0cfaf.js.map +1 -0
  73. xinference/web/ui/node_modules/.cache/babel-loader/b6807ecc0c231fea699533518a0eb2a2bf68a081ce00d452be40600dbffa17a7.json +1 -0
  74. {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/METADATA +18 -8
  75. {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/RECORD +80 -36
  76. xinference/web/ui/build/static/js/main.95c1d652.js.map +0 -1
  77. xinference/web/ui/node_modules/.cache/babel-loader/709711edada3f1596b309d571285fd31f1c364d66f4425bc28723d0088cc351a.json +0 -1
  78. /xinference/web/ui/build/static/js/{main.95c1d652.js.LICENSE.txt → main.2ef0cfaf.js.LICENSE.txt} +0 -0
  79. {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/LICENSE +0 -0
  80. {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/WHEEL +0 -0
  81. {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/entry_points.txt +0 -0
  82. {xinference-0.13.1.dist-info → xinference-0.13.3.dist-info}/top_level.txt +0 -0
@@ -11,10 +11,17 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+ import copy
15
+ import json
16
+ import threading
14
17
  import time
15
18
  import uuid
16
19
  from typing import Any, Dict, Iterator, List, Optional, Union
17
20
 
21
+ import torch
22
+ from transformers.generation.logits_process import LogitsProcessor
23
+ from transformers.generation.utils import LogitsProcessorList
24
+
18
25
  from ....core.scheduler import InferenceRequest
19
26
  from ....types import (
20
27
  SPECIAL_TOOL_PROMPT,
@@ -33,6 +40,16 @@ from ..utils import GLM4_TOOL_CALL_FAMILY
33
40
  from .core import PytorchChatModel, PytorchModelConfig
34
41
 
35
42
 
43
+ class InvalidScoreLogitsProcessor(LogitsProcessor):
44
+ def __call__(
45
+ self, input_ids: torch.LongTensor, scores: torch.FloatTensor
46
+ ) -> torch.FloatTensor:
47
+ if torch.isnan(scores).any() or torch.isinf(scores).any():
48
+ scores.zero_()
49
+ scores[..., 198] = 5e4
50
+ return scores
51
+
52
+
36
53
  class ChatglmPytorchChatModel(PytorchChatModel):
37
54
  def __init__(
38
55
  self,
@@ -103,9 +120,11 @@ class ChatglmPytorchChatModel(PytorchChatModel):
103
120
  tools = generate_config.pop("tools", None)
104
121
  if tools is None:
105
122
  return False
123
+ # Convert a iterable to a list
124
+ tools = list(tools)
106
125
  tool_choice = generate_config.pop("tool_choice", "none")
107
126
  if self.model_family.model_name in GLM4_TOOL_CALL_FAMILY:
108
- chat_history[:] = self.process_messages(
127
+ chat_history[:] = self._process_messages(
109
128
  chat_history, tools=tools, tool_choice=tool_choice
110
129
  )
111
130
  return True
@@ -124,7 +143,7 @@ class ChatglmPytorchChatModel(PytorchChatModel):
124
143
  return True
125
144
 
126
145
  @staticmethod
127
- def process_messages(messages, tools=None, tool_choice="none"):
146
+ def _process_messages(messages, tools=None, tool_choice="none"):
128
147
  # This method is adapted from https://github.com/THUDM/GLM-4/blob/main/basic_demo/openai_api_server.py
129
148
  _messages = messages
130
149
  processed_messages = []
@@ -210,6 +229,209 @@ class ChatglmPytorchChatModel(PytorchChatModel):
210
229
  break
211
230
  return processed_messages
212
231
 
232
+ @staticmethod
233
+ def _process_response(output, history, tools, end=False):
234
+ # Copy from https://huggingface.co/THUDM/glm-4-9b-chat/blob/main/modeling_chatglm.py
235
+ content = ""
236
+ history = copy.deepcopy(history)
237
+ if not tools and end:
238
+ return None, None
239
+ for response in output.split("<|assistant|>"):
240
+ if "\n" in response:
241
+ metadata, content = response.split("\n", maxsplit=1)
242
+ else:
243
+ metadata, content = "", response
244
+ if not metadata.strip():
245
+ if tools and any(t.startswith(response) for t in tools) and not end:
246
+ # Waiting for tool call complete.
247
+ return None, None
248
+ content = content.strip()
249
+ history.append(
250
+ {"role": "assistant", "metadata": metadata, "content": content}
251
+ )
252
+ content = content.replace("[[训练时间]]", "2023年")
253
+ else:
254
+ if tools and metadata in tools and not end:
255
+ return None, None
256
+ history.append(
257
+ {"role": "assistant", "metadata": metadata, "content": content}
258
+ )
259
+ metadata = metadata.strip()
260
+ if tools and metadata in tools and end:
261
+ try:
262
+ parameters = json.loads(content)
263
+ content = {"name": metadata.strip(), "parameters": parameters}
264
+ except json.JSONDecodeError:
265
+ content = {"name": metadata.strip(), "content": content}
266
+ else:
267
+ content = {"name": metadata.strip(), "content": content}
268
+ return content, history
269
+
270
+ def _get_generate_args(
271
+ self,
272
+ tokenizer,
273
+ query: str,
274
+ history: Optional[List[Dict]] = None,
275
+ role: str = "user",
276
+ past_key_values=None,
277
+ max_length: int = 8192,
278
+ do_sample=True,
279
+ top_p=0.8,
280
+ temperature=0.8,
281
+ logits_processor=None,
282
+ **kwargs,
283
+ ):
284
+ # Copy from https://huggingface.co/THUDM/glm-4-9b-chat/blob/main/modeling_chatglm.py
285
+ if history is None:
286
+ history = []
287
+ if logits_processor is None:
288
+ logits_processor = LogitsProcessorList()
289
+ logits_processor.append(InvalidScoreLogitsProcessor())
290
+ eos_token_id = [
291
+ tokenizer.eos_token_id,
292
+ tokenizer.convert_tokens_to_ids("<|user|>"),
293
+ tokenizer.convert_tokens_to_ids("<|observation|>"),
294
+ ]
295
+ gen_kwargs = {
296
+ "max_length": max_length,
297
+ "do_sample": do_sample,
298
+ "top_p": top_p,
299
+ "temperature": temperature,
300
+ "logits_processor": logits_processor,
301
+ **kwargs,
302
+ }
303
+ if past_key_values is None:
304
+ inputs = tokenizer.apply_chat_template(
305
+ history + [{"role": role, "content": query}],
306
+ add_generation_prompt=True,
307
+ tokenize=True,
308
+ return_tensors="pt",
309
+ return_dict=True,
310
+ )
311
+ else:
312
+ inputs = tokenizer.apply_chat_template(
313
+ [{"role": role, "content": query}],
314
+ add_special_tokens=False,
315
+ add_generation_prompt=True,
316
+ tokenize=True,
317
+ return_tensors="pt",
318
+ return_dict=True,
319
+ )
320
+ inputs = inputs.to(self._model.device)
321
+ if past_key_values is not None:
322
+ past_length = past_key_values[0][0].shape[2]
323
+ inputs.position_ids += past_length
324
+ attention_mask = inputs.attention_mask
325
+ attention_mask = torch.cat(
326
+ (attention_mask.new_ones(1, past_length), attention_mask), dim=1
327
+ )
328
+ inputs["attention_mask"] = attention_mask
329
+ history.append({"role": role, "content": query})
330
+ tools = history[0]["role"] == "system" and history[0].get("tools")
331
+ tools = (
332
+ [
333
+ t.get("function", {}).get("name", "")
334
+ for t in tools
335
+ if isinstance(t, dict)
336
+ ]
337
+ if tools
338
+ else []
339
+ )
340
+ kwargs = dict(inputs)
341
+ kwargs["past_key_values"] = past_key_values
342
+ kwargs["eos_token_id"] = eos_token_id
343
+ kwargs.update(gen_kwargs)
344
+ return kwargs, tools
345
+
346
+ @torch.inference_mode()
347
+ def stream_chat(
348
+ self,
349
+ tokenizer,
350
+ query: str,
351
+ history: Optional[List[Dict]] = None,
352
+ role: str = "user",
353
+ past_key_values=None,
354
+ max_length: int = 8192,
355
+ do_sample=True,
356
+ top_p=0.8,
357
+ temperature=0.8,
358
+ logits_processor=None,
359
+ **kwargs,
360
+ ):
361
+ from transformers import TextIteratorStreamer
362
+
363
+ kwargs, tools = self._get_generate_args(
364
+ tokenizer=tokenizer,
365
+ query=query,
366
+ history=history,
367
+ role=role,
368
+ past_key_values=past_key_values,
369
+ max_length=max_length,
370
+ do_sample=do_sample,
371
+ top_p=top_p,
372
+ temperature=temperature,
373
+ logits_processor=logits_processor,
374
+ **kwargs,
375
+ )
376
+
377
+ streamer = TextIteratorStreamer(
378
+ tokenizer, skip_prompt=True, skip_special_tokens=True
379
+ )
380
+ kwargs["streamer"] = streamer
381
+ thread = threading.Thread(target=self._model.generate, kwargs=kwargs)
382
+ thread.start()
383
+
384
+ response = ""
385
+ for token in streamer:
386
+ response += token
387
+ if response and response[-1] != "�":
388
+ new_response, new_history = self._process_response(
389
+ response, history, tools, end=False
390
+ )
391
+ if new_response is None:
392
+ continue
393
+ yield new_response, new_history
394
+ if tools:
395
+ new_response, new_history = self._process_response(
396
+ response, history, tools, end=True
397
+ )
398
+ if new_response:
399
+ yield new_response, new_history
400
+
401
+ @torch.inference_mode()
402
+ def non_stream_chat(
403
+ self,
404
+ tokenizer,
405
+ query: str,
406
+ history: Optional[List[Dict]] = None,
407
+ role: str = "user",
408
+ past_key_values=None,
409
+ max_length: int = 8192,
410
+ do_sample=True,
411
+ top_p=0.8,
412
+ temperature=0.8,
413
+ logits_processor=None,
414
+ **kwargs,
415
+ ):
416
+ kwargs, tools = self._get_generate_args(
417
+ tokenizer=tokenizer,
418
+ query=query,
419
+ history=history,
420
+ role=role,
421
+ past_key_values=past_key_values,
422
+ max_length=max_length,
423
+ do_sample=do_sample,
424
+ top_p=top_p,
425
+ temperature=temperature,
426
+ logits_processor=logits_processor,
427
+ **kwargs,
428
+ )
429
+
430
+ outputs = self._model.generate(**kwargs)
431
+ outputs = outputs[:, kwargs["input_ids"].shape[1] :]
432
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
433
+ return self._process_response(response, history, tools, end=True)
434
+
213
435
  def chat(
214
436
  self,
215
437
  prompt: str,
@@ -247,7 +469,13 @@ class ChatglmPytorchChatModel(PytorchChatModel):
247
469
  if isinstance(stream_options, dict)
248
470
  else False
249
471
  )
250
- if stream and not tools:
472
+ if stream and (
473
+ not tools or self.model_family.model_name in GLM4_TOOL_CALL_FAMILY
474
+ ):
475
+ if self.model_family.model_name in GLM4_TOOL_CALL_FAMILY:
476
+ stream_chat = self.stream_chat
477
+ else:
478
+ stream_chat = self._model.stream_chat
251
479
 
252
480
  def _stream_generator():
253
481
  last_chunk_text_length = 0
@@ -256,9 +484,14 @@ class ChatglmPytorchChatModel(PytorchChatModel):
256
484
  inputs = self._tokenizer([prompt], return_tensors="pt")
257
485
  inputs = inputs.to(self._model.device)
258
486
  prompt_tokens = len(inputs["input_ids"][0])
259
- for chunk_text, _ in self._model.stream_chat(
487
+ for chunk_text, _ in stream_chat(
260
488
  self._tokenizer, prompt, chat_history, **kwargs
261
489
  ):
490
+ if tools and isinstance(chunk_text, dict):
491
+ yield self._tool_calls_completion_chunk(
492
+ self.model_family, self.model_uid, [chunk_text, _], tools
493
+ )
494
+ return
262
495
  completion_tokens = completion_tokens + 1
263
496
  total_tokens = prompt_tokens + completion_tokens
264
497
  chunk_text = chunk_text[last_chunk_text_length:]
@@ -312,7 +545,12 @@ class ChatglmPytorchChatModel(PytorchChatModel):
312
545
 
313
546
  return self._to_chat_completion_chunks(_stream_generator())
314
547
  else:
315
- response = self._model.chat(self._tokenizer, prompt, chat_history, **kwargs)
548
+ if self.model_family.model_name in GLM4_TOOL_CALL_FAMILY:
549
+ chat = self.non_stream_chat
550
+ else:
551
+ chat = self._model.chat
552
+
553
+ response = chat(self._tokenizer, prompt, chat_history, **kwargs)
316
554
  if tools:
317
555
  return self._tool_calls_completion(
318
556
  self.model_family, self.model_uid, response, tools
@@ -387,7 +387,7 @@ class CogVLM2Model(PytorchChatModel):
387
387
  prompt, system_prompt=system_prompt, chat_history=chat_history
388
388
  )
389
389
 
390
- input_by_model: dict = self._model.build_conversation_input_ids(
390
+ input_by_model: dict = self._model.build_conversation_input_ids( # type: ignore
391
391
  self._tokenizer,
392
392
  query=query,
393
393
  history=history,
@@ -269,8 +269,13 @@ class SGLANGModel(LLM):
269
269
  )
270
270
  stream = sanitized_generate_config.pop("stream")
271
271
  stream_options = sanitized_generate_config.pop("stream_options")
272
- if isinstance(stream_options, dict):
273
- include_usage = stream_options.pop("include_usage", False)
272
+
273
+ include_usage = (
274
+ stream_options.pop("include_usage")
275
+ if isinstance(stream_options, dict)
276
+ else False
277
+ )
278
+
274
279
  request_id = str(uuid.uuid1())
275
280
  state = pipeline.run(
276
281
  question=prompt,
@@ -483,11 +483,40 @@ Begin!"""
483
483
  else:
484
484
  ret += role
485
485
  return ret
486
+ elif prompt_style.style_name == "mistral-nemo":
487
+ seps = [prompt_style.intra_message_sep, prompt_style.inter_message_sep]
488
+ ret = "<s>"
489
+ for i, message in enumerate(chat_history):
490
+ role = get_role(message["role"])
491
+ content = message["content"]
492
+ if content:
493
+ if i == len(chat_history) - 2 and prompt_style.system_prompt:
494
+ ret += (
495
+ role
496
+ + " "
497
+ + prompt_style.system_prompt
498
+ + "\n\n"
499
+ + content
500
+ + seps[i % 2]
501
+ )
502
+ else:
503
+ ret += role + " " + content + seps[i % 2]
504
+ else:
505
+ ret += role
506
+ return ret
486
507
  else:
487
508
  raise ValueError(f"Invalid prompt style: {prompt_style.style_name}")
488
509
 
489
510
  @classmethod
490
511
  def _to_chat_completion_chunk(cls, chunk: CompletionChunk) -> ChatCompletionChunk:
512
+ choices = chunk.get("choices")
513
+ if (
514
+ chunk.get("object") == "chat.completion.chunk"
515
+ and choices
516
+ and "delta" in choices[0]
517
+ ):
518
+ # Already a ChatCompletionChunk, we don't need to convert chunk.
519
+ return cast(ChatCompletionChunk, chunk)
491
520
  chat_chunk = {
492
521
  "id": "chat" + chunk["id"],
493
522
  "model": chunk["model"],
@@ -497,7 +526,7 @@ Begin!"""
497
526
  {
498
527
  "index": i,
499
528
  "delta": {
500
- "content": choice["text"],
529
+ "content": choice.get("text"),
501
530
  **(
502
531
  {"tool_calls": choice["tool_calls"]}
503
532
  if "tool_calls" in choice
@@ -718,6 +747,54 @@ Begin!"""
718
747
  else:
719
748
  return lambda tokens, delta: delta
720
749
 
750
+ @classmethod
751
+ def _tool_calls_completion_chunk(cls, model_family, model_uid, c, tools):
752
+ _id = str(uuid.uuid4())
753
+ content, func, args = cls._eval_tool_arguments(model_family, c, tools)
754
+ if func:
755
+ d = {
756
+ "role": "assistant",
757
+ "content": content,
758
+ "tool_calls": [
759
+ {
760
+ "id": f"call_{_id}",
761
+ "type": "function",
762
+ "function": {
763
+ "name": func,
764
+ "arguments": json.dumps(args),
765
+ },
766
+ }
767
+ ],
768
+ }
769
+ finish_reason = "tool_calls"
770
+ else:
771
+ d = {"role": "assistant", "content": content, "tool_calls": []}
772
+ finish_reason = "stop"
773
+ try:
774
+ usage = c.get("usage")
775
+ assert "prompt_tokens" in usage
776
+ except Exception:
777
+ usage = {
778
+ "prompt_tokens": -1,
779
+ "completion_tokens": -1,
780
+ "total_tokens": -1,
781
+ }
782
+ return {
783
+ "id": "chat" + f"cmpl-{_id}",
784
+ "model": model_uid,
785
+ "object": "chat.completion.chunk",
786
+ "created": int(time.time()),
787
+ "choices": [
788
+ {
789
+ "index": 0,
790
+ "delta": d,
791
+ "logprobs": None,
792
+ "finish_reason": finish_reason,
793
+ }
794
+ ],
795
+ "usage": usage,
796
+ }
797
+
721
798
  @classmethod
722
799
  def _tool_calls_completion(cls, model_family, model_uid, c, tools):
723
800
  _id = str(uuid.uuid4())
@@ -112,6 +112,8 @@ VLLM_SUPPORTED_CHAT_MODELS = [
112
112
  "internlm-chat-8k",
113
113
  "internlm-chat-20b",
114
114
  "internlm2-chat",
115
+ "internlm2.5-chat",
116
+ "internlm2.5-chat-1m",
115
117
  "qwen-chat",
116
118
  "Yi-chat",
117
119
  "Yi-1.5-chat",
@@ -127,6 +129,7 @@ VLLM_SUPPORTED_CHAT_MODELS = [
127
129
  "chatglm3-128k",
128
130
  "glm4-chat",
129
131
  "glm4-chat-1m",
132
+ "codegeex4",
130
133
  "deepseek-chat",
131
134
  "deepseek-coder-instruct",
132
135
  ]
@@ -148,6 +151,14 @@ if VLLM_INSTALLED and vllm.__version__ >= "0.4.0":
148
151
  VLLM_SUPPORTED_CHAT_MODELS.append("qwen2-moe-instruct")
149
152
  VLLM_SUPPORTED_CHAT_MODELS.append("c4ai-command-r-v01")
150
153
 
154
+ if VLLM_INSTALLED and vllm.__version__ >= "0.5.3":
155
+ VLLM_SUPPORTED_CHAT_MODELS.append("mistral-nemo-instruct")
156
+ VLLM_SUPPORTED_CHAT_MODELS.append("mistral-large-instruct")
157
+
158
+ if VLLM_INSTALLED and vllm.__version__ > "0.5.3":
159
+ VLLM_SUPPORTED_MODELS.append("llama-3.1")
160
+ VLLM_SUPPORTED_CHAT_MODELS.append("llama-3.1-instruct")
161
+
151
162
 
152
163
  class VLLMModel(LLM):
153
164
  def __init__(
File without changes
File without changes
@@ -0,0 +1,114 @@
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import print_function
16
+
17
+ import argparse
18
+ import logging
19
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
20
+ import os
21
+
22
+ import torch
23
+ from torch.utils.data import DataLoader
24
+ import torchaudio
25
+ from hyperpyyaml import load_hyperpyyaml
26
+ from tqdm import tqdm
27
+ from cosyvoice.cli.model import CosyVoiceModel
28
+
29
+ from cosyvoice.dataset.dataset import Dataset
30
+
31
+ def get_args():
32
+ parser = argparse.ArgumentParser(description='inference with your model')
33
+ parser.add_argument('--config', required=True, help='config file')
34
+ parser.add_argument('--prompt_data', required=True, help='prompt data file')
35
+ parser.add_argument('--prompt_utt2data', required=True, help='prompt data file')
36
+ parser.add_argument('--tts_text', required=True, help='tts input file')
37
+ parser.add_argument('--llm_model', required=True, help='llm model file')
38
+ parser.add_argument('--flow_model', required=True, help='flow model file')
39
+ parser.add_argument('--hifigan_model', required=True, help='hifigan model file')
40
+ parser.add_argument('--gpu',
41
+ type=int,
42
+ default=-1,
43
+ help='gpu id for this rank, -1 for cpu')
44
+ parser.add_argument('--mode',
45
+ default='sft',
46
+ choices=['sft', 'zero_shot'],
47
+ help='inference mode')
48
+ parser.add_argument('--result_dir', required=True, help='asr result file')
49
+ args = parser.parse_args()
50
+ print(args)
51
+ return args
52
+
53
+
54
+ def main():
55
+ args = get_args()
56
+ logging.basicConfig(level=logging.DEBUG,
57
+ format='%(asctime)s %(levelname)s %(message)s')
58
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
59
+
60
+ # Init cosyvoice models from configs
61
+ use_cuda = args.gpu >= 0 and torch.cuda.is_available()
62
+ device = torch.device('cuda' if use_cuda else 'cpu')
63
+ with open(args.config, 'r') as f:
64
+ configs = load_hyperpyyaml(f)
65
+
66
+ model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
67
+ model.load(args.llm_model, args.flow_model, args.hifigan_model)
68
+
69
+ test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False, tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
70
+ test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
71
+
72
+ del configs
73
+ os.makedirs(args.result_dir, exist_ok=True)
74
+ fn = os.path.join(args.result_dir, 'wav.scp')
75
+ f = open(fn, 'w')
76
+ with torch.no_grad():
77
+ for batch_idx, batch in tqdm(enumerate(test_data_loader)):
78
+ utts = batch["utts"]
79
+ assert len(utts) == 1, "inference mode only support batchsize 1"
80
+ text = batch["text"]
81
+ text_token = batch["text_token"].to(device)
82
+ text_token_len = batch["text_token_len"].to(device)
83
+ tts_text = batch["tts_text"]
84
+ tts_index = batch["tts_index"]
85
+ tts_text_token = batch["tts_text_token"].to(device)
86
+ tts_text_token_len = batch["tts_text_token_len"].to(device)
87
+ speech_token = batch["speech_token"].to(device)
88
+ speech_token_len = batch["speech_token_len"].to(device)
89
+ speech_feat = batch["speech_feat"].to(device)
90
+ speech_feat_len = batch["speech_feat_len"].to(device)
91
+ utt_embedding = batch["utt_embedding"].to(device)
92
+ spk_embedding = batch["spk_embedding"].to(device)
93
+ if args.mode == 'sft':
94
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
95
+ 'llm_embedding': spk_embedding, 'flow_embedding': spk_embedding}
96
+ else:
97
+ model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
98
+ 'prompt_text': text_token, 'prompt_text_len': text_token_len,
99
+ 'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
100
+ 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
101
+ 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
102
+ 'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
103
+ model_output = model.inference(**model_input)
104
+ tts_key = '{}_{}'.format(utts[0], tts_index[0])
105
+ tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
106
+ torchaudio.save(tts_fn, model_output['tts_speech'], sample_rate=22050)
107
+ f.write('{} {}\n'.format(tts_key, tts_fn))
108
+ f.flush()
109
+ f.close()
110
+ logging.info('Result wav.scp saved in {}'.format(fn))
111
+
112
+
113
+ if __name__ == '__main__':
114
+ main()