xinference 0.13.0__py3-none-any.whl → 0.13.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +123 -3
- xinference/client/restful/restful_client.py +131 -2
- xinference/core/model.py +93 -24
- xinference/core/supervisor.py +132 -15
- xinference/core/worker.py +165 -8
- xinference/deploy/cmdline.py +5 -0
- xinference/model/audio/chattts.py +46 -14
- xinference/model/audio/core.py +23 -15
- xinference/model/core.py +12 -3
- xinference/model/embedding/core.py +25 -16
- xinference/model/flexible/__init__.py +40 -0
- xinference/model/flexible/core.py +228 -0
- xinference/model/flexible/launchers/__init__.py +15 -0
- xinference/model/flexible/launchers/transformers_launcher.py +63 -0
- xinference/model/flexible/utils.py +33 -0
- xinference/model/image/core.py +21 -14
- xinference/model/image/custom.py +1 -1
- xinference/model/image/model_spec.json +14 -0
- xinference/model/image/stable_diffusion/core.py +43 -6
- xinference/model/llm/__init__.py +0 -2
- xinference/model/llm/core.py +3 -2
- xinference/model/llm/ggml/llamacpp.py +1 -10
- xinference/model/llm/llm_family.json +292 -36
- xinference/model/llm/llm_family.py +97 -52
- xinference/model/llm/llm_family_modelscope.json +220 -27
- xinference/model/llm/pytorch/core.py +0 -80
- xinference/model/llm/sglang/core.py +7 -2
- xinference/model/llm/utils.py +4 -2
- xinference/model/llm/vllm/core.py +3 -0
- xinference/model/rerank/core.py +24 -25
- xinference/types.py +0 -1
- xinference/web/ui/build/asset-manifest.json +3 -3
- xinference/web/ui/build/index.html +1 -1
- xinference/web/ui/build/static/js/{main.0fb6f3ab.js → main.95c1d652.js} +3 -3
- xinference/web/ui/build/static/js/main.95c1d652.js.map +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/07ce9e632e6aff24d7aa3ad8e48224433bbfeb0d633fca723453f1fcae0c9f1c.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/40f17338fc75ae095de7d2b4d8eae0d5ca0193a7e2bcece4ee745b22a7a2f4b7.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/5262556baf9207738bf6a8ba141ec6599d0a636345c245d61fdf88d3171998cb.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/709711edada3f1596b309d571285fd31f1c364d66f4425bc28723d0088cc351a.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/70fa8c07463a5fe57c68bf92502910105a8f647371836fe8c3a7408246ca7ba0.json +1 -0
- xinference/web/ui/node_modules/.cache/babel-loader/f3e02274cb1964e99b1fe69cbb6db233d3d8d7dd05d50ebcdb8e66d50b224b7b.json +1 -0
- {xinference-0.13.0.dist-info → xinference-0.13.2.dist-info}/METADATA +9 -11
- {xinference-0.13.0.dist-info → xinference-0.13.2.dist-info}/RECORD +49 -58
- xinference/model/llm/ggml/chatglm.py +0 -457
- xinference/thirdparty/ChatTTS/__init__.py +0 -1
- xinference/thirdparty/ChatTTS/core.py +0 -200
- xinference/thirdparty/ChatTTS/experimental/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/experimental/llm.py +0 -40
- xinference/thirdparty/ChatTTS/infer/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/infer/api.py +0 -125
- xinference/thirdparty/ChatTTS/model/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/model/dvae.py +0 -155
- xinference/thirdparty/ChatTTS/model/gpt.py +0 -265
- xinference/thirdparty/ChatTTS/utils/__init__.py +0 -0
- xinference/thirdparty/ChatTTS/utils/gpu_utils.py +0 -23
- xinference/thirdparty/ChatTTS/utils/infer_utils.py +0 -141
- xinference/thirdparty/ChatTTS/utils/io_utils.py +0 -14
- xinference/web/ui/build/static/js/main.0fb6f3ab.js.map +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/0f6b391abec76271137faad13a3793fe7acc1024e8cd2269c147b653ecd3a73b.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/30a0c79d8025d6441eb75b2df5bc2750a14f30119c869ef02570d294dff65c2f.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/40486e655c3c5801f087e2cf206c0b5511aaa0dfdba78046b7181bf9c17e54c5.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/b5507cd57f16a3a230aa0128e39fe103e928de139ea29e2679e4c64dcbba3b3a.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/d779b915f83f9c7b5a72515b6932fdd114f1822cef90ae01cc0d12bca59abc2d.json +0 -1
- xinference/web/ui/node_modules/.cache/babel-loader/d87824cb266194447a9c0c69ebab2d507bfc3e3148976173760d18c035e9dd26.json +0 -1
- /xinference/web/ui/build/static/js/{main.0fb6f3ab.js.LICENSE.txt → main.95c1d652.js.LICENSE.txt} +0 -0
- {xinference-0.13.0.dist-info → xinference-0.13.2.dist-info}/LICENSE +0 -0
- {xinference-0.13.0.dist-info → xinference-0.13.2.dist-info}/WHEEL +0 -0
- {xinference-0.13.0.dist-info → xinference-0.13.2.dist-info}/entry_points.txt +0 -0
- {xinference-0.13.0.dist-info → xinference-0.13.2.dist-info}/top_level.txt +0 -0
|
@@ -1,457 +0,0 @@
|
|
|
1
|
-
# Copyright 2022-2023 XProbe Inc.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
import json
|
|
15
|
-
import logging
|
|
16
|
-
import os
|
|
17
|
-
import time
|
|
18
|
-
import uuid
|
|
19
|
-
from pathlib import Path
|
|
20
|
-
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Union
|
|
21
|
-
|
|
22
|
-
from ....types import (
|
|
23
|
-
SPECIAL_TOOL_PROMPT,
|
|
24
|
-
ChatCompletion,
|
|
25
|
-
ChatCompletionChunk,
|
|
26
|
-
ChatCompletionMessage,
|
|
27
|
-
ChatglmCppGenerateConfig,
|
|
28
|
-
ChatglmCppModelConfig,
|
|
29
|
-
Completion,
|
|
30
|
-
CompletionChunk,
|
|
31
|
-
)
|
|
32
|
-
from .. import LLMFamilyV1, LLMSpecV1
|
|
33
|
-
from ..core import LLM
|
|
34
|
-
|
|
35
|
-
if TYPE_CHECKING:
|
|
36
|
-
from chatglm_cpp import Pipeline
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
logger = logging.getLogger(__name__)
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
class ChatglmCppChatModel(LLM):
|
|
43
|
-
def __init__(
|
|
44
|
-
self,
|
|
45
|
-
model_uid: str,
|
|
46
|
-
model_family: "LLMFamilyV1",
|
|
47
|
-
model_spec: "LLMSpecV1",
|
|
48
|
-
quantization: str,
|
|
49
|
-
model_path: str,
|
|
50
|
-
model_config: Optional[ChatglmCppModelConfig] = None,
|
|
51
|
-
):
|
|
52
|
-
super().__init__(model_uid, model_family, model_spec, quantization, model_path)
|
|
53
|
-
self._llm: Optional["Pipeline"] = None
|
|
54
|
-
|
|
55
|
-
# just a placeholder for now as the chatglm_cpp repo doesn't support model config.
|
|
56
|
-
self._model_config = model_config
|
|
57
|
-
|
|
58
|
-
@classmethod
|
|
59
|
-
def _sanitize_generate_config(
|
|
60
|
-
cls,
|
|
61
|
-
chatglmcpp_generate_config: Optional[ChatglmCppGenerateConfig],
|
|
62
|
-
) -> ChatglmCppGenerateConfig:
|
|
63
|
-
if chatglmcpp_generate_config is None:
|
|
64
|
-
chatglmcpp_generate_config = ChatglmCppGenerateConfig()
|
|
65
|
-
chatglmcpp_generate_config.setdefault("stream", False)
|
|
66
|
-
return chatglmcpp_generate_config
|
|
67
|
-
|
|
68
|
-
def load(self):
|
|
69
|
-
try:
|
|
70
|
-
import chatglm_cpp
|
|
71
|
-
except ImportError:
|
|
72
|
-
error_message = "Failed to import module 'chatglm_cpp'"
|
|
73
|
-
installation_guide = [
|
|
74
|
-
"Please make sure 'chatglm_cpp' is installed. ",
|
|
75
|
-
"You can install it by running the following command in the terminal:\n",
|
|
76
|
-
"pip install git+https://github.com/li-plus/chatglm.cpp.git@main\n\n",
|
|
77
|
-
"Or visit the original git repo if the above command fails:\n",
|
|
78
|
-
"https://github.com/li-plus/chatglm.cpp",
|
|
79
|
-
]
|
|
80
|
-
|
|
81
|
-
raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
|
|
82
|
-
|
|
83
|
-
model_file_path = os.path.join(
|
|
84
|
-
self.model_path,
|
|
85
|
-
self.model_spec.model_file_name_template.format(
|
|
86
|
-
quantization=self.quantization
|
|
87
|
-
),
|
|
88
|
-
)
|
|
89
|
-
|
|
90
|
-
# handle legacy cache.
|
|
91
|
-
legacy_model_file_path = os.path.join(self.model_path, "model.bin")
|
|
92
|
-
if os.path.exists(legacy_model_file_path):
|
|
93
|
-
model_file_path = legacy_model_file_path
|
|
94
|
-
|
|
95
|
-
self._llm = chatglm_cpp.Pipeline(Path(model_file_path))
|
|
96
|
-
|
|
97
|
-
@classmethod
|
|
98
|
-
def match(
|
|
99
|
-
cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
|
|
100
|
-
) -> bool:
|
|
101
|
-
if llm_spec.model_format != "ggmlv3":
|
|
102
|
-
return False
|
|
103
|
-
if "chatglm" not in llm_family.model_name:
|
|
104
|
-
return False
|
|
105
|
-
if "chat" not in llm_family.model_ability:
|
|
106
|
-
return False
|
|
107
|
-
return True
|
|
108
|
-
|
|
109
|
-
@staticmethod
|
|
110
|
-
def _convert_raw_text_chunks_to_chat(
|
|
111
|
-
tokens: Iterator[Any], model_name: str, include_usage: bool, input_ids
|
|
112
|
-
) -> Iterator[ChatCompletionChunk]:
|
|
113
|
-
request_id = str(uuid.uuid4())
|
|
114
|
-
yield {
|
|
115
|
-
"id": "chat" + f"cmpl-{request_id}",
|
|
116
|
-
"model": model_name,
|
|
117
|
-
"object": "chat.completion.chunk",
|
|
118
|
-
"created": int(time.time()),
|
|
119
|
-
"choices": [
|
|
120
|
-
{
|
|
121
|
-
"index": 0,
|
|
122
|
-
"delta": {
|
|
123
|
-
"role": "assistant",
|
|
124
|
-
},
|
|
125
|
-
"finish_reason": None,
|
|
126
|
-
}
|
|
127
|
-
],
|
|
128
|
-
}
|
|
129
|
-
prompt_tokens, completion_tokens, total_tokens = 0, 0, 0
|
|
130
|
-
for token in tokens:
|
|
131
|
-
prompt_tokens = len(input_ids)
|
|
132
|
-
completion_tokens = completion_tokens + 1
|
|
133
|
-
total_tokens = prompt_tokens + completion_tokens
|
|
134
|
-
yield {
|
|
135
|
-
"id": "chat" + f"cmpl-{request_id}",
|
|
136
|
-
"model": model_name,
|
|
137
|
-
"object": "chat.completion.chunk",
|
|
138
|
-
"created": int(time.time()),
|
|
139
|
-
"choices": [
|
|
140
|
-
{
|
|
141
|
-
"index": 0,
|
|
142
|
-
"delta": {
|
|
143
|
-
"content": (
|
|
144
|
-
token if isinstance(token, str) else token.content
|
|
145
|
-
),
|
|
146
|
-
},
|
|
147
|
-
"finish_reason": None,
|
|
148
|
-
}
|
|
149
|
-
],
|
|
150
|
-
}
|
|
151
|
-
# stop
|
|
152
|
-
yield {
|
|
153
|
-
"id": "chat" + f"cmpl-{request_id}",
|
|
154
|
-
"model": model_name,
|
|
155
|
-
"object": "chat.completion.chunk",
|
|
156
|
-
"created": int(time.time()),
|
|
157
|
-
"choices": [
|
|
158
|
-
{
|
|
159
|
-
"index": 0,
|
|
160
|
-
"delta": {
|
|
161
|
-
"content": "",
|
|
162
|
-
},
|
|
163
|
-
"finish_reason": "stop",
|
|
164
|
-
}
|
|
165
|
-
],
|
|
166
|
-
}
|
|
167
|
-
if include_usage:
|
|
168
|
-
yield {
|
|
169
|
-
"id": "chat" + f"cmpl-{request_id}",
|
|
170
|
-
"model": model_name,
|
|
171
|
-
"object": "chat.completion.chunk",
|
|
172
|
-
"created": int(time.time()),
|
|
173
|
-
"choices": [],
|
|
174
|
-
"usage": {
|
|
175
|
-
"prompt_tokens": prompt_tokens,
|
|
176
|
-
"completion_tokens": completion_tokens,
|
|
177
|
-
"total_tokens": total_tokens,
|
|
178
|
-
},
|
|
179
|
-
}
|
|
180
|
-
|
|
181
|
-
@classmethod
|
|
182
|
-
def _convert_raw_text_completion_to_chat(
|
|
183
|
-
cls, text: Any, model_name: str
|
|
184
|
-
) -> ChatCompletion:
|
|
185
|
-
_id = str(uuid.uuid4())
|
|
186
|
-
return {
|
|
187
|
-
"id": "chat" + f"cmpl-{_id}",
|
|
188
|
-
"model": model_name,
|
|
189
|
-
"object": "chat.completion",
|
|
190
|
-
"created": int(time.time()),
|
|
191
|
-
"choices": [
|
|
192
|
-
{
|
|
193
|
-
"index": 0,
|
|
194
|
-
"message": cls._message_to_json_string(_id, text),
|
|
195
|
-
"finish_reason": cls._finish_reason_from_msg(text),
|
|
196
|
-
}
|
|
197
|
-
],
|
|
198
|
-
"usage": {
|
|
199
|
-
"prompt_tokens": -1,
|
|
200
|
-
"completion_tokens": -1,
|
|
201
|
-
"total_tokens": -1,
|
|
202
|
-
},
|
|
203
|
-
}
|
|
204
|
-
|
|
205
|
-
@staticmethod
|
|
206
|
-
def _finish_reason_from_msg(msg):
|
|
207
|
-
if isinstance(msg, str):
|
|
208
|
-
return None
|
|
209
|
-
else:
|
|
210
|
-
return "tool_calls" if msg.tool_calls else "stop"
|
|
211
|
-
|
|
212
|
-
@staticmethod
|
|
213
|
-
def _eval_arguments(arguments):
|
|
214
|
-
def tool_call(**kwargs):
|
|
215
|
-
return kwargs
|
|
216
|
-
|
|
217
|
-
try:
|
|
218
|
-
return json.dumps(eval(arguments, dict(tool_call=tool_call)))
|
|
219
|
-
except Exception:
|
|
220
|
-
return f"Invalid arguments {arguments}"
|
|
221
|
-
|
|
222
|
-
@classmethod
|
|
223
|
-
def _message_to_json_string(cls, _id, msg) -> ChatCompletionMessage:
|
|
224
|
-
if isinstance(msg, str):
|
|
225
|
-
return {
|
|
226
|
-
"role": "assistant",
|
|
227
|
-
"content": msg,
|
|
228
|
-
}
|
|
229
|
-
else:
|
|
230
|
-
return {
|
|
231
|
-
"role": msg.role,
|
|
232
|
-
"content": msg.content,
|
|
233
|
-
"tool_calls": [
|
|
234
|
-
{
|
|
235
|
-
"id": f"call_{_id}",
|
|
236
|
-
"type": tc.type,
|
|
237
|
-
"function": {
|
|
238
|
-
"name": tc.function.name,
|
|
239
|
-
"arguments": cls._eval_arguments(tc.function.arguments),
|
|
240
|
-
},
|
|
241
|
-
}
|
|
242
|
-
for tc in msg.tool_calls
|
|
243
|
-
],
|
|
244
|
-
}
|
|
245
|
-
|
|
246
|
-
@staticmethod
|
|
247
|
-
def _handle_tools(generate_config) -> Optional[ChatCompletionMessage]:
|
|
248
|
-
"""Convert openai tools to ChatGLM tools."""
|
|
249
|
-
if generate_config is None:
|
|
250
|
-
return None
|
|
251
|
-
tools = generate_config.pop("tools", None)
|
|
252
|
-
if tools is None:
|
|
253
|
-
return None
|
|
254
|
-
chatglm_tools = []
|
|
255
|
-
for elem in tools:
|
|
256
|
-
if elem.get("type") != "function" or "function" not in elem:
|
|
257
|
-
raise ValueError("ChatGLM tools only support function type.")
|
|
258
|
-
chatglm_tools.append(elem["function"])
|
|
259
|
-
return {
|
|
260
|
-
"role": "system",
|
|
261
|
-
"content": (
|
|
262
|
-
f"Answer the following questions as best as you can. You have access to the following tools:\n"
|
|
263
|
-
f"{json.dumps(chatglm_tools, indent=4, ensure_ascii=False)}"
|
|
264
|
-
),
|
|
265
|
-
}
|
|
266
|
-
|
|
267
|
-
@staticmethod
|
|
268
|
-
def _to_chatglm_chat_messages(history_list: List[Any]):
|
|
269
|
-
from chatglm_cpp import ChatMessage
|
|
270
|
-
|
|
271
|
-
return [ChatMessage(role=v["role"], content=v["content"]) for v in history_list]
|
|
272
|
-
|
|
273
|
-
def chat(
|
|
274
|
-
self,
|
|
275
|
-
prompt: str,
|
|
276
|
-
system_prompt: Optional[str] = None,
|
|
277
|
-
chat_history: Optional[List[ChatCompletionMessage]] = None,
|
|
278
|
-
generate_config: Optional[ChatglmCppGenerateConfig] = None,
|
|
279
|
-
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
|
280
|
-
chat_history_list = []
|
|
281
|
-
if system_prompt is not None:
|
|
282
|
-
chat_history_list.append({"role": "system", "content": system_prompt})
|
|
283
|
-
if chat_history is not None:
|
|
284
|
-
chat_history_list.extend(chat_history) # type: ignore
|
|
285
|
-
|
|
286
|
-
tool_message = self._handle_tools(generate_config)
|
|
287
|
-
if tool_message is not None:
|
|
288
|
-
chat_history_list.insert(0, tool_message) # type: ignore
|
|
289
|
-
|
|
290
|
-
# We drop the message which contains tool calls to walkaround the issue:
|
|
291
|
-
# https://github.com/li-plus/chatglm.cpp/issues/231
|
|
292
|
-
chat_history_list = [m for m in chat_history_list if not m.get("tool_calls")]
|
|
293
|
-
for idx, m in enumerate(chat_history_list):
|
|
294
|
-
if m.get("role") == "tool":
|
|
295
|
-
# Reconstruct a simple tool message.
|
|
296
|
-
chat_history_list[idx] = {
|
|
297
|
-
"content": m["content"],
|
|
298
|
-
"role": "observation",
|
|
299
|
-
}
|
|
300
|
-
break
|
|
301
|
-
|
|
302
|
-
if prompt != SPECIAL_TOOL_PROMPT:
|
|
303
|
-
chat_history_list.append({"role": "user", "content": prompt})
|
|
304
|
-
logger.debug("Full conversation history:\n%s", str(chat_history_list))
|
|
305
|
-
|
|
306
|
-
generate_config = self._sanitize_generate_config(generate_config)
|
|
307
|
-
|
|
308
|
-
params = {
|
|
309
|
-
"max_length": generate_config.get("max_tokens"),
|
|
310
|
-
"max_context_length": generate_config.get("max_tokens", 1024),
|
|
311
|
-
"top_k": generate_config.get("top_k"),
|
|
312
|
-
"top_p": generate_config.get("top_p"),
|
|
313
|
-
"temperature": generate_config.get("temperature"),
|
|
314
|
-
"stream": generate_config.get("stream", False),
|
|
315
|
-
}
|
|
316
|
-
|
|
317
|
-
# Remove None values to exclude missing keys from params
|
|
318
|
-
params = {k: v for k, v in params.items() if v is not None}
|
|
319
|
-
|
|
320
|
-
assert self._llm is not None
|
|
321
|
-
chat_history_messages = self._to_chatglm_chat_messages(chat_history_list)
|
|
322
|
-
|
|
323
|
-
stream = generate_config.get("stream")
|
|
324
|
-
stream_options = generate_config.get("stream_options", None)
|
|
325
|
-
include_usage = (
|
|
326
|
-
stream_options["include_usage"]
|
|
327
|
-
if isinstance(stream_options, dict)
|
|
328
|
-
else False
|
|
329
|
-
)
|
|
330
|
-
|
|
331
|
-
if stream:
|
|
332
|
-
it = self._llm.chat(
|
|
333
|
-
chat_history_messages,
|
|
334
|
-
**params,
|
|
335
|
-
)
|
|
336
|
-
assert not isinstance(it, str)
|
|
337
|
-
input_ids = self._llm.tokenizer.encode_messages(
|
|
338
|
-
chat_history_messages, params["max_context_length"]
|
|
339
|
-
)
|
|
340
|
-
return self._convert_raw_text_chunks_to_chat(
|
|
341
|
-
it, self.model_uid, include_usage, input_ids
|
|
342
|
-
)
|
|
343
|
-
|
|
344
|
-
else:
|
|
345
|
-
c = self._llm.chat(
|
|
346
|
-
chat_history_messages,
|
|
347
|
-
**params,
|
|
348
|
-
)
|
|
349
|
-
assert not isinstance(c, Iterator)
|
|
350
|
-
return self._convert_raw_text_completion_to_chat(c, self.model_uid)
|
|
351
|
-
|
|
352
|
-
@staticmethod
|
|
353
|
-
def _convert_str_to_completion(data: str, model_name: str) -> Completion:
|
|
354
|
-
return {
|
|
355
|
-
"id": "generate" + f"-{str(uuid.uuid4())}",
|
|
356
|
-
"model": model_name,
|
|
357
|
-
"object": "text_completion",
|
|
358
|
-
"created": int(time.time()),
|
|
359
|
-
"choices": [
|
|
360
|
-
{"index": 0, "text": data, "finish_reason": None, "logprobs": None}
|
|
361
|
-
],
|
|
362
|
-
"usage": {
|
|
363
|
-
"prompt_tokens": -1,
|
|
364
|
-
"completion_tokens": -1,
|
|
365
|
-
"total_tokens": -1,
|
|
366
|
-
},
|
|
367
|
-
}
|
|
368
|
-
|
|
369
|
-
@staticmethod
|
|
370
|
-
def _convert_str_to_completion_chunk(
|
|
371
|
-
tokens: Iterator[str], model_name: str, include_usage: bool, input_ids
|
|
372
|
-
) -> Iterator[CompletionChunk]:
|
|
373
|
-
request_id = str(uuid.uuid4())
|
|
374
|
-
prompt_tokens, completion_tokens, total_tokens = 0, 0, 0
|
|
375
|
-
for i, token in enumerate(tokens):
|
|
376
|
-
yield {
|
|
377
|
-
"id": "generate" + f"-{request_id}",
|
|
378
|
-
"model": model_name,
|
|
379
|
-
"object": "text_completion",
|
|
380
|
-
"created": int(time.time()),
|
|
381
|
-
"choices": [
|
|
382
|
-
{"index": 0, "text": token, "finish_reason": None, "logprobs": None}
|
|
383
|
-
],
|
|
384
|
-
}
|
|
385
|
-
prompt_tokens = len(input_ids)
|
|
386
|
-
completion_tokens = i
|
|
387
|
-
total_tokens = prompt_tokens + completion_tokens
|
|
388
|
-
# stop
|
|
389
|
-
yield {
|
|
390
|
-
"id": "chat" + f"cmpl-{request_id}",
|
|
391
|
-
"model": model_name,
|
|
392
|
-
"object": "text_completion",
|
|
393
|
-
"created": int(time.time()),
|
|
394
|
-
"choices": [
|
|
395
|
-
{"index": 0, "text": "", "finish_reason": "stop", "logprobs": None}
|
|
396
|
-
],
|
|
397
|
-
}
|
|
398
|
-
if include_usage:
|
|
399
|
-
yield {
|
|
400
|
-
"id": "chat" + f"cmpl-{request_id}",
|
|
401
|
-
"model": model_name,
|
|
402
|
-
"object": "text_completion",
|
|
403
|
-
"created": int(time.time()),
|
|
404
|
-
"choices": [],
|
|
405
|
-
"usage": {
|
|
406
|
-
"prompt_tokens": prompt_tokens,
|
|
407
|
-
"completion_tokens": completion_tokens,
|
|
408
|
-
"total_tokens": total_tokens,
|
|
409
|
-
},
|
|
410
|
-
}
|
|
411
|
-
|
|
412
|
-
def generate(
|
|
413
|
-
self,
|
|
414
|
-
prompt: str,
|
|
415
|
-
generate_config: Optional[ChatglmCppGenerateConfig] = None,
|
|
416
|
-
) -> Union[Completion, Iterator[CompletionChunk]]:
|
|
417
|
-
logger.debug(f"Prompt for generate:\n{prompt}")
|
|
418
|
-
|
|
419
|
-
generate_config = self._sanitize_generate_config(generate_config)
|
|
420
|
-
|
|
421
|
-
params = {
|
|
422
|
-
"max_length": generate_config.get("max_tokens"),
|
|
423
|
-
"max_context_length": generate_config.get("max_tokens", 1024),
|
|
424
|
-
"top_k": generate_config.get("top_k"),
|
|
425
|
-
"top_p": generate_config.get("top_p"),
|
|
426
|
-
"temperature": generate_config.get("temperature"),
|
|
427
|
-
"stream": generate_config.get("stream", False),
|
|
428
|
-
}
|
|
429
|
-
|
|
430
|
-
# Remove None values to exclude missing keys from params
|
|
431
|
-
params = {k: v for k, v in params.items() if v is not None}
|
|
432
|
-
|
|
433
|
-
assert self._llm is not None
|
|
434
|
-
stream = generate_config.get("stream")
|
|
435
|
-
stream_options = generate_config.get("stream_options", None)
|
|
436
|
-
include_usage = (
|
|
437
|
-
stream_options["include_usage"]
|
|
438
|
-
if isinstance(stream_options, dict)
|
|
439
|
-
else False
|
|
440
|
-
)
|
|
441
|
-
if stream:
|
|
442
|
-
it = self._llm.generate(
|
|
443
|
-
prompt,
|
|
444
|
-
**params,
|
|
445
|
-
)
|
|
446
|
-
assert not isinstance(it, str)
|
|
447
|
-
input_ids = self._llm.tokenizer.encode(prompt, params["max_context_length"])
|
|
448
|
-
return self._convert_str_to_completion_chunk(
|
|
449
|
-
it, self.model_uid, include_usage, input_ids
|
|
450
|
-
)
|
|
451
|
-
else:
|
|
452
|
-
c = self._llm.generate(
|
|
453
|
-
prompt,
|
|
454
|
-
**params,
|
|
455
|
-
)
|
|
456
|
-
assert not isinstance(c, Iterator)
|
|
457
|
-
return self._convert_str_to_completion(c, self.model_uid)
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
from .core import Chat
|
|
@@ -1,200 +0,0 @@
|
|
|
1
|
-
|
|
2
|
-
import os
|
|
3
|
-
import logging
|
|
4
|
-
from functools import partial
|
|
5
|
-
from omegaconf import OmegaConf
|
|
6
|
-
|
|
7
|
-
import torch
|
|
8
|
-
from vocos import Vocos
|
|
9
|
-
from .model.dvae import DVAE
|
|
10
|
-
from .model.gpt import GPT_warpper
|
|
11
|
-
from .utils.gpu_utils import select_device
|
|
12
|
-
from .utils.infer_utils import count_invalid_characters, detect_language, apply_character_map, apply_half2full_map
|
|
13
|
-
from .utils.io_utils import get_latest_modified_file
|
|
14
|
-
from .infer.api import refine_text, infer_code
|
|
15
|
-
|
|
16
|
-
from huggingface_hub import snapshot_download
|
|
17
|
-
|
|
18
|
-
logging.basicConfig(level = logging.INFO)
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
class Chat:
|
|
22
|
-
def __init__(self, ):
|
|
23
|
-
self.pretrain_models = {}
|
|
24
|
-
self.normalizer = {}
|
|
25
|
-
self.logger = logging.getLogger(__name__)
|
|
26
|
-
|
|
27
|
-
def check_model(self, level = logging.INFO, use_decoder = False):
|
|
28
|
-
not_finish = False
|
|
29
|
-
check_list = ['vocos', 'gpt', 'tokenizer']
|
|
30
|
-
|
|
31
|
-
if use_decoder:
|
|
32
|
-
check_list.append('decoder')
|
|
33
|
-
else:
|
|
34
|
-
check_list.append('dvae')
|
|
35
|
-
|
|
36
|
-
for module in check_list:
|
|
37
|
-
if module not in self.pretrain_models:
|
|
38
|
-
self.logger.log(logging.WARNING, f'{module} not initialized.')
|
|
39
|
-
not_finish = True
|
|
40
|
-
|
|
41
|
-
if not not_finish:
|
|
42
|
-
self.logger.log(level, f'All initialized.')
|
|
43
|
-
|
|
44
|
-
return not not_finish
|
|
45
|
-
|
|
46
|
-
def load_models(self, source='huggingface', force_redownload=False, local_path='<LOCAL_PATH>', **kwargs):
|
|
47
|
-
if source == 'huggingface':
|
|
48
|
-
hf_home = os.getenv('HF_HOME', os.path.expanduser("~/.cache/huggingface"))
|
|
49
|
-
try:
|
|
50
|
-
download_path = get_latest_modified_file(os.path.join(hf_home, 'hub/models--2Noise--ChatTTS/snapshots'))
|
|
51
|
-
except:
|
|
52
|
-
download_path = None
|
|
53
|
-
if download_path is None or force_redownload:
|
|
54
|
-
self.logger.log(logging.INFO, f'Download from HF: https://huggingface.co/2Noise/ChatTTS')
|
|
55
|
-
download_path = snapshot_download(repo_id="2Noise/ChatTTS", allow_patterns=["*.pt", "*.yaml"])
|
|
56
|
-
else:
|
|
57
|
-
self.logger.log(logging.INFO, f'Load from cache: {download_path}')
|
|
58
|
-
elif source == 'local':
|
|
59
|
-
self.logger.log(logging.INFO, f'Load from local: {local_path}')
|
|
60
|
-
download_path = local_path
|
|
61
|
-
|
|
62
|
-
self._load(**{k: os.path.join(download_path, v) for k, v in OmegaConf.load(os.path.join(download_path, 'config', 'path.yaml')).items()}, **kwargs)
|
|
63
|
-
|
|
64
|
-
def _load(
|
|
65
|
-
self,
|
|
66
|
-
vocos_config_path: str = None,
|
|
67
|
-
vocos_ckpt_path: str = None,
|
|
68
|
-
dvae_config_path: str = None,
|
|
69
|
-
dvae_ckpt_path: str = None,
|
|
70
|
-
gpt_config_path: str = None,
|
|
71
|
-
gpt_ckpt_path: str = None,
|
|
72
|
-
decoder_config_path: str = None,
|
|
73
|
-
decoder_ckpt_path: str = None,
|
|
74
|
-
tokenizer_path: str = None,
|
|
75
|
-
device: str = None,
|
|
76
|
-
compile: bool = True,
|
|
77
|
-
):
|
|
78
|
-
if not device:
|
|
79
|
-
device = select_device(4096)
|
|
80
|
-
self.logger.log(logging.INFO, f'use {device}')
|
|
81
|
-
|
|
82
|
-
if vocos_config_path:
|
|
83
|
-
vocos = Vocos.from_hparams(vocos_config_path).to(device).eval()
|
|
84
|
-
assert vocos_ckpt_path, 'vocos_ckpt_path should not be None'
|
|
85
|
-
vocos.load_state_dict(torch.load(vocos_ckpt_path))
|
|
86
|
-
self.pretrain_models['vocos'] = vocos
|
|
87
|
-
self.logger.log(logging.INFO, 'vocos loaded.')
|
|
88
|
-
|
|
89
|
-
if dvae_config_path:
|
|
90
|
-
cfg = OmegaConf.load(dvae_config_path)
|
|
91
|
-
dvae = DVAE(**cfg).to(device).eval()
|
|
92
|
-
assert dvae_ckpt_path, 'dvae_ckpt_path should not be None'
|
|
93
|
-
dvae.load_state_dict(torch.load(dvae_ckpt_path, map_location='cpu'))
|
|
94
|
-
self.pretrain_models['dvae'] = dvae
|
|
95
|
-
self.logger.log(logging.INFO, 'dvae loaded.')
|
|
96
|
-
|
|
97
|
-
if gpt_config_path:
|
|
98
|
-
cfg = OmegaConf.load(gpt_config_path)
|
|
99
|
-
gpt = GPT_warpper(**cfg).to(device).eval()
|
|
100
|
-
assert gpt_ckpt_path, 'gpt_ckpt_path should not be None'
|
|
101
|
-
gpt.load_state_dict(torch.load(gpt_ckpt_path, map_location='cpu'))
|
|
102
|
-
if compile and 'cuda' in str(device):
|
|
103
|
-
gpt.gpt.forward = torch.compile(gpt.gpt.forward, backend='inductor', dynamic=True)
|
|
104
|
-
self.pretrain_models['gpt'] = gpt
|
|
105
|
-
spk_stat_path = os.path.join(os.path.dirname(gpt_ckpt_path), 'spk_stat.pt')
|
|
106
|
-
assert os.path.exists(spk_stat_path), f'Missing spk_stat.pt: {spk_stat_path}'
|
|
107
|
-
self.pretrain_models['spk_stat'] = torch.load(spk_stat_path).to(device)
|
|
108
|
-
self.logger.log(logging.INFO, 'gpt loaded.')
|
|
109
|
-
|
|
110
|
-
if decoder_config_path:
|
|
111
|
-
cfg = OmegaConf.load(decoder_config_path)
|
|
112
|
-
decoder = DVAE(**cfg).to(device).eval()
|
|
113
|
-
assert decoder_ckpt_path, 'decoder_ckpt_path should not be None'
|
|
114
|
-
decoder.load_state_dict(torch.load(decoder_ckpt_path, map_location='cpu'))
|
|
115
|
-
self.pretrain_models['decoder'] = decoder
|
|
116
|
-
self.logger.log(logging.INFO, 'decoder loaded.')
|
|
117
|
-
|
|
118
|
-
if tokenizer_path:
|
|
119
|
-
tokenizer = torch.load(tokenizer_path, map_location='cpu')
|
|
120
|
-
tokenizer.padding_side = 'left'
|
|
121
|
-
self.pretrain_models['tokenizer'] = tokenizer
|
|
122
|
-
self.logger.log(logging.INFO, 'tokenizer loaded.')
|
|
123
|
-
|
|
124
|
-
self.check_model()
|
|
125
|
-
|
|
126
|
-
def infer(
|
|
127
|
-
self,
|
|
128
|
-
text,
|
|
129
|
-
skip_refine_text=False,
|
|
130
|
-
refine_text_only=False,
|
|
131
|
-
params_refine_text={},
|
|
132
|
-
params_infer_code={'prompt':'[speed_5]'},
|
|
133
|
-
use_decoder=True,
|
|
134
|
-
do_text_normalization=True,
|
|
135
|
-
lang=None,
|
|
136
|
-
):
|
|
137
|
-
|
|
138
|
-
assert self.check_model(use_decoder=use_decoder)
|
|
139
|
-
|
|
140
|
-
if not isinstance(text, list):
|
|
141
|
-
text = [text]
|
|
142
|
-
|
|
143
|
-
if do_text_normalization:
|
|
144
|
-
for i, t in enumerate(text):
|
|
145
|
-
_lang = detect_language(t) if lang is None else lang
|
|
146
|
-
self.init_normalizer(_lang)
|
|
147
|
-
text[i] = self.normalizer[_lang](t)
|
|
148
|
-
if _lang == 'zh':
|
|
149
|
-
text[i] = apply_half2full_map(text[i])
|
|
150
|
-
|
|
151
|
-
for i, t in enumerate(text):
|
|
152
|
-
invalid_characters = count_invalid_characters(t)
|
|
153
|
-
if len(invalid_characters):
|
|
154
|
-
self.logger.log(logging.WARNING, f'Invalid characters found! : {invalid_characters}')
|
|
155
|
-
text[i] = apply_character_map(t)
|
|
156
|
-
|
|
157
|
-
if not skip_refine_text:
|
|
158
|
-
text_tokens = refine_text(self.pretrain_models, text, **params_refine_text)['ids']
|
|
159
|
-
text_tokens = [i[i < self.pretrain_models['tokenizer'].convert_tokens_to_ids('[break_0]')] for i in text_tokens]
|
|
160
|
-
text = self.pretrain_models['tokenizer'].batch_decode(text_tokens)
|
|
161
|
-
if refine_text_only:
|
|
162
|
-
return text
|
|
163
|
-
|
|
164
|
-
text = [params_infer_code.get('prompt', '') + i for i in text]
|
|
165
|
-
params_infer_code.pop('prompt', '')
|
|
166
|
-
result = infer_code(self.pretrain_models, text, **params_infer_code, return_hidden=use_decoder)
|
|
167
|
-
|
|
168
|
-
if use_decoder:
|
|
169
|
-
mel_spec = [self.pretrain_models['decoder'](i[None].permute(0,2,1)) for i in result['hiddens']]
|
|
170
|
-
else:
|
|
171
|
-
mel_spec = [self.pretrain_models['dvae'](i[None].permute(0,2,1)) for i in result['ids']]
|
|
172
|
-
|
|
173
|
-
wav = [self.pretrain_models['vocos'].decode(i).cpu().numpy() for i in mel_spec]
|
|
174
|
-
|
|
175
|
-
return wav
|
|
176
|
-
|
|
177
|
-
def sample_random_speaker(self, ):
|
|
178
|
-
|
|
179
|
-
dim = self.pretrain_models['gpt'].gpt.layers[0].mlp.gate_proj.in_features
|
|
180
|
-
std, mean = self.pretrain_models['spk_stat'].chunk(2)
|
|
181
|
-
return torch.randn(dim, device=std.device) * std + mean
|
|
182
|
-
|
|
183
|
-
def init_normalizer(self, lang):
|
|
184
|
-
|
|
185
|
-
if lang not in self.normalizer:
|
|
186
|
-
if lang == 'zh':
|
|
187
|
-
try:
|
|
188
|
-
from tn.chinese.normalizer import Normalizer
|
|
189
|
-
except:
|
|
190
|
-
self.logger.log(logging.WARNING, f'Package WeTextProcessing not found! \
|
|
191
|
-
Run: conda install -c conda-forge pynini=2.1.5 && pip install WeTextProcessing')
|
|
192
|
-
self.normalizer[lang] = Normalizer().normalize
|
|
193
|
-
else:
|
|
194
|
-
try:
|
|
195
|
-
from nemo_text_processing.text_normalization.normalize import Normalizer
|
|
196
|
-
except:
|
|
197
|
-
self.logger.log(logging.WARNING, f'Package nemo_text_processing not found! \
|
|
198
|
-
Run: conda install -c conda-forge pynini=2.1.5 && pip install nemo_text_processing')
|
|
199
|
-
self.normalizer[lang] = partial(Normalizer(input_case='cased', lang=lang).normalize, verbose=False, punct_post_process=True)
|
|
200
|
-
|
|
File without changes
|