xinference 0.16.2__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (60) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +62 -11
  3. xinference/client/restful/restful_client.py +8 -2
  4. xinference/conftest.py +0 -8
  5. xinference/constants.py +2 -0
  6. xinference/core/model.py +44 -5
  7. xinference/core/supervisor.py +13 -7
  8. xinference/core/utils.py +76 -12
  9. xinference/core/worker.py +5 -4
  10. xinference/deploy/cmdline.py +5 -0
  11. xinference/deploy/utils.py +7 -4
  12. xinference/model/audio/model_spec.json +2 -2
  13. xinference/model/image/stable_diffusion/core.py +5 -2
  14. xinference/model/llm/core.py +1 -3
  15. xinference/model/llm/llm_family.json +263 -4
  16. xinference/model/llm/llm_family_modelscope.json +302 -0
  17. xinference/model/llm/mlx/core.py +45 -2
  18. xinference/model/llm/vllm/core.py +2 -1
  19. xinference/model/rerank/core.py +11 -4
  20. xinference/thirdparty/fish_speech/fish_speech/conversation.py +254 -0
  21. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +2 -1
  22. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +2 -1
  23. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +2 -2
  24. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ko_KR.json +123 -0
  25. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +2 -1
  26. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +76 -11
  27. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +9 -9
  28. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +1 -1
  29. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +32 -1
  30. xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +2 -1
  31. xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +22 -0
  32. xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +1 -1
  33. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1 -1
  34. xinference/thirdparty/fish_speech/tools/api.py +578 -75
  35. xinference/thirdparty/fish_speech/tools/e2e_webui.py +232 -0
  36. xinference/thirdparty/fish_speech/tools/fish_e2e.py +298 -0
  37. xinference/thirdparty/fish_speech/tools/llama/generate.py +393 -9
  38. xinference/thirdparty/fish_speech/tools/msgpack_api.py +90 -29
  39. xinference/thirdparty/fish_speech/tools/post_api.py +37 -15
  40. xinference/thirdparty/fish_speech/tools/schema.py +187 -0
  41. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +7 -1
  42. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +2 -3
  43. xinference/thirdparty/fish_speech/tools/webui.py +138 -75
  44. {xinference-0.16.2.dist-info → xinference-1.0.0.dist-info}/METADATA +26 -3
  45. {xinference-0.16.2.dist-info → xinference-1.0.0.dist-info}/RECORD +49 -56
  46. {xinference-0.16.2.dist-info → xinference-1.0.0.dist-info}/WHEEL +1 -1
  47. xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
  48. xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
  49. xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
  50. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
  51. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
  52. xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
  53. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
  54. xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
  55. xinference/thirdparty/fish_speech/tools/commons.py +0 -35
  56. xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
  57. xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
  58. {xinference-0.16.2.dist-info → xinference-1.0.0.dist-info}/LICENSE +0 -0
  59. {xinference-0.16.2.dist-info → xinference-1.0.0.dist-info}/entry_points.txt +0 -0
  60. {xinference-0.16.2.dist-info → xinference-1.0.0.dist-info}/top_level.txt +0 -0
@@ -1,16 +1,16 @@
1
- import base64
2
1
  import io
3
- import json
2
+ import os
4
3
  import queue
5
- import random
6
- import sys
4
+ import re
5
+ import time
7
6
  import traceback
8
7
  import wave
9
8
  from argparse import ArgumentParser
10
9
  from http import HTTPStatus
11
10
  from pathlib import Path
12
- from typing import Annotated, Any, Literal, Optional
11
+ from typing import Annotated, Any
13
12
 
13
+ import librosa
14
14
  import numpy as np
15
15
  import ormsgpack
16
16
  # import pyrootutils
@@ -28,27 +28,74 @@ import torchaudio
28
28
  # Kui,
29
29
  # OpenAPI,
30
30
  # StreamResponse,
31
+ # request,
31
32
  # )
32
33
  # from kui.asgi.routing import MultimethodRoutes
33
34
  from loguru import logger
34
- from pydantic import BaseModel, Field, conint
35
+ from transformers import AutoTokenizer
35
36
 
36
37
  # pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
38
+ import struct
39
+ from threading import Lock
40
+
41
+ import httpx
42
+ from cachetools import LRUCache, cached
43
+ from funasr import AutoModel
44
+ from silero_vad import get_speech_timestamps, load_silero_vad
45
+
46
+ from fish_speech.conversation import IM_END_TOKEN, SEMANTIC_TOKEN
47
+ from fish_speech.models.text2semantic.llama import BaseModelArgs
37
48
 
38
49
  # from fish_speech.models.vqgan.lit_module import VQGAN
39
50
  from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
40
51
  from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
41
- from fish_speech.utils import autocast_exclude_mps
42
- from tools.commons import ServeReferenceAudio, ServeTTSRequest
52
+ from fish_speech.utils import autocast_exclude_mps, set_seed
43
53
  from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text
44
54
  from tools.llama.generate import (
45
55
  GenerateRequest,
46
56
  GenerateResponse,
47
57
  WrappedGenerateResponse,
48
58
  launch_thread_safe_queue,
59
+ launch_thread_safe_queue_agent,
60
+ )
61
+ from tools.schema import (
62
+ GLOBAL_NUM_SAMPLES,
63
+ ASRPackRequest,
64
+ ServeASRRequest,
65
+ ServeASRResponse,
66
+ ServeASRSegment,
67
+ ServeAudioPart,
68
+ ServeForwardMessage,
69
+ ServeMessage,
70
+ ServeRequest,
71
+ ServeResponse,
72
+ ServeStreamDelta,
73
+ ServeStreamResponse,
74
+ ServeTextPart,
75
+ ServeTimedASRResponse,
76
+ ServeTTSRequest,
77
+ ServeVQGANDecodeRequest,
78
+ ServeVQGANDecodeResponse,
79
+ ServeVQGANEncodeRequest,
80
+ ServeVQGANEncodeResponse,
81
+ ServeVQPart,
49
82
  )
50
83
  from tools.vqgan.inference import load_model as load_decoder_model
51
84
 
85
+ global_lock = Lock()
86
+
87
+ # Whether to disable keepalive (which is helpful if the server is in the same cluster)
88
+ DISABLE_KEEPALIVE = os.getenv("DISABLE_KEEPALIVE", "false").lower() == "true"
89
+ async_client = httpx.AsyncClient(
90
+ timeout=120, limits=httpx.Limits(keepalive_expiry=0 if DISABLE_KEEPALIVE else None)
91
+ )
92
+ backends = torchaudio.list_audio_backends()
93
+
94
+ if "ffmpeg" in backends:
95
+ backend = "ffmpeg"
96
+ else:
97
+ backend = "soundfile"
98
+
52
99
 
53
100
  def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
54
101
  buffer = io.BytesIO()
@@ -91,9 +138,7 @@ def load_audio(reference_audio, sr):
91
138
  audio_data = reference_audio
92
139
  reference_audio = io.BytesIO(audio_data)
93
140
 
94
- waveform, original_sr = torchaudio.load(
95
- reference_audio, backend="sox" if sys.platform == "linux" else "soundfile"
96
- )
141
+ waveform, original_sr = torchaudio.load(reference_audio, backend=backend)
97
142
 
98
143
  if waveform.shape[0] > 1:
99
144
  waveform = torch.mean(waveform, dim=0, keepdim=True)
@@ -167,9 +212,390 @@ def get_content_type(audio_format):
167
212
  return "application/octet-stream"
168
213
 
169
214
 
215
+ @torch.no_grad()
216
+ @torch.autocast(device_type="cuda", dtype=torch.half)
217
+ def batch_encode(model, audios: list[bytes | torch.Tensor]):
218
+ audios = [
219
+ (
220
+ torch.from_numpy(
221
+ librosa.load(io.BytesIO(audio), sr=model.spec_transform.sample_rate)[0]
222
+ )[None]
223
+ if isinstance(audio, bytes)
224
+ else audio
225
+ )
226
+ for audio in audios
227
+ ]
228
+
229
+ # if any(audio.shape[-1] > model.spec_transform.sample_rate * 120 for audio in audios):
230
+ # raise ValueError("Single audio length is too long (>120s)")
231
+
232
+ max_length = max(audio.shape[-1] for audio in audios)
233
+ print(f"Encode max length: {max_length / model.spec_transform.sample_rate:.2f}s")
234
+
235
+ lengths = torch.tensor([audio.shape[-1] for audio in audios], device=model.device)
236
+ max_length = lengths.max().item()
237
+ padded = torch.stack(
238
+ [
239
+ torch.nn.functional.pad(audio, (0, max_length - audio.shape[-1]))
240
+ for audio in audios
241
+ ]
242
+ ).to(model.device)
243
+
244
+ features, feature_lengths = model.encode(padded, audio_lengths=lengths)
245
+ features, feature_lengths = features.cpu(), feature_lengths.cpu()
246
+
247
+ return [feature[..., :length] for feature, length in zip(features, feature_lengths)]
248
+
249
+
250
+ @cached(
251
+ cache=LRUCache(maxsize=10000),
252
+ key=lambda model, audios: (model.device, tuple(audios)),
253
+ )
254
+ def cached_vqgan_batch_encode(model, audios: list[bytes]):
255
+ return batch_encode(model, audios)
256
+
257
+
258
+ # @routes.http.post("/v1/vqgan/encode")
259
+ # def api_vqgan_encode(payload: Annotated[ServeVQGANEncodeRequest, Body(exclusive=True)]):
260
+ #
261
+ # start_time = time.time()
262
+ # tokens = cached_vqgan_batch_encode(decoder_model, payload.audios)
263
+ # logger.info(f"[EXEC] VQGAN encode time: {(time.time() - start_time) * 1000:.2f}ms")
264
+ #
265
+ # return ormsgpack.packb(
266
+ # ServeVQGANEncodeResponse(tokens=[i.tolist() for i in tokens]),
267
+ # option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
268
+ # )
269
+
270
+
271
+ @torch.no_grad()
272
+ @torch.autocast(device_type="cuda", dtype=torch.half)
273
+ def vqgan_decode(model, features):
274
+ lengths = torch.tensor(
275
+ [feature.shape[-1] for feature in features], device=model.device
276
+ )
277
+ max_length = lengths.max().item()
278
+ padded = torch.stack(
279
+ [
280
+ torch.nn.functional.pad(feature, (0, max_length - feature.shape[-1]))
281
+ for feature in features
282
+ ]
283
+ ).to(model.device)
284
+
285
+ # If bs too large, we do micro batch decode
286
+ audios, audio_lengths = [], []
287
+ for i in range(0, padded.shape[0], 8):
288
+ audio, audio_length = model.decode(
289
+ padded[i : i + 8], feature_lengths=lengths[i : i + 8]
290
+ )
291
+ audios.append(audio)
292
+ audio_lengths.append(audio_length)
293
+ audios = torch.cat(audios, dim=0)
294
+ audio_lengths = torch.cat(audio_lengths, dim=0)
295
+ audios, audio_lengths = audios.cpu(), audio_lengths.cpu()
296
+
297
+ return [audio[..., :length].numpy() for audio, length in zip(audios, audio_lengths)]
298
+
299
+
300
+ # @routes.http.post("/v1/vqgan/decode")
301
+ # def api_vqgan_decode(payload: Annotated[ServeVQGANDecodeRequest, Body(exclusive=True)]):
302
+ # tokens = [torch.tensor(token, dtype=torch.int) for token in payload.tokens]
303
+ # start_time = time.time()
304
+ # audios = vqgan_decode(decoder_model, tokens)
305
+ # logger.info(f"[EXEC] VQGAN decode time: {(time.time() - start_time) * 1000:.2f}ms")
306
+ # audios = [audio.astype(np.float16).tobytes() for audio in audios]
307
+ # return ormsgpack.packb(
308
+ # ServeVQGANDecodeResponse(audios=audios), option=ormsgpack.OPT_SERIALIZE_PYDANTIC
309
+ # )
310
+
311
+
312
+ @torch.no_grad()
313
+ def batch_asr(model, audios, sr, language="auto"):
314
+ resampled_audios = []
315
+ for audio in audios:
316
+ audio = torchaudio.functional.resample(audio, sr, 16000)
317
+ assert audio.ndim == 1
318
+ resampled_audios.append(audio)
319
+
320
+ with global_lock:
321
+ res = model.generate(
322
+ input=resampled_audios,
323
+ batch_size=len(resampled_audios),
324
+ language=language,
325
+ use_itn=True,
326
+ )
327
+
328
+ results = []
329
+ for r, audio in zip(res, audios):
330
+ text = r["text"]
331
+ text = re.sub(r"<\|.*?\|>", "", text)
332
+ duration = len(audio) / sr * 1000
333
+ huge_gap = False
334
+
335
+ if "timestamp" in r and len(r["timestamp"]) > 2:
336
+ for timestamp_a, timestamp_b in zip(
337
+ r["timestamp"][:-1], r["timestamp"][1:]
338
+ ):
339
+ # If there is a gap of more than 5 seconds, we consider it as a huge gap
340
+ if timestamp_b[0] - timestamp_a[1] > 5000:
341
+ huge_gap = True
342
+ break
343
+
344
+ # Doesn't make sense to have a huge gap at the end
345
+ if duration - r["timestamp"][-1][1] > 3000:
346
+ huge_gap = True
347
+
348
+ results.append(
349
+ {
350
+ "text": text,
351
+ "duration": duration,
352
+ "huge_gap": huge_gap,
353
+ }
354
+ )
355
+
356
+ return results
357
+
358
+
359
+ # @routes.http.post("/v1/asr")
360
+ # def api_invoke_asr(payload: Annotated[ServeASRRequest, Body(exclusive=True)]):
361
+ # start_time = time.time()
362
+ # audios = [np.frombuffer(audio, dtype=np.float16) for audio in payload.audios]
363
+ # audios = [torch.from_numpy(audio).float() for audio in audios]
364
+ #
365
+ # if any(audios.shape[-1] >= 30 * payload.sample_rate for audios in audios):
366
+ # raise HTTPException(status_code=400, detail="Audio length is too long")
367
+ #
368
+ # transcriptions = batch_asr(
369
+ # asr_model, audios=audios, sr=payload.sample_rate, language=payload.language
370
+ # )
371
+ # logger.info(f"[EXEC] ASR time: {(time.time() - start_time) * 1000:.2f}ms")
372
+ #
373
+ # return ormsgpack.packb(
374
+ # ServeASRResponse(transcriptions=transcriptions),
375
+ # option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
376
+ # )
377
+
378
+
379
+ from fish_speech.conversation import Conversation, Message
380
+
381
+
382
+ def execute_request(
383
+ input_queue: queue.Queue,
384
+ tokenizer: AutoTokenizer,
385
+ config: BaseModelArgs,
386
+ request: ServeRequest,
387
+ device: str = "cuda:0",
388
+ ):
389
+ semantic_id, im_end_id = tokenizer.convert_tokens_to_ids(
390
+ [SEMANTIC_TOKEN, IM_END_TOKEN]
391
+ )
392
+ messages = []
393
+ for message in request.messages:
394
+ messages.append(message.to_conversation_message())
395
+
396
+ assert len(messages) >= 1, "At least one message is required"
397
+ # assert messages[-1].role == "user", "The last message must be from the user"
398
+
399
+ if messages[-1].role == "user":
400
+ messages.append(Message(role="assistant", parts=[], add_im_end=False))
401
+ else:
402
+ assert (
403
+ messages[-1].role == "assistant"
404
+ ), "The last message must be from the assistant"
405
+ messages[-1].add_im_end = False
406
+
407
+ conv = Conversation(messages=messages)
408
+ prompt = conv.encode_for_inference(
409
+ tokenizer=tokenizer, num_codebooks=config.num_codebooks
410
+ ).to(device)
411
+
412
+ if request.streaming:
413
+ for i in range(request.num_samples):
414
+ yield ServeStreamResponse(
415
+ sample_id=i,
416
+ delta=ServeStreamDelta(
417
+ role="assistant",
418
+ ),
419
+ )
420
+
421
+ req = {
422
+ "prompt": prompt,
423
+ "max_new_tokens": request.max_new_tokens,
424
+ "im_end_id": im_end_id,
425
+ "semantic_id": semantic_id,
426
+ "temperature": request.temperature,
427
+ "top_p": request.top_p,
428
+ "repetition_penalty": request.repetition_penalty,
429
+ "num_samples": request.num_samples,
430
+ "early_stop_threshold": request.early_stop_threshold,
431
+ }
432
+
433
+ start = time.time()
434
+ response_queue = queue.Queue()
435
+ input_queue.put(GenerateRequest(req, response_queue))
436
+
437
+ # Decoding
438
+ decode_buffer = [[] for _ in range(request.num_samples)]
439
+ parts = [[] for _ in range(request.num_samples)]
440
+
441
+ def send_reset_buffer(sample_id):
442
+ nonlocal decode_buffer
443
+ if len(decode_buffer[sample_id]) == 0:
444
+ return
445
+
446
+ decoded = tokenizer.decode(decode_buffer[sample_id])
447
+ part = ServeTextPart(text=decoded)
448
+
449
+ if request.streaming:
450
+ yield ServeStreamResponse(delta=ServeStreamDelta(part=part))
451
+ else:
452
+ parts[sample_id].append(part)
453
+
454
+ decode_buffer[sample_id] = []
455
+
456
+ # Decode process
457
+ finished = [False for _ in range(request.num_samples)]
458
+ stats = {}
459
+ idx = 0
460
+ while True:
461
+ response = response_queue.get()
462
+
463
+ if response in ["stop", "error"]:
464
+ break
465
+
466
+ for sample_id, tokens in enumerate(response):
467
+ if finished[sample_id]:
468
+ continue
469
+
470
+ if tokens[0] == im_end_id:
471
+ finished[sample_id] = True
472
+ if request.streaming:
473
+ yield from send_reset_buffer(sample_id)
474
+ yield ServeStreamResponse(
475
+ sample_id=sample_id,
476
+ finish_reason="stop",
477
+ stats=stats,
478
+ )
479
+ continue
480
+
481
+ if tokens[0] == semantic_id and request.streaming:
482
+ yield from send_reset_buffer(sample_id)
483
+ # Streaming vq
484
+ _tokens = tokens[1:].clone() - 1
485
+
486
+ if config.share_codebook_embeddings is False:
487
+ for i in range(len(_tokens)):
488
+ _tokens[i] -= config.codebook_size * i
489
+
490
+ yield ServeStreamResponse(
491
+ sample_id=sample_id,
492
+ delta=ServeStreamDelta(part=ServeVQPart(codes=_tokens.tolist())),
493
+ )
494
+ continue
495
+
496
+ # Not streaming vq
497
+ if tokens[0] == semantic_id:
498
+ yield from send_reset_buffer(sample_id)
499
+ # None streaming vq
500
+ if len(parts[sample_id]) == 0 or not isinstance(
501
+ parts[sample_id][-1], ServeVQPart
502
+ ):
503
+ _tokens = tokens[1:].clone() - 1
504
+
505
+ if config.share_codebook_embeddings is False:
506
+ for i in range(len(_tokens)):
507
+ _tokens[i] -= config.codebook_size * i
508
+
509
+ parts[sample_id].append(ServeVQPart(codes=_tokens.tolist()))
510
+ else:
511
+ for codebook_id, value in enumerate(tokens[1:, :]):
512
+ val = value.item() - 1
513
+ if config.share_codebook_embeddings is False:
514
+ val -= config.codebook_size * codebook_id
515
+
516
+ parts[sample_id][-1].codes[codebook_id].append(val)
517
+ continue
518
+
519
+ if tokens[0] != semantic_id:
520
+ # Stream text decode is not supported now
521
+ decode_buffer[sample_id].append(tokens[0, 0])
522
+
523
+ if idx == 0:
524
+ stats["time_to_first_token"] = (time.time() - start) * 1000
525
+
526
+ idx += 1
527
+
528
+ for sample_id in range(request.num_samples):
529
+ yield from send_reset_buffer(sample_id)
530
+
531
+ stats["total_time"] = (time.time() - start) * 1000
532
+ stats["total_tokens"] = idx
533
+
534
+ if request.streaming:
535
+ for sample_id in range(request.num_samples):
536
+ if finished[sample_id]:
537
+ continue
538
+ yield ServeStreamResponse(
539
+ finish_reason=response, stats=stats, sample_id=sample_id
540
+ )
541
+ return
542
+
543
+ yield ServeResponse(
544
+ messages=[
545
+ ServeMessage(role="assistant", parts=parts[i])
546
+ for i in range(request.num_samples)
547
+ ],
548
+ finish_reason=response,
549
+ stats=stats,
550
+ )
551
+
552
+
553
+ # @routes.http.post("/v1/chat")
554
+ # def api_invoke_chat(
555
+ # req: Annotated[ServeRequest, Body(exclusive=True)],
556
+ # ):
557
+ # """
558
+ # Invoke model and generate audio
559
+ # """
560
+ #
561
+ # # This makes torch compile happy
562
+ # assert (
563
+ # req.num_samples == GLOBAL_NUM_SAMPLES
564
+ # ), f"num_samples must be {GLOBAL_NUM_SAMPLES}"
565
+ #
566
+ # content_type = request.headers.get("Content-Type", "application/json")
567
+ # json_mode = "application/json" in content_type
568
+ #
569
+ # async def wrapped_generator():
570
+ # generator = execute_request(llama_queue, tokenizer, config, req, args.device)
571
+ #
572
+ # for i in generator:
573
+ # if json_mode:
574
+ # body = i.model_dump_json().encode("utf-8")
575
+ # yield b"data: " + body + b"\n\n"
576
+ # else:
577
+ # body = ormsgpack.packb(i, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
578
+ # yield struct.pack("I", len(body)) + body
579
+ #
580
+ # # Naive mode
581
+ # if req.streaming is False:
582
+ # result = next(execute_request(llama_queue, tokenizer, config, req, args.device))
583
+ #
584
+ # if json_mode:
585
+ # return JSONResponse(result.model_dump())
586
+ # else:
587
+ # return ormsgpack.packb(result, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
588
+ #
589
+ # return StreamResponse(
590
+ # iterable=wrapped_generator(), content_type="text/event-stream"
591
+ # )
592
+
593
+
170
594
  @torch.inference_mode()
171
595
  def inference(req: ServeTTSRequest):
172
596
 
597
+ global prompt_tokens, prompt_texts
598
+
173
599
  idstr: str | None = req.reference_id
174
600
  if idstr is not None:
175
601
  ref_folder = Path("references") / idstr
@@ -177,33 +603,47 @@ def inference(req: ServeTTSRequest):
177
603
  ref_audios = list_files(
178
604
  ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
179
605
  )
180
- prompt_tokens = [
181
- encode_reference(
182
- decoder_model=decoder_model,
183
- reference_audio=audio_to_bytes(str(ref_audio)),
184
- enable_reference_audio=True,
185
- )
186
- for ref_audio in ref_audios
187
- ]
188
- prompt_texts = [
189
- read_ref_text(str(ref_audio.with_suffix(".lab")))
190
- for ref_audio in ref_audios
191
- ]
606
+
607
+ if req.use_memory_cache == "never" or (
608
+ req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0
609
+ ):
610
+ prompt_tokens = [
611
+ encode_reference(
612
+ decoder_model=decoder_model,
613
+ reference_audio=audio_to_bytes(str(ref_audio)),
614
+ enable_reference_audio=True,
615
+ )
616
+ for ref_audio in ref_audios
617
+ ]
618
+ prompt_texts = [
619
+ read_ref_text(str(ref_audio.with_suffix(".lab")))
620
+ for ref_audio in ref_audios
621
+ ]
622
+ else:
623
+ logger.info("Use same references")
192
624
 
193
625
  else:
194
626
  # Parse reference audio aka prompt
195
627
  refs = req.references
196
- if refs is None:
197
- refs = []
198
- prompt_tokens = [
199
- encode_reference(
200
- decoder_model=decoder_model,
201
- reference_audio=ref.audio,
202
- enable_reference_audio=True,
203
- )
204
- for ref in refs
205
- ]
206
- prompt_texts = [ref.text for ref in refs]
628
+
629
+ if req.use_memory_cache == "never" or (
630
+ req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0
631
+ ):
632
+ prompt_tokens = [
633
+ encode_reference(
634
+ decoder_model=decoder_model,
635
+ reference_audio=ref.audio,
636
+ enable_reference_audio=True,
637
+ )
638
+ for ref in refs
639
+ ]
640
+ prompt_texts = [ref.text for ref in refs]
641
+ else:
642
+ logger.info("Use same references")
643
+
644
+ if req.seed is not None:
645
+ set_seed(req.seed)
646
+ logger.warning(f"set seed: {req.seed}")
207
647
 
208
648
  # LLAMA Inference
209
649
  request = dict(
@@ -220,7 +660,7 @@ def inference(req: ServeTTSRequest):
220
660
  compile=args.compile,
221
661
  iterative_prompt=req.chunk_length > 0,
222
662
  chunk_length=req.chunk_length,
223
- max_length=2048,
663
+ max_length=4096,
224
664
  prompt_tokens=prompt_tokens,
225
665
  prompt_text=prompt_texts,
226
666
  )
@@ -342,6 +782,8 @@ async def buffer_to_async_generator(buffer):
342
782
 
343
783
  def parse_args():
344
784
  parser = ArgumentParser()
785
+ parser.add_argument("--mode", type=str, choices=["agent", "tts"], default="tts")
786
+ parser.add_argument("--load-asr-model", action="store_true")
345
787
  parser.add_argument(
346
788
  "--llama-checkpoint-path",
347
789
  type=str,
@@ -367,18 +809,26 @@ def parse_args():
367
809
  # openapi = OpenAPI(
368
810
  # {
369
811
  # "title": "Fish Speech API",
812
+ # "version": "1.4.2",
370
813
  # },
371
814
  # ).routes
372
815
  #
373
816
  #
374
817
  # class MsgPackRequest(HttpRequest):
375
- # async def data(self) -> Annotated[Any, ContentType("application/msgpack")]:
818
+ # async def data(
819
+ # self,
820
+ # ) -> Annotated[
821
+ # Any, ContentType("application/msgpack"), ContentType("application/json")
822
+ # ]:
376
823
  # if self.content_type == "application/msgpack":
377
824
  # return ormsgpack.unpackb(await self.body)
378
825
  #
826
+ # elif self.content_type == "application/json":
827
+ # return await self.json
828
+ #
379
829
  # raise HTTPException(
380
830
  # HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
381
- # headers={"Accept": "application/msgpack"},
831
+ # headers={"Accept": "application/msgpack, application/json"},
382
832
  # )
383
833
  #
384
834
  #
@@ -393,48 +843,101 @@ def parse_args():
393
843
  # )
394
844
 
395
845
 
396
- if __name__ == "__main__":
846
+ def load_asr_model(*, device="cuda", hub="ms"):
847
+ return AutoModel(
848
+ model="iic/SenseVoiceSmall",
849
+ device=device,
850
+ disable_pbar=True,
851
+ hub=hub,
852
+ )
397
853
 
398
- import uvicorn
399
854
 
400
- args = parse_args()
401
- args.precision = torch.half if args.half else torch.bfloat16
855
+ # Each worker process created by Uvicorn has its own memory space,
856
+ # meaning that models and variables are not shared between processes.
857
+ # Therefore, any global variables (like `llama_queue` or `decoder_model`)
858
+ # will not be shared across workers.
402
859
 
403
- logger.info("Loading Llama model...")
404
- llama_queue = launch_thread_safe_queue(
405
- checkpoint_path=args.llama_checkpoint_path,
406
- device=args.device,
407
- precision=args.precision,
408
- compile=args.compile,
409
- )
410
- logger.info("Llama model loaded, loading VQ-GAN model...")
411
860
 
412
- decoder_model = load_decoder_model(
413
- config_name=args.decoder_config_name,
414
- checkpoint_path=args.decoder_checkpoint_path,
415
- device=args.device,
416
- )
861
+ # Multi-threading for deep learning can cause issues, such as inconsistent
862
+ # outputs if multiple threads access the same buffers simultaneously.
863
+ # Instead, it's better to use multiprocessing or independent models per thread.
864
+ # @app.on_startup
865
+ # def initialize_app(app: Kui):
866
+ #
867
+ # global args, llama_queue, tokenizer, config, decoder_model, vad_model, asr_model, prompt_tokens, prompt_texts
868
+ #
869
+ # prompt_tokens, prompt_texts = [], []
870
+ #
871
+ # args = parse_args() # args same as ones in other processes
872
+ # args.precision = torch.half if args.half else torch.bfloat16
873
+ #
874
+ # if args.load_asr_model:
875
+ # logger.info(f"Loading ASR model...")
876
+ # asr_model = load_asr_model(device=args.device)
877
+ #
878
+ # logger.info("Loading Llama model...")
879
+ #
880
+ # if args.mode == "tts":
881
+ # llama_queue = launch_thread_safe_queue(
882
+ # checkpoint_path=args.llama_checkpoint_path,
883
+ # device=args.device,
884
+ # precision=args.precision,
885
+ # compile=args.compile,
886
+ # )
887
+ # else:
888
+ # llama_queue, tokenizer, config = launch_thread_safe_queue_agent(
889
+ # checkpoint_path=args.llama_checkpoint_path,
890
+ # device=args.device,
891
+ # precision=args.precision,
892
+ # compile=args.compile,
893
+ # )
894
+ #
895
+ # logger.info("Llama model loaded, loading VQ-GAN model...")
896
+ #
897
+ # decoder_model = load_decoder_model(
898
+ # config_name=args.decoder_config_name,
899
+ # checkpoint_path=args.decoder_checkpoint_path,
900
+ # device=args.device,
901
+ # )
902
+ #
903
+ # logger.info("VQ-GAN model loaded, warming up...")
904
+ #
905
+ # vad_model = load_silero_vad()
906
+ #
907
+ # logger.info("VAD model loaded, warming up...")
908
+ #
909
+ # if args.mode == "tts":
910
+ # # Dry run to ensure models work and avoid first-time latency
911
+ # list(
912
+ # inference(
913
+ # ServeTTSRequest(
914
+ # text="Hello world.",
915
+ # references=[],
916
+ # reference_id=None,
917
+ # max_new_tokens=0,
918
+ # chunk_length=200,
919
+ # top_p=0.7,
920
+ # repetition_penalty=1.2,
921
+ # temperature=0.7,
922
+ # emotion=None,
923
+ # format="wav",
924
+ # )
925
+ # )
926
+ # )
927
+ #
928
+ # logger.info(f"Warming up done, starting server at http://{args.listen}")
417
929
 
418
- logger.info("VQ-GAN model loaded, warming up...")
419
-
420
- # Dry run to check if the model is loaded correctly and avoid the first-time latency
421
- list(
422
- inference(
423
- ServeTTSRequest(
424
- text="Hello world.",
425
- references=[],
426
- reference_id=None,
427
- max_new_tokens=1024,
428
- chunk_length=200,
429
- top_p=0.7,
430
- repetition_penalty=1.2,
431
- temperature=0.7,
432
- emotion=None,
433
- format="wav",
434
- )
435
- )
436
- )
437
930
 
438
- logger.info(f"Warming up done, starting server at http://{args.listen}")
931
+ if __name__ == "__main__":
932
+
933
+ import uvicorn
934
+
935
+ args = parse_args()
439
936
  host, port = args.listen.split(":")
440
- uvicorn.run(app, host=host, port=int(port), workers=args.workers, log_level="info")
937
+ uvicorn.run(
938
+ "tools.api:app",
939
+ host=host,
940
+ port=int(port),
941
+ workers=args.workers,
942
+ log_level="info",
943
+ )