xinference 0.16.3__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (69) hide show
  1. xinference/_compat.py +22 -2
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +148 -12
  4. xinference/client/restful/restful_client.py +47 -2
  5. xinference/constants.py +1 -0
  6. xinference/core/model.py +45 -15
  7. xinference/core/supervisor.py +8 -2
  8. xinference/core/utils.py +67 -2
  9. xinference/model/audio/__init__.py +12 -0
  10. xinference/model/audio/core.py +21 -4
  11. xinference/model/audio/fish_speech.py +70 -35
  12. xinference/model/audio/model_spec.json +81 -1
  13. xinference/model/audio/whisper_mlx.py +208 -0
  14. xinference/model/embedding/core.py +259 -4
  15. xinference/model/embedding/model_spec.json +1 -1
  16. xinference/model/embedding/model_spec_modelscope.json +1 -1
  17. xinference/model/image/stable_diffusion/core.py +5 -2
  18. xinference/model/llm/__init__.py +2 -0
  19. xinference/model/llm/llm_family.json +485 -6
  20. xinference/model/llm/llm_family_modelscope.json +519 -0
  21. xinference/model/llm/mlx/core.py +45 -3
  22. xinference/model/llm/sglang/core.py +1 -0
  23. xinference/model/llm/transformers/core.py +1 -0
  24. xinference/model/llm/transformers/glm_edge_v.py +230 -0
  25. xinference/model/llm/utils.py +19 -0
  26. xinference/model/llm/vllm/core.py +84 -2
  27. xinference/model/rerank/core.py +11 -4
  28. xinference/thirdparty/fish_speech/fish_speech/conversation.py +254 -0
  29. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +2 -1
  30. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +2 -1
  31. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +2 -2
  32. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ko_KR.json +123 -0
  33. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +2 -1
  34. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +76 -11
  35. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +9 -9
  36. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +1 -1
  37. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +32 -1
  38. xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +2 -1
  39. xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +22 -0
  40. xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +1 -1
  41. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1 -1
  42. xinference/thirdparty/fish_speech/tools/api.py +578 -75
  43. xinference/thirdparty/fish_speech/tools/e2e_webui.py +232 -0
  44. xinference/thirdparty/fish_speech/tools/fish_e2e.py +298 -0
  45. xinference/thirdparty/fish_speech/tools/llama/generate.py +393 -9
  46. xinference/thirdparty/fish_speech/tools/msgpack_api.py +90 -29
  47. xinference/thirdparty/fish_speech/tools/post_api.py +37 -15
  48. xinference/thirdparty/fish_speech/tools/schema.py +187 -0
  49. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +7 -1
  50. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +2 -3
  51. xinference/thirdparty/fish_speech/tools/webui.py +138 -75
  52. xinference/types.py +2 -1
  53. {xinference-0.16.3.dist-info → xinference-1.0.1.dist-info}/METADATA +30 -6
  54. {xinference-0.16.3.dist-info → xinference-1.0.1.dist-info}/RECORD +58 -63
  55. {xinference-0.16.3.dist-info → xinference-1.0.1.dist-info}/WHEEL +1 -1
  56. xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
  57. xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
  58. xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
  59. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
  60. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
  61. xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
  62. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
  63. xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
  64. xinference/thirdparty/fish_speech/tools/commons.py +0 -35
  65. xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
  66. xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
  67. {xinference-0.16.3.dist-info → xinference-1.0.1.dist-info}/LICENSE +0 -0
  68. {xinference-0.16.3.dist-info → xinference-1.0.1.dist-info}/entry_points.txt +0 -0
  69. {xinference-0.16.3.dist-info → xinference-1.0.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,232 @@
1
+ import io
2
+ import re
3
+ import wave
4
+
5
+ import gradio as gr
6
+ import numpy as np
7
+
8
+ from .fish_e2e import FishE2EAgent, FishE2EEventType
9
+ from .schema import ServeMessage, ServeTextPart, ServeVQPart
10
+
11
+
12
+ def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
13
+ buffer = io.BytesIO()
14
+
15
+ with wave.open(buffer, "wb") as wav_file:
16
+ wav_file.setnchannels(channels)
17
+ wav_file.setsampwidth(bit_depth // 8)
18
+ wav_file.setframerate(sample_rate)
19
+
20
+ wav_header_bytes = buffer.getvalue()
21
+ buffer.close()
22
+ return wav_header_bytes
23
+
24
+
25
+ class ChatState:
26
+ def __init__(self):
27
+ self.conversation = []
28
+ self.added_systext = False
29
+ self.added_sysaudio = False
30
+
31
+ def get_history(self):
32
+ results = []
33
+ for msg in self.conversation:
34
+ results.append({"role": msg.role, "content": self.repr_message(msg)})
35
+
36
+ # Process assistant messages to extract questions and update user messages
37
+ for i, msg in enumerate(results):
38
+ if msg["role"] == "assistant":
39
+ match = re.search(r"Question: (.*?)\n\nResponse:", msg["content"])
40
+ if match and i > 0 and results[i - 1]["role"] == "user":
41
+ # Update previous user message with extracted question
42
+ results[i - 1]["content"] += "\n" + match.group(1)
43
+ # Remove the Question/Answer format from assistant message
44
+ msg["content"] = msg["content"].split("\n\nResponse: ", 1)[1]
45
+ return results
46
+
47
+ def repr_message(self, msg: ServeMessage):
48
+ response = ""
49
+ for part in msg.parts:
50
+ if isinstance(part, ServeTextPart):
51
+ response += part.text
52
+ elif isinstance(part, ServeVQPart):
53
+ response += f"<audio {len(part.codes[0]) / 21:.2f}s>"
54
+ return response
55
+
56
+
57
+ def clear_fn():
58
+ return [], ChatState(), None, None, None
59
+
60
+
61
+ async def process_audio_input(
62
+ sys_audio_input, sys_text_input, audio_input, state: ChatState, text_input: str
63
+ ):
64
+ if audio_input is None and not text_input:
65
+ raise gr.Error("No input provided")
66
+
67
+ agent = FishE2EAgent() # Create new agent instance for each request
68
+
69
+ # Convert audio input to numpy array
70
+ if isinstance(audio_input, tuple):
71
+ sr, audio_data = audio_input
72
+ elif text_input:
73
+ sr = 44100
74
+ audio_data = None
75
+ else:
76
+ raise gr.Error("Invalid audio format")
77
+
78
+ if isinstance(sys_audio_input, tuple):
79
+ sr, sys_audio_data = sys_audio_input
80
+ else:
81
+ sr = 44100
82
+ sys_audio_data = None
83
+
84
+ def append_to_chat_ctx(
85
+ part: ServeTextPart | ServeVQPart, role: str = "assistant"
86
+ ) -> None:
87
+ if not state.conversation or state.conversation[-1].role != role:
88
+ state.conversation.append(ServeMessage(role=role, parts=[part]))
89
+ else:
90
+ state.conversation[-1].parts.append(part)
91
+
92
+ if state.added_systext is False and sys_text_input:
93
+ state.added_systext = True
94
+ append_to_chat_ctx(ServeTextPart(text=sys_text_input), role="system")
95
+ if text_input:
96
+ append_to_chat_ctx(ServeTextPart(text=text_input), role="user")
97
+ audio_data = None
98
+
99
+ result_audio = b""
100
+ async for event in agent.stream(
101
+ sys_audio_data,
102
+ audio_data,
103
+ sr,
104
+ 1,
105
+ chat_ctx={
106
+ "messages": state.conversation,
107
+ "added_sysaudio": state.added_sysaudio,
108
+ },
109
+ ):
110
+ if event.type == FishE2EEventType.USER_CODES:
111
+ append_to_chat_ctx(ServeVQPart(codes=event.vq_codes), role="user")
112
+ elif event.type == FishE2EEventType.SPEECH_SEGMENT:
113
+ append_to_chat_ctx(ServeVQPart(codes=event.vq_codes))
114
+ yield state.get_history(), wav_chunk_header() + event.frame.data, None, None
115
+ elif event.type == FishE2EEventType.TEXT_SEGMENT:
116
+ append_to_chat_ctx(ServeTextPart(text=event.text))
117
+ yield state.get_history(), None, None, None
118
+
119
+ yield state.get_history(), None, None, None
120
+
121
+
122
+ async def process_text_input(
123
+ sys_audio_input, sys_text_input, state: ChatState, text_input: str
124
+ ):
125
+ async for event in process_audio_input(
126
+ sys_audio_input, sys_text_input, None, state, text_input
127
+ ):
128
+ yield event
129
+
130
+
131
+ def create_demo():
132
+ with gr.Blocks() as demo:
133
+ state = gr.State(ChatState())
134
+
135
+ with gr.Row():
136
+ # Left column (70%) for chatbot and notes
137
+ with gr.Column(scale=7):
138
+ chatbot = gr.Chatbot(
139
+ [],
140
+ elem_id="chatbot",
141
+ bubble_full_width=False,
142
+ height=600,
143
+ type="messages",
144
+ )
145
+
146
+ # notes = gr.Markdown(
147
+ # """
148
+ # # Fish Agent
149
+ # 1. 此Demo为Fish Audio自研端到端语言模型Fish Agent 3B版本.
150
+ # 2. 你可以在我们的官方仓库找到代码以及权重,但是相关内容全部基于 CC BY-NC-SA 4.0 许可证发布.
151
+ # 3. Demo为早期灰度测试版本,推理速度尚待优化.
152
+ # # 特色
153
+ # 1. 该模型自动集成ASR与TTS部分,不需要外挂其它模型,即真正的端到端,而非三段式(ASR+LLM+TTS).
154
+ # 2. 模型可以使用reference audio控制说话音色.
155
+ # 3. 可以生成具有较强情感与韵律的音频.
156
+ # """
157
+ # )
158
+ notes = gr.Markdown(
159
+ """
160
+ # Fish Agent
161
+ 1. This demo is Fish Audio's self-researh end-to-end language model, Fish Agent version 3B.
162
+ 2. You can find the code and weights in our official repo in [gitub](https://github.com/fishaudio/fish-speech) and [hugging face](https://huggingface.co/fishaudio/fish-agent-v0.1-3b), but the content is released under a CC BY-NC-SA 4.0 licence.
163
+ 3. The demo is an early alpha test version, the inference speed needs to be optimised.
164
+ # Features
165
+ 1. The model automatically integrates ASR and TTS parts, no need to plug-in other models, i.e., true end-to-end, not three-stage (ASR+LLM+TTS).
166
+ 2. The model can use reference audio to control the speech timbre.
167
+ 3. The model can generate speech with strong emotion.
168
+ """
169
+ )
170
+
171
+ # Right column (30%) for controls
172
+ with gr.Column(scale=3):
173
+ sys_audio_input = gr.Audio(
174
+ sources=["upload"],
175
+ type="numpy",
176
+ label="Give a timbre for your assistant",
177
+ )
178
+ sys_text_input = gr.Textbox(
179
+ label="What is your assistant's role?",
180
+ value="You are a voice assistant created by Fish Audio, offering end-to-end voice interaction for a seamless user experience. You are required to first transcribe the user's speech, then answer it in the following format: 'Question: [USER_SPEECH]\n\nAnswer: [YOUR_RESPONSE]\n'. You are required to use the following voice in this conversation.",
181
+ type="text",
182
+ )
183
+ audio_input = gr.Audio(
184
+ sources=["microphone"], type="numpy", label="Speak your message"
185
+ )
186
+
187
+ text_input = gr.Textbox(label="Or type your message", type="text")
188
+
189
+ output_audio = gr.Audio(
190
+ label="Assistant's Voice",
191
+ streaming=True,
192
+ autoplay=True,
193
+ interactive=False,
194
+ )
195
+
196
+ send_button = gr.Button("Send", variant="primary")
197
+ clear_button = gr.Button("Clear")
198
+
199
+ # Event handlers
200
+ audio_input.stop_recording(
201
+ process_audio_input,
202
+ inputs=[sys_audio_input, sys_text_input, audio_input, state, text_input],
203
+ outputs=[chatbot, output_audio, audio_input, text_input],
204
+ show_progress=True,
205
+ )
206
+
207
+ send_button.click(
208
+ process_text_input,
209
+ inputs=[sys_audio_input, sys_text_input, state, text_input],
210
+ outputs=[chatbot, output_audio, audio_input, text_input],
211
+ show_progress=True,
212
+ )
213
+
214
+ text_input.submit(
215
+ process_text_input,
216
+ inputs=[sys_audio_input, sys_text_input, state, text_input],
217
+ outputs=[chatbot, output_audio, audio_input, text_input],
218
+ show_progress=True,
219
+ )
220
+
221
+ clear_button.click(
222
+ clear_fn,
223
+ inputs=[],
224
+ outputs=[chatbot, state, audio_input, output_audio, text_input],
225
+ )
226
+
227
+ return demo
228
+
229
+
230
+ if __name__ == "__main__":
231
+ demo = create_demo()
232
+ demo.launch(server_name="127.0.0.1", server_port=7860, share=True)
@@ -0,0 +1,298 @@
1
+ import base64
2
+ import ctypes
3
+ import io
4
+ import json
5
+ import os
6
+ import struct
7
+ from dataclasses import dataclass
8
+ from enum import Enum
9
+ from typing import AsyncGenerator, Union
10
+
11
+ import httpx
12
+ import numpy as np
13
+ import ormsgpack
14
+ import soundfile as sf
15
+
16
+ from .schema import (
17
+ ServeMessage,
18
+ ServeRequest,
19
+ ServeTextPart,
20
+ ServeVQGANDecodeRequest,
21
+ ServeVQGANEncodeRequest,
22
+ ServeVQPart,
23
+ )
24
+
25
+
26
+ class CustomAudioFrame:
27
+ def __init__(self, data, sample_rate, num_channels, samples_per_channel):
28
+ if len(data) < num_channels * samples_per_channel * ctypes.sizeof(
29
+ ctypes.c_int16
30
+ ):
31
+ raise ValueError(
32
+ "data length must be >= num_channels * samples_per_channel * sizeof(int16)"
33
+ )
34
+
35
+ self._data = bytearray(data)
36
+ self._sample_rate = sample_rate
37
+ self._num_channels = num_channels
38
+ self._samples_per_channel = samples_per_channel
39
+
40
+ @property
41
+ def data(self):
42
+ return memoryview(self._data).cast("h")
43
+
44
+ @property
45
+ def sample_rate(self):
46
+ return self._sample_rate
47
+
48
+ @property
49
+ def num_channels(self):
50
+ return self._num_channels
51
+
52
+ @property
53
+ def samples_per_channel(self):
54
+ return self._samples_per_channel
55
+
56
+ @property
57
+ def duration(self):
58
+ return self.samples_per_channel / self.sample_rate
59
+
60
+ def __repr__(self):
61
+ return (
62
+ f"CustomAudioFrame(sample_rate={self.sample_rate}, "
63
+ f"num_channels={self.num_channels}, "
64
+ f"samples_per_channel={self.samples_per_channel}, "
65
+ f"duration={self.duration:.3f})"
66
+ )
67
+
68
+
69
+ class FishE2EEventType(Enum):
70
+ SPEECH_SEGMENT = 1
71
+ TEXT_SEGMENT = 2
72
+ END_OF_TEXT = 3
73
+ END_OF_SPEECH = 4
74
+ ASR_RESULT = 5
75
+ USER_CODES = 6
76
+
77
+
78
+ @dataclass
79
+ class FishE2EEvent:
80
+ type: FishE2EEventType
81
+ frame: np.ndarray = None
82
+ text: str = None
83
+ vq_codes: list[list[int]] = None
84
+
85
+
86
+ client = httpx.AsyncClient(
87
+ timeout=None,
88
+ limits=httpx.Limits(
89
+ max_connections=None,
90
+ max_keepalive_connections=None,
91
+ keepalive_expiry=None,
92
+ ),
93
+ )
94
+
95
+
96
+ class FishE2EAgent:
97
+ def __init__(self):
98
+ self.llm_url = "http://localhost:8080/v1/chat"
99
+ self.vqgan_url = "http://localhost:8080"
100
+ self.client = httpx.AsyncClient(timeout=None)
101
+
102
+ async def get_codes(self, audio_data, sample_rate):
103
+ audio_buffer = io.BytesIO()
104
+ sf.write(audio_buffer, audio_data, sample_rate, format="WAV")
105
+ audio_buffer.seek(0)
106
+ # Step 1: Encode audio using VQGAN
107
+ encode_request = ServeVQGANEncodeRequest(audios=[audio_buffer.read()])
108
+ encode_request_bytes = ormsgpack.packb(
109
+ encode_request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC
110
+ )
111
+ encode_response = await self.client.post(
112
+ f"{self.vqgan_url}/v1/vqgan/encode",
113
+ data=encode_request_bytes,
114
+ headers={"Content-Type": "application/msgpack"},
115
+ )
116
+ encode_response_data = ormsgpack.unpackb(encode_response.content)
117
+ codes = encode_response_data["tokens"][0]
118
+ return codes
119
+
120
+ async def stream(
121
+ self,
122
+ system_audio_data: np.ndarray | None,
123
+ user_audio_data: np.ndarray | None,
124
+ sample_rate: int,
125
+ num_channels: int,
126
+ chat_ctx: dict | None = None,
127
+ ) -> AsyncGenerator[bytes, None]:
128
+
129
+ if system_audio_data is not None:
130
+ sys_codes = await self.get_codes(system_audio_data, sample_rate)
131
+ else:
132
+ sys_codes = None
133
+ if user_audio_data is not None:
134
+ user_codes = await self.get_codes(user_audio_data, sample_rate)
135
+ # Step 2: Prepare LLM request
136
+ if chat_ctx is None:
137
+ sys_parts = [
138
+ ServeTextPart(
139
+ text='您是由 Fish Audio 设计的语音助手,提供端到端的语音交互,实现无缝用户体验。首先转录用户的语音,然后使用以下格式回答:"Question: [用户语音]\n\nAnswer: [你的回答]\n"。'
140
+ ),
141
+ ]
142
+ if system_audio_data is not None:
143
+ sys_parts.append(ServeVQPart(codes=sys_codes))
144
+ chat_ctx = {
145
+ "messages": [
146
+ ServeMessage(
147
+ role="system",
148
+ parts=sys_parts,
149
+ ),
150
+ ],
151
+ }
152
+ else:
153
+ if chat_ctx["added_sysaudio"] is False and sys_codes:
154
+ chat_ctx["added_sysaudio"] = True
155
+ chat_ctx["messages"][0].parts.append(ServeVQPart(codes=sys_codes))
156
+
157
+ prev_messages = chat_ctx["messages"].copy()
158
+ if user_audio_data is not None:
159
+ yield FishE2EEvent(
160
+ type=FishE2EEventType.USER_CODES,
161
+ vq_codes=user_codes,
162
+ )
163
+ else:
164
+ user_codes = None
165
+
166
+ request = ServeRequest(
167
+ messages=prev_messages
168
+ + (
169
+ [
170
+ ServeMessage(
171
+ role="user",
172
+ parts=[ServeVQPart(codes=user_codes)],
173
+ )
174
+ ]
175
+ if user_codes
176
+ else []
177
+ ),
178
+ streaming=True,
179
+ num_samples=1,
180
+ )
181
+
182
+ # Step 3: Stream LLM response and decode audio
183
+ buffer = b""
184
+ vq_codes = []
185
+ current_vq = False
186
+
187
+ async def decode_send():
188
+ nonlocal current_vq
189
+ nonlocal vq_codes
190
+
191
+ data = np.concatenate(vq_codes, axis=1).tolist()
192
+ # Decode VQ codes to audio
193
+ decode_request = ServeVQGANDecodeRequest(tokens=[data])
194
+ decode_response = await self.client.post(
195
+ f"{self.vqgan_url}/v1/vqgan/decode",
196
+ data=ormsgpack.packb(
197
+ decode_request,
198
+ option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
199
+ ),
200
+ headers={"Content-Type": "application/msgpack"},
201
+ )
202
+ decode_data = ormsgpack.unpackb(decode_response.content)
203
+
204
+ # Convert float16 audio data to int16
205
+ audio_data = np.frombuffer(decode_data["audios"][0], dtype=np.float16)
206
+ audio_data = (audio_data * 32768).astype(np.int16).tobytes()
207
+
208
+ audio_frame = CustomAudioFrame(
209
+ data=audio_data,
210
+ samples_per_channel=len(audio_data) // 2,
211
+ sample_rate=44100,
212
+ num_channels=1,
213
+ )
214
+ yield FishE2EEvent(
215
+ type=FishE2EEventType.SPEECH_SEGMENT,
216
+ frame=audio_frame,
217
+ vq_codes=data,
218
+ )
219
+
220
+ current_vq = False
221
+ vq_codes = []
222
+
223
+ async with self.client.stream(
224
+ "POST",
225
+ self.llm_url,
226
+ data=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
227
+ headers={"Content-Type": "application/msgpack"},
228
+ ) as response:
229
+
230
+ async for chunk in response.aiter_bytes():
231
+ buffer += chunk
232
+
233
+ while len(buffer) >= 4:
234
+ read_length = struct.unpack("I", buffer[:4])[0]
235
+ if len(buffer) < 4 + read_length:
236
+ break
237
+
238
+ body = buffer[4 : 4 + read_length]
239
+ buffer = buffer[4 + read_length :]
240
+ data = ormsgpack.unpackb(body)
241
+
242
+ if data["delta"] and data["delta"]["part"]:
243
+ if current_vq and data["delta"]["part"]["type"] == "text":
244
+ async for event in decode_send():
245
+ yield event
246
+ if data["delta"]["part"]["type"] == "text":
247
+ yield FishE2EEvent(
248
+ type=FishE2EEventType.TEXT_SEGMENT,
249
+ text=data["delta"]["part"]["text"],
250
+ )
251
+ elif data["delta"]["part"]["type"] == "vq":
252
+ vq_codes.append(np.array(data["delta"]["part"]["codes"]))
253
+ current_vq = True
254
+
255
+ if current_vq and vq_codes:
256
+ async for event in decode_send():
257
+ yield event
258
+
259
+ yield FishE2EEvent(type=FishE2EEventType.END_OF_TEXT)
260
+ yield FishE2EEvent(type=FishE2EEventType.END_OF_SPEECH)
261
+
262
+
263
+ # Example usage:
264
+ async def main():
265
+ import torchaudio
266
+
267
+ agent = FishE2EAgent()
268
+
269
+ # Replace this with actual audio data loading
270
+ with open("uz_story_en.m4a", "rb") as f:
271
+ audio_data = f.read()
272
+
273
+ audio_data, sample_rate = torchaudio.load("uz_story_en.m4a")
274
+ audio_data = (audio_data.numpy() * 32768).astype(np.int16)
275
+
276
+ stream = agent.stream(audio_data, sample_rate, 1)
277
+ if os.path.exists("audio_segment.wav"):
278
+ os.remove("audio_segment.wav")
279
+
280
+ async for event in stream:
281
+ if event.type == FishE2EEventType.SPEECH_SEGMENT:
282
+ # Handle speech segment (e.g., play audio or save to file)
283
+ with open("audio_segment.wav", "ab+") as f:
284
+ f.write(event.frame.data)
285
+ elif event.type == FishE2EEventType.ASR_RESULT:
286
+ print(event.text, flush=True)
287
+ elif event.type == FishE2EEventType.TEXT_SEGMENT:
288
+ print(event.text, flush=True, end="")
289
+ elif event.type == FishE2EEventType.END_OF_TEXT:
290
+ print("\nEnd of text reached.")
291
+ elif event.type == FishE2EEventType.END_OF_SPEECH:
292
+ print("End of speech reached.")
293
+
294
+
295
+ if __name__ == "__main__":
296
+ import asyncio
297
+
298
+ asyncio.run(main())