xinference 0.16.3__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (69) hide show
  1. xinference/_compat.py +22 -2
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +148 -12
  4. xinference/client/restful/restful_client.py +47 -2
  5. xinference/constants.py +1 -0
  6. xinference/core/model.py +45 -15
  7. xinference/core/supervisor.py +8 -2
  8. xinference/core/utils.py +67 -2
  9. xinference/model/audio/__init__.py +12 -0
  10. xinference/model/audio/core.py +21 -4
  11. xinference/model/audio/fish_speech.py +70 -35
  12. xinference/model/audio/model_spec.json +81 -1
  13. xinference/model/audio/whisper_mlx.py +208 -0
  14. xinference/model/embedding/core.py +259 -4
  15. xinference/model/embedding/model_spec.json +1 -1
  16. xinference/model/embedding/model_spec_modelscope.json +1 -1
  17. xinference/model/image/stable_diffusion/core.py +5 -2
  18. xinference/model/llm/__init__.py +2 -0
  19. xinference/model/llm/llm_family.json +485 -6
  20. xinference/model/llm/llm_family_modelscope.json +519 -0
  21. xinference/model/llm/mlx/core.py +45 -3
  22. xinference/model/llm/sglang/core.py +1 -0
  23. xinference/model/llm/transformers/core.py +1 -0
  24. xinference/model/llm/transformers/glm_edge_v.py +230 -0
  25. xinference/model/llm/utils.py +19 -0
  26. xinference/model/llm/vllm/core.py +84 -2
  27. xinference/model/rerank/core.py +11 -4
  28. xinference/thirdparty/fish_speech/fish_speech/conversation.py +254 -0
  29. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +2 -1
  30. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +2 -1
  31. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +2 -2
  32. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ko_KR.json +123 -0
  33. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +2 -1
  34. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +76 -11
  35. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +9 -9
  36. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +1 -1
  37. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +32 -1
  38. xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +2 -1
  39. xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +22 -0
  40. xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +1 -1
  41. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1 -1
  42. xinference/thirdparty/fish_speech/tools/api.py +578 -75
  43. xinference/thirdparty/fish_speech/tools/e2e_webui.py +232 -0
  44. xinference/thirdparty/fish_speech/tools/fish_e2e.py +298 -0
  45. xinference/thirdparty/fish_speech/tools/llama/generate.py +393 -9
  46. xinference/thirdparty/fish_speech/tools/msgpack_api.py +90 -29
  47. xinference/thirdparty/fish_speech/tools/post_api.py +37 -15
  48. xinference/thirdparty/fish_speech/tools/schema.py +187 -0
  49. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +7 -1
  50. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +2 -3
  51. xinference/thirdparty/fish_speech/tools/webui.py +138 -75
  52. xinference/types.py +2 -1
  53. {xinference-0.16.3.dist-info → xinference-1.0.1.dist-info}/METADATA +30 -6
  54. {xinference-0.16.3.dist-info → xinference-1.0.1.dist-info}/RECORD +58 -63
  55. {xinference-0.16.3.dist-info → xinference-1.0.1.dist-info}/WHEEL +1 -1
  56. xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
  57. xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
  58. xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
  59. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
  60. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
  61. xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
  62. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
  63. xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
  64. xinference/thirdparty/fish_speech/tools/commons.py +0 -35
  65. xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
  66. xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
  67. {xinference-0.16.3.dist-info → xinference-1.0.1.dist-info}/LICENSE +0 -0
  68. {xinference-0.16.3.dist-info → xinference-1.0.1.dist-info}/entry_points.txt +0 -0
  69. {xinference-0.16.3.dist-info → xinference-1.0.1.dist-info}/top_level.txt +0 -0
@@ -8,14 +8,15 @@ import requests
8
8
  from pydub import AudioSegment
9
9
  from pydub.playback import play
10
10
 
11
- from tools.commons import ServeReferenceAudio, ServeTTSRequest
12
11
  from tools.file import audio_to_bytes, read_ref_text
12
+ from tools.schema import ServeReferenceAudio, ServeTTSRequest
13
13
 
14
14
 
15
15
  def parse_args():
16
16
 
17
17
  parser = argparse.ArgumentParser(
18
- description="Send a WAV file and text to a server and receive synthesized audio."
18
+ description="Send a WAV file and text to a server and receive synthesized audio.",
19
+ formatter_class=argparse.RawTextHelpFormatter,
19
20
  )
20
21
 
21
22
  parser.add_argument(
@@ -33,7 +34,7 @@ def parse_args():
33
34
  "-id",
34
35
  type=str,
35
36
  default=None,
36
- help="ID of the reference model o be used for the speech",
37
+ help="ID of the reference model to be used for the speech\n(Local: name of folder containing audios and files)",
37
38
  )
38
39
  parser.add_argument(
39
40
  "--reference_audio",
@@ -41,7 +42,7 @@ def parse_args():
41
42
  type=str,
42
43
  nargs="+",
43
44
  default=None,
44
- help="Path to the WAV file",
45
+ help="Path to the audio file",
45
46
  )
46
47
  parser.add_argument(
47
48
  "--reference_text",
@@ -68,17 +69,25 @@ def parse_args():
68
69
  parser.add_argument(
69
70
  "--format", type=str, choices=["wav", "mp3", "flac"], default="wav"
70
71
  )
71
- parser.add_argument("--mp3_bitrate", type=int, default=64)
72
+ parser.add_argument(
73
+ "--mp3_bitrate", type=int, choices=[64, 128, 192], default=64, help="kHz"
74
+ )
72
75
  parser.add_argument("--opus_bitrate", type=int, default=-1000)
73
- parser.add_argument("--latency", type=str, default="normal", help="延迟选项")
76
+ parser.add_argument(
77
+ "--latency",
78
+ type=str,
79
+ default="normal",
80
+ choices=["normal", "balanced"],
81
+ help="Used in api.fish.audio/v1/tts",
82
+ )
74
83
  parser.add_argument(
75
84
  "--max_new_tokens",
76
85
  type=int,
77
- default=1024,
78
- help="Maximum new tokens to generate",
86
+ default=0,
87
+ help="Maximum new tokens to generate. \n0 means no limit.",
79
88
  )
80
89
  parser.add_argument(
81
- "--chunk_length", type=int, default=100, help="Chunk length for synthesis"
90
+ "--chunk_length", type=int, default=200, help="Chunk length for synthesis"
82
91
  )
83
92
  parser.add_argument(
84
93
  "--top_p", type=float, default=0.7, help="Top-p sampling for synthesis"
@@ -92,10 +101,7 @@ def parse_args():
92
101
  parser.add_argument(
93
102
  "--temperature", type=float, default=0.7, help="Temperature for sampling"
94
103
  )
95
- parser.add_argument(
96
- "--speaker", type=str, default=None, help="Speaker ID for voice synthesis"
97
- )
98
- parser.add_argument("--emotion", type=str, default=None, help="Speaker's Emotion")
104
+
99
105
  parser.add_argument(
100
106
  "--streaming", type=bool, default=False, help="Enable streaming response"
101
107
  )
@@ -103,6 +109,22 @@ def parse_args():
103
109
  "--channels", type=int, default=1, help="Number of audio channels"
104
110
  )
105
111
  parser.add_argument("--rate", type=int, default=44100, help="Sample rate for audio")
112
+ parser.add_argument(
113
+ "--use_memory_cache",
114
+ type=str,
115
+ default="never",
116
+ choices=["on-demand", "never"],
117
+ help="Cache encoded references codes in memory.\n"
118
+ "If `on-demand`, the server will use cached encodings\n "
119
+ "instead of encoding reference audio again.",
120
+ )
121
+ parser.add_argument(
122
+ "--seed",
123
+ type=int,
124
+ default=None,
125
+ help="`None` means randomized inference, otherwise deterministic.\n"
126
+ "It can't be used for fixing a timbre.",
127
+ )
106
128
 
107
129
  return parser.parse_args()
108
130
 
@@ -145,9 +167,9 @@ if __name__ == "__main__":
145
167
  "top_p": args.top_p,
146
168
  "repetition_penalty": args.repetition_penalty,
147
169
  "temperature": args.temperature,
148
- "speaker": args.speaker,
149
- "emotion": args.emotion,
150
170
  "streaming": args.streaming,
171
+ "use_memory_cache": args.use_memory_cache,
172
+ "seed": args.seed,
151
173
  }
152
174
 
153
175
  pydantic_data = ServeTTSRequest(**data)
@@ -0,0 +1,187 @@
1
+ import os
2
+ import queue
3
+ from dataclasses import dataclass
4
+ from typing import Annotated, Literal, Optional
5
+
6
+ import torch
7
+ from pydantic import AfterValidator, BaseModel, Field, confloat, conint, conlist
8
+ from pydantic.functional_validators import SkipValidation
9
+
10
+ from fish_speech.conversation import Message, TextPart, VQPart
11
+
12
+ GLOBAL_NUM_SAMPLES = int(os.getenv("GLOBAL_NUM_SAMPLES", 1))
13
+
14
+
15
+ class ServeVQPart(BaseModel):
16
+ type: Literal["vq"] = "vq"
17
+ codes: SkipValidation[list[list[int]]]
18
+
19
+
20
+ class ServeTextPart(BaseModel):
21
+ type: Literal["text"] = "text"
22
+ text: str
23
+
24
+
25
+ class ServeAudioPart(BaseModel):
26
+ type: Literal["audio"] = "audio"
27
+ audio: bytes
28
+
29
+
30
+ @dataclass
31
+ class ASRPackRequest:
32
+ audio: torch.Tensor
33
+ result_queue: queue.Queue
34
+ language: str
35
+
36
+
37
+ class ServeASRRequest(BaseModel):
38
+ # The audio should be an uncompressed PCM float16 audio
39
+ audios: list[bytes]
40
+ sample_rate: int = 44100
41
+ language: Literal["zh", "en", "ja", "auto"] = "auto"
42
+
43
+
44
+ class ServeASRTranscription(BaseModel):
45
+ text: str
46
+ duration: float
47
+ huge_gap: bool
48
+
49
+
50
+ class ServeASRSegment(BaseModel):
51
+ text: str
52
+ start: float
53
+ end: float
54
+
55
+
56
+ class ServeTimedASRResponse(BaseModel):
57
+ text: str
58
+ segments: list[ServeASRSegment]
59
+ duration: float
60
+
61
+
62
+ class ServeASRResponse(BaseModel):
63
+ transcriptions: list[ServeASRTranscription]
64
+
65
+
66
+ class ServeMessage(BaseModel):
67
+ role: Literal["system", "assistant", "user"]
68
+ parts: list[ServeVQPart | ServeTextPart]
69
+
70
+ def to_conversation_message(self):
71
+ new_message = Message(role=self.role, parts=[])
72
+ for part in self.parts:
73
+ if isinstance(part, ServeTextPart):
74
+ new_message.parts.append(TextPart(text=part.text))
75
+ elif isinstance(part, ServeVQPart):
76
+ new_message.parts.append(
77
+ VQPart(codes=torch.tensor(part.codes, dtype=torch.int))
78
+ )
79
+ else:
80
+ raise ValueError(f"Unsupported part type: {part}")
81
+
82
+ return new_message
83
+
84
+
85
+ class ServeRequest(BaseModel):
86
+ messages: Annotated[list[ServeMessage], conlist(ServeMessage, min_length=1)]
87
+ max_new_tokens: int = 1024
88
+ top_p: float = 0.7
89
+ repetition_penalty: float = 1.2
90
+ temperature: float = 0.7
91
+ streaming: bool = False
92
+ num_samples: int = 1
93
+ early_stop_threshold: float = 1.0
94
+
95
+
96
+ class ServeVQGANEncodeRequest(BaseModel):
97
+ # The audio here should be in wav, mp3, etc
98
+ audios: list[bytes]
99
+
100
+
101
+ class ServeVQGANEncodeResponse(BaseModel):
102
+ tokens: SkipValidation[list[list[list[int]]]]
103
+
104
+
105
+ class ServeVQGANDecodeRequest(BaseModel):
106
+ tokens: SkipValidation[list[list[list[int]]]]
107
+
108
+
109
+ class ServeVQGANDecodeResponse(BaseModel):
110
+ # The audio here should be in PCM float16 format
111
+ audios: list[bytes]
112
+
113
+
114
+ class ServeReferenceAudio(BaseModel):
115
+ audio: bytes
116
+ text: str
117
+
118
+
119
+ class ServeForwardMessage(BaseModel):
120
+ role: str
121
+ content: str
122
+
123
+
124
+ class ServeResponse(BaseModel):
125
+ messages: list[ServeMessage]
126
+ finish_reason: Literal["stop", "error"] | None = None
127
+ stats: dict[str, int | float | str] = {}
128
+
129
+
130
+ class ServeStreamDelta(BaseModel):
131
+ role: Literal["system", "assistant", "user"] | None = None
132
+ part: ServeVQPart | ServeTextPart | None = None
133
+
134
+
135
+ class ServeStreamResponse(BaseModel):
136
+ sample_id: int = 0
137
+ delta: ServeStreamDelta | None = None
138
+ finish_reason: Literal["stop", "error"] | None = None
139
+ stats: dict[str, int | float | str] | None = None
140
+
141
+
142
+ class ServeReferenceAudio(BaseModel):
143
+ audio: bytes
144
+ text: str
145
+
146
+ def __repr__(self) -> str:
147
+ return f"ServeReferenceAudio(text={self.text!r}, audio_size={len(self.audio)})"
148
+
149
+
150
+ class ServeChatRequestV1(BaseModel):
151
+ model: str = "llama3-8b"
152
+ messages: list[ServeForwardMessage] = []
153
+ audio: bytes | None = None
154
+ temperature: float = 1.0
155
+ top_p: float = 1.0
156
+ max_tokens: int = 256
157
+ voice: str = "jessica"
158
+ tts_audio_format: Literal["mp3", "pcm", "opus"] = "mp3"
159
+ tts_audio_bitrate: Literal[16, 24, 32, 48, 64, 96, 128, 192] = 128
160
+
161
+
162
+ class ServeTTSRequest(BaseModel):
163
+ text: str
164
+ chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
165
+ # Audio format
166
+ format: Literal["wav", "pcm", "mp3"] = "wav"
167
+ mp3_bitrate: Literal[64, 128, 192] = 128
168
+ # References audios for in-context learning
169
+ references: list[ServeReferenceAudio] = []
170
+ # Reference id
171
+ # For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
172
+ # Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
173
+ reference_id: str | None = None
174
+ seed: int | None = None
175
+ use_memory_cache: Literal["on-demand", "never"] = "never"
176
+ # Normalize text for en & zh, this increase stability for numbers
177
+ normalize: bool = True
178
+ mp3_bitrate: Optional[int] = 64
179
+ opus_bitrate: Optional[int] = -1000
180
+ # Balance mode will reduce latency to 300ms, but may decrease stability
181
+ latency: Literal["normal", "balanced"] = "normal"
182
+ # not usually used below
183
+ streaming: bool = False
184
+ max_new_tokens: int = 1024
185
+ top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
186
+ repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
187
+ temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
@@ -24,6 +24,12 @@ OmegaConf.register_new_resolver("eval", eval)
24
24
  # This file is used to convert the audio files to text files using the Whisper model.
25
25
  # It's mainly used to generate the training data for the VQ model.
26
26
 
27
+ backends = torchaudio.list_audio_backends()
28
+
29
+ if "ffmpeg" in backends:
30
+ backend = "ffmpeg"
31
+ else:
32
+ backend = "soundfile"
27
33
 
28
34
  RANK = int(os.environ.get("SLURM_PROCID", 0))
29
35
  WORLD_SIZE = int(os.environ.get("SLURM_NTASKS", 1))
@@ -81,7 +87,7 @@ def process_batch(files: list[Path], model) -> float:
81
87
  for file in files:
82
88
  try:
83
89
  wav, sr = torchaudio.load(
84
- str(file), backend="sox" if sys.platform == "linux" else "soundfile"
90
+ str(file), backend=backend
85
91
  ) # Need to install libsox-dev
86
92
  except Exception as e:
87
93
  logger.error(f"Error reading {file}: {e}")
@@ -24,8 +24,7 @@ def load_model(config_name, checkpoint_path, device="cuda"):
24
24
 
25
25
  model = instantiate(cfg)
26
26
  state_dict = torch.load(
27
- checkpoint_path,
28
- map_location=device,
27
+ checkpoint_path, map_location=device, mmap=True, weights_only=True
29
28
  )
30
29
  if "state_dict" in state_dict:
31
30
  state_dict = state_dict["state_dict"]
@@ -37,7 +36,7 @@ def load_model(config_name, checkpoint_path, device="cuda"):
37
36
  if "generator." in k
38
37
  }
39
38
 
40
- result = model.load_state_dict(state_dict, strict=False)
39
+ result = model.load_state_dict(state_dict, strict=False, assign=True)
41
40
  model.eval()
42
41
  model.to(device)
43
42
 
@@ -21,8 +21,9 @@ from transformers import AutoTokenizer
21
21
 
22
22
  from fish_speech.i18n import i18n
23
23
  from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
24
- from fish_speech.utils import autocast_exclude_mps
24
+ from fish_speech.utils import autocast_exclude_mps, set_seed
25
25
  from tools.api import decode_vq_tokens, encode_reference
26
+ from tools.file import AUDIO_EXTENSIONS, list_files
26
27
  from tools.llama.generate import (
27
28
  GenerateRequest,
28
29
  GenerateResponse,
@@ -70,6 +71,7 @@ def inference(
70
71
  top_p,
71
72
  repetition_penalty,
72
73
  temperature,
74
+ seed="0",
73
75
  streaming=False,
74
76
  ):
75
77
  if args.max_gradio_length > 0 and len(text) > args.max_gradio_length:
@@ -81,6 +83,11 @@ def inference(
81
83
  ),
82
84
  )
83
85
 
86
+ seed = int(seed)
87
+ if seed != 0:
88
+ set_seed(seed)
89
+ logger.warning(f"set seed: {seed}")
90
+
84
91
  # Parse reference audio aka prompt
85
92
  prompt_tokens = encode_reference(
86
93
  decoder_model=decoder_model,
@@ -139,7 +146,9 @@ def inference(
139
146
  segments.append(fake_audios)
140
147
 
141
148
  if streaming:
142
- yield (fake_audios * 32768).astype(np.int16).tobytes(), None, None
149
+ wav_header = wav_chunk_header()
150
+ audio_data = (fake_audios * 32768).astype(np.int16).tobytes()
151
+ yield wav_header + audio_data, None, None
143
152
 
144
153
  if len(segments) == 0:
145
154
  return (
@@ -177,6 +186,7 @@ def inference_wrapper(
177
186
  top_p,
178
187
  repetition_penalty,
179
188
  temperature,
189
+ seed,
180
190
  batch_infer_num,
181
191
  ):
182
192
  audios = []
@@ -193,6 +203,7 @@ def inference_wrapper(
193
203
  top_p,
194
204
  repetition_penalty,
195
205
  temperature,
206
+ seed,
196
207
  )
197
208
 
198
209
  _, audio_data, error_message = next(result)
@@ -235,7 +246,11 @@ def normalize_text(user_input, use_normalization):
235
246
  return user_input
236
247
 
237
248
 
238
- asr_model = None
249
+ def update_examples():
250
+ examples_dir = Path("references")
251
+ examples_dir.mkdir(parents=True, exist_ok=True)
252
+ example_audios = list_files(examples_dir, AUDIO_EXTENSIONS, recursive=True)
253
+ return gr.Dropdown(choices=example_audios + [""])
239
254
 
240
255
 
241
256
  def build_app():
@@ -273,76 +288,100 @@ def build_app():
273
288
  )
274
289
 
275
290
  with gr.Row():
276
- with gr.Tab(label=i18n("Advanced Config")):
277
- chunk_length = gr.Slider(
278
- label=i18n("Iterative Prompt Length, 0 means off"),
279
- minimum=50,
280
- maximum=300,
281
- value=200,
282
- step=8,
283
- )
284
-
285
- max_new_tokens = gr.Slider(
286
- label=i18n("Maximum tokens per batch, 0 means no limit"),
287
- minimum=0,
288
- maximum=2048,
289
- value=1024, # 0 means no limit
290
- step=8,
291
- )
292
-
293
- top_p = gr.Slider(
294
- label="Top-P",
295
- minimum=0.6,
296
- maximum=0.9,
297
- value=0.7,
298
- step=0.01,
299
- )
300
-
301
- repetition_penalty = gr.Slider(
302
- label=i18n("Repetition Penalty"),
303
- minimum=1,
304
- maximum=1.5,
305
- value=1.2,
306
- step=0.01,
307
- )
308
-
309
- temperature = gr.Slider(
310
- label="Temperature",
311
- minimum=0.6,
312
- maximum=0.9,
313
- value=0.7,
314
- step=0.01,
315
- )
316
-
317
- with gr.Tab(label=i18n("Reference Audio")):
318
- gr.Markdown(
319
- i18n(
320
- "5 to 10 seconds of reference audio, useful for specifying speaker."
321
- )
322
- )
323
-
324
- enable_reference_audio = gr.Checkbox(
325
- label=i18n("Enable Reference Audio"),
326
- )
327
- reference_audio = gr.Audio(
328
- label=i18n("Reference Audio"),
329
- type="filepath",
330
- )
331
- with gr.Row():
332
- reference_text = gr.Textbox(
333
- label=i18n("Reference Text"),
334
- lines=1,
335
- placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
336
- value="",
337
- )
338
- with gr.Tab(label=i18n("Batch Inference")):
339
- batch_infer_num = gr.Slider(
340
- label="Batch infer nums",
341
- minimum=1,
342
- maximum=n_audios,
343
- step=1,
344
- value=1,
345
- )
291
+ with gr.Column():
292
+ with gr.Tab(label=i18n("Advanced Config")):
293
+ with gr.Row():
294
+ chunk_length = gr.Slider(
295
+ label=i18n("Iterative Prompt Length, 0 means off"),
296
+ minimum=50,
297
+ maximum=300,
298
+ value=200,
299
+ step=8,
300
+ )
301
+
302
+ max_new_tokens = gr.Slider(
303
+ label=i18n(
304
+ "Maximum tokens per batch, 0 means no limit"
305
+ ),
306
+ minimum=0,
307
+ maximum=2048,
308
+ value=0, # 0 means no limit
309
+ step=8,
310
+ )
311
+
312
+ with gr.Row():
313
+ top_p = gr.Slider(
314
+ label="Top-P",
315
+ minimum=0.6,
316
+ maximum=0.9,
317
+ value=0.7,
318
+ step=0.01,
319
+ )
320
+
321
+ repetition_penalty = gr.Slider(
322
+ label=i18n("Repetition Penalty"),
323
+ minimum=1,
324
+ maximum=1.5,
325
+ value=1.2,
326
+ step=0.01,
327
+ )
328
+
329
+ with gr.Row():
330
+ temperature = gr.Slider(
331
+ label="Temperature",
332
+ minimum=0.6,
333
+ maximum=0.9,
334
+ value=0.7,
335
+ step=0.01,
336
+ )
337
+ seed = gr.Textbox(
338
+ label="Seed",
339
+ info="0 means randomized inference, otherwise deterministic",
340
+ placeholder="any 32-bit-integer",
341
+ value="0",
342
+ )
343
+
344
+ with gr.Tab(label=i18n("Reference Audio")):
345
+ with gr.Row():
346
+ gr.Markdown(
347
+ i18n(
348
+ "5 to 10 seconds of reference audio, useful for specifying speaker."
349
+ )
350
+ )
351
+ with gr.Row():
352
+ enable_reference_audio = gr.Checkbox(
353
+ label=i18n("Enable Reference Audio"),
354
+ )
355
+
356
+ with gr.Row():
357
+ example_audio_dropdown = gr.Dropdown(
358
+ label=i18n("Select Example Audio"),
359
+ choices=[""],
360
+ value="",
361
+ interactive=True,
362
+ allow_custom_value=True,
363
+ )
364
+ with gr.Row():
365
+ reference_audio = gr.Audio(
366
+ label=i18n("Reference Audio"),
367
+ type="filepath",
368
+ )
369
+ with gr.Row():
370
+ reference_text = gr.Textbox(
371
+ label=i18n("Reference Text"),
372
+ lines=1,
373
+ placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
374
+ value="",
375
+ )
376
+ with gr.Tab(label=i18n("Batch Inference")):
377
+ with gr.Row():
378
+ batch_infer_num = gr.Slider(
379
+ label="Batch infer nums",
380
+ minimum=1,
381
+ maximum=n_audios,
382
+ step=1,
383
+ value=1,
384
+ )
346
385
 
347
386
  with gr.Column(scale=3):
348
387
  for _ in range(n_audios):
@@ -383,6 +422,28 @@ def build_app():
383
422
  fn=normalize_text, inputs=[text, if_refine_text], outputs=[refined_text]
384
423
  )
385
424
 
425
+ def select_example_audio(audio_path):
426
+ audio_path = Path(audio_path)
427
+ if audio_path.is_file():
428
+ lab_file = Path(audio_path.with_suffix(".lab"))
429
+
430
+ if lab_file.exists():
431
+ lab_content = lab_file.read_text(encoding="utf-8").strip()
432
+ else:
433
+ lab_content = ""
434
+
435
+ return str(audio_path), lab_content, True
436
+ return None, "", False
437
+
438
+ # Connect the dropdown to update reference audio and text
439
+ example_audio_dropdown.change(
440
+ fn=update_examples, inputs=[], outputs=[example_audio_dropdown]
441
+ ).then(
442
+ fn=select_example_audio,
443
+ inputs=[example_audio_dropdown],
444
+ outputs=[reference_audio, reference_text, enable_reference_audio],
445
+ )
446
+
386
447
  # # Submit
387
448
  generate.click(
388
449
  inference_wrapper,
@@ -396,6 +457,7 @@ def build_app():
396
457
  top_p,
397
458
  repetition_penalty,
398
459
  temperature,
460
+ seed,
399
461
  batch_infer_num,
400
462
  ],
401
463
  [stream_audio, *global_audio_list, *global_error_list],
@@ -414,9 +476,10 @@ def build_app():
414
476
  top_p,
415
477
  repetition_penalty,
416
478
  temperature,
479
+ seed,
417
480
  ],
418
481
  [stream_audio, global_audio_list[0], global_error_list[0]],
419
- concurrency_limit=10,
482
+ concurrency_limit=1,
420
483
  )
421
484
  return app
422
485
 
@@ -471,7 +534,7 @@ if __name__ == "__main__":
471
534
  enable_reference_audio=False,
472
535
  reference_audio=None,
473
536
  reference_text="",
474
- max_new_tokens=1024,
537
+ max_new_tokens=0,
475
538
  chunk_length=200,
476
539
  top_p=0.7,
477
540
  repetition_penalty=1.2,
xinference/types.py CHANGED
@@ -71,7 +71,8 @@ class EmbeddingUsage(TypedDict):
71
71
  class EmbeddingData(TypedDict):
72
72
  index: int
73
73
  object: str
74
- embedding: List[float]
74
+ # support sparse embedding
75
+ embedding: Union[List[float], Dict[str, float]]
75
76
 
76
77
 
77
78
  class Embedding(TypedDict):