xinference 0.16.3__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_compat.py +22 -2
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +148 -12
- xinference/client/restful/restful_client.py +47 -2
- xinference/constants.py +1 -0
- xinference/core/model.py +45 -15
- xinference/core/supervisor.py +8 -2
- xinference/core/utils.py +67 -2
- xinference/model/audio/__init__.py +12 -0
- xinference/model/audio/core.py +21 -4
- xinference/model/audio/fish_speech.py +70 -35
- xinference/model/audio/model_spec.json +81 -1
- xinference/model/audio/whisper_mlx.py +208 -0
- xinference/model/embedding/core.py +259 -4
- xinference/model/embedding/model_spec.json +1 -1
- xinference/model/embedding/model_spec_modelscope.json +1 -1
- xinference/model/image/stable_diffusion/core.py +5 -2
- xinference/model/llm/__init__.py +2 -0
- xinference/model/llm/llm_family.json +485 -6
- xinference/model/llm/llm_family_modelscope.json +519 -0
- xinference/model/llm/mlx/core.py +45 -3
- xinference/model/llm/sglang/core.py +1 -0
- xinference/model/llm/transformers/core.py +1 -0
- xinference/model/llm/transformers/glm_edge_v.py +230 -0
- xinference/model/llm/utils.py +19 -0
- xinference/model/llm/vllm/core.py +84 -2
- xinference/model/rerank/core.py +11 -4
- xinference/thirdparty/fish_speech/fish_speech/conversation.py +254 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +2 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +2 -1
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +2 -2
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ko_KR.json +123 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +2 -1
- xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +76 -11
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +9 -9
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +1 -1
- xinference/thirdparty/fish_speech/fish_speech/text/clean.py +32 -1
- xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +2 -1
- xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +22 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +1 -1
- xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1 -1
- xinference/thirdparty/fish_speech/tools/api.py +578 -75
- xinference/thirdparty/fish_speech/tools/e2e_webui.py +232 -0
- xinference/thirdparty/fish_speech/tools/fish_e2e.py +298 -0
- xinference/thirdparty/fish_speech/tools/llama/generate.py +393 -9
- xinference/thirdparty/fish_speech/tools/msgpack_api.py +90 -29
- xinference/thirdparty/fish_speech/tools/post_api.py +37 -15
- xinference/thirdparty/fish_speech/tools/schema.py +187 -0
- xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +7 -1
- xinference/thirdparty/fish_speech/tools/vqgan/inference.py +2 -3
- xinference/thirdparty/fish_speech/tools/webui.py +138 -75
- xinference/types.py +2 -1
- {xinference-0.16.3.dist-info → xinference-1.0.1.dist-info}/METADATA +30 -6
- {xinference-0.16.3.dist-info → xinference-1.0.1.dist-info}/RECORD +58 -63
- {xinference-0.16.3.dist-info → xinference-1.0.1.dist-info}/WHEEL +1 -1
- xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
- xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/commons.py +0 -35
- xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
- xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
- {xinference-0.16.3.dist-info → xinference-1.0.1.dist-info}/LICENSE +0 -0
- {xinference-0.16.3.dist-info → xinference-1.0.1.dist-info}/entry_points.txt +0 -0
- {xinference-0.16.3.dist-info → xinference-1.0.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,232 @@
|
|
|
1
|
+
import io
|
|
2
|
+
import re
|
|
3
|
+
import wave
|
|
4
|
+
|
|
5
|
+
import gradio as gr
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
from .fish_e2e import FishE2EAgent, FishE2EEventType
|
|
9
|
+
from .schema import ServeMessage, ServeTextPart, ServeVQPart
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
|
|
13
|
+
buffer = io.BytesIO()
|
|
14
|
+
|
|
15
|
+
with wave.open(buffer, "wb") as wav_file:
|
|
16
|
+
wav_file.setnchannels(channels)
|
|
17
|
+
wav_file.setsampwidth(bit_depth // 8)
|
|
18
|
+
wav_file.setframerate(sample_rate)
|
|
19
|
+
|
|
20
|
+
wav_header_bytes = buffer.getvalue()
|
|
21
|
+
buffer.close()
|
|
22
|
+
return wav_header_bytes
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ChatState:
|
|
26
|
+
def __init__(self):
|
|
27
|
+
self.conversation = []
|
|
28
|
+
self.added_systext = False
|
|
29
|
+
self.added_sysaudio = False
|
|
30
|
+
|
|
31
|
+
def get_history(self):
|
|
32
|
+
results = []
|
|
33
|
+
for msg in self.conversation:
|
|
34
|
+
results.append({"role": msg.role, "content": self.repr_message(msg)})
|
|
35
|
+
|
|
36
|
+
# Process assistant messages to extract questions and update user messages
|
|
37
|
+
for i, msg in enumerate(results):
|
|
38
|
+
if msg["role"] == "assistant":
|
|
39
|
+
match = re.search(r"Question: (.*?)\n\nResponse:", msg["content"])
|
|
40
|
+
if match and i > 0 and results[i - 1]["role"] == "user":
|
|
41
|
+
# Update previous user message with extracted question
|
|
42
|
+
results[i - 1]["content"] += "\n" + match.group(1)
|
|
43
|
+
# Remove the Question/Answer format from assistant message
|
|
44
|
+
msg["content"] = msg["content"].split("\n\nResponse: ", 1)[1]
|
|
45
|
+
return results
|
|
46
|
+
|
|
47
|
+
def repr_message(self, msg: ServeMessage):
|
|
48
|
+
response = ""
|
|
49
|
+
for part in msg.parts:
|
|
50
|
+
if isinstance(part, ServeTextPart):
|
|
51
|
+
response += part.text
|
|
52
|
+
elif isinstance(part, ServeVQPart):
|
|
53
|
+
response += f"<audio {len(part.codes[0]) / 21:.2f}s>"
|
|
54
|
+
return response
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def clear_fn():
|
|
58
|
+
return [], ChatState(), None, None, None
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
async def process_audio_input(
|
|
62
|
+
sys_audio_input, sys_text_input, audio_input, state: ChatState, text_input: str
|
|
63
|
+
):
|
|
64
|
+
if audio_input is None and not text_input:
|
|
65
|
+
raise gr.Error("No input provided")
|
|
66
|
+
|
|
67
|
+
agent = FishE2EAgent() # Create new agent instance for each request
|
|
68
|
+
|
|
69
|
+
# Convert audio input to numpy array
|
|
70
|
+
if isinstance(audio_input, tuple):
|
|
71
|
+
sr, audio_data = audio_input
|
|
72
|
+
elif text_input:
|
|
73
|
+
sr = 44100
|
|
74
|
+
audio_data = None
|
|
75
|
+
else:
|
|
76
|
+
raise gr.Error("Invalid audio format")
|
|
77
|
+
|
|
78
|
+
if isinstance(sys_audio_input, tuple):
|
|
79
|
+
sr, sys_audio_data = sys_audio_input
|
|
80
|
+
else:
|
|
81
|
+
sr = 44100
|
|
82
|
+
sys_audio_data = None
|
|
83
|
+
|
|
84
|
+
def append_to_chat_ctx(
|
|
85
|
+
part: ServeTextPart | ServeVQPart, role: str = "assistant"
|
|
86
|
+
) -> None:
|
|
87
|
+
if not state.conversation or state.conversation[-1].role != role:
|
|
88
|
+
state.conversation.append(ServeMessage(role=role, parts=[part]))
|
|
89
|
+
else:
|
|
90
|
+
state.conversation[-1].parts.append(part)
|
|
91
|
+
|
|
92
|
+
if state.added_systext is False and sys_text_input:
|
|
93
|
+
state.added_systext = True
|
|
94
|
+
append_to_chat_ctx(ServeTextPart(text=sys_text_input), role="system")
|
|
95
|
+
if text_input:
|
|
96
|
+
append_to_chat_ctx(ServeTextPart(text=text_input), role="user")
|
|
97
|
+
audio_data = None
|
|
98
|
+
|
|
99
|
+
result_audio = b""
|
|
100
|
+
async for event in agent.stream(
|
|
101
|
+
sys_audio_data,
|
|
102
|
+
audio_data,
|
|
103
|
+
sr,
|
|
104
|
+
1,
|
|
105
|
+
chat_ctx={
|
|
106
|
+
"messages": state.conversation,
|
|
107
|
+
"added_sysaudio": state.added_sysaudio,
|
|
108
|
+
},
|
|
109
|
+
):
|
|
110
|
+
if event.type == FishE2EEventType.USER_CODES:
|
|
111
|
+
append_to_chat_ctx(ServeVQPart(codes=event.vq_codes), role="user")
|
|
112
|
+
elif event.type == FishE2EEventType.SPEECH_SEGMENT:
|
|
113
|
+
append_to_chat_ctx(ServeVQPart(codes=event.vq_codes))
|
|
114
|
+
yield state.get_history(), wav_chunk_header() + event.frame.data, None, None
|
|
115
|
+
elif event.type == FishE2EEventType.TEXT_SEGMENT:
|
|
116
|
+
append_to_chat_ctx(ServeTextPart(text=event.text))
|
|
117
|
+
yield state.get_history(), None, None, None
|
|
118
|
+
|
|
119
|
+
yield state.get_history(), None, None, None
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
async def process_text_input(
|
|
123
|
+
sys_audio_input, sys_text_input, state: ChatState, text_input: str
|
|
124
|
+
):
|
|
125
|
+
async for event in process_audio_input(
|
|
126
|
+
sys_audio_input, sys_text_input, None, state, text_input
|
|
127
|
+
):
|
|
128
|
+
yield event
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def create_demo():
|
|
132
|
+
with gr.Blocks() as demo:
|
|
133
|
+
state = gr.State(ChatState())
|
|
134
|
+
|
|
135
|
+
with gr.Row():
|
|
136
|
+
# Left column (70%) for chatbot and notes
|
|
137
|
+
with gr.Column(scale=7):
|
|
138
|
+
chatbot = gr.Chatbot(
|
|
139
|
+
[],
|
|
140
|
+
elem_id="chatbot",
|
|
141
|
+
bubble_full_width=False,
|
|
142
|
+
height=600,
|
|
143
|
+
type="messages",
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
# notes = gr.Markdown(
|
|
147
|
+
# """
|
|
148
|
+
# # Fish Agent
|
|
149
|
+
# 1. 此Demo为Fish Audio自研端到端语言模型Fish Agent 3B版本.
|
|
150
|
+
# 2. 你可以在我们的官方仓库找到代码以及权重,但是相关内容全部基于 CC BY-NC-SA 4.0 许可证发布.
|
|
151
|
+
# 3. Demo为早期灰度测试版本,推理速度尚待优化.
|
|
152
|
+
# # 特色
|
|
153
|
+
# 1. 该模型自动集成ASR与TTS部分,不需要外挂其它模型,即真正的端到端,而非三段式(ASR+LLM+TTS).
|
|
154
|
+
# 2. 模型可以使用reference audio控制说话音色.
|
|
155
|
+
# 3. 可以生成具有较强情感与韵律的音频.
|
|
156
|
+
# """
|
|
157
|
+
# )
|
|
158
|
+
notes = gr.Markdown(
|
|
159
|
+
"""
|
|
160
|
+
# Fish Agent
|
|
161
|
+
1. This demo is Fish Audio's self-researh end-to-end language model, Fish Agent version 3B.
|
|
162
|
+
2. You can find the code and weights in our official repo in [gitub](https://github.com/fishaudio/fish-speech) and [hugging face](https://huggingface.co/fishaudio/fish-agent-v0.1-3b), but the content is released under a CC BY-NC-SA 4.0 licence.
|
|
163
|
+
3. The demo is an early alpha test version, the inference speed needs to be optimised.
|
|
164
|
+
# Features
|
|
165
|
+
1. The model automatically integrates ASR and TTS parts, no need to plug-in other models, i.e., true end-to-end, not three-stage (ASR+LLM+TTS).
|
|
166
|
+
2. The model can use reference audio to control the speech timbre.
|
|
167
|
+
3. The model can generate speech with strong emotion.
|
|
168
|
+
"""
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
# Right column (30%) for controls
|
|
172
|
+
with gr.Column(scale=3):
|
|
173
|
+
sys_audio_input = gr.Audio(
|
|
174
|
+
sources=["upload"],
|
|
175
|
+
type="numpy",
|
|
176
|
+
label="Give a timbre for your assistant",
|
|
177
|
+
)
|
|
178
|
+
sys_text_input = gr.Textbox(
|
|
179
|
+
label="What is your assistant's role?",
|
|
180
|
+
value="You are a voice assistant created by Fish Audio, offering end-to-end voice interaction for a seamless user experience. You are required to first transcribe the user's speech, then answer it in the following format: 'Question: [USER_SPEECH]\n\nAnswer: [YOUR_RESPONSE]\n'. You are required to use the following voice in this conversation.",
|
|
181
|
+
type="text",
|
|
182
|
+
)
|
|
183
|
+
audio_input = gr.Audio(
|
|
184
|
+
sources=["microphone"], type="numpy", label="Speak your message"
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
text_input = gr.Textbox(label="Or type your message", type="text")
|
|
188
|
+
|
|
189
|
+
output_audio = gr.Audio(
|
|
190
|
+
label="Assistant's Voice",
|
|
191
|
+
streaming=True,
|
|
192
|
+
autoplay=True,
|
|
193
|
+
interactive=False,
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
send_button = gr.Button("Send", variant="primary")
|
|
197
|
+
clear_button = gr.Button("Clear")
|
|
198
|
+
|
|
199
|
+
# Event handlers
|
|
200
|
+
audio_input.stop_recording(
|
|
201
|
+
process_audio_input,
|
|
202
|
+
inputs=[sys_audio_input, sys_text_input, audio_input, state, text_input],
|
|
203
|
+
outputs=[chatbot, output_audio, audio_input, text_input],
|
|
204
|
+
show_progress=True,
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
send_button.click(
|
|
208
|
+
process_text_input,
|
|
209
|
+
inputs=[sys_audio_input, sys_text_input, state, text_input],
|
|
210
|
+
outputs=[chatbot, output_audio, audio_input, text_input],
|
|
211
|
+
show_progress=True,
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
text_input.submit(
|
|
215
|
+
process_text_input,
|
|
216
|
+
inputs=[sys_audio_input, sys_text_input, state, text_input],
|
|
217
|
+
outputs=[chatbot, output_audio, audio_input, text_input],
|
|
218
|
+
show_progress=True,
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
clear_button.click(
|
|
222
|
+
clear_fn,
|
|
223
|
+
inputs=[],
|
|
224
|
+
outputs=[chatbot, state, audio_input, output_audio, text_input],
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
return demo
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
if __name__ == "__main__":
|
|
231
|
+
demo = create_demo()
|
|
232
|
+
demo.launch(server_name="127.0.0.1", server_port=7860, share=True)
|
|
@@ -0,0 +1,298 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import ctypes
|
|
3
|
+
import io
|
|
4
|
+
import json
|
|
5
|
+
import os
|
|
6
|
+
import struct
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from enum import Enum
|
|
9
|
+
from typing import AsyncGenerator, Union
|
|
10
|
+
|
|
11
|
+
import httpx
|
|
12
|
+
import numpy as np
|
|
13
|
+
import ormsgpack
|
|
14
|
+
import soundfile as sf
|
|
15
|
+
|
|
16
|
+
from .schema import (
|
|
17
|
+
ServeMessage,
|
|
18
|
+
ServeRequest,
|
|
19
|
+
ServeTextPart,
|
|
20
|
+
ServeVQGANDecodeRequest,
|
|
21
|
+
ServeVQGANEncodeRequest,
|
|
22
|
+
ServeVQPart,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class CustomAudioFrame:
|
|
27
|
+
def __init__(self, data, sample_rate, num_channels, samples_per_channel):
|
|
28
|
+
if len(data) < num_channels * samples_per_channel * ctypes.sizeof(
|
|
29
|
+
ctypes.c_int16
|
|
30
|
+
):
|
|
31
|
+
raise ValueError(
|
|
32
|
+
"data length must be >= num_channels * samples_per_channel * sizeof(int16)"
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
self._data = bytearray(data)
|
|
36
|
+
self._sample_rate = sample_rate
|
|
37
|
+
self._num_channels = num_channels
|
|
38
|
+
self._samples_per_channel = samples_per_channel
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def data(self):
|
|
42
|
+
return memoryview(self._data).cast("h")
|
|
43
|
+
|
|
44
|
+
@property
|
|
45
|
+
def sample_rate(self):
|
|
46
|
+
return self._sample_rate
|
|
47
|
+
|
|
48
|
+
@property
|
|
49
|
+
def num_channels(self):
|
|
50
|
+
return self._num_channels
|
|
51
|
+
|
|
52
|
+
@property
|
|
53
|
+
def samples_per_channel(self):
|
|
54
|
+
return self._samples_per_channel
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
def duration(self):
|
|
58
|
+
return self.samples_per_channel / self.sample_rate
|
|
59
|
+
|
|
60
|
+
def __repr__(self):
|
|
61
|
+
return (
|
|
62
|
+
f"CustomAudioFrame(sample_rate={self.sample_rate}, "
|
|
63
|
+
f"num_channels={self.num_channels}, "
|
|
64
|
+
f"samples_per_channel={self.samples_per_channel}, "
|
|
65
|
+
f"duration={self.duration:.3f})"
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class FishE2EEventType(Enum):
|
|
70
|
+
SPEECH_SEGMENT = 1
|
|
71
|
+
TEXT_SEGMENT = 2
|
|
72
|
+
END_OF_TEXT = 3
|
|
73
|
+
END_OF_SPEECH = 4
|
|
74
|
+
ASR_RESULT = 5
|
|
75
|
+
USER_CODES = 6
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
@dataclass
|
|
79
|
+
class FishE2EEvent:
|
|
80
|
+
type: FishE2EEventType
|
|
81
|
+
frame: np.ndarray = None
|
|
82
|
+
text: str = None
|
|
83
|
+
vq_codes: list[list[int]] = None
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
client = httpx.AsyncClient(
|
|
87
|
+
timeout=None,
|
|
88
|
+
limits=httpx.Limits(
|
|
89
|
+
max_connections=None,
|
|
90
|
+
max_keepalive_connections=None,
|
|
91
|
+
keepalive_expiry=None,
|
|
92
|
+
),
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class FishE2EAgent:
|
|
97
|
+
def __init__(self):
|
|
98
|
+
self.llm_url = "http://localhost:8080/v1/chat"
|
|
99
|
+
self.vqgan_url = "http://localhost:8080"
|
|
100
|
+
self.client = httpx.AsyncClient(timeout=None)
|
|
101
|
+
|
|
102
|
+
async def get_codes(self, audio_data, sample_rate):
|
|
103
|
+
audio_buffer = io.BytesIO()
|
|
104
|
+
sf.write(audio_buffer, audio_data, sample_rate, format="WAV")
|
|
105
|
+
audio_buffer.seek(0)
|
|
106
|
+
# Step 1: Encode audio using VQGAN
|
|
107
|
+
encode_request = ServeVQGANEncodeRequest(audios=[audio_buffer.read()])
|
|
108
|
+
encode_request_bytes = ormsgpack.packb(
|
|
109
|
+
encode_request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC
|
|
110
|
+
)
|
|
111
|
+
encode_response = await self.client.post(
|
|
112
|
+
f"{self.vqgan_url}/v1/vqgan/encode",
|
|
113
|
+
data=encode_request_bytes,
|
|
114
|
+
headers={"Content-Type": "application/msgpack"},
|
|
115
|
+
)
|
|
116
|
+
encode_response_data = ormsgpack.unpackb(encode_response.content)
|
|
117
|
+
codes = encode_response_data["tokens"][0]
|
|
118
|
+
return codes
|
|
119
|
+
|
|
120
|
+
async def stream(
|
|
121
|
+
self,
|
|
122
|
+
system_audio_data: np.ndarray | None,
|
|
123
|
+
user_audio_data: np.ndarray | None,
|
|
124
|
+
sample_rate: int,
|
|
125
|
+
num_channels: int,
|
|
126
|
+
chat_ctx: dict | None = None,
|
|
127
|
+
) -> AsyncGenerator[bytes, None]:
|
|
128
|
+
|
|
129
|
+
if system_audio_data is not None:
|
|
130
|
+
sys_codes = await self.get_codes(system_audio_data, sample_rate)
|
|
131
|
+
else:
|
|
132
|
+
sys_codes = None
|
|
133
|
+
if user_audio_data is not None:
|
|
134
|
+
user_codes = await self.get_codes(user_audio_data, sample_rate)
|
|
135
|
+
# Step 2: Prepare LLM request
|
|
136
|
+
if chat_ctx is None:
|
|
137
|
+
sys_parts = [
|
|
138
|
+
ServeTextPart(
|
|
139
|
+
text='您是由 Fish Audio 设计的语音助手,提供端到端的语音交互,实现无缝用户体验。首先转录用户的语音,然后使用以下格式回答:"Question: [用户语音]\n\nAnswer: [你的回答]\n"。'
|
|
140
|
+
),
|
|
141
|
+
]
|
|
142
|
+
if system_audio_data is not None:
|
|
143
|
+
sys_parts.append(ServeVQPart(codes=sys_codes))
|
|
144
|
+
chat_ctx = {
|
|
145
|
+
"messages": [
|
|
146
|
+
ServeMessage(
|
|
147
|
+
role="system",
|
|
148
|
+
parts=sys_parts,
|
|
149
|
+
),
|
|
150
|
+
],
|
|
151
|
+
}
|
|
152
|
+
else:
|
|
153
|
+
if chat_ctx["added_sysaudio"] is False and sys_codes:
|
|
154
|
+
chat_ctx["added_sysaudio"] = True
|
|
155
|
+
chat_ctx["messages"][0].parts.append(ServeVQPart(codes=sys_codes))
|
|
156
|
+
|
|
157
|
+
prev_messages = chat_ctx["messages"].copy()
|
|
158
|
+
if user_audio_data is not None:
|
|
159
|
+
yield FishE2EEvent(
|
|
160
|
+
type=FishE2EEventType.USER_CODES,
|
|
161
|
+
vq_codes=user_codes,
|
|
162
|
+
)
|
|
163
|
+
else:
|
|
164
|
+
user_codes = None
|
|
165
|
+
|
|
166
|
+
request = ServeRequest(
|
|
167
|
+
messages=prev_messages
|
|
168
|
+
+ (
|
|
169
|
+
[
|
|
170
|
+
ServeMessage(
|
|
171
|
+
role="user",
|
|
172
|
+
parts=[ServeVQPart(codes=user_codes)],
|
|
173
|
+
)
|
|
174
|
+
]
|
|
175
|
+
if user_codes
|
|
176
|
+
else []
|
|
177
|
+
),
|
|
178
|
+
streaming=True,
|
|
179
|
+
num_samples=1,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
# Step 3: Stream LLM response and decode audio
|
|
183
|
+
buffer = b""
|
|
184
|
+
vq_codes = []
|
|
185
|
+
current_vq = False
|
|
186
|
+
|
|
187
|
+
async def decode_send():
|
|
188
|
+
nonlocal current_vq
|
|
189
|
+
nonlocal vq_codes
|
|
190
|
+
|
|
191
|
+
data = np.concatenate(vq_codes, axis=1).tolist()
|
|
192
|
+
# Decode VQ codes to audio
|
|
193
|
+
decode_request = ServeVQGANDecodeRequest(tokens=[data])
|
|
194
|
+
decode_response = await self.client.post(
|
|
195
|
+
f"{self.vqgan_url}/v1/vqgan/decode",
|
|
196
|
+
data=ormsgpack.packb(
|
|
197
|
+
decode_request,
|
|
198
|
+
option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
|
|
199
|
+
),
|
|
200
|
+
headers={"Content-Type": "application/msgpack"},
|
|
201
|
+
)
|
|
202
|
+
decode_data = ormsgpack.unpackb(decode_response.content)
|
|
203
|
+
|
|
204
|
+
# Convert float16 audio data to int16
|
|
205
|
+
audio_data = np.frombuffer(decode_data["audios"][0], dtype=np.float16)
|
|
206
|
+
audio_data = (audio_data * 32768).astype(np.int16).tobytes()
|
|
207
|
+
|
|
208
|
+
audio_frame = CustomAudioFrame(
|
|
209
|
+
data=audio_data,
|
|
210
|
+
samples_per_channel=len(audio_data) // 2,
|
|
211
|
+
sample_rate=44100,
|
|
212
|
+
num_channels=1,
|
|
213
|
+
)
|
|
214
|
+
yield FishE2EEvent(
|
|
215
|
+
type=FishE2EEventType.SPEECH_SEGMENT,
|
|
216
|
+
frame=audio_frame,
|
|
217
|
+
vq_codes=data,
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
current_vq = False
|
|
221
|
+
vq_codes = []
|
|
222
|
+
|
|
223
|
+
async with self.client.stream(
|
|
224
|
+
"POST",
|
|
225
|
+
self.llm_url,
|
|
226
|
+
data=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
|
|
227
|
+
headers={"Content-Type": "application/msgpack"},
|
|
228
|
+
) as response:
|
|
229
|
+
|
|
230
|
+
async for chunk in response.aiter_bytes():
|
|
231
|
+
buffer += chunk
|
|
232
|
+
|
|
233
|
+
while len(buffer) >= 4:
|
|
234
|
+
read_length = struct.unpack("I", buffer[:4])[0]
|
|
235
|
+
if len(buffer) < 4 + read_length:
|
|
236
|
+
break
|
|
237
|
+
|
|
238
|
+
body = buffer[4 : 4 + read_length]
|
|
239
|
+
buffer = buffer[4 + read_length :]
|
|
240
|
+
data = ormsgpack.unpackb(body)
|
|
241
|
+
|
|
242
|
+
if data["delta"] and data["delta"]["part"]:
|
|
243
|
+
if current_vq and data["delta"]["part"]["type"] == "text":
|
|
244
|
+
async for event in decode_send():
|
|
245
|
+
yield event
|
|
246
|
+
if data["delta"]["part"]["type"] == "text":
|
|
247
|
+
yield FishE2EEvent(
|
|
248
|
+
type=FishE2EEventType.TEXT_SEGMENT,
|
|
249
|
+
text=data["delta"]["part"]["text"],
|
|
250
|
+
)
|
|
251
|
+
elif data["delta"]["part"]["type"] == "vq":
|
|
252
|
+
vq_codes.append(np.array(data["delta"]["part"]["codes"]))
|
|
253
|
+
current_vq = True
|
|
254
|
+
|
|
255
|
+
if current_vq and vq_codes:
|
|
256
|
+
async for event in decode_send():
|
|
257
|
+
yield event
|
|
258
|
+
|
|
259
|
+
yield FishE2EEvent(type=FishE2EEventType.END_OF_TEXT)
|
|
260
|
+
yield FishE2EEvent(type=FishE2EEventType.END_OF_SPEECH)
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
# Example usage:
|
|
264
|
+
async def main():
|
|
265
|
+
import torchaudio
|
|
266
|
+
|
|
267
|
+
agent = FishE2EAgent()
|
|
268
|
+
|
|
269
|
+
# Replace this with actual audio data loading
|
|
270
|
+
with open("uz_story_en.m4a", "rb") as f:
|
|
271
|
+
audio_data = f.read()
|
|
272
|
+
|
|
273
|
+
audio_data, sample_rate = torchaudio.load("uz_story_en.m4a")
|
|
274
|
+
audio_data = (audio_data.numpy() * 32768).astype(np.int16)
|
|
275
|
+
|
|
276
|
+
stream = agent.stream(audio_data, sample_rate, 1)
|
|
277
|
+
if os.path.exists("audio_segment.wav"):
|
|
278
|
+
os.remove("audio_segment.wav")
|
|
279
|
+
|
|
280
|
+
async for event in stream:
|
|
281
|
+
if event.type == FishE2EEventType.SPEECH_SEGMENT:
|
|
282
|
+
# Handle speech segment (e.g., play audio or save to file)
|
|
283
|
+
with open("audio_segment.wav", "ab+") as f:
|
|
284
|
+
f.write(event.frame.data)
|
|
285
|
+
elif event.type == FishE2EEventType.ASR_RESULT:
|
|
286
|
+
print(event.text, flush=True)
|
|
287
|
+
elif event.type == FishE2EEventType.TEXT_SEGMENT:
|
|
288
|
+
print(event.text, flush=True, end="")
|
|
289
|
+
elif event.type == FishE2EEventType.END_OF_TEXT:
|
|
290
|
+
print("\nEnd of text reached.")
|
|
291
|
+
elif event.type == FishE2EEventType.END_OF_SPEECH:
|
|
292
|
+
print("End of speech reached.")
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
if __name__ == "__main__":
|
|
296
|
+
import asyncio
|
|
297
|
+
|
|
298
|
+
asyncio.run(main())
|