xinference 0.16.3__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +62 -11
- xinference/client/restful/restful_client.py +8 -2
- xinference/constants.py +1 -0
- xinference/core/model.py +10 -3
- xinference/core/supervisor.py +8 -2
- xinference/core/utils.py +67 -2
- xinference/model/audio/model_spec.json +1 -1
- xinference/model/image/stable_diffusion/core.py +5 -2
- xinference/model/llm/llm_family.json +176 -4
- xinference/model/llm/llm_family_modelscope.json +211 -0
- xinference/model/llm/mlx/core.py +45 -2
- xinference/model/rerank/core.py +11 -4
- xinference/thirdparty/fish_speech/fish_speech/conversation.py +254 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +2 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +2 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +2 -2
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ko_KR.json +123 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +2 -1
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +76 -11
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +9 -9
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +1 -1
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +32 -1
- xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +2 -1
- xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +22 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +1 -1
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1 -1
- xinference/thirdparty/fish_speech/tools/api.py +578 -75
- xinference/thirdparty/fish_speech/tools/e2e_webui.py +232 -0
- xinference/thirdparty/fish_speech/tools/fish_e2e.py +298 -0
- xinference/thirdparty/fish_speech/tools/llama/generate.py +393 -9
- xinference/thirdparty/fish_speech/tools/msgpack_api.py +90 -29
- xinference/thirdparty/fish_speech/tools/post_api.py +37 -15
- xinference/thirdparty/fish_speech/tools/schema.py +187 -0
- xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +7 -1
- xinference/thirdparty/fish_speech/tools/vqgan/inference.py +2 -3
- xinference/thirdparty/fish_speech/tools/webui.py +138 -75
- {xinference-0.16.3.dist-info → xinference-1.0.0.dist-info}/METADATA +23 -1
- {xinference-0.16.3.dist-info → xinference-1.0.0.dist-info}/RECORD +43 -50
- {xinference-0.16.3.dist-info → xinference-1.0.0.dist-info}/WHEEL +1 -1
- xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/commons.py +0 -35
- xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
- {xinference-0.16.3.dist-info → xinference-1.0.0.dist-info}/LICENSE +0 -0
- {xinference-0.16.3.dist-info → xinference-1.0.0.dist-info}/entry_points.txt +0 -0
- {xinference-0.16.3.dist-info → xinference-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -8,14 +8,15 @@ import requests
|
|
|
8
8
|
from pydub import AudioSegment
|
|
9
9
|
from pydub.playback import play
|
|
10
10
|
|
|
11
|
-
from tools.commons import ServeReferenceAudio, ServeTTSRequest
|
|
12
11
|
from tools.file import audio_to_bytes, read_ref_text
|
|
12
|
+
from tools.schema import ServeReferenceAudio, ServeTTSRequest
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
def parse_args():
|
|
16
16
|
|
|
17
17
|
parser = argparse.ArgumentParser(
|
|
18
|
-
description="Send a WAV file and text to a server and receive synthesized audio."
|
|
18
|
+
description="Send a WAV file and text to a server and receive synthesized audio.",
|
|
19
|
+
formatter_class=argparse.RawTextHelpFormatter,
|
|
19
20
|
)
|
|
20
21
|
|
|
21
22
|
parser.add_argument(
|
|
@@ -33,7 +34,7 @@ def parse_args():
|
|
|
33
34
|
"-id",
|
|
34
35
|
type=str,
|
|
35
36
|
default=None,
|
|
36
|
-
help="ID of the reference model
|
|
37
|
+
help="ID of the reference model to be used for the speech\n(Local: name of folder containing audios and files)",
|
|
37
38
|
)
|
|
38
39
|
parser.add_argument(
|
|
39
40
|
"--reference_audio",
|
|
@@ -41,7 +42,7 @@ def parse_args():
|
|
|
41
42
|
type=str,
|
|
42
43
|
nargs="+",
|
|
43
44
|
default=None,
|
|
44
|
-
help="Path to the
|
|
45
|
+
help="Path to the audio file",
|
|
45
46
|
)
|
|
46
47
|
parser.add_argument(
|
|
47
48
|
"--reference_text",
|
|
@@ -68,17 +69,25 @@ def parse_args():
|
|
|
68
69
|
parser.add_argument(
|
|
69
70
|
"--format", type=str, choices=["wav", "mp3", "flac"], default="wav"
|
|
70
71
|
)
|
|
71
|
-
parser.add_argument(
|
|
72
|
+
parser.add_argument(
|
|
73
|
+
"--mp3_bitrate", type=int, choices=[64, 128, 192], default=64, help="kHz"
|
|
74
|
+
)
|
|
72
75
|
parser.add_argument("--opus_bitrate", type=int, default=-1000)
|
|
73
|
-
parser.add_argument(
|
|
76
|
+
parser.add_argument(
|
|
77
|
+
"--latency",
|
|
78
|
+
type=str,
|
|
79
|
+
default="normal",
|
|
80
|
+
choices=["normal", "balanced"],
|
|
81
|
+
help="Used in api.fish.audio/v1/tts",
|
|
82
|
+
)
|
|
74
83
|
parser.add_argument(
|
|
75
84
|
"--max_new_tokens",
|
|
76
85
|
type=int,
|
|
77
|
-
default=
|
|
78
|
-
help="Maximum new tokens to generate",
|
|
86
|
+
default=0,
|
|
87
|
+
help="Maximum new tokens to generate. \n0 means no limit.",
|
|
79
88
|
)
|
|
80
89
|
parser.add_argument(
|
|
81
|
-
"--chunk_length", type=int, default=
|
|
90
|
+
"--chunk_length", type=int, default=200, help="Chunk length for synthesis"
|
|
82
91
|
)
|
|
83
92
|
parser.add_argument(
|
|
84
93
|
"--top_p", type=float, default=0.7, help="Top-p sampling for synthesis"
|
|
@@ -92,10 +101,7 @@ def parse_args():
|
|
|
92
101
|
parser.add_argument(
|
|
93
102
|
"--temperature", type=float, default=0.7, help="Temperature for sampling"
|
|
94
103
|
)
|
|
95
|
-
|
|
96
|
-
"--speaker", type=str, default=None, help="Speaker ID for voice synthesis"
|
|
97
|
-
)
|
|
98
|
-
parser.add_argument("--emotion", type=str, default=None, help="Speaker's Emotion")
|
|
104
|
+
|
|
99
105
|
parser.add_argument(
|
|
100
106
|
"--streaming", type=bool, default=False, help="Enable streaming response"
|
|
101
107
|
)
|
|
@@ -103,6 +109,22 @@ def parse_args():
|
|
|
103
109
|
"--channels", type=int, default=1, help="Number of audio channels"
|
|
104
110
|
)
|
|
105
111
|
parser.add_argument("--rate", type=int, default=44100, help="Sample rate for audio")
|
|
112
|
+
parser.add_argument(
|
|
113
|
+
"--use_memory_cache",
|
|
114
|
+
type=str,
|
|
115
|
+
default="never",
|
|
116
|
+
choices=["on-demand", "never"],
|
|
117
|
+
help="Cache encoded references codes in memory.\n"
|
|
118
|
+
"If `on-demand`, the server will use cached encodings\n "
|
|
119
|
+
"instead of encoding reference audio again.",
|
|
120
|
+
)
|
|
121
|
+
parser.add_argument(
|
|
122
|
+
"--seed",
|
|
123
|
+
type=int,
|
|
124
|
+
default=None,
|
|
125
|
+
help="`None` means randomized inference, otherwise deterministic.\n"
|
|
126
|
+
"It can't be used for fixing a timbre.",
|
|
127
|
+
)
|
|
106
128
|
|
|
107
129
|
return parser.parse_args()
|
|
108
130
|
|
|
@@ -145,9 +167,9 @@ if __name__ == "__main__":
|
|
|
145
167
|
"top_p": args.top_p,
|
|
146
168
|
"repetition_penalty": args.repetition_penalty,
|
|
147
169
|
"temperature": args.temperature,
|
|
148
|
-
"speaker": args.speaker,
|
|
149
|
-
"emotion": args.emotion,
|
|
150
170
|
"streaming": args.streaming,
|
|
171
|
+
"use_memory_cache": args.use_memory_cache,
|
|
172
|
+
"seed": args.seed,
|
|
151
173
|
}
|
|
152
174
|
|
|
153
175
|
pydantic_data = ServeTTSRequest(**data)
|
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import queue
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Annotated, Literal, Optional
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from pydantic import AfterValidator, BaseModel, Field, confloat, conint, conlist
|
|
8
|
+
from pydantic.functional_validators import SkipValidation
|
|
9
|
+
|
|
10
|
+
from fish_speech.conversation import Message, TextPart, VQPart
|
|
11
|
+
|
|
12
|
+
GLOBAL_NUM_SAMPLES = int(os.getenv("GLOBAL_NUM_SAMPLES", 1))
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ServeVQPart(BaseModel):
|
|
16
|
+
type: Literal["vq"] = "vq"
|
|
17
|
+
codes: SkipValidation[list[list[int]]]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ServeTextPart(BaseModel):
|
|
21
|
+
type: Literal["text"] = "text"
|
|
22
|
+
text: str
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ServeAudioPart(BaseModel):
|
|
26
|
+
type: Literal["audio"] = "audio"
|
|
27
|
+
audio: bytes
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class ASRPackRequest:
|
|
32
|
+
audio: torch.Tensor
|
|
33
|
+
result_queue: queue.Queue
|
|
34
|
+
language: str
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class ServeASRRequest(BaseModel):
|
|
38
|
+
# The audio should be an uncompressed PCM float16 audio
|
|
39
|
+
audios: list[bytes]
|
|
40
|
+
sample_rate: int = 44100
|
|
41
|
+
language: Literal["zh", "en", "ja", "auto"] = "auto"
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class ServeASRTranscription(BaseModel):
|
|
45
|
+
text: str
|
|
46
|
+
duration: float
|
|
47
|
+
huge_gap: bool
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class ServeASRSegment(BaseModel):
|
|
51
|
+
text: str
|
|
52
|
+
start: float
|
|
53
|
+
end: float
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class ServeTimedASRResponse(BaseModel):
|
|
57
|
+
text: str
|
|
58
|
+
segments: list[ServeASRSegment]
|
|
59
|
+
duration: float
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class ServeASRResponse(BaseModel):
|
|
63
|
+
transcriptions: list[ServeASRTranscription]
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class ServeMessage(BaseModel):
|
|
67
|
+
role: Literal["system", "assistant", "user"]
|
|
68
|
+
parts: list[ServeVQPart | ServeTextPart]
|
|
69
|
+
|
|
70
|
+
def to_conversation_message(self):
|
|
71
|
+
new_message = Message(role=self.role, parts=[])
|
|
72
|
+
for part in self.parts:
|
|
73
|
+
if isinstance(part, ServeTextPart):
|
|
74
|
+
new_message.parts.append(TextPart(text=part.text))
|
|
75
|
+
elif isinstance(part, ServeVQPart):
|
|
76
|
+
new_message.parts.append(
|
|
77
|
+
VQPart(codes=torch.tensor(part.codes, dtype=torch.int))
|
|
78
|
+
)
|
|
79
|
+
else:
|
|
80
|
+
raise ValueError(f"Unsupported part type: {part}")
|
|
81
|
+
|
|
82
|
+
return new_message
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class ServeRequest(BaseModel):
|
|
86
|
+
messages: Annotated[list[ServeMessage], conlist(ServeMessage, min_length=1)]
|
|
87
|
+
max_new_tokens: int = 1024
|
|
88
|
+
top_p: float = 0.7
|
|
89
|
+
repetition_penalty: float = 1.2
|
|
90
|
+
temperature: float = 0.7
|
|
91
|
+
streaming: bool = False
|
|
92
|
+
num_samples: int = 1
|
|
93
|
+
early_stop_threshold: float = 1.0
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class ServeVQGANEncodeRequest(BaseModel):
|
|
97
|
+
# The audio here should be in wav, mp3, etc
|
|
98
|
+
audios: list[bytes]
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class ServeVQGANEncodeResponse(BaseModel):
|
|
102
|
+
tokens: SkipValidation[list[list[list[int]]]]
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class ServeVQGANDecodeRequest(BaseModel):
|
|
106
|
+
tokens: SkipValidation[list[list[list[int]]]]
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class ServeVQGANDecodeResponse(BaseModel):
|
|
110
|
+
# The audio here should be in PCM float16 format
|
|
111
|
+
audios: list[bytes]
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class ServeReferenceAudio(BaseModel):
|
|
115
|
+
audio: bytes
|
|
116
|
+
text: str
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class ServeForwardMessage(BaseModel):
|
|
120
|
+
role: str
|
|
121
|
+
content: str
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class ServeResponse(BaseModel):
|
|
125
|
+
messages: list[ServeMessage]
|
|
126
|
+
finish_reason: Literal["stop", "error"] | None = None
|
|
127
|
+
stats: dict[str, int | float | str] = {}
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
class ServeStreamDelta(BaseModel):
|
|
131
|
+
role: Literal["system", "assistant", "user"] | None = None
|
|
132
|
+
part: ServeVQPart | ServeTextPart | None = None
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class ServeStreamResponse(BaseModel):
|
|
136
|
+
sample_id: int = 0
|
|
137
|
+
delta: ServeStreamDelta | None = None
|
|
138
|
+
finish_reason: Literal["stop", "error"] | None = None
|
|
139
|
+
stats: dict[str, int | float | str] | None = None
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
class ServeReferenceAudio(BaseModel):
|
|
143
|
+
audio: bytes
|
|
144
|
+
text: str
|
|
145
|
+
|
|
146
|
+
def __repr__(self) -> str:
|
|
147
|
+
return f"ServeReferenceAudio(text={self.text!r}, audio_size={len(self.audio)})"
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
class ServeChatRequestV1(BaseModel):
|
|
151
|
+
model: str = "llama3-8b"
|
|
152
|
+
messages: list[ServeForwardMessage] = []
|
|
153
|
+
audio: bytes | None = None
|
|
154
|
+
temperature: float = 1.0
|
|
155
|
+
top_p: float = 1.0
|
|
156
|
+
max_tokens: int = 256
|
|
157
|
+
voice: str = "jessica"
|
|
158
|
+
tts_audio_format: Literal["mp3", "pcm", "opus"] = "mp3"
|
|
159
|
+
tts_audio_bitrate: Literal[16, 24, 32, 48, 64, 96, 128, 192] = 128
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
class ServeTTSRequest(BaseModel):
|
|
163
|
+
text: str
|
|
164
|
+
chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
|
|
165
|
+
# Audio format
|
|
166
|
+
format: Literal["wav", "pcm", "mp3"] = "wav"
|
|
167
|
+
mp3_bitrate: Literal[64, 128, 192] = 128
|
|
168
|
+
# References audios for in-context learning
|
|
169
|
+
references: list[ServeReferenceAudio] = []
|
|
170
|
+
# Reference id
|
|
171
|
+
# For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
|
|
172
|
+
# Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
|
|
173
|
+
reference_id: str | None = None
|
|
174
|
+
seed: int | None = None
|
|
175
|
+
use_memory_cache: Literal["on-demand", "never"] = "never"
|
|
176
|
+
# Normalize text for en & zh, this increase stability for numbers
|
|
177
|
+
normalize: bool = True
|
|
178
|
+
mp3_bitrate: Optional[int] = 64
|
|
179
|
+
opus_bitrate: Optional[int] = -1000
|
|
180
|
+
# Balance mode will reduce latency to 300ms, but may decrease stability
|
|
181
|
+
latency: Literal["normal", "balanced"] = "normal"
|
|
182
|
+
# not usually used below
|
|
183
|
+
streaming: bool = False
|
|
184
|
+
max_new_tokens: int = 1024
|
|
185
|
+
top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
|
|
186
|
+
repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
|
|
187
|
+
temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
|
|
@@ -24,6 +24,12 @@ OmegaConf.register_new_resolver("eval", eval)
|
|
|
24
24
|
# This file is used to convert the audio files to text files using the Whisper model.
|
|
25
25
|
# It's mainly used to generate the training data for the VQ model.
|
|
26
26
|
|
|
27
|
+
backends = torchaudio.list_audio_backends()
|
|
28
|
+
|
|
29
|
+
if "ffmpeg" in backends:
|
|
30
|
+
backend = "ffmpeg"
|
|
31
|
+
else:
|
|
32
|
+
backend = "soundfile"
|
|
27
33
|
|
|
28
34
|
RANK = int(os.environ.get("SLURM_PROCID", 0))
|
|
29
35
|
WORLD_SIZE = int(os.environ.get("SLURM_NTASKS", 1))
|
|
@@ -81,7 +87,7 @@ def process_batch(files: list[Path], model) -> float:
|
|
|
81
87
|
for file in files:
|
|
82
88
|
try:
|
|
83
89
|
wav, sr = torchaudio.load(
|
|
84
|
-
str(file), backend=
|
|
90
|
+
str(file), backend=backend
|
|
85
91
|
) # Need to install libsox-dev
|
|
86
92
|
except Exception as e:
|
|
87
93
|
logger.error(f"Error reading {file}: {e}")
|
|
@@ -24,8 +24,7 @@ def load_model(config_name, checkpoint_path, device="cuda"):
|
|
|
24
24
|
|
|
25
25
|
model = instantiate(cfg)
|
|
26
26
|
state_dict = torch.load(
|
|
27
|
-
checkpoint_path,
|
|
28
|
-
map_location=device,
|
|
27
|
+
checkpoint_path, map_location=device, mmap=True, weights_only=True
|
|
29
28
|
)
|
|
30
29
|
if "state_dict" in state_dict:
|
|
31
30
|
state_dict = state_dict["state_dict"]
|
|
@@ -37,7 +36,7 @@ def load_model(config_name, checkpoint_path, device="cuda"):
|
|
|
37
36
|
if "generator." in k
|
|
38
37
|
}
|
|
39
38
|
|
|
40
|
-
result = model.load_state_dict(state_dict, strict=False)
|
|
39
|
+
result = model.load_state_dict(state_dict, strict=False, assign=True)
|
|
41
40
|
model.eval()
|
|
42
41
|
model.to(device)
|
|
43
42
|
|
|
@@ -21,8 +21,9 @@ from transformers import AutoTokenizer
|
|
|
21
21
|
|
|
22
22
|
from fish_speech.i18n import i18n
|
|
23
23
|
from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
|
|
24
|
-
from fish_speech.utils import autocast_exclude_mps
|
|
24
|
+
from fish_speech.utils import autocast_exclude_mps, set_seed
|
|
25
25
|
from tools.api import decode_vq_tokens, encode_reference
|
|
26
|
+
from tools.file import AUDIO_EXTENSIONS, list_files
|
|
26
27
|
from tools.llama.generate import (
|
|
27
28
|
GenerateRequest,
|
|
28
29
|
GenerateResponse,
|
|
@@ -70,6 +71,7 @@ def inference(
|
|
|
70
71
|
top_p,
|
|
71
72
|
repetition_penalty,
|
|
72
73
|
temperature,
|
|
74
|
+
seed="0",
|
|
73
75
|
streaming=False,
|
|
74
76
|
):
|
|
75
77
|
if args.max_gradio_length > 0 and len(text) > args.max_gradio_length:
|
|
@@ -81,6 +83,11 @@ def inference(
|
|
|
81
83
|
),
|
|
82
84
|
)
|
|
83
85
|
|
|
86
|
+
seed = int(seed)
|
|
87
|
+
if seed != 0:
|
|
88
|
+
set_seed(seed)
|
|
89
|
+
logger.warning(f"set seed: {seed}")
|
|
90
|
+
|
|
84
91
|
# Parse reference audio aka prompt
|
|
85
92
|
prompt_tokens = encode_reference(
|
|
86
93
|
decoder_model=decoder_model,
|
|
@@ -139,7 +146,9 @@ def inference(
|
|
|
139
146
|
segments.append(fake_audios)
|
|
140
147
|
|
|
141
148
|
if streaming:
|
|
142
|
-
|
|
149
|
+
wav_header = wav_chunk_header()
|
|
150
|
+
audio_data = (fake_audios * 32768).astype(np.int16).tobytes()
|
|
151
|
+
yield wav_header + audio_data, None, None
|
|
143
152
|
|
|
144
153
|
if len(segments) == 0:
|
|
145
154
|
return (
|
|
@@ -177,6 +186,7 @@ def inference_wrapper(
|
|
|
177
186
|
top_p,
|
|
178
187
|
repetition_penalty,
|
|
179
188
|
temperature,
|
|
189
|
+
seed,
|
|
180
190
|
batch_infer_num,
|
|
181
191
|
):
|
|
182
192
|
audios = []
|
|
@@ -193,6 +203,7 @@ def inference_wrapper(
|
|
|
193
203
|
top_p,
|
|
194
204
|
repetition_penalty,
|
|
195
205
|
temperature,
|
|
206
|
+
seed,
|
|
196
207
|
)
|
|
197
208
|
|
|
198
209
|
_, audio_data, error_message = next(result)
|
|
@@ -235,7 +246,11 @@ def normalize_text(user_input, use_normalization):
|
|
|
235
246
|
return user_input
|
|
236
247
|
|
|
237
248
|
|
|
238
|
-
|
|
249
|
+
def update_examples():
|
|
250
|
+
examples_dir = Path("references")
|
|
251
|
+
examples_dir.mkdir(parents=True, exist_ok=True)
|
|
252
|
+
example_audios = list_files(examples_dir, AUDIO_EXTENSIONS, recursive=True)
|
|
253
|
+
return gr.Dropdown(choices=example_audios + [""])
|
|
239
254
|
|
|
240
255
|
|
|
241
256
|
def build_app():
|
|
@@ -273,76 +288,100 @@ def build_app():
|
|
|
273
288
|
)
|
|
274
289
|
|
|
275
290
|
with gr.Row():
|
|
276
|
-
with gr.
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
291
|
+
with gr.Column():
|
|
292
|
+
with gr.Tab(label=i18n("Advanced Config")):
|
|
293
|
+
with gr.Row():
|
|
294
|
+
chunk_length = gr.Slider(
|
|
295
|
+
label=i18n("Iterative Prompt Length, 0 means off"),
|
|
296
|
+
minimum=50,
|
|
297
|
+
maximum=300,
|
|
298
|
+
value=200,
|
|
299
|
+
step=8,
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
max_new_tokens = gr.Slider(
|
|
303
|
+
label=i18n(
|
|
304
|
+
"Maximum tokens per batch, 0 means no limit"
|
|
305
|
+
),
|
|
306
|
+
minimum=0,
|
|
307
|
+
maximum=2048,
|
|
308
|
+
value=0, # 0 means no limit
|
|
309
|
+
step=8,
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
with gr.Row():
|
|
313
|
+
top_p = gr.Slider(
|
|
314
|
+
label="Top-P",
|
|
315
|
+
minimum=0.6,
|
|
316
|
+
maximum=0.9,
|
|
317
|
+
value=0.7,
|
|
318
|
+
step=0.01,
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
repetition_penalty = gr.Slider(
|
|
322
|
+
label=i18n("Repetition Penalty"),
|
|
323
|
+
minimum=1,
|
|
324
|
+
maximum=1.5,
|
|
325
|
+
value=1.2,
|
|
326
|
+
step=0.01,
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
with gr.Row():
|
|
330
|
+
temperature = gr.Slider(
|
|
331
|
+
label="Temperature",
|
|
332
|
+
minimum=0.6,
|
|
333
|
+
maximum=0.9,
|
|
334
|
+
value=0.7,
|
|
335
|
+
step=0.01,
|
|
336
|
+
)
|
|
337
|
+
seed = gr.Textbox(
|
|
338
|
+
label="Seed",
|
|
339
|
+
info="0 means randomized inference, otherwise deterministic",
|
|
340
|
+
placeholder="any 32-bit-integer",
|
|
341
|
+
value="0",
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
with gr.Tab(label=i18n("Reference Audio")):
|
|
345
|
+
with gr.Row():
|
|
346
|
+
gr.Markdown(
|
|
347
|
+
i18n(
|
|
348
|
+
"5 to 10 seconds of reference audio, useful for specifying speaker."
|
|
349
|
+
)
|
|
350
|
+
)
|
|
351
|
+
with gr.Row():
|
|
352
|
+
enable_reference_audio = gr.Checkbox(
|
|
353
|
+
label=i18n("Enable Reference Audio"),
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
with gr.Row():
|
|
357
|
+
example_audio_dropdown = gr.Dropdown(
|
|
358
|
+
label=i18n("Select Example Audio"),
|
|
359
|
+
choices=[""],
|
|
360
|
+
value="",
|
|
361
|
+
interactive=True,
|
|
362
|
+
allow_custom_value=True,
|
|
363
|
+
)
|
|
364
|
+
with gr.Row():
|
|
365
|
+
reference_audio = gr.Audio(
|
|
366
|
+
label=i18n("Reference Audio"),
|
|
367
|
+
type="filepath",
|
|
368
|
+
)
|
|
369
|
+
with gr.Row():
|
|
370
|
+
reference_text = gr.Textbox(
|
|
371
|
+
label=i18n("Reference Text"),
|
|
372
|
+
lines=1,
|
|
373
|
+
placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
|
|
374
|
+
value="",
|
|
375
|
+
)
|
|
376
|
+
with gr.Tab(label=i18n("Batch Inference")):
|
|
377
|
+
with gr.Row():
|
|
378
|
+
batch_infer_num = gr.Slider(
|
|
379
|
+
label="Batch infer nums",
|
|
380
|
+
minimum=1,
|
|
381
|
+
maximum=n_audios,
|
|
382
|
+
step=1,
|
|
383
|
+
value=1,
|
|
384
|
+
)
|
|
346
385
|
|
|
347
386
|
with gr.Column(scale=3):
|
|
348
387
|
for _ in range(n_audios):
|
|
@@ -383,6 +422,28 @@ def build_app():
|
|
|
383
422
|
fn=normalize_text, inputs=[text, if_refine_text], outputs=[refined_text]
|
|
384
423
|
)
|
|
385
424
|
|
|
425
|
+
def select_example_audio(audio_path):
|
|
426
|
+
audio_path = Path(audio_path)
|
|
427
|
+
if audio_path.is_file():
|
|
428
|
+
lab_file = Path(audio_path.with_suffix(".lab"))
|
|
429
|
+
|
|
430
|
+
if lab_file.exists():
|
|
431
|
+
lab_content = lab_file.read_text(encoding="utf-8").strip()
|
|
432
|
+
else:
|
|
433
|
+
lab_content = ""
|
|
434
|
+
|
|
435
|
+
return str(audio_path), lab_content, True
|
|
436
|
+
return None, "", False
|
|
437
|
+
|
|
438
|
+
# Connect the dropdown to update reference audio and text
|
|
439
|
+
example_audio_dropdown.change(
|
|
440
|
+
fn=update_examples, inputs=[], outputs=[example_audio_dropdown]
|
|
441
|
+
).then(
|
|
442
|
+
fn=select_example_audio,
|
|
443
|
+
inputs=[example_audio_dropdown],
|
|
444
|
+
outputs=[reference_audio, reference_text, enable_reference_audio],
|
|
445
|
+
)
|
|
446
|
+
|
|
386
447
|
# # Submit
|
|
387
448
|
generate.click(
|
|
388
449
|
inference_wrapper,
|
|
@@ -396,6 +457,7 @@ def build_app():
|
|
|
396
457
|
top_p,
|
|
397
458
|
repetition_penalty,
|
|
398
459
|
temperature,
|
|
460
|
+
seed,
|
|
399
461
|
batch_infer_num,
|
|
400
462
|
],
|
|
401
463
|
[stream_audio, *global_audio_list, *global_error_list],
|
|
@@ -414,9 +476,10 @@ def build_app():
|
|
|
414
476
|
top_p,
|
|
415
477
|
repetition_penalty,
|
|
416
478
|
temperature,
|
|
479
|
+
seed,
|
|
417
480
|
],
|
|
418
481
|
[stream_audio, global_audio_list[0], global_error_list[0]],
|
|
419
|
-
concurrency_limit=
|
|
482
|
+
concurrency_limit=1,
|
|
420
483
|
)
|
|
421
484
|
return app
|
|
422
485
|
|
|
@@ -471,7 +534,7 @@ if __name__ == "__main__":
|
|
|
471
534
|
enable_reference_audio=False,
|
|
472
535
|
reference_audio=None,
|
|
473
536
|
reference_text="",
|
|
474
|
-
max_new_tokens=
|
|
537
|
+
max_new_tokens=0,
|
|
475
538
|
chunk_length=200,
|
|
476
539
|
top_p=0.7,
|
|
477
540
|
repetition_penalty=1.2,
|