xinference 0.16.2__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (60) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +62 -11
  3. xinference/client/restful/restful_client.py +8 -2
  4. xinference/conftest.py +0 -8
  5. xinference/constants.py +2 -0
  6. xinference/core/model.py +44 -5
  7. xinference/core/supervisor.py +13 -7
  8. xinference/core/utils.py +76 -12
  9. xinference/core/worker.py +5 -4
  10. xinference/deploy/cmdline.py +5 -0
  11. xinference/deploy/utils.py +7 -4
  12. xinference/model/audio/model_spec.json +2 -2
  13. xinference/model/image/stable_diffusion/core.py +5 -2
  14. xinference/model/llm/core.py +1 -3
  15. xinference/model/llm/llm_family.json +263 -4
  16. xinference/model/llm/llm_family_modelscope.json +302 -0
  17. xinference/model/llm/mlx/core.py +45 -2
  18. xinference/model/llm/vllm/core.py +2 -1
  19. xinference/model/rerank/core.py +11 -4
  20. xinference/thirdparty/fish_speech/fish_speech/conversation.py +254 -0
  21. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +2 -1
  22. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +2 -1
  23. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +2 -2
  24. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ko_KR.json +123 -0
  25. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +2 -1
  26. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +76 -11
  27. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +9 -9
  28. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +1 -1
  29. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +32 -1
  30. xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +2 -1
  31. xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +22 -0
  32. xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +1 -1
  33. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1 -1
  34. xinference/thirdparty/fish_speech/tools/api.py +578 -75
  35. xinference/thirdparty/fish_speech/tools/e2e_webui.py +232 -0
  36. xinference/thirdparty/fish_speech/tools/fish_e2e.py +298 -0
  37. xinference/thirdparty/fish_speech/tools/llama/generate.py +393 -9
  38. xinference/thirdparty/fish_speech/tools/msgpack_api.py +90 -29
  39. xinference/thirdparty/fish_speech/tools/post_api.py +37 -15
  40. xinference/thirdparty/fish_speech/tools/schema.py +187 -0
  41. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +7 -1
  42. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +2 -3
  43. xinference/thirdparty/fish_speech/tools/webui.py +138 -75
  44. {xinference-0.16.2.dist-info → xinference-1.0.0.dist-info}/METADATA +26 -3
  45. {xinference-0.16.2.dist-info → xinference-1.0.0.dist-info}/RECORD +49 -56
  46. {xinference-0.16.2.dist-info → xinference-1.0.0.dist-info}/WHEEL +1 -1
  47. xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
  48. xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
  49. xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
  50. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
  51. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
  52. xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
  53. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
  54. xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
  55. xinference/thirdparty/fish_speech/tools/commons.py +0 -35
  56. xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
  57. xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
  58. {xinference-0.16.2.dist-info → xinference-1.0.0.dist-info}/LICENSE +0 -0
  59. {xinference-0.16.2.dist-info → xinference-1.0.0.dist-info}/entry_points.txt +0 -0
  60. {xinference-0.16.2.dist-info → xinference-1.0.0.dist-info}/top_level.txt +0 -0
@@ -15,8 +15,10 @@ import torch._dynamo.config
15
15
  import torch._inductor.config
16
16
  from loguru import logger
17
17
  from tqdm import tqdm
18
+ from transformers import AutoTokenizer
18
19
 
19
20
  from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
21
+ from fish_speech.models.text2semantic.llama import BaseModelArgs
20
22
  from fish_speech.text import clean_text, split_text
21
23
 
22
24
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -28,6 +30,8 @@ if hasattr(torch._inductor.config, "fx_graph_cache"):
28
30
  torch._inductor.config.fx_graph_cache = True
29
31
 
30
32
 
33
+ from torch.nn.attention import SDPBackend, sdpa_kernel
34
+
31
35
  from fish_speech.models.text2semantic.llama import (
32
36
  BaseTransformer,
33
37
  DualARTransformer,
@@ -74,6 +78,45 @@ def logits_to_probs(
74
78
  return probs
75
79
 
76
80
 
81
+ def multinomial_sample_one_no_sync_agent(
82
+ probs_sort,
83
+ ): # Does multinomial sampling without a cuda synchronization
84
+ q = torch.empty_like(probs_sort).exponential_(1)
85
+ return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
86
+
87
+
88
+ def logits_to_probs_agent(
89
+ logits,
90
+ previous_tokens: Optional[torch.Tensor] = None,
91
+ temperature: torch.Tensor = 1.0,
92
+ top_p: torch.Tensor = 1.0,
93
+ repetition_penalty: torch.Tensor = 1.0,
94
+ ) -> torch.Tensor:
95
+ # Apply repetition penalty
96
+ if previous_tokens is not None:
97
+ previous_tokens = previous_tokens.long()
98
+ score = torch.gather(logits, dim=-1, index=previous_tokens)
99
+ score = torch.where(
100
+ score < 0, score * repetition_penalty, score / repetition_penalty
101
+ )
102
+ logits.scatter_(dim=-1, index=previous_tokens, src=score)
103
+
104
+ # Apply top-p sampling
105
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
106
+ cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
107
+ sorted_indices_to_remove = cum_probs > top_p
108
+ sorted_indices_to_remove[..., 0] = False # keep at least one option
109
+ indices_to_remove = sorted_indices_to_remove.scatter(
110
+ dim=-1, index=sorted_indices, src=sorted_indices_to_remove
111
+ )
112
+ logits = logits.masked_fill(indices_to_remove, -float("Inf"))
113
+
114
+ logits = logits / max(temperature, 1e-5)
115
+
116
+ probs = torch.nn.functional.softmax(logits, dim=-1)
117
+ return probs
118
+
119
+
77
120
  def sample(
78
121
  logits,
79
122
  previous_tokens: Optional[torch.Tensor] = None,
@@ -86,20 +129,135 @@ def sample(
86
129
  return idx_next, probs
87
130
 
88
131
 
89
- def decode_one_token_ar(
132
+ def sample_agent(
133
+ logits,
134
+ previous_tokens: Optional[torch.Tensor] = None,
135
+ **sampling_kwargs,
136
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
137
+ probs = logits_to_probs_agent(
138
+ logits=logits[:, -1], previous_tokens=previous_tokens, **sampling_kwargs
139
+ )
140
+ idx_next = multinomial_sample_one_no_sync_agent(probs)
141
+ return idx_next, probs
142
+
143
+
144
+ def decode_one_token_ar_agent(
90
145
  model: DualARTransformer,
91
146
  x: torch.Tensor,
92
147
  input_pos: torch.Tensor,
93
148
  previous_tokens: torch.Tensor = None,
149
+ semantic_id: int = 32003,
94
150
  **sampling_kwargs,
95
151
  ) -> torch.Tensor:
152
+ # print(x, input_pos)
96
153
  x = model.forward_generate(x, input_pos)
154
+ logits = x.logits # [:, -1:]
155
+ hidden_states = x.hidden_states # [:, -1:]
97
156
 
98
157
  sampling_kwargs_main = sampling_kwargs.copy()
99
158
  sampling_kwargs_main["temperature"] = 0.1
100
159
  sampling_kwargs_main["top_p"] = 0.1
101
160
  sampling_kwargs_main["repetition_penalty"] = 1.0
102
161
 
162
+ codebooks = [
163
+ sample_agent(
164
+ logits,
165
+ previous_tokens=None, # Disable repetition penalty for the token codebook
166
+ **sampling_kwargs_main,
167
+ )[0]
168
+ ]
169
+
170
+ # Cleanup the cache
171
+ for layer in model.fast_layers:
172
+ layer.attention.kv_cache.k_cache.fill_(0)
173
+ layer.attention.kv_cache.v_cache.fill_(0)
174
+
175
+ for codebook_idx in range(model.config.num_codebooks):
176
+ input_pos = torch.tensor(
177
+ [codebook_idx], device=hidden_states.device, dtype=torch.long
178
+ )
179
+ logits = model.forward_generate_fast(hidden_states, input_pos)
180
+ a = sample_agent(
181
+ logits,
182
+ previous_tokens=(
183
+ previous_tokens[:, codebook_idx + 1]
184
+ if previous_tokens is not None
185
+ else None
186
+ ),
187
+ **sampling_kwargs,
188
+ )[0]
189
+ hidden_states = model.fast_embeddings(a)
190
+ codebooks.append(a)
191
+
192
+ codebooks = torch.stack(codebooks, dim=1)
193
+ codebooks[:, 1:, :] = torch.masked_fill(
194
+ codebooks[:, 1:, :], codebooks[:, :1, :] != semantic_id, CODEBOOK_PAD_TOKEN_ID
195
+ )
196
+
197
+ # for i in range(codebooks.size(1) - 1):
198
+ # codebooks[:, i + 1, :] = torch.masked_fill(
199
+ # codebooks[:, i + 1, :],
200
+ # codebooks[:, :1, :] != semantic_id,
201
+ # CODEBOOK_PAD_TOKEN_ID + i * 1024,
202
+ # )
203
+
204
+ # print(codebooks)
205
+
206
+ return codebooks
207
+
208
+
209
+ def decode_one_token_naive_agent(
210
+ model: NaiveTransformer,
211
+ x: torch.Tensor,
212
+ input_pos: torch.Tensor,
213
+ previous_tokens: torch.Tensor = None,
214
+ semantic_id: int = 32003,
215
+ **sampling_kwargs,
216
+ ) -> torch.Tensor:
217
+ x = model.forward_generate(x, input_pos)
218
+
219
+ codebooks = [
220
+ sample(
221
+ x.token_logits,
222
+ previous_tokens=None, # Disable repetition penalty for the token codebook
223
+ **sampling_kwargs,
224
+ )[0]
225
+ ]
226
+
227
+ for i in range(model.config.num_codebooks):
228
+ codebooks.append(
229
+ sample_agent(
230
+ x.codebook_logits[:, :, i],
231
+ previous_tokens=(
232
+ previous_tokens[:, i + 1] if previous_tokens is not None else None
233
+ ),
234
+ **sampling_kwargs,
235
+ )[0]
236
+ )
237
+
238
+ codebooks = torch.stack(codebooks, dim=1)
239
+ codebooks[:, 1:, :] = torch.masked_fill(
240
+ codebooks[:, 1:, :], codebooks[:, :1, :] != semantic_id, CODEBOOK_PAD_TOKEN_ID
241
+ )
242
+
243
+ return codebooks
244
+
245
+
246
+ def decode_one_token_ar(
247
+ model: DualARTransformer,
248
+ x: torch.Tensor,
249
+ input_pos: torch.Tensor,
250
+ previous_tokens: torch.Tensor = None,
251
+ semantic_id: int = 0,
252
+ **sampling_kwargs,
253
+ ) -> torch.Tensor:
254
+ x = model.forward_generate(x, input_pos)
255
+
256
+ sampling_kwargs_main = sampling_kwargs.copy()
257
+ # sampling_kwargs_main["temperature"] = 0.1
258
+ # sampling_kwargs_main["top_p"] = 0.1
259
+ # sampling_kwargs_main["repetition_penalty"] = 1.0
260
+
103
261
  codebooks = [
104
262
  sample(
105
263
  x.logits,
@@ -130,7 +288,12 @@ def decode_one_token_ar(
130
288
  x = model.fast_embeddings(a)
131
289
  codebooks.append(a)
132
290
 
133
- return torch.stack(codebooks, dim=0)
291
+ codebooks = torch.stack(codebooks, dim=0)
292
+ codebooks[1:, :] = torch.masked_fill(
293
+ codebooks[1:, :], codebooks[:1, :] != semantic_id, CODEBOOK_PAD_TOKEN_ID
294
+ )
295
+
296
+ return codebooks
134
297
 
135
298
 
136
299
  def decode_one_token_naive(
@@ -176,6 +339,7 @@ def decode_n_tokens(
176
339
  num_new_tokens: int,
177
340
  im_end_id: int = 4,
178
341
  decode_one_token=decode_one_token_naive,
342
+ semantic_id: int = 0,
179
343
  **sampling_kwargs,
180
344
  ):
181
345
  previous_tokens = torch.zeros(
@@ -204,6 +368,7 @@ def decode_n_tokens(
204
368
  x=cur_token,
205
369
  input_pos=input_pos,
206
370
  previous_tokens=window,
371
+ semantic_id=semantic_id,
207
372
  **sampling_kwargs,
208
373
  )
209
374
 
@@ -236,12 +401,25 @@ def generate(
236
401
 
237
402
  # create an empty tensor of the expected final shape and fill in the current tokens
238
403
  T = prompt.size(1)
404
+ semantic_id = model.tokenizer.convert_tokens_to_ids("<|semantic|>")
405
+
406
+ if max_new_tokens:
407
+ if T + max_new_tokens > model.config.max_seq_len:
408
+ max_new_tokens = model.config.max_seq_len - T
409
+ logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
410
+
411
+ T_new = T + max_new_tokens
412
+ else:
413
+ T_new = model.config.max_seq_len
414
+ max_new_tokens = T_new - T
239
415
 
240
416
  device, dtype = prompt.device, prompt.dtype
241
417
 
242
418
  codebook_dim = 1 + model.config.num_codebooks
243
419
  # create an empty tensor of the expected final shape and fill in the current tokens
244
- empty = torch.empty((codebook_dim, max_new_tokens), dtype=dtype, device=device)
420
+ empty = torch.empty(
421
+ (codebook_dim, model.config.max_seq_len), dtype=dtype, device=device
422
+ )
245
423
  empty[:, :T] = prompt
246
424
  seq = empty
247
425
  input_pos = torch.arange(0, T, device=device)
@@ -254,7 +432,11 @@ def generate(
254
432
  )
255
433
 
256
434
  next_token = prefill_decode(
257
- model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs
435
+ model,
436
+ prompt.view(1, codebook_dim, -1),
437
+ input_pos,
438
+ semantic_id=semantic_id,
439
+ **sampling_kwargs,
258
440
  )
259
441
  seq[:, T : T + 1] = next_token
260
442
 
@@ -266,6 +448,7 @@ def generate(
266
448
  max_new_tokens - 1,
267
449
  im_end_id=im_end_id,
268
450
  decode_one_token=decode_one_token,
451
+ semantic_id=semantic_id,
269
452
  **sampling_kwargs,
270
453
  )
271
454
  # x = torch.cat(generated_tokens, dim=1)
@@ -275,6 +458,142 @@ def generate(
275
458
  return seq
276
459
 
277
460
 
461
+ def decode_n_tokens_agent(
462
+ model: NaiveTransformer,
463
+ cur_token: torch.Tensor,
464
+ input_pos: torch.Tensor,
465
+ num_new_tokens: int,
466
+ im_end_id: int = 4,
467
+ semantic_id: int = 32003,
468
+ decode_one_token=decode_one_token_naive_agent,
469
+ early_stop_threshold: float = 0.6,
470
+ **sampling_kwargs,
471
+ ):
472
+ batch_size = cur_token.size(0)
473
+ previous_tokens = torch.zeros(
474
+ (batch_size, model.config.num_codebooks + 1, model.config.max_seq_len),
475
+ dtype=torch.int,
476
+ device=cur_token.device,
477
+ )
478
+ finished = torch.zeros(batch_size, dtype=torch.bool, device=cur_token.device)
479
+ finished = finished | (cur_token[:, 0, -1] == im_end_id)
480
+ start_time = time.time()
481
+
482
+ for i in tqdm(range(num_new_tokens), desc="Decoding: ", total=num_new_tokens):
483
+ # We need to get windowed repeat penalty
484
+ win_size = 16
485
+ if i < win_size:
486
+ window = previous_tokens[:, :, :win_size]
487
+ else:
488
+ window = previous_tokens[:, :, i - win_size : i]
489
+
490
+ with sdpa_kernel(
491
+ SDPBackend.MATH
492
+ ): # Actually better for Inductor to codegen attention here
493
+ next_token = decode_one_token(
494
+ model=model,
495
+ x=cur_token,
496
+ input_pos=input_pos,
497
+ previous_tokens=window,
498
+ semantic_id=semantic_id,
499
+ **sampling_kwargs,
500
+ )
501
+
502
+ input_pos += 1
503
+ cur_token = next_token.view(batch_size, model.config.num_codebooks + 1, -1)
504
+ previous_tokens[:, :, i : i + 1] = next_token.view(
505
+ batch_size, model.config.num_codebooks + 1, -1
506
+ )
507
+
508
+ yield cur_token.cpu()
509
+
510
+ finished = finished | (cur_token[:, 0, -1] == im_end_id)
511
+ if finished.all() or (
512
+ 0 < early_stop_threshold < 1
513
+ and finished.sum() >= round(batch_size * early_stop_threshold)
514
+ ):
515
+ break
516
+
517
+ total_time = time.time() - start_time
518
+ generated_tokens = i + 1
519
+ tokens_per_second = (generated_tokens / total_time) * batch_size
520
+ logger.info(
521
+ f"Decoded {generated_tokens} x {batch_size} tokens in {total_time:.2f}s ({tokens_per_second:.2f} tokens/s)"
522
+ )
523
+
524
+
525
+ @torch.no_grad()
526
+ @torch.inference_mode()
527
+ def generate_agent(
528
+ *,
529
+ model: BaseTransformer,
530
+ prompt: torch.Tensor,
531
+ max_new_tokens: int,
532
+ im_end_id: int = 4,
533
+ semantic_id: int = 32003,
534
+ decode_one_token=decode_one_token_naive_agent,
535
+ num_samples: int = 1,
536
+ early_stop_threshold: float = 0.6,
537
+ **sampling_kwargs,
538
+ ):
539
+ """
540
+ Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
541
+ """
542
+
543
+ # create an empty tensor of the expected final shape and fill in the current tokens
544
+ T = prompt.size(1)
545
+ prompt = prompt[None].repeat(num_samples, 1, 1)
546
+
547
+ if T >= model.config.max_seq_len:
548
+ raise ValueError(
549
+ f"Input sequence length {T} exceeds max_seq_len {model.config.max_seq_len}"
550
+ )
551
+
552
+ if max_new_tokens:
553
+ if T + max_new_tokens > model.config.max_seq_len:
554
+ max_new_tokens = model.config.max_seq_len - T
555
+ logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
556
+
557
+ T_new = T + max_new_tokens
558
+ else:
559
+ T_new = model.config.max_seq_len
560
+ max_new_tokens = T_new - T
561
+
562
+ device, dtype = prompt.device, prompt.dtype
563
+
564
+ codebook_dim = 1 + model.config.num_codebooks
565
+ input_pos = torch.arange(0, T, device=device)
566
+
567
+ # Use non-accelerated version for now, to avoid compilation overhead
568
+ prefill_decode = (
569
+ decode_one_token_naive_agent
570
+ if isinstance(model, NaiveTransformer)
571
+ else decode_one_token_ar_agent
572
+ )
573
+ next_token = prefill_decode(
574
+ model,
575
+ prompt,
576
+ input_pos,
577
+ semantic_id=semantic_id,
578
+ **sampling_kwargs,
579
+ ).view(num_samples, codebook_dim, -1)
580
+ yield next_token.cpu()
581
+
582
+ input_pos = torch.tensor([T], device=device, dtype=torch.int)
583
+
584
+ yield from decode_n_tokens_agent(
585
+ model,
586
+ next_token,
587
+ input_pos,
588
+ max_new_tokens - 1,
589
+ im_end_id=im_end_id,
590
+ semantic_id=semantic_id,
591
+ decode_one_token=decode_one_token,
592
+ early_stop_threshold=early_stop_threshold,
593
+ **sampling_kwargs,
594
+ )
595
+
596
+
278
597
  def encode_tokens(
279
598
  tokenizer,
280
599
  string,
@@ -339,7 +658,7 @@ def encode_tokens(
339
658
  return prompt
340
659
 
341
660
 
342
- def load_model(checkpoint_path, device, precision, compile=False):
661
+ def load_model(checkpoint_path, device, precision, compile=False, is_agent=False):
343
662
  model: Union[NaiveTransformer, DualARTransformer] = BaseTransformer.from_pretrained(
344
663
  checkpoint_path, load_weights=True
345
664
  )
@@ -348,10 +667,14 @@ def load_model(checkpoint_path, device, precision, compile=False):
348
667
  logger.info(f"Restored model from checkpoint")
349
668
 
350
669
  if isinstance(model, DualARTransformer):
351
- decode_one_token = decode_one_token_ar
670
+ decode_one_token = (
671
+ decode_one_token_ar_agent if is_agent else decode_one_token_ar
672
+ )
352
673
  logger.info("Using DualARTransformer")
353
674
  else:
354
- decode_one_token = decode_one_token_naive
675
+ decode_one_token = (
676
+ decode_one_token_naive_agent if is_agent else decode_one_token_naive
677
+ )
355
678
  logger.info("Using NaiveTransformer")
356
679
 
357
680
  if compile:
@@ -563,7 +886,9 @@ def launch_thread_safe_queue(
563
886
  )
564
887
  with torch.device(device):
565
888
  model.setup_caches(
566
- max_batch_size=1, max_seq_len=2048, dtype=next(model.parameters()).dtype
889
+ max_batch_size=1,
890
+ max_seq_len=model.config.max_seq_len,
891
+ dtype=next(model.parameters()).dtype,
567
892
  )
568
893
  init_event.set()
569
894
 
@@ -591,6 +916,60 @@ def launch_thread_safe_queue(
591
916
  return input_queue
592
917
 
593
918
 
919
+ def launch_thread_safe_queue_agent(
920
+ checkpoint_path,
921
+ device,
922
+ precision,
923
+ compile: bool = False,
924
+ ):
925
+ input_queue = queue.Queue()
926
+ init_event = threading.Event()
927
+
928
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
929
+ config = BaseModelArgs.from_pretrained(checkpoint_path)
930
+
931
+ def worker():
932
+ model, decode_one_token = load_model(
933
+ checkpoint_path, device, precision, compile=compile, is_agent=True
934
+ )
935
+
936
+ with torch.device(device):
937
+ model.setup_caches(
938
+ max_batch_size=1,
939
+ max_seq_len=model.config.max_seq_len,
940
+ dtype=next(model.parameters()).dtype,
941
+ )
942
+ init_event.set()
943
+
944
+ while True:
945
+ item: GenerateRequest | None = input_queue.get()
946
+ if item is None:
947
+ break
948
+
949
+ kwargs = item.request
950
+ response_queue = item.response_queue
951
+
952
+ try:
953
+ for token in generate_agent(
954
+ model=model,
955
+ decode_one_token=decode_one_token,
956
+ **kwargs,
957
+ ):
958
+ response_queue.put(token)
959
+
960
+ response_queue.put("stop")
961
+ except Exception as e:
962
+ import traceback
963
+
964
+ logger.exception(f"Error in worker: {traceback.format_exc()}")
965
+ response_queue.put("error")
966
+
967
+ threading.Thread(target=worker, daemon=True).start()
968
+ init_event.wait()
969
+
970
+ return input_queue, tokenizer, config
971
+
972
+
594
973
  @click.command()
595
974
  @click.option(
596
975
  "--text",
@@ -650,7 +1029,12 @@ def main(
650
1029
  model, decode_one_token = load_model(
651
1030
  checkpoint_path, device, precision, compile=compile
652
1031
  )
653
-
1032
+ with torch.device(device):
1033
+ model.setup_caches(
1034
+ max_batch_size=1,
1035
+ max_seq_len=model.config.max_seq_len,
1036
+ dtype=next(model.parameters()).dtype,
1037
+ )
654
1038
  if torch.cuda.is_available():
655
1039
  torch.cuda.synchronize()
656
1040
 
@@ -1,34 +1,95 @@
1
+ import os
2
+ from argparse import ArgumentParser
3
+ from pathlib import Path
4
+
1
5
  import httpx
2
6
  import ormsgpack
3
7
 
4
- from tools.commons import ServeReferenceAudio, ServeTTSRequest
8
+ from tools.schema import ServeReferenceAudio, ServeTTSRequest
9
+
10
+ api_key = os.environ.get("FISH_API_KEY", "YOUR_API_KEY")
11
+
12
+
13
+ def audio_request():
14
+ # priority: ref_id > references
15
+ request = ServeTTSRequest(
16
+ text="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
17
+ # reference_id="114514",
18
+ references=[
19
+ ServeReferenceAudio(
20
+ audio=open("lengyue.wav", "rb").read(),
21
+ text=open("lengyue.lab", "r", encoding="utf-8").read(),
22
+ )
23
+ ],
24
+ streaming=True,
25
+ )
26
+
27
+ api_key = os.environ.get("FISH_API_KEY", "YOUR_API_KEY")
28
+
29
+ with (
30
+ httpx.Client() as client,
31
+ open("hello.wav", "wb") as f,
32
+ ):
33
+ with client.stream(
34
+ "POST",
35
+ "http://127.0.0.1:8080/v1/tts",
36
+ content=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
37
+ headers={
38
+ "authorization": f"Bearer {api_key}",
39
+ "content-type": "application/msgpack",
40
+ },
41
+ timeout=None,
42
+ ) as response:
43
+ for chunk in response.iter_bytes():
44
+ f.write(chunk)
45
+
46
+
47
+ def asr_request(audio_path: Path):
48
+
49
+ # Read the audio file
50
+ with open(
51
+ str(audio_path),
52
+ "rb",
53
+ ) as audio_file:
54
+ audio_data = audio_file.read()
55
+
56
+ # Prepare the request data
57
+ request_data = {
58
+ "audio": audio_data,
59
+ "language": "en", # Optional: specify the language
60
+ "ignore_timestamps": False, # Optional: set to True to ignore precise timestamps
61
+ }
5
62
 
6
- # priority: ref_id > references
7
- request = ServeTTSRequest(
8
- text="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
9
- # reference_id="114514",
10
- references=[
11
- ServeReferenceAudio(
12
- audio=open("lengyue.wav", "rb").read(),
13
- text=open("lengyue.lab", "r", encoding="utf-8").read(),
63
+ # Send the request
64
+ with httpx.Client() as client:
65
+ response = client.post(
66
+ "https://api.fish.audio/v1/asr",
67
+ headers={
68
+ "Authorization": f"Bearer {api_key}",
69
+ "Content-Type": "application/msgpack",
70
+ },
71
+ content=ormsgpack.packb(request_data),
14
72
  )
15
- ],
16
- streaming=True,
17
- )
18
-
19
- with (
20
- httpx.Client() as client,
21
- open("hello.wav", "wb") as f,
22
- ):
23
- with client.stream(
24
- "POST",
25
- "http://127.0.0.1:8080/v1/tts",
26
- content=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
27
- headers={
28
- "authorization": "Bearer YOUR_API_KEY",
29
- "content-type": "application/msgpack",
30
- },
31
- timeout=None,
32
- ) as response:
33
- for chunk in response.iter_bytes():
34
- f.write(chunk)
73
+
74
+ # Parse the response
75
+ result = response.json()
76
+
77
+ print(f"Transcribed text: {result['text']}")
78
+ print(f"Audio duration: {result['duration']} seconds")
79
+
80
+ for segment in result["segments"]:
81
+ print(f"Segment: {segment['text']}")
82
+ print(f"Start time: {segment['start']}, End time: {segment['end']}")
83
+
84
+
85
+ def parse_args():
86
+ parser = ArgumentParser()
87
+ parser.add_argument("--audio_path", type=Path, default="audio/ref/trump.mp3")
88
+
89
+ return parser.parse_args()
90
+
91
+
92
+ if __name__ == "__main__":
93
+ args = parse_args()
94
+
95
+ asr_request(args.audio_path)