xinference 0.13.0__py3-none-any.whl → 0.13.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (66) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +50 -2
  3. xinference/client/restful/restful_client.py +49 -2
  4. xinference/core/model.py +15 -0
  5. xinference/core/supervisor.py +132 -15
  6. xinference/core/worker.py +165 -8
  7. xinference/deploy/cmdline.py +5 -0
  8. xinference/model/audio/chattts.py +6 -6
  9. xinference/model/audio/core.py +23 -15
  10. xinference/model/core.py +12 -3
  11. xinference/model/embedding/core.py +25 -16
  12. xinference/model/flexible/__init__.py +40 -0
  13. xinference/model/flexible/core.py +228 -0
  14. xinference/model/flexible/launchers/__init__.py +15 -0
  15. xinference/model/flexible/launchers/transformers_launcher.py +63 -0
  16. xinference/model/flexible/utils.py +33 -0
  17. xinference/model/image/core.py +18 -14
  18. xinference/model/image/custom.py +1 -1
  19. xinference/model/llm/__init__.py +0 -2
  20. xinference/model/llm/core.py +3 -2
  21. xinference/model/llm/ggml/llamacpp.py +1 -10
  22. xinference/model/llm/llm_family.json +52 -35
  23. xinference/model/llm/llm_family.py +71 -46
  24. xinference/model/llm/llm_family_modelscope.json +55 -27
  25. xinference/model/llm/pytorch/core.py +0 -80
  26. xinference/model/llm/utils.py +4 -2
  27. xinference/model/rerank/core.py +24 -25
  28. xinference/types.py +0 -1
  29. xinference/web/ui/build/asset-manifest.json +3 -3
  30. xinference/web/ui/build/index.html +1 -1
  31. xinference/web/ui/build/static/js/{main.0fb6f3ab.js → main.95c1d652.js} +3 -3
  32. xinference/web/ui/build/static/js/main.95c1d652.js.map +1 -0
  33. xinference/web/ui/node_modules/.cache/babel-loader/07ce9e632e6aff24d7aa3ad8e48224433bbfeb0d633fca723453f1fcae0c9f1c.json +1 -0
  34. xinference/web/ui/node_modules/.cache/babel-loader/40f17338fc75ae095de7d2b4d8eae0d5ca0193a7e2bcece4ee745b22a7a2f4b7.json +1 -0
  35. xinference/web/ui/node_modules/.cache/babel-loader/5262556baf9207738bf6a8ba141ec6599d0a636345c245d61fdf88d3171998cb.json +1 -0
  36. xinference/web/ui/node_modules/.cache/babel-loader/709711edada3f1596b309d571285fd31f1c364d66f4425bc28723d0088cc351a.json +1 -0
  37. xinference/web/ui/node_modules/.cache/babel-loader/70fa8c07463a5fe57c68bf92502910105a8f647371836fe8c3a7408246ca7ba0.json +1 -0
  38. xinference/web/ui/node_modules/.cache/babel-loader/f3e02274cb1964e99b1fe69cbb6db233d3d8d7dd05d50ebcdb8e66d50b224b7b.json +1 -0
  39. {xinference-0.13.0.dist-info → xinference-0.13.1.dist-info}/METADATA +7 -11
  40. {xinference-0.13.0.dist-info → xinference-0.13.1.dist-info}/RECORD +45 -54
  41. xinference/model/llm/ggml/chatglm.py +0 -457
  42. xinference/thirdparty/ChatTTS/__init__.py +0 -1
  43. xinference/thirdparty/ChatTTS/core.py +0 -200
  44. xinference/thirdparty/ChatTTS/experimental/__init__.py +0 -0
  45. xinference/thirdparty/ChatTTS/experimental/llm.py +0 -40
  46. xinference/thirdparty/ChatTTS/infer/__init__.py +0 -0
  47. xinference/thirdparty/ChatTTS/infer/api.py +0 -125
  48. xinference/thirdparty/ChatTTS/model/__init__.py +0 -0
  49. xinference/thirdparty/ChatTTS/model/dvae.py +0 -155
  50. xinference/thirdparty/ChatTTS/model/gpt.py +0 -265
  51. xinference/thirdparty/ChatTTS/utils/__init__.py +0 -0
  52. xinference/thirdparty/ChatTTS/utils/gpu_utils.py +0 -23
  53. xinference/thirdparty/ChatTTS/utils/infer_utils.py +0 -141
  54. xinference/thirdparty/ChatTTS/utils/io_utils.py +0 -14
  55. xinference/web/ui/build/static/js/main.0fb6f3ab.js.map +0 -1
  56. xinference/web/ui/node_modules/.cache/babel-loader/0f6b391abec76271137faad13a3793fe7acc1024e8cd2269c147b653ecd3a73b.json +0 -1
  57. xinference/web/ui/node_modules/.cache/babel-loader/30a0c79d8025d6441eb75b2df5bc2750a14f30119c869ef02570d294dff65c2f.json +0 -1
  58. xinference/web/ui/node_modules/.cache/babel-loader/40486e655c3c5801f087e2cf206c0b5511aaa0dfdba78046b7181bf9c17e54c5.json +0 -1
  59. xinference/web/ui/node_modules/.cache/babel-loader/b5507cd57f16a3a230aa0128e39fe103e928de139ea29e2679e4c64dcbba3b3a.json +0 -1
  60. xinference/web/ui/node_modules/.cache/babel-loader/d779b915f83f9c7b5a72515b6932fdd114f1822cef90ae01cc0d12bca59abc2d.json +0 -1
  61. xinference/web/ui/node_modules/.cache/babel-loader/d87824cb266194447a9c0c69ebab2d507bfc3e3148976173760d18c035e9dd26.json +0 -1
  62. /xinference/web/ui/build/static/js/{main.0fb6f3ab.js.LICENSE.txt → main.95c1d652.js.LICENSE.txt} +0 -0
  63. {xinference-0.13.0.dist-info → xinference-0.13.1.dist-info}/LICENSE +0 -0
  64. {xinference-0.13.0.dist-info → xinference-0.13.1.dist-info}/WHEEL +0 -0
  65. {xinference-0.13.0.dist-info → xinference-0.13.1.dist-info}/entry_points.txt +0 -0
  66. {xinference-0.13.0.dist-info → xinference-0.13.1.dist-info}/top_level.txt +0 -0
@@ -1,457 +0,0 @@
1
- # Copyright 2022-2023 XProbe Inc.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- import json
15
- import logging
16
- import os
17
- import time
18
- import uuid
19
- from pathlib import Path
20
- from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Union
21
-
22
- from ....types import (
23
- SPECIAL_TOOL_PROMPT,
24
- ChatCompletion,
25
- ChatCompletionChunk,
26
- ChatCompletionMessage,
27
- ChatglmCppGenerateConfig,
28
- ChatglmCppModelConfig,
29
- Completion,
30
- CompletionChunk,
31
- )
32
- from .. import LLMFamilyV1, LLMSpecV1
33
- from ..core import LLM
34
-
35
- if TYPE_CHECKING:
36
- from chatglm_cpp import Pipeline
37
-
38
-
39
- logger = logging.getLogger(__name__)
40
-
41
-
42
- class ChatglmCppChatModel(LLM):
43
- def __init__(
44
- self,
45
- model_uid: str,
46
- model_family: "LLMFamilyV1",
47
- model_spec: "LLMSpecV1",
48
- quantization: str,
49
- model_path: str,
50
- model_config: Optional[ChatglmCppModelConfig] = None,
51
- ):
52
- super().__init__(model_uid, model_family, model_spec, quantization, model_path)
53
- self._llm: Optional["Pipeline"] = None
54
-
55
- # just a placeholder for now as the chatglm_cpp repo doesn't support model config.
56
- self._model_config = model_config
57
-
58
- @classmethod
59
- def _sanitize_generate_config(
60
- cls,
61
- chatglmcpp_generate_config: Optional[ChatglmCppGenerateConfig],
62
- ) -> ChatglmCppGenerateConfig:
63
- if chatglmcpp_generate_config is None:
64
- chatglmcpp_generate_config = ChatglmCppGenerateConfig()
65
- chatglmcpp_generate_config.setdefault("stream", False)
66
- return chatglmcpp_generate_config
67
-
68
- def load(self):
69
- try:
70
- import chatglm_cpp
71
- except ImportError:
72
- error_message = "Failed to import module 'chatglm_cpp'"
73
- installation_guide = [
74
- "Please make sure 'chatglm_cpp' is installed. ",
75
- "You can install it by running the following command in the terminal:\n",
76
- "pip install git+https://github.com/li-plus/chatglm.cpp.git@main\n\n",
77
- "Or visit the original git repo if the above command fails:\n",
78
- "https://github.com/li-plus/chatglm.cpp",
79
- ]
80
-
81
- raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
82
-
83
- model_file_path = os.path.join(
84
- self.model_path,
85
- self.model_spec.model_file_name_template.format(
86
- quantization=self.quantization
87
- ),
88
- )
89
-
90
- # handle legacy cache.
91
- legacy_model_file_path = os.path.join(self.model_path, "model.bin")
92
- if os.path.exists(legacy_model_file_path):
93
- model_file_path = legacy_model_file_path
94
-
95
- self._llm = chatglm_cpp.Pipeline(Path(model_file_path))
96
-
97
- @classmethod
98
- def match(
99
- cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
100
- ) -> bool:
101
- if llm_spec.model_format != "ggmlv3":
102
- return False
103
- if "chatglm" not in llm_family.model_name:
104
- return False
105
- if "chat" not in llm_family.model_ability:
106
- return False
107
- return True
108
-
109
- @staticmethod
110
- def _convert_raw_text_chunks_to_chat(
111
- tokens: Iterator[Any], model_name: str, include_usage: bool, input_ids
112
- ) -> Iterator[ChatCompletionChunk]:
113
- request_id = str(uuid.uuid4())
114
- yield {
115
- "id": "chat" + f"cmpl-{request_id}",
116
- "model": model_name,
117
- "object": "chat.completion.chunk",
118
- "created": int(time.time()),
119
- "choices": [
120
- {
121
- "index": 0,
122
- "delta": {
123
- "role": "assistant",
124
- },
125
- "finish_reason": None,
126
- }
127
- ],
128
- }
129
- prompt_tokens, completion_tokens, total_tokens = 0, 0, 0
130
- for token in tokens:
131
- prompt_tokens = len(input_ids)
132
- completion_tokens = completion_tokens + 1
133
- total_tokens = prompt_tokens + completion_tokens
134
- yield {
135
- "id": "chat" + f"cmpl-{request_id}",
136
- "model": model_name,
137
- "object": "chat.completion.chunk",
138
- "created": int(time.time()),
139
- "choices": [
140
- {
141
- "index": 0,
142
- "delta": {
143
- "content": (
144
- token if isinstance(token, str) else token.content
145
- ),
146
- },
147
- "finish_reason": None,
148
- }
149
- ],
150
- }
151
- # stop
152
- yield {
153
- "id": "chat" + f"cmpl-{request_id}",
154
- "model": model_name,
155
- "object": "chat.completion.chunk",
156
- "created": int(time.time()),
157
- "choices": [
158
- {
159
- "index": 0,
160
- "delta": {
161
- "content": "",
162
- },
163
- "finish_reason": "stop",
164
- }
165
- ],
166
- }
167
- if include_usage:
168
- yield {
169
- "id": "chat" + f"cmpl-{request_id}",
170
- "model": model_name,
171
- "object": "chat.completion.chunk",
172
- "created": int(time.time()),
173
- "choices": [],
174
- "usage": {
175
- "prompt_tokens": prompt_tokens,
176
- "completion_tokens": completion_tokens,
177
- "total_tokens": total_tokens,
178
- },
179
- }
180
-
181
- @classmethod
182
- def _convert_raw_text_completion_to_chat(
183
- cls, text: Any, model_name: str
184
- ) -> ChatCompletion:
185
- _id = str(uuid.uuid4())
186
- return {
187
- "id": "chat" + f"cmpl-{_id}",
188
- "model": model_name,
189
- "object": "chat.completion",
190
- "created": int(time.time()),
191
- "choices": [
192
- {
193
- "index": 0,
194
- "message": cls._message_to_json_string(_id, text),
195
- "finish_reason": cls._finish_reason_from_msg(text),
196
- }
197
- ],
198
- "usage": {
199
- "prompt_tokens": -1,
200
- "completion_tokens": -1,
201
- "total_tokens": -1,
202
- },
203
- }
204
-
205
- @staticmethod
206
- def _finish_reason_from_msg(msg):
207
- if isinstance(msg, str):
208
- return None
209
- else:
210
- return "tool_calls" if msg.tool_calls else "stop"
211
-
212
- @staticmethod
213
- def _eval_arguments(arguments):
214
- def tool_call(**kwargs):
215
- return kwargs
216
-
217
- try:
218
- return json.dumps(eval(arguments, dict(tool_call=tool_call)))
219
- except Exception:
220
- return f"Invalid arguments {arguments}"
221
-
222
- @classmethod
223
- def _message_to_json_string(cls, _id, msg) -> ChatCompletionMessage:
224
- if isinstance(msg, str):
225
- return {
226
- "role": "assistant",
227
- "content": msg,
228
- }
229
- else:
230
- return {
231
- "role": msg.role,
232
- "content": msg.content,
233
- "tool_calls": [
234
- {
235
- "id": f"call_{_id}",
236
- "type": tc.type,
237
- "function": {
238
- "name": tc.function.name,
239
- "arguments": cls._eval_arguments(tc.function.arguments),
240
- },
241
- }
242
- for tc in msg.tool_calls
243
- ],
244
- }
245
-
246
- @staticmethod
247
- def _handle_tools(generate_config) -> Optional[ChatCompletionMessage]:
248
- """Convert openai tools to ChatGLM tools."""
249
- if generate_config is None:
250
- return None
251
- tools = generate_config.pop("tools", None)
252
- if tools is None:
253
- return None
254
- chatglm_tools = []
255
- for elem in tools:
256
- if elem.get("type") != "function" or "function" not in elem:
257
- raise ValueError("ChatGLM tools only support function type.")
258
- chatglm_tools.append(elem["function"])
259
- return {
260
- "role": "system",
261
- "content": (
262
- f"Answer the following questions as best as you can. You have access to the following tools:\n"
263
- f"{json.dumps(chatglm_tools, indent=4, ensure_ascii=False)}"
264
- ),
265
- }
266
-
267
- @staticmethod
268
- def _to_chatglm_chat_messages(history_list: List[Any]):
269
- from chatglm_cpp import ChatMessage
270
-
271
- return [ChatMessage(role=v["role"], content=v["content"]) for v in history_list]
272
-
273
- def chat(
274
- self,
275
- prompt: str,
276
- system_prompt: Optional[str] = None,
277
- chat_history: Optional[List[ChatCompletionMessage]] = None,
278
- generate_config: Optional[ChatglmCppGenerateConfig] = None,
279
- ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
280
- chat_history_list = []
281
- if system_prompt is not None:
282
- chat_history_list.append({"role": "system", "content": system_prompt})
283
- if chat_history is not None:
284
- chat_history_list.extend(chat_history) # type: ignore
285
-
286
- tool_message = self._handle_tools(generate_config)
287
- if tool_message is not None:
288
- chat_history_list.insert(0, tool_message) # type: ignore
289
-
290
- # We drop the message which contains tool calls to walkaround the issue:
291
- # https://github.com/li-plus/chatglm.cpp/issues/231
292
- chat_history_list = [m for m in chat_history_list if not m.get("tool_calls")]
293
- for idx, m in enumerate(chat_history_list):
294
- if m.get("role") == "tool":
295
- # Reconstruct a simple tool message.
296
- chat_history_list[idx] = {
297
- "content": m["content"],
298
- "role": "observation",
299
- }
300
- break
301
-
302
- if prompt != SPECIAL_TOOL_PROMPT:
303
- chat_history_list.append({"role": "user", "content": prompt})
304
- logger.debug("Full conversation history:\n%s", str(chat_history_list))
305
-
306
- generate_config = self._sanitize_generate_config(generate_config)
307
-
308
- params = {
309
- "max_length": generate_config.get("max_tokens"),
310
- "max_context_length": generate_config.get("max_tokens", 1024),
311
- "top_k": generate_config.get("top_k"),
312
- "top_p": generate_config.get("top_p"),
313
- "temperature": generate_config.get("temperature"),
314
- "stream": generate_config.get("stream", False),
315
- }
316
-
317
- # Remove None values to exclude missing keys from params
318
- params = {k: v for k, v in params.items() if v is not None}
319
-
320
- assert self._llm is not None
321
- chat_history_messages = self._to_chatglm_chat_messages(chat_history_list)
322
-
323
- stream = generate_config.get("stream")
324
- stream_options = generate_config.get("stream_options", None)
325
- include_usage = (
326
- stream_options["include_usage"]
327
- if isinstance(stream_options, dict)
328
- else False
329
- )
330
-
331
- if stream:
332
- it = self._llm.chat(
333
- chat_history_messages,
334
- **params,
335
- )
336
- assert not isinstance(it, str)
337
- input_ids = self._llm.tokenizer.encode_messages(
338
- chat_history_messages, params["max_context_length"]
339
- )
340
- return self._convert_raw_text_chunks_to_chat(
341
- it, self.model_uid, include_usage, input_ids
342
- )
343
-
344
- else:
345
- c = self._llm.chat(
346
- chat_history_messages,
347
- **params,
348
- )
349
- assert not isinstance(c, Iterator)
350
- return self._convert_raw_text_completion_to_chat(c, self.model_uid)
351
-
352
- @staticmethod
353
- def _convert_str_to_completion(data: str, model_name: str) -> Completion:
354
- return {
355
- "id": "generate" + f"-{str(uuid.uuid4())}",
356
- "model": model_name,
357
- "object": "text_completion",
358
- "created": int(time.time()),
359
- "choices": [
360
- {"index": 0, "text": data, "finish_reason": None, "logprobs": None}
361
- ],
362
- "usage": {
363
- "prompt_tokens": -1,
364
- "completion_tokens": -1,
365
- "total_tokens": -1,
366
- },
367
- }
368
-
369
- @staticmethod
370
- def _convert_str_to_completion_chunk(
371
- tokens: Iterator[str], model_name: str, include_usage: bool, input_ids
372
- ) -> Iterator[CompletionChunk]:
373
- request_id = str(uuid.uuid4())
374
- prompt_tokens, completion_tokens, total_tokens = 0, 0, 0
375
- for i, token in enumerate(tokens):
376
- yield {
377
- "id": "generate" + f"-{request_id}",
378
- "model": model_name,
379
- "object": "text_completion",
380
- "created": int(time.time()),
381
- "choices": [
382
- {"index": 0, "text": token, "finish_reason": None, "logprobs": None}
383
- ],
384
- }
385
- prompt_tokens = len(input_ids)
386
- completion_tokens = i
387
- total_tokens = prompt_tokens + completion_tokens
388
- # stop
389
- yield {
390
- "id": "chat" + f"cmpl-{request_id}",
391
- "model": model_name,
392
- "object": "text_completion",
393
- "created": int(time.time()),
394
- "choices": [
395
- {"index": 0, "text": "", "finish_reason": "stop", "logprobs": None}
396
- ],
397
- }
398
- if include_usage:
399
- yield {
400
- "id": "chat" + f"cmpl-{request_id}",
401
- "model": model_name,
402
- "object": "text_completion",
403
- "created": int(time.time()),
404
- "choices": [],
405
- "usage": {
406
- "prompt_tokens": prompt_tokens,
407
- "completion_tokens": completion_tokens,
408
- "total_tokens": total_tokens,
409
- },
410
- }
411
-
412
- def generate(
413
- self,
414
- prompt: str,
415
- generate_config: Optional[ChatglmCppGenerateConfig] = None,
416
- ) -> Union[Completion, Iterator[CompletionChunk]]:
417
- logger.debug(f"Prompt for generate:\n{prompt}")
418
-
419
- generate_config = self._sanitize_generate_config(generate_config)
420
-
421
- params = {
422
- "max_length": generate_config.get("max_tokens"),
423
- "max_context_length": generate_config.get("max_tokens", 1024),
424
- "top_k": generate_config.get("top_k"),
425
- "top_p": generate_config.get("top_p"),
426
- "temperature": generate_config.get("temperature"),
427
- "stream": generate_config.get("stream", False),
428
- }
429
-
430
- # Remove None values to exclude missing keys from params
431
- params = {k: v for k, v in params.items() if v is not None}
432
-
433
- assert self._llm is not None
434
- stream = generate_config.get("stream")
435
- stream_options = generate_config.get("stream_options", None)
436
- include_usage = (
437
- stream_options["include_usage"]
438
- if isinstance(stream_options, dict)
439
- else False
440
- )
441
- if stream:
442
- it = self._llm.generate(
443
- prompt,
444
- **params,
445
- )
446
- assert not isinstance(it, str)
447
- input_ids = self._llm.tokenizer.encode(prompt, params["max_context_length"])
448
- return self._convert_str_to_completion_chunk(
449
- it, self.model_uid, include_usage, input_ids
450
- )
451
- else:
452
- c = self._llm.generate(
453
- prompt,
454
- **params,
455
- )
456
- assert not isinstance(c, Iterator)
457
- return self._convert_str_to_completion(c, self.model_uid)
@@ -1 +0,0 @@
1
- from .core import Chat
@@ -1,200 +0,0 @@
1
-
2
- import os
3
- import logging
4
- from functools import partial
5
- from omegaconf import OmegaConf
6
-
7
- import torch
8
- from vocos import Vocos
9
- from .model.dvae import DVAE
10
- from .model.gpt import GPT_warpper
11
- from .utils.gpu_utils import select_device
12
- from .utils.infer_utils import count_invalid_characters, detect_language, apply_character_map, apply_half2full_map
13
- from .utils.io_utils import get_latest_modified_file
14
- from .infer.api import refine_text, infer_code
15
-
16
- from huggingface_hub import snapshot_download
17
-
18
- logging.basicConfig(level = logging.INFO)
19
-
20
-
21
- class Chat:
22
- def __init__(self, ):
23
- self.pretrain_models = {}
24
- self.normalizer = {}
25
- self.logger = logging.getLogger(__name__)
26
-
27
- def check_model(self, level = logging.INFO, use_decoder = False):
28
- not_finish = False
29
- check_list = ['vocos', 'gpt', 'tokenizer']
30
-
31
- if use_decoder:
32
- check_list.append('decoder')
33
- else:
34
- check_list.append('dvae')
35
-
36
- for module in check_list:
37
- if module not in self.pretrain_models:
38
- self.logger.log(logging.WARNING, f'{module} not initialized.')
39
- not_finish = True
40
-
41
- if not not_finish:
42
- self.logger.log(level, f'All initialized.')
43
-
44
- return not not_finish
45
-
46
- def load_models(self, source='huggingface', force_redownload=False, local_path='<LOCAL_PATH>', **kwargs):
47
- if source == 'huggingface':
48
- hf_home = os.getenv('HF_HOME', os.path.expanduser("~/.cache/huggingface"))
49
- try:
50
- download_path = get_latest_modified_file(os.path.join(hf_home, 'hub/models--2Noise--ChatTTS/snapshots'))
51
- except:
52
- download_path = None
53
- if download_path is None or force_redownload:
54
- self.logger.log(logging.INFO, f'Download from HF: https://huggingface.co/2Noise/ChatTTS')
55
- download_path = snapshot_download(repo_id="2Noise/ChatTTS", allow_patterns=["*.pt", "*.yaml"])
56
- else:
57
- self.logger.log(logging.INFO, f'Load from cache: {download_path}')
58
- elif source == 'local':
59
- self.logger.log(logging.INFO, f'Load from local: {local_path}')
60
- download_path = local_path
61
-
62
- self._load(**{k: os.path.join(download_path, v) for k, v in OmegaConf.load(os.path.join(download_path, 'config', 'path.yaml')).items()}, **kwargs)
63
-
64
- def _load(
65
- self,
66
- vocos_config_path: str = None,
67
- vocos_ckpt_path: str = None,
68
- dvae_config_path: str = None,
69
- dvae_ckpt_path: str = None,
70
- gpt_config_path: str = None,
71
- gpt_ckpt_path: str = None,
72
- decoder_config_path: str = None,
73
- decoder_ckpt_path: str = None,
74
- tokenizer_path: str = None,
75
- device: str = None,
76
- compile: bool = True,
77
- ):
78
- if not device:
79
- device = select_device(4096)
80
- self.logger.log(logging.INFO, f'use {device}')
81
-
82
- if vocos_config_path:
83
- vocos = Vocos.from_hparams(vocos_config_path).to(device).eval()
84
- assert vocos_ckpt_path, 'vocos_ckpt_path should not be None'
85
- vocos.load_state_dict(torch.load(vocos_ckpt_path))
86
- self.pretrain_models['vocos'] = vocos
87
- self.logger.log(logging.INFO, 'vocos loaded.')
88
-
89
- if dvae_config_path:
90
- cfg = OmegaConf.load(dvae_config_path)
91
- dvae = DVAE(**cfg).to(device).eval()
92
- assert dvae_ckpt_path, 'dvae_ckpt_path should not be None'
93
- dvae.load_state_dict(torch.load(dvae_ckpt_path, map_location='cpu'))
94
- self.pretrain_models['dvae'] = dvae
95
- self.logger.log(logging.INFO, 'dvae loaded.')
96
-
97
- if gpt_config_path:
98
- cfg = OmegaConf.load(gpt_config_path)
99
- gpt = GPT_warpper(**cfg).to(device).eval()
100
- assert gpt_ckpt_path, 'gpt_ckpt_path should not be None'
101
- gpt.load_state_dict(torch.load(gpt_ckpt_path, map_location='cpu'))
102
- if compile and 'cuda' in str(device):
103
- gpt.gpt.forward = torch.compile(gpt.gpt.forward, backend='inductor', dynamic=True)
104
- self.pretrain_models['gpt'] = gpt
105
- spk_stat_path = os.path.join(os.path.dirname(gpt_ckpt_path), 'spk_stat.pt')
106
- assert os.path.exists(spk_stat_path), f'Missing spk_stat.pt: {spk_stat_path}'
107
- self.pretrain_models['spk_stat'] = torch.load(spk_stat_path).to(device)
108
- self.logger.log(logging.INFO, 'gpt loaded.')
109
-
110
- if decoder_config_path:
111
- cfg = OmegaConf.load(decoder_config_path)
112
- decoder = DVAE(**cfg).to(device).eval()
113
- assert decoder_ckpt_path, 'decoder_ckpt_path should not be None'
114
- decoder.load_state_dict(torch.load(decoder_ckpt_path, map_location='cpu'))
115
- self.pretrain_models['decoder'] = decoder
116
- self.logger.log(logging.INFO, 'decoder loaded.')
117
-
118
- if tokenizer_path:
119
- tokenizer = torch.load(tokenizer_path, map_location='cpu')
120
- tokenizer.padding_side = 'left'
121
- self.pretrain_models['tokenizer'] = tokenizer
122
- self.logger.log(logging.INFO, 'tokenizer loaded.')
123
-
124
- self.check_model()
125
-
126
- def infer(
127
- self,
128
- text,
129
- skip_refine_text=False,
130
- refine_text_only=False,
131
- params_refine_text={},
132
- params_infer_code={'prompt':'[speed_5]'},
133
- use_decoder=True,
134
- do_text_normalization=True,
135
- lang=None,
136
- ):
137
-
138
- assert self.check_model(use_decoder=use_decoder)
139
-
140
- if not isinstance(text, list):
141
- text = [text]
142
-
143
- if do_text_normalization:
144
- for i, t in enumerate(text):
145
- _lang = detect_language(t) if lang is None else lang
146
- self.init_normalizer(_lang)
147
- text[i] = self.normalizer[_lang](t)
148
- if _lang == 'zh':
149
- text[i] = apply_half2full_map(text[i])
150
-
151
- for i, t in enumerate(text):
152
- invalid_characters = count_invalid_characters(t)
153
- if len(invalid_characters):
154
- self.logger.log(logging.WARNING, f'Invalid characters found! : {invalid_characters}')
155
- text[i] = apply_character_map(t)
156
-
157
- if not skip_refine_text:
158
- text_tokens = refine_text(self.pretrain_models, text, **params_refine_text)['ids']
159
- text_tokens = [i[i < self.pretrain_models['tokenizer'].convert_tokens_to_ids('[break_0]')] for i in text_tokens]
160
- text = self.pretrain_models['tokenizer'].batch_decode(text_tokens)
161
- if refine_text_only:
162
- return text
163
-
164
- text = [params_infer_code.get('prompt', '') + i for i in text]
165
- params_infer_code.pop('prompt', '')
166
- result = infer_code(self.pretrain_models, text, **params_infer_code, return_hidden=use_decoder)
167
-
168
- if use_decoder:
169
- mel_spec = [self.pretrain_models['decoder'](i[None].permute(0,2,1)) for i in result['hiddens']]
170
- else:
171
- mel_spec = [self.pretrain_models['dvae'](i[None].permute(0,2,1)) for i in result['ids']]
172
-
173
- wav = [self.pretrain_models['vocos'].decode(i).cpu().numpy() for i in mel_spec]
174
-
175
- return wav
176
-
177
- def sample_random_speaker(self, ):
178
-
179
- dim = self.pretrain_models['gpt'].gpt.layers[0].mlp.gate_proj.in_features
180
- std, mean = self.pretrain_models['spk_stat'].chunk(2)
181
- return torch.randn(dim, device=std.device) * std + mean
182
-
183
- def init_normalizer(self, lang):
184
-
185
- if lang not in self.normalizer:
186
- if lang == 'zh':
187
- try:
188
- from tn.chinese.normalizer import Normalizer
189
- except:
190
- self.logger.log(logging.WARNING, f'Package WeTextProcessing not found! \
191
- Run: conda install -c conda-forge pynini=2.1.5 && pip install WeTextProcessing')
192
- self.normalizer[lang] = Normalizer().normalize
193
- else:
194
- try:
195
- from nemo_text_processing.text_normalization.normalize import Normalizer
196
- except:
197
- self.logger.log(logging.WARNING, f'Package nemo_text_processing not found! \
198
- Run: conda install -c conda-forge pynini=2.1.5 && pip install nemo_text_processing')
199
- self.normalizer[lang] = partial(Normalizer(input_case='cased', lang=lang).normalize, verbose=False, punct_post_process=True)
200
-
File without changes