xinference 0.16.3__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +62 -11
- xinference/client/restful/restful_client.py +8 -2
- xinference/constants.py +1 -0
- xinference/core/model.py +10 -3
- xinference/core/supervisor.py +8 -2
- xinference/core/utils.py +67 -2
- xinference/model/audio/model_spec.json +1 -1
- xinference/model/image/stable_diffusion/core.py +5 -2
- xinference/model/llm/llm_family.json +176 -4
- xinference/model/llm/llm_family_modelscope.json +211 -0
- xinference/model/llm/mlx/core.py +45 -2
- xinference/model/rerank/core.py +11 -4
- xinference/thirdparty/fish_speech/fish_speech/conversation.py +254 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +2 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +2 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +2 -2
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ko_KR.json +123 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +2 -1
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +76 -11
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +9 -9
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +1 -1
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +32 -1
- xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +2 -1
- xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +22 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +1 -1
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1 -1
- xinference/thirdparty/fish_speech/tools/api.py +578 -75
- xinference/thirdparty/fish_speech/tools/e2e_webui.py +232 -0
- xinference/thirdparty/fish_speech/tools/fish_e2e.py +298 -0
- xinference/thirdparty/fish_speech/tools/llama/generate.py +393 -9
- xinference/thirdparty/fish_speech/tools/msgpack_api.py +90 -29
- xinference/thirdparty/fish_speech/tools/post_api.py +37 -15
- xinference/thirdparty/fish_speech/tools/schema.py +187 -0
- xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +7 -1
- xinference/thirdparty/fish_speech/tools/vqgan/inference.py +2 -3
- xinference/thirdparty/fish_speech/tools/webui.py +138 -75
- {xinference-0.16.3.dist-info → xinference-1.0.0.dist-info}/METADATA +23 -1
- {xinference-0.16.3.dist-info → xinference-1.0.0.dist-info}/RECORD +43 -50
- {xinference-0.16.3.dist-info → xinference-1.0.0.dist-info}/WHEEL +1 -1
- xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/commons.py +0 -35
- xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
- {xinference-0.16.3.dist-info → xinference-1.0.0.dist-info}/LICENSE +0 -0
- {xinference-0.16.3.dist-info → xinference-1.0.0.dist-info}/entry_points.txt +0 -0
- {xinference-0.16.3.dist-info → xinference-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -15,8 +15,10 @@ import torch._dynamo.config
|
|
|
15
15
|
import torch._inductor.config
|
|
16
16
|
from loguru import logger
|
|
17
17
|
from tqdm import tqdm
|
|
18
|
+
from transformers import AutoTokenizer
|
|
18
19
|
|
|
19
20
|
from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
|
|
21
|
+
from fish_speech.models.text2semantic.llama import BaseModelArgs
|
|
20
22
|
from fish_speech.text import clean_text, split_text
|
|
21
23
|
|
|
22
24
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
@@ -28,6 +30,8 @@ if hasattr(torch._inductor.config, "fx_graph_cache"):
|
|
|
28
30
|
torch._inductor.config.fx_graph_cache = True
|
|
29
31
|
|
|
30
32
|
|
|
33
|
+
from torch.nn.attention import SDPBackend, sdpa_kernel
|
|
34
|
+
|
|
31
35
|
from fish_speech.models.text2semantic.llama import (
|
|
32
36
|
BaseTransformer,
|
|
33
37
|
DualARTransformer,
|
|
@@ -74,6 +78,45 @@ def logits_to_probs(
|
|
|
74
78
|
return probs
|
|
75
79
|
|
|
76
80
|
|
|
81
|
+
def multinomial_sample_one_no_sync_agent(
|
|
82
|
+
probs_sort,
|
|
83
|
+
): # Does multinomial sampling without a cuda synchronization
|
|
84
|
+
q = torch.empty_like(probs_sort).exponential_(1)
|
|
85
|
+
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def logits_to_probs_agent(
|
|
89
|
+
logits,
|
|
90
|
+
previous_tokens: Optional[torch.Tensor] = None,
|
|
91
|
+
temperature: torch.Tensor = 1.0,
|
|
92
|
+
top_p: torch.Tensor = 1.0,
|
|
93
|
+
repetition_penalty: torch.Tensor = 1.0,
|
|
94
|
+
) -> torch.Tensor:
|
|
95
|
+
# Apply repetition penalty
|
|
96
|
+
if previous_tokens is not None:
|
|
97
|
+
previous_tokens = previous_tokens.long()
|
|
98
|
+
score = torch.gather(logits, dim=-1, index=previous_tokens)
|
|
99
|
+
score = torch.where(
|
|
100
|
+
score < 0, score * repetition_penalty, score / repetition_penalty
|
|
101
|
+
)
|
|
102
|
+
logits.scatter_(dim=-1, index=previous_tokens, src=score)
|
|
103
|
+
|
|
104
|
+
# Apply top-p sampling
|
|
105
|
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
|
106
|
+
cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
|
|
107
|
+
sorted_indices_to_remove = cum_probs > top_p
|
|
108
|
+
sorted_indices_to_remove[..., 0] = False # keep at least one option
|
|
109
|
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
|
110
|
+
dim=-1, index=sorted_indices, src=sorted_indices_to_remove
|
|
111
|
+
)
|
|
112
|
+
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
|
113
|
+
|
|
114
|
+
logits = logits / max(temperature, 1e-5)
|
|
115
|
+
|
|
116
|
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
|
117
|
+
return probs
|
|
118
|
+
|
|
119
|
+
|
|
77
120
|
def sample(
|
|
78
121
|
logits,
|
|
79
122
|
previous_tokens: Optional[torch.Tensor] = None,
|
|
@@ -86,20 +129,135 @@ def sample(
|
|
|
86
129
|
return idx_next, probs
|
|
87
130
|
|
|
88
131
|
|
|
89
|
-
def
|
|
132
|
+
def sample_agent(
|
|
133
|
+
logits,
|
|
134
|
+
previous_tokens: Optional[torch.Tensor] = None,
|
|
135
|
+
**sampling_kwargs,
|
|
136
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
137
|
+
probs = logits_to_probs_agent(
|
|
138
|
+
logits=logits[:, -1], previous_tokens=previous_tokens, **sampling_kwargs
|
|
139
|
+
)
|
|
140
|
+
idx_next = multinomial_sample_one_no_sync_agent(probs)
|
|
141
|
+
return idx_next, probs
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def decode_one_token_ar_agent(
|
|
90
145
|
model: DualARTransformer,
|
|
91
146
|
x: torch.Tensor,
|
|
92
147
|
input_pos: torch.Tensor,
|
|
93
148
|
previous_tokens: torch.Tensor = None,
|
|
149
|
+
semantic_id: int = 32003,
|
|
94
150
|
**sampling_kwargs,
|
|
95
151
|
) -> torch.Tensor:
|
|
152
|
+
# print(x, input_pos)
|
|
96
153
|
x = model.forward_generate(x, input_pos)
|
|
154
|
+
logits = x.logits # [:, -1:]
|
|
155
|
+
hidden_states = x.hidden_states # [:, -1:]
|
|
97
156
|
|
|
98
157
|
sampling_kwargs_main = sampling_kwargs.copy()
|
|
99
158
|
sampling_kwargs_main["temperature"] = 0.1
|
|
100
159
|
sampling_kwargs_main["top_p"] = 0.1
|
|
101
160
|
sampling_kwargs_main["repetition_penalty"] = 1.0
|
|
102
161
|
|
|
162
|
+
codebooks = [
|
|
163
|
+
sample_agent(
|
|
164
|
+
logits,
|
|
165
|
+
previous_tokens=None, # Disable repetition penalty for the token codebook
|
|
166
|
+
**sampling_kwargs_main,
|
|
167
|
+
)[0]
|
|
168
|
+
]
|
|
169
|
+
|
|
170
|
+
# Cleanup the cache
|
|
171
|
+
for layer in model.fast_layers:
|
|
172
|
+
layer.attention.kv_cache.k_cache.fill_(0)
|
|
173
|
+
layer.attention.kv_cache.v_cache.fill_(0)
|
|
174
|
+
|
|
175
|
+
for codebook_idx in range(model.config.num_codebooks):
|
|
176
|
+
input_pos = torch.tensor(
|
|
177
|
+
[codebook_idx], device=hidden_states.device, dtype=torch.long
|
|
178
|
+
)
|
|
179
|
+
logits = model.forward_generate_fast(hidden_states, input_pos)
|
|
180
|
+
a = sample_agent(
|
|
181
|
+
logits,
|
|
182
|
+
previous_tokens=(
|
|
183
|
+
previous_tokens[:, codebook_idx + 1]
|
|
184
|
+
if previous_tokens is not None
|
|
185
|
+
else None
|
|
186
|
+
),
|
|
187
|
+
**sampling_kwargs,
|
|
188
|
+
)[0]
|
|
189
|
+
hidden_states = model.fast_embeddings(a)
|
|
190
|
+
codebooks.append(a)
|
|
191
|
+
|
|
192
|
+
codebooks = torch.stack(codebooks, dim=1)
|
|
193
|
+
codebooks[:, 1:, :] = torch.masked_fill(
|
|
194
|
+
codebooks[:, 1:, :], codebooks[:, :1, :] != semantic_id, CODEBOOK_PAD_TOKEN_ID
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
# for i in range(codebooks.size(1) - 1):
|
|
198
|
+
# codebooks[:, i + 1, :] = torch.masked_fill(
|
|
199
|
+
# codebooks[:, i + 1, :],
|
|
200
|
+
# codebooks[:, :1, :] != semantic_id,
|
|
201
|
+
# CODEBOOK_PAD_TOKEN_ID + i * 1024,
|
|
202
|
+
# )
|
|
203
|
+
|
|
204
|
+
# print(codebooks)
|
|
205
|
+
|
|
206
|
+
return codebooks
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def decode_one_token_naive_agent(
|
|
210
|
+
model: NaiveTransformer,
|
|
211
|
+
x: torch.Tensor,
|
|
212
|
+
input_pos: torch.Tensor,
|
|
213
|
+
previous_tokens: torch.Tensor = None,
|
|
214
|
+
semantic_id: int = 32003,
|
|
215
|
+
**sampling_kwargs,
|
|
216
|
+
) -> torch.Tensor:
|
|
217
|
+
x = model.forward_generate(x, input_pos)
|
|
218
|
+
|
|
219
|
+
codebooks = [
|
|
220
|
+
sample(
|
|
221
|
+
x.token_logits,
|
|
222
|
+
previous_tokens=None, # Disable repetition penalty for the token codebook
|
|
223
|
+
**sampling_kwargs,
|
|
224
|
+
)[0]
|
|
225
|
+
]
|
|
226
|
+
|
|
227
|
+
for i in range(model.config.num_codebooks):
|
|
228
|
+
codebooks.append(
|
|
229
|
+
sample_agent(
|
|
230
|
+
x.codebook_logits[:, :, i],
|
|
231
|
+
previous_tokens=(
|
|
232
|
+
previous_tokens[:, i + 1] if previous_tokens is not None else None
|
|
233
|
+
),
|
|
234
|
+
**sampling_kwargs,
|
|
235
|
+
)[0]
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
codebooks = torch.stack(codebooks, dim=1)
|
|
239
|
+
codebooks[:, 1:, :] = torch.masked_fill(
|
|
240
|
+
codebooks[:, 1:, :], codebooks[:, :1, :] != semantic_id, CODEBOOK_PAD_TOKEN_ID
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
return codebooks
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def decode_one_token_ar(
|
|
247
|
+
model: DualARTransformer,
|
|
248
|
+
x: torch.Tensor,
|
|
249
|
+
input_pos: torch.Tensor,
|
|
250
|
+
previous_tokens: torch.Tensor = None,
|
|
251
|
+
semantic_id: int = 0,
|
|
252
|
+
**sampling_kwargs,
|
|
253
|
+
) -> torch.Tensor:
|
|
254
|
+
x = model.forward_generate(x, input_pos)
|
|
255
|
+
|
|
256
|
+
sampling_kwargs_main = sampling_kwargs.copy()
|
|
257
|
+
# sampling_kwargs_main["temperature"] = 0.1
|
|
258
|
+
# sampling_kwargs_main["top_p"] = 0.1
|
|
259
|
+
# sampling_kwargs_main["repetition_penalty"] = 1.0
|
|
260
|
+
|
|
103
261
|
codebooks = [
|
|
104
262
|
sample(
|
|
105
263
|
x.logits,
|
|
@@ -130,7 +288,12 @@ def decode_one_token_ar(
|
|
|
130
288
|
x = model.fast_embeddings(a)
|
|
131
289
|
codebooks.append(a)
|
|
132
290
|
|
|
133
|
-
|
|
291
|
+
codebooks = torch.stack(codebooks, dim=0)
|
|
292
|
+
codebooks[1:, :] = torch.masked_fill(
|
|
293
|
+
codebooks[1:, :], codebooks[:1, :] != semantic_id, CODEBOOK_PAD_TOKEN_ID
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
return codebooks
|
|
134
297
|
|
|
135
298
|
|
|
136
299
|
def decode_one_token_naive(
|
|
@@ -176,6 +339,7 @@ def decode_n_tokens(
|
|
|
176
339
|
num_new_tokens: int,
|
|
177
340
|
im_end_id: int = 4,
|
|
178
341
|
decode_one_token=decode_one_token_naive,
|
|
342
|
+
semantic_id: int = 0,
|
|
179
343
|
**sampling_kwargs,
|
|
180
344
|
):
|
|
181
345
|
previous_tokens = torch.zeros(
|
|
@@ -204,6 +368,7 @@ def decode_n_tokens(
|
|
|
204
368
|
x=cur_token,
|
|
205
369
|
input_pos=input_pos,
|
|
206
370
|
previous_tokens=window,
|
|
371
|
+
semantic_id=semantic_id,
|
|
207
372
|
**sampling_kwargs,
|
|
208
373
|
)
|
|
209
374
|
|
|
@@ -236,12 +401,25 @@ def generate(
|
|
|
236
401
|
|
|
237
402
|
# create an empty tensor of the expected final shape and fill in the current tokens
|
|
238
403
|
T = prompt.size(1)
|
|
404
|
+
semantic_id = model.tokenizer.convert_tokens_to_ids("<|semantic|>")
|
|
405
|
+
|
|
406
|
+
if max_new_tokens:
|
|
407
|
+
if T + max_new_tokens > model.config.max_seq_len:
|
|
408
|
+
max_new_tokens = model.config.max_seq_len - T
|
|
409
|
+
logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
|
|
410
|
+
|
|
411
|
+
T_new = T + max_new_tokens
|
|
412
|
+
else:
|
|
413
|
+
T_new = model.config.max_seq_len
|
|
414
|
+
max_new_tokens = T_new - T
|
|
239
415
|
|
|
240
416
|
device, dtype = prompt.device, prompt.dtype
|
|
241
417
|
|
|
242
418
|
codebook_dim = 1 + model.config.num_codebooks
|
|
243
419
|
# create an empty tensor of the expected final shape and fill in the current tokens
|
|
244
|
-
empty = torch.empty(
|
|
420
|
+
empty = torch.empty(
|
|
421
|
+
(codebook_dim, model.config.max_seq_len), dtype=dtype, device=device
|
|
422
|
+
)
|
|
245
423
|
empty[:, :T] = prompt
|
|
246
424
|
seq = empty
|
|
247
425
|
input_pos = torch.arange(0, T, device=device)
|
|
@@ -254,7 +432,11 @@ def generate(
|
|
|
254
432
|
)
|
|
255
433
|
|
|
256
434
|
next_token = prefill_decode(
|
|
257
|
-
model,
|
|
435
|
+
model,
|
|
436
|
+
prompt.view(1, codebook_dim, -1),
|
|
437
|
+
input_pos,
|
|
438
|
+
semantic_id=semantic_id,
|
|
439
|
+
**sampling_kwargs,
|
|
258
440
|
)
|
|
259
441
|
seq[:, T : T + 1] = next_token
|
|
260
442
|
|
|
@@ -266,6 +448,7 @@ def generate(
|
|
|
266
448
|
max_new_tokens - 1,
|
|
267
449
|
im_end_id=im_end_id,
|
|
268
450
|
decode_one_token=decode_one_token,
|
|
451
|
+
semantic_id=semantic_id,
|
|
269
452
|
**sampling_kwargs,
|
|
270
453
|
)
|
|
271
454
|
# x = torch.cat(generated_tokens, dim=1)
|
|
@@ -275,6 +458,142 @@ def generate(
|
|
|
275
458
|
return seq
|
|
276
459
|
|
|
277
460
|
|
|
461
|
+
def decode_n_tokens_agent(
|
|
462
|
+
model: NaiveTransformer,
|
|
463
|
+
cur_token: torch.Tensor,
|
|
464
|
+
input_pos: torch.Tensor,
|
|
465
|
+
num_new_tokens: int,
|
|
466
|
+
im_end_id: int = 4,
|
|
467
|
+
semantic_id: int = 32003,
|
|
468
|
+
decode_one_token=decode_one_token_naive_agent,
|
|
469
|
+
early_stop_threshold: float = 0.6,
|
|
470
|
+
**sampling_kwargs,
|
|
471
|
+
):
|
|
472
|
+
batch_size = cur_token.size(0)
|
|
473
|
+
previous_tokens = torch.zeros(
|
|
474
|
+
(batch_size, model.config.num_codebooks + 1, model.config.max_seq_len),
|
|
475
|
+
dtype=torch.int,
|
|
476
|
+
device=cur_token.device,
|
|
477
|
+
)
|
|
478
|
+
finished = torch.zeros(batch_size, dtype=torch.bool, device=cur_token.device)
|
|
479
|
+
finished = finished | (cur_token[:, 0, -1] == im_end_id)
|
|
480
|
+
start_time = time.time()
|
|
481
|
+
|
|
482
|
+
for i in tqdm(range(num_new_tokens), desc="Decoding: ", total=num_new_tokens):
|
|
483
|
+
# We need to get windowed repeat penalty
|
|
484
|
+
win_size = 16
|
|
485
|
+
if i < win_size:
|
|
486
|
+
window = previous_tokens[:, :, :win_size]
|
|
487
|
+
else:
|
|
488
|
+
window = previous_tokens[:, :, i - win_size : i]
|
|
489
|
+
|
|
490
|
+
with sdpa_kernel(
|
|
491
|
+
SDPBackend.MATH
|
|
492
|
+
): # Actually better for Inductor to codegen attention here
|
|
493
|
+
next_token = decode_one_token(
|
|
494
|
+
model=model,
|
|
495
|
+
x=cur_token,
|
|
496
|
+
input_pos=input_pos,
|
|
497
|
+
previous_tokens=window,
|
|
498
|
+
semantic_id=semantic_id,
|
|
499
|
+
**sampling_kwargs,
|
|
500
|
+
)
|
|
501
|
+
|
|
502
|
+
input_pos += 1
|
|
503
|
+
cur_token = next_token.view(batch_size, model.config.num_codebooks + 1, -1)
|
|
504
|
+
previous_tokens[:, :, i : i + 1] = next_token.view(
|
|
505
|
+
batch_size, model.config.num_codebooks + 1, -1
|
|
506
|
+
)
|
|
507
|
+
|
|
508
|
+
yield cur_token.cpu()
|
|
509
|
+
|
|
510
|
+
finished = finished | (cur_token[:, 0, -1] == im_end_id)
|
|
511
|
+
if finished.all() or (
|
|
512
|
+
0 < early_stop_threshold < 1
|
|
513
|
+
and finished.sum() >= round(batch_size * early_stop_threshold)
|
|
514
|
+
):
|
|
515
|
+
break
|
|
516
|
+
|
|
517
|
+
total_time = time.time() - start_time
|
|
518
|
+
generated_tokens = i + 1
|
|
519
|
+
tokens_per_second = (generated_tokens / total_time) * batch_size
|
|
520
|
+
logger.info(
|
|
521
|
+
f"Decoded {generated_tokens} x {batch_size} tokens in {total_time:.2f}s ({tokens_per_second:.2f} tokens/s)"
|
|
522
|
+
)
|
|
523
|
+
|
|
524
|
+
|
|
525
|
+
@torch.no_grad()
|
|
526
|
+
@torch.inference_mode()
|
|
527
|
+
def generate_agent(
|
|
528
|
+
*,
|
|
529
|
+
model: BaseTransformer,
|
|
530
|
+
prompt: torch.Tensor,
|
|
531
|
+
max_new_tokens: int,
|
|
532
|
+
im_end_id: int = 4,
|
|
533
|
+
semantic_id: int = 32003,
|
|
534
|
+
decode_one_token=decode_one_token_naive_agent,
|
|
535
|
+
num_samples: int = 1,
|
|
536
|
+
early_stop_threshold: float = 0.6,
|
|
537
|
+
**sampling_kwargs,
|
|
538
|
+
):
|
|
539
|
+
"""
|
|
540
|
+
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
|
|
541
|
+
"""
|
|
542
|
+
|
|
543
|
+
# create an empty tensor of the expected final shape and fill in the current tokens
|
|
544
|
+
T = prompt.size(1)
|
|
545
|
+
prompt = prompt[None].repeat(num_samples, 1, 1)
|
|
546
|
+
|
|
547
|
+
if T >= model.config.max_seq_len:
|
|
548
|
+
raise ValueError(
|
|
549
|
+
f"Input sequence length {T} exceeds max_seq_len {model.config.max_seq_len}"
|
|
550
|
+
)
|
|
551
|
+
|
|
552
|
+
if max_new_tokens:
|
|
553
|
+
if T + max_new_tokens > model.config.max_seq_len:
|
|
554
|
+
max_new_tokens = model.config.max_seq_len - T
|
|
555
|
+
logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
|
|
556
|
+
|
|
557
|
+
T_new = T + max_new_tokens
|
|
558
|
+
else:
|
|
559
|
+
T_new = model.config.max_seq_len
|
|
560
|
+
max_new_tokens = T_new - T
|
|
561
|
+
|
|
562
|
+
device, dtype = prompt.device, prompt.dtype
|
|
563
|
+
|
|
564
|
+
codebook_dim = 1 + model.config.num_codebooks
|
|
565
|
+
input_pos = torch.arange(0, T, device=device)
|
|
566
|
+
|
|
567
|
+
# Use non-accelerated version for now, to avoid compilation overhead
|
|
568
|
+
prefill_decode = (
|
|
569
|
+
decode_one_token_naive_agent
|
|
570
|
+
if isinstance(model, NaiveTransformer)
|
|
571
|
+
else decode_one_token_ar_agent
|
|
572
|
+
)
|
|
573
|
+
next_token = prefill_decode(
|
|
574
|
+
model,
|
|
575
|
+
prompt,
|
|
576
|
+
input_pos,
|
|
577
|
+
semantic_id=semantic_id,
|
|
578
|
+
**sampling_kwargs,
|
|
579
|
+
).view(num_samples, codebook_dim, -1)
|
|
580
|
+
yield next_token.cpu()
|
|
581
|
+
|
|
582
|
+
input_pos = torch.tensor([T], device=device, dtype=torch.int)
|
|
583
|
+
|
|
584
|
+
yield from decode_n_tokens_agent(
|
|
585
|
+
model,
|
|
586
|
+
next_token,
|
|
587
|
+
input_pos,
|
|
588
|
+
max_new_tokens - 1,
|
|
589
|
+
im_end_id=im_end_id,
|
|
590
|
+
semantic_id=semantic_id,
|
|
591
|
+
decode_one_token=decode_one_token,
|
|
592
|
+
early_stop_threshold=early_stop_threshold,
|
|
593
|
+
**sampling_kwargs,
|
|
594
|
+
)
|
|
595
|
+
|
|
596
|
+
|
|
278
597
|
def encode_tokens(
|
|
279
598
|
tokenizer,
|
|
280
599
|
string,
|
|
@@ -339,7 +658,7 @@ def encode_tokens(
|
|
|
339
658
|
return prompt
|
|
340
659
|
|
|
341
660
|
|
|
342
|
-
def load_model(checkpoint_path, device, precision, compile=False):
|
|
661
|
+
def load_model(checkpoint_path, device, precision, compile=False, is_agent=False):
|
|
343
662
|
model: Union[NaiveTransformer, DualARTransformer] = BaseTransformer.from_pretrained(
|
|
344
663
|
checkpoint_path, load_weights=True
|
|
345
664
|
)
|
|
@@ -348,10 +667,14 @@ def load_model(checkpoint_path, device, precision, compile=False):
|
|
|
348
667
|
logger.info(f"Restored model from checkpoint")
|
|
349
668
|
|
|
350
669
|
if isinstance(model, DualARTransformer):
|
|
351
|
-
decode_one_token =
|
|
670
|
+
decode_one_token = (
|
|
671
|
+
decode_one_token_ar_agent if is_agent else decode_one_token_ar
|
|
672
|
+
)
|
|
352
673
|
logger.info("Using DualARTransformer")
|
|
353
674
|
else:
|
|
354
|
-
decode_one_token =
|
|
675
|
+
decode_one_token = (
|
|
676
|
+
decode_one_token_naive_agent if is_agent else decode_one_token_naive
|
|
677
|
+
)
|
|
355
678
|
logger.info("Using NaiveTransformer")
|
|
356
679
|
|
|
357
680
|
if compile:
|
|
@@ -563,7 +886,9 @@ def launch_thread_safe_queue(
|
|
|
563
886
|
)
|
|
564
887
|
with torch.device(device):
|
|
565
888
|
model.setup_caches(
|
|
566
|
-
max_batch_size=1,
|
|
889
|
+
max_batch_size=1,
|
|
890
|
+
max_seq_len=model.config.max_seq_len,
|
|
891
|
+
dtype=next(model.parameters()).dtype,
|
|
567
892
|
)
|
|
568
893
|
init_event.set()
|
|
569
894
|
|
|
@@ -591,6 +916,60 @@ def launch_thread_safe_queue(
|
|
|
591
916
|
return input_queue
|
|
592
917
|
|
|
593
918
|
|
|
919
|
+
def launch_thread_safe_queue_agent(
|
|
920
|
+
checkpoint_path,
|
|
921
|
+
device,
|
|
922
|
+
precision,
|
|
923
|
+
compile: bool = False,
|
|
924
|
+
):
|
|
925
|
+
input_queue = queue.Queue()
|
|
926
|
+
init_event = threading.Event()
|
|
927
|
+
|
|
928
|
+
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
|
|
929
|
+
config = BaseModelArgs.from_pretrained(checkpoint_path)
|
|
930
|
+
|
|
931
|
+
def worker():
|
|
932
|
+
model, decode_one_token = load_model(
|
|
933
|
+
checkpoint_path, device, precision, compile=compile, is_agent=True
|
|
934
|
+
)
|
|
935
|
+
|
|
936
|
+
with torch.device(device):
|
|
937
|
+
model.setup_caches(
|
|
938
|
+
max_batch_size=1,
|
|
939
|
+
max_seq_len=model.config.max_seq_len,
|
|
940
|
+
dtype=next(model.parameters()).dtype,
|
|
941
|
+
)
|
|
942
|
+
init_event.set()
|
|
943
|
+
|
|
944
|
+
while True:
|
|
945
|
+
item: GenerateRequest | None = input_queue.get()
|
|
946
|
+
if item is None:
|
|
947
|
+
break
|
|
948
|
+
|
|
949
|
+
kwargs = item.request
|
|
950
|
+
response_queue = item.response_queue
|
|
951
|
+
|
|
952
|
+
try:
|
|
953
|
+
for token in generate_agent(
|
|
954
|
+
model=model,
|
|
955
|
+
decode_one_token=decode_one_token,
|
|
956
|
+
**kwargs,
|
|
957
|
+
):
|
|
958
|
+
response_queue.put(token)
|
|
959
|
+
|
|
960
|
+
response_queue.put("stop")
|
|
961
|
+
except Exception as e:
|
|
962
|
+
import traceback
|
|
963
|
+
|
|
964
|
+
logger.exception(f"Error in worker: {traceback.format_exc()}")
|
|
965
|
+
response_queue.put("error")
|
|
966
|
+
|
|
967
|
+
threading.Thread(target=worker, daemon=True).start()
|
|
968
|
+
init_event.wait()
|
|
969
|
+
|
|
970
|
+
return input_queue, tokenizer, config
|
|
971
|
+
|
|
972
|
+
|
|
594
973
|
@click.command()
|
|
595
974
|
@click.option(
|
|
596
975
|
"--text",
|
|
@@ -650,7 +1029,12 @@ def main(
|
|
|
650
1029
|
model, decode_one_token = load_model(
|
|
651
1030
|
checkpoint_path, device, precision, compile=compile
|
|
652
1031
|
)
|
|
653
|
-
|
|
1032
|
+
with torch.device(device):
|
|
1033
|
+
model.setup_caches(
|
|
1034
|
+
max_batch_size=1,
|
|
1035
|
+
max_seq_len=model.config.max_seq_len,
|
|
1036
|
+
dtype=next(model.parameters()).dtype,
|
|
1037
|
+
)
|
|
654
1038
|
if torch.cuda.is_available():
|
|
655
1039
|
torch.cuda.synchronize()
|
|
656
1040
|
|
|
@@ -1,34 +1,95 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from argparse import ArgumentParser
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
1
5
|
import httpx
|
|
2
6
|
import ormsgpack
|
|
3
7
|
|
|
4
|
-
from tools.
|
|
8
|
+
from tools.schema import ServeReferenceAudio, ServeTTSRequest
|
|
9
|
+
|
|
10
|
+
api_key = os.environ.get("FISH_API_KEY", "YOUR_API_KEY")
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def audio_request():
|
|
14
|
+
# priority: ref_id > references
|
|
15
|
+
request = ServeTTSRequest(
|
|
16
|
+
text="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
|
|
17
|
+
# reference_id="114514",
|
|
18
|
+
references=[
|
|
19
|
+
ServeReferenceAudio(
|
|
20
|
+
audio=open("lengyue.wav", "rb").read(),
|
|
21
|
+
text=open("lengyue.lab", "r", encoding="utf-8").read(),
|
|
22
|
+
)
|
|
23
|
+
],
|
|
24
|
+
streaming=True,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
api_key = os.environ.get("FISH_API_KEY", "YOUR_API_KEY")
|
|
28
|
+
|
|
29
|
+
with (
|
|
30
|
+
httpx.Client() as client,
|
|
31
|
+
open("hello.wav", "wb") as f,
|
|
32
|
+
):
|
|
33
|
+
with client.stream(
|
|
34
|
+
"POST",
|
|
35
|
+
"http://127.0.0.1:8080/v1/tts",
|
|
36
|
+
content=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
|
|
37
|
+
headers={
|
|
38
|
+
"authorization": f"Bearer {api_key}",
|
|
39
|
+
"content-type": "application/msgpack",
|
|
40
|
+
},
|
|
41
|
+
timeout=None,
|
|
42
|
+
) as response:
|
|
43
|
+
for chunk in response.iter_bytes():
|
|
44
|
+
f.write(chunk)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def asr_request(audio_path: Path):
|
|
48
|
+
|
|
49
|
+
# Read the audio file
|
|
50
|
+
with open(
|
|
51
|
+
str(audio_path),
|
|
52
|
+
"rb",
|
|
53
|
+
) as audio_file:
|
|
54
|
+
audio_data = audio_file.read()
|
|
55
|
+
|
|
56
|
+
# Prepare the request data
|
|
57
|
+
request_data = {
|
|
58
|
+
"audio": audio_data,
|
|
59
|
+
"language": "en", # Optional: specify the language
|
|
60
|
+
"ignore_timestamps": False, # Optional: set to True to ignore precise timestamps
|
|
61
|
+
}
|
|
5
62
|
|
|
6
|
-
#
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
63
|
+
# Send the request
|
|
64
|
+
with httpx.Client() as client:
|
|
65
|
+
response = client.post(
|
|
66
|
+
"https://api.fish.audio/v1/asr",
|
|
67
|
+
headers={
|
|
68
|
+
"Authorization": f"Bearer {api_key}",
|
|
69
|
+
"Content-Type": "application/msgpack",
|
|
70
|
+
},
|
|
71
|
+
content=ormsgpack.packb(request_data),
|
|
14
72
|
)
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
)
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
"
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
73
|
+
|
|
74
|
+
# Parse the response
|
|
75
|
+
result = response.json()
|
|
76
|
+
|
|
77
|
+
print(f"Transcribed text: {result['text']}")
|
|
78
|
+
print(f"Audio duration: {result['duration']} seconds")
|
|
79
|
+
|
|
80
|
+
for segment in result["segments"]:
|
|
81
|
+
print(f"Segment: {segment['text']}")
|
|
82
|
+
print(f"Start time: {segment['start']}, End time: {segment['end']}")
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def parse_args():
|
|
86
|
+
parser = ArgumentParser()
|
|
87
|
+
parser.add_argument("--audio_path", type=Path, default="audio/ref/trump.mp3")
|
|
88
|
+
|
|
89
|
+
return parser.parse_args()
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
if __name__ == "__main__":
|
|
93
|
+
args = parse_args()
|
|
94
|
+
|
|
95
|
+
asr_request(args.audio_path)
|