livekit-plugins-neuphonic 1.0.23__tar.gz → 1.1.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -146,6 +146,9 @@ venv.bak/
146
146
  .dmypy.json
147
147
  dmypy.json
148
148
 
149
+ # trunk
150
+ .trunk/
151
+
149
152
  # Pyre type checker
150
153
  .pyre/
151
154
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: livekit-plugins-neuphonic
3
- Version: 1.0.23
3
+ Version: 1.1.1
4
4
  Summary: Neuphonic inference plugin for LiveKit Agents
5
5
  Project-URL: Documentation, https://docs.livekit.io
6
6
  Project-URL: Website, https://livekit.io/
@@ -16,7 +16,7 @@ Classifier: Programming Language :: Python :: 3.12
16
16
  Classifier: Topic :: Multimedia :: Sound/Audio
17
17
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
18
18
  Requires-Python: >=3.9.0
19
- Requires-Dist: livekit-agents>=1.0.23
19
+ Requires-Dist: livekit-agents>=1.1.1
20
20
  Description-Content-Type: text/markdown
21
21
 
22
22
  # Neuphonic plugin for LiveKit Agents
@@ -28,7 +28,7 @@ from .log import logger
28
28
 
29
29
 
30
30
  class NeuphonicPlugin(Plugin):
31
- def __init__(self):
31
+ def __init__(self) -> None:
32
32
  super().__init__(__name__, __version__, __package__, logger)
33
33
 
34
34
 
@@ -0,0 +1,3 @@
1
+ from typing import Literal
2
+
3
+ TTSLangCodes = Literal["en", "nl", "es", "de", "hi", "en-hi", "ar"]
@@ -0,0 +1,233 @@
1
+ # Copyright 2023 LiveKit, Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import asyncio
18
+ import base64
19
+ import json
20
+ import os
21
+ from dataclasses import dataclass, replace
22
+
23
+ import aiohttp
24
+
25
+ from livekit.agents import (
26
+ APIConnectionError,
27
+ APIConnectOptions,
28
+ APIStatusError,
29
+ APITimeoutError,
30
+ tts,
31
+ utils,
32
+ )
33
+ from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS, NOT_GIVEN, NotGivenOr
34
+ from livekit.agents.utils import is_given
35
+
36
+ from .models import TTSLangCodes
37
+
38
+ API_BASE_URL = "api.neuphonic.com"
39
+ AUTHORIZATION_HEADER = "X-API-KEY"
40
+
41
+
42
+ @dataclass
43
+ class _TTSOptions:
44
+ base_url: str
45
+ lang_code: TTSLangCodes | str
46
+ api_key: str
47
+ sample_rate: int
48
+ speed: float
49
+ voice_id: str | None
50
+
51
+
52
+ class TTS(tts.TTS):
53
+ def __init__(
54
+ self,
55
+ *,
56
+ voice_id: str = "8e9c4bc8-3979-48ab-8626-df53befc2090",
57
+ api_key: str | None = None,
58
+ lang_code: TTSLangCodes | str = "en",
59
+ speed: float = 1.0,
60
+ sample_rate: int = 22050,
61
+ http_session: aiohttp.ClientSession | None = None,
62
+ base_url: str = API_BASE_URL,
63
+ ) -> None:
64
+ """
65
+ Create a new instance of the Neuphonic TTS.
66
+
67
+ See https://docs.neuphonic.com for more documentation on all of these options, or go to https://app.neuphonic.com/ to test out different options.
68
+
69
+ Args:
70
+ voice_id (str, optional): The voice ID for the desired voice. Defaults to None.
71
+ lang_code (TTSLanguages | str, optional): The language code for synthesis. Defaults to "en".
72
+ encoding (TTSEncodings | str, optional): The audio encoding format. Defaults to "pcm_mulaw".
73
+ speed (float, optional): The audio playback speed. Defaults to 1.0.
74
+ sample_rate (int, optional): The audio sample rate in Hz. Defaults to 22050.
75
+ api_key (str | None, optional): The Neuphonic API key. If not provided, it will be read from the NEUPHONIC_API_KEY environment variable.
76
+ http_session (aiohttp.ClientSession | None, optional): An existing aiohttp ClientSession to use. If not provided, a new session will be created.
77
+ base_url (str, optional): The base URL for the Neuphonic API. Defaults to "api.neuphonic.com".
78
+ """ # noqa: E501
79
+ super().__init__(
80
+ capabilities=tts.TTSCapabilities(streaming=True),
81
+ sample_rate=sample_rate,
82
+ num_channels=1,
83
+ )
84
+
85
+ api_key = api_key or os.environ.get("NEUPHONIC_API_KEY")
86
+ if not api_key:
87
+ raise ValueError("API key must be provided or set in NEUPHONIC_API_KEY")
88
+
89
+ self._opts = _TTSOptions(
90
+ voice_id=voice_id,
91
+ lang_code=lang_code,
92
+ api_key=api_key,
93
+ speed=speed,
94
+ sample_rate=sample_rate,
95
+ base_url=base_url,
96
+ )
97
+ self._session = http_session
98
+
99
+ def _ensure_session(self) -> aiohttp.ClientSession:
100
+ if not self._session:
101
+ self._session = utils.http_context.http_session()
102
+
103
+ return self._session
104
+
105
+ def update_options(
106
+ self,
107
+ *,
108
+ voice_id: NotGivenOr[str] = NOT_GIVEN,
109
+ lang_code: NotGivenOr[TTSLangCodes] = NOT_GIVEN,
110
+ speed: NotGivenOr[float] = NOT_GIVEN,
111
+ sample_rate: NotGivenOr[int] = NOT_GIVEN,
112
+ ) -> None:
113
+ """
114
+ Update the Text-to-Speech (TTS) configuration options.
115
+
116
+ This method allows updating the TTS settings, including model type, voice_id, lang_code,
117
+ encoding, speed and sample_rate. If any parameter is not provided, the existing value will be
118
+ retained.
119
+
120
+ Args:
121
+ model (TTSModels | str, optional): The Neuphonic model to use.
122
+ voice_id (str, optional): The voice ID for the desired voice.
123
+ lang_code (TTSLanguages | str, optional): The language code for synthesis..
124
+ encoding (TTSEncodings | str, optional): The audio encoding format.
125
+ speed (float, optional): The audio playback speed.
126
+ sample_rate (int, optional): The audio sample rate in Hz.
127
+ """ # noqa: E501
128
+ if is_given(voice_id):
129
+ self._opts.voice_id = voice_id
130
+ if is_given(lang_code):
131
+ self._opts.lang_code = lang_code
132
+ if is_given(speed):
133
+ self._opts.speed = speed
134
+ if is_given(sample_rate):
135
+ self._opts.sample_rate = sample_rate
136
+
137
+ def synthesize(
138
+ self, text: str, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS
139
+ ) -> ChunkedStream:
140
+ return ChunkedStream(tts=self, input_text=text, conn_options=conn_options)
141
+
142
+
143
+ class ChunkedStream(tts.ChunkedStream):
144
+ """Synthesize chunked text using the SSE endpoint"""
145
+
146
+ def __init__(
147
+ self,
148
+ *,
149
+ tts: TTS,
150
+ input_text: str,
151
+ conn_options: APIConnectOptions,
152
+ ) -> None:
153
+ super().__init__(tts=tts, input_text=input_text, conn_options=conn_options)
154
+ self._tts: TTS = tts
155
+ self._opts = replace(tts._opts)
156
+
157
+ async def _run(self, output_emitter: tts.AudioEmitter) -> None:
158
+ try:
159
+ async with self._tts._ensure_session().post(
160
+ f"https://{self._opts.base_url}/sse/speak/{self._opts.lang_code}",
161
+ headers={AUTHORIZATION_HEADER: self._opts.api_key},
162
+ json={
163
+ "text": self._input_text,
164
+ "voice_id": self._opts.voice_id,
165
+ "lang_code": self._opts.lang_code,
166
+ "encoding": "pcm_linear",
167
+ "sampling_rate": self._opts.sample_rate,
168
+ "speed": self._opts.speed,
169
+ },
170
+ timeout=aiohttp.ClientTimeout(
171
+ total=30,
172
+ sock_connect=self._conn_options.timeout,
173
+ ),
174
+ # large read_bufsize to avoid `ValueError: Chunk too big`
175
+ read_bufsize=10 * 1024 * 1024,
176
+ ) as resp:
177
+ resp.raise_for_status()
178
+
179
+ output_emitter.initialize(
180
+ request_id=utils.shortuuid(),
181
+ sample_rate=self._opts.sample_rate,
182
+ num_channels=1,
183
+ mime_type="audio/pcm",
184
+ )
185
+
186
+ async for line in resp.content:
187
+ message = line.decode("utf-8")
188
+ if not message:
189
+ continue
190
+
191
+ parsed_message = _parse_sse_message(message)
192
+
193
+ if (
194
+ parsed_message is not None
195
+ and parsed_message.get("data", {}).get("audio") is not None
196
+ ):
197
+ audio_bytes = base64.b64decode(parsed_message["data"]["audio"])
198
+ output_emitter.push(audio_bytes)
199
+
200
+ output_emitter.flush()
201
+ except asyncio.TimeoutError:
202
+ raise APITimeoutError() from None
203
+ except aiohttp.ClientResponseError as e:
204
+ raise APIStatusError(
205
+ message=e.message, status_code=e.status, request_id=None, body=None
206
+ ) from None
207
+ except Exception as e:
208
+ raise APIConnectionError() from e
209
+
210
+
211
+ def _parse_sse_message(message: str) -> dict | None:
212
+ """
213
+ Parse each response from the SSE endpoint.
214
+
215
+ The message will either be a string reading:
216
+ - `event: error`
217
+ - `event: message`
218
+ - `data: { "status_code": 200, "data": {"audio": ... } }`
219
+ """
220
+ message = message.strip()
221
+
222
+ if not message or "data" not in message:
223
+ return None
224
+
225
+ _, value = message.split(": ", 1)
226
+ message_dict: dict = json.loads(value)
227
+
228
+ if message_dict.get("errors") is not None:
229
+ raise Exception(
230
+ f"received error status {message_dict['status_code']}: {message_dict['errors']}"
231
+ )
232
+
233
+ return message_dict
@@ -12,4 +12,4 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- __version__ = "1.0.23"
15
+ __version__ = "1.1.1"
@@ -21,7 +21,7 @@ classifiers = [
21
21
  "Programming Language :: Python :: 3 :: Only",
22
22
  ]
23
23
  dependencies = [
24
- "livekit-agents>=1.0.23",
24
+ "livekit-agents>=1.1.1",
25
25
  ]
26
26
 
27
27
  [project.urls]
@@ -1,10 +0,0 @@
1
- from typing import Literal
2
-
3
- TTSEncodings = Literal[
4
- "pcm_linear",
5
- "pcm_mulaw",
6
- ]
7
-
8
- TTSModels = Literal["neu-fast", "neu-hq"]
9
-
10
- TTSLangCodes = Literal["en", "nl", "es", "de", "hi", "en-hi", "ar"]
@@ -1,420 +0,0 @@
1
- # Copyright 2023 LiveKit, Inc.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from __future__ import annotations
16
-
17
- import asyncio
18
- import base64
19
- import json
20
- import os
21
- import weakref
22
- from dataclasses import dataclass
23
-
24
- import aiohttp
25
-
26
- from livekit.agents import (
27
- APIConnectionError,
28
- APIConnectOptions,
29
- APIStatusError,
30
- APITimeoutError,
31
- tts,
32
- utils,
33
- )
34
- from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS, NOT_GIVEN, NotGivenOr
35
- from livekit.agents.utils import is_given
36
-
37
- from .log import logger
38
- from .models import TTSEncodings, TTSLangCodes, TTSModels
39
-
40
- API_BASE_URL = "api.neuphonic.com"
41
- AUTHORIZATION_HEADER = "X-API-KEY"
42
- NUM_CHANNELS = 1
43
-
44
-
45
- @dataclass
46
- class _TTSOptions:
47
- base_url: str
48
- api_key: str
49
- model: TTSModels | str
50
- lang_code: TTSLangCodes | str
51
- encoding: TTSEncodings | str
52
- sampling_rate: int
53
- speed: float
54
- voice_id: NotGivenOr[str] = NOT_GIVEN
55
-
56
- @property
57
- def model_params(self) -> dict:
58
- """Returns a dictionary of model parameters for API requests."""
59
- params = {
60
- "voice_id": self.voice_id,
61
- "model": self.model,
62
- "lang_code": self.lang_code,
63
- "encoding": self.encoding,
64
- "sampling_rate": self.sampling_rate,
65
- "speed": self.speed,
66
- }
67
- return {k: v for k, v in params.items() if is_given(v) and v is not None}
68
-
69
- def get_query_param_string(self):
70
- """Forms the query parameter string from all model parameters."""
71
- queries = []
72
- for key, value in self.model_params.items():
73
- queries.append(f"{key}={value}")
74
-
75
- return "?" + "&".join(queries)
76
-
77
-
78
- def _parse_sse_message(message: str) -> dict:
79
- """
80
- Parse each response from the SSE endpoint.
81
-
82
- The message will either be a string reading:
83
- - `event: error`
84
- - `event: message`
85
- - `data: { "status_code": 200, "data": {"audio": ... } }`
86
- """
87
- message = message.strip()
88
-
89
- if not message or "data" not in message:
90
- return None
91
-
92
- _, value = message.split(": ", 1)
93
- message = json.loads(value)
94
-
95
- if message.get("errors") is not None:
96
- raise Exception(f"Status {message.status_code} error received: {message.errors}.")
97
-
98
- return message
99
-
100
-
101
- class TTS(tts.TTS):
102
- def __init__(
103
- self,
104
- *,
105
- model: TTSModels | str = "neu_hq",
106
- voice_id: NotGivenOr[str] = NOT_GIVEN,
107
- lang_code: TTSLangCodes | str = "en",
108
- encoding: TTSEncodings | str = "pcm_linear",
109
- speed: float = 1.0,
110
- sample_rate: int = 22050,
111
- api_key: NotGivenOr[str] = NOT_GIVEN,
112
- http_session: aiohttp.ClientSession | None = None,
113
- base_url: str = API_BASE_URL,
114
- ) -> None:
115
- """
116
- Create a new instance of the Neuphonic TTS.
117
-
118
- See https://docs.neuphonic.com for more documentation on all of these options, or go to https://app.neuphonic.com/ to test out different options.
119
-
120
- Args:
121
- model (TTSModels | str, optional): The Neuphonic model to use. See Defaults to "neu_hq".
122
- voice_id (str, optional): The voice ID for the desired voice. Defaults to None.
123
- lang_code (TTSLanguages | str, optional): The language code for synthesis. Defaults to "en".
124
- encoding (TTSEncodings | str, optional): The audio encoding format. Defaults to "pcm_mulaw".
125
- speed (float, optional): The audio playback speed. Defaults to 1.0.
126
- sample_rate (int, optional): The audio sample rate in Hz. Defaults to 22050.
127
- api_key (str | None, optional): The Neuphonic API key. If not provided, it will be read from the NEUPHONIC_API_TOKEN environment variable.
128
- http_session (aiohttp.ClientSession | None, optional): An existing aiohttp ClientSession to use. If not provided, a new session will be created.
129
- base_url (str, optional): The base URL for the Neuphonic API. Defaults to "api.neuphonic.com".
130
- """ # noqa: E501
131
- super().__init__(
132
- capabilities=tts.TTSCapabilities(streaming=True),
133
- sample_rate=sample_rate,
134
- num_channels=NUM_CHANNELS,
135
- )
136
-
137
- neuphonic_api_key = (
138
- api_key
139
- if is_given(api_key)
140
- else os.environ.get("NEUPHONIC_API_KEY") or os.environ.get("NEUPHONIC_API_TOKEN")
141
- )
142
-
143
- if not neuphonic_api_key:
144
- raise ValueError("API key must be provided or set in NEUPHONIC_API_KEY")
145
-
146
- self._opts = _TTSOptions(
147
- model=model,
148
- voice_id=voice_id,
149
- lang_code=lang_code,
150
- encoding=encoding,
151
- speed=speed,
152
- sampling_rate=sample_rate,
153
- api_key=neuphonic_api_key,
154
- base_url=base_url,
155
- )
156
-
157
- self._session = http_session
158
- self._pool = utils.ConnectionPool[aiohttp.ClientWebSocketResponse](
159
- connect_cb=self._connect_ws,
160
- close_cb=self._close_ws,
161
- max_session_duration=90,
162
- mark_refreshed_on_get=True,
163
- )
164
- self._streams = weakref.WeakSet[SynthesizeStream]()
165
-
166
- async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse:
167
- session = self._ensure_session()
168
- url = f"wss://{self._opts.base_url}/speak/{self._opts.lang_code}{self._opts.get_query_param_string()}"
169
-
170
- return await asyncio.wait_for(
171
- session.ws_connect(url, headers={AUTHORIZATION_HEADER: self._opts.api_key}),
172
- self._conn_options.timeout,
173
- )
174
-
175
- async def _close_ws(self, ws: aiohttp.ClientWebSocketResponse):
176
- await ws.close()
177
-
178
- def _ensure_session(self) -> aiohttp.ClientSession:
179
- if not self._session:
180
- self._session = utils.http_context.http_session()
181
-
182
- return self._session
183
-
184
- def prewarm(self) -> None:
185
- self._pool.prewarm()
186
-
187
- def update_options(
188
- self,
189
- *,
190
- model: NotGivenOr[TTSModels] = NOT_GIVEN,
191
- voice_id: NotGivenOr[str] = NOT_GIVEN,
192
- lang_code: NotGivenOr[TTSLangCodes] = NOT_GIVEN,
193
- encoding: NotGivenOr[TTSEncodings] = NOT_GIVEN,
194
- speed: NotGivenOr[float] = NOT_GIVEN,
195
- sample_rate: NotGivenOr[int] = NOT_GIVEN,
196
- ) -> None:
197
- """
198
- Update the Text-to-Speech (TTS) configuration options.
199
-
200
- This method allows updating the TTS settings, including model type, voice_id, lang_code,
201
- encoding, speed and sample_rate. If any parameter is not provided, the existing value will be
202
- retained.
203
-
204
- Args:
205
- model (TTSModels | str, optional): The Neuphonic model to use.
206
- voice_id (str, optional): The voice ID for the desired voice.
207
- lang_code (TTSLanguages | str, optional): The language code for synthesis..
208
- encoding (TTSEncodings | str, optional): The audio encoding format.
209
- speed (float, optional): The audio playback speed.
210
- sample_rate (int, optional): The audio sample rate in Hz.
211
- """ # noqa: E501
212
- if is_given(model):
213
- self._opts.model = model
214
- if is_given(voice_id):
215
- self._opts.voice_id = voice_id
216
- if is_given(lang_code):
217
- self._opts.lang_code = lang_code
218
- if is_given(encoding):
219
- self._opts.encoding = encoding
220
- if is_given(speed):
221
- self._opts.speed = speed
222
- if is_given(sample_rate):
223
- self._opts.sampling_rate = sample_rate
224
- self._pool.invalidate()
225
-
226
- def synthesize(
227
- self,
228
- text: str,
229
- *,
230
- conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
231
- ) -> ChunkedStream:
232
- return ChunkedStream(
233
- tts=self,
234
- input_text=text,
235
- conn_options=conn_options,
236
- opts=self._opts,
237
- session=self._ensure_session(),
238
- )
239
-
240
- def stream(
241
- self, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS
242
- ) -> SynthesizeStream:
243
- stream = SynthesizeStream(
244
- tts=self,
245
- pool=self._pool,
246
- opts=self._opts,
247
- )
248
- self._streams.add(stream)
249
- return stream
250
-
251
- async def aclose(self) -> None:
252
- for stream in list(self._streams):
253
- await stream.aclose()
254
- self._streams.clear()
255
- await self._pool.aclose()
256
- await super().aclose()
257
-
258
-
259
- class ChunkedStream(tts.ChunkedStream):
260
- """Synthesize chunked text using the SSE endpoint"""
261
-
262
- def __init__(
263
- self,
264
- *,
265
- tts: TTS,
266
- input_text: str,
267
- opts: _TTSOptions,
268
- session: aiohttp.ClientSession,
269
- conn_options: APIConnectOptions,
270
- ) -> None:
271
- super().__init__(tts=tts, input_text=input_text, conn_options=conn_options)
272
- self._opts, self._session = opts, session
273
-
274
- async def _run(self) -> None:
275
- request_id = utils.shortuuid()
276
- bstream = utils.audio.AudioByteStream(
277
- sample_rate=self._opts.sampling_rate, num_channels=NUM_CHANNELS
278
- )
279
-
280
- json_data = {
281
- "text": self._input_text,
282
- **self._opts.model_params,
283
- }
284
-
285
- headers = {
286
- AUTHORIZATION_HEADER: self._opts.api_key,
287
- }
288
-
289
- try:
290
- async with self._session.post(
291
- f"https://{self._opts.base_url}/sse/speak/{self._opts.lang_code}",
292
- headers=headers,
293
- json=json_data,
294
- timeout=aiohttp.ClientTimeout(
295
- total=30,
296
- sock_connect=self._conn_options.timeout,
297
- ),
298
- read_bufsize=10
299
- * 1024
300
- * 1024, # large read_bufsize to avoid `ValueError: Chunk too big`
301
- ) as response:
302
- response.raise_for_status()
303
- emitter = tts.SynthesizedAudioEmitter(
304
- event_ch=self._event_ch,
305
- request_id=request_id,
306
- )
307
-
308
- async for line in response.content:
309
- message = line.decode("utf-8").strip()
310
- if message:
311
- parsed_message = _parse_sse_message(message)
312
-
313
- if (
314
- parsed_message is not None
315
- and parsed_message.get("data", {}).get("audio") is not None
316
- ):
317
- audio_bytes = base64.b64decode(parsed_message["data"]["audio"])
318
-
319
- for frame in bstream.write(audio_bytes):
320
- emitter.push(frame)
321
-
322
- for frame in bstream.flush():
323
- emitter.push(frame)
324
- emitter.flush()
325
- except asyncio.TimeoutError as e:
326
- raise APITimeoutError() from e
327
- except aiohttp.ClientResponseError as e:
328
- raise APIStatusError(
329
- message=e.message,
330
- status_code=e.status,
331
- request_id=None,
332
- body=None,
333
- ) from e
334
- except Exception as e:
335
- raise APIConnectionError() from e
336
-
337
-
338
- class SynthesizeStream(tts.SynthesizeStream):
339
- def __init__(
340
- self,
341
- *,
342
- tts: TTS,
343
- opts: _TTSOptions,
344
- pool: utils.ConnectionPool[aiohttp.ClientWebSocketResponse],
345
- ):
346
- super().__init__(tts=tts)
347
- self._opts, self._pool = opts, pool
348
-
349
- async def _run(self) -> None:
350
- request_id = utils.shortuuid()
351
-
352
- async def _send_task(ws: aiohttp.ClientWebSocketResponse):
353
- """Stream text to the websocket."""
354
- async for data in self._input_ch:
355
- self._mark_started()
356
-
357
- if isinstance(data, self._FlushSentinel):
358
- await ws.send_str(json.dumps({"text": "<STOP>"}))
359
- continue
360
-
361
- await ws.send_str(json.dumps({"text": data}))
362
-
363
- async def _recv_task(ws: aiohttp.ClientWebSocketResponse):
364
- audio_bstream = utils.audio.AudioByteStream(
365
- sample_rate=self._opts.sampling_rate,
366
- num_channels=NUM_CHANNELS,
367
- )
368
- emitter = tts.SynthesizedAudioEmitter(
369
- event_ch=self._event_ch,
370
- request_id=request_id,
371
- )
372
-
373
- while True:
374
- try:
375
- msg = await ws.receive()
376
- except Exception as e:
377
- raise APIStatusError(
378
- "Neuphonic connection closed unexpectedly",
379
- request_id=request_id,
380
- ) from e
381
-
382
- if msg.type in (
383
- aiohttp.WSMsgType.CLOSED,
384
- aiohttp.WSMsgType.CLOSE,
385
- aiohttp.WSMsgType.CLOSING,
386
- ):
387
- raise APIStatusError(
388
- "Neuphonic connection closed unexpectedly",
389
- request_id=request_id,
390
- )
391
-
392
- if msg.type != aiohttp.WSMsgType.TEXT:
393
- logger.warning("Unexpected Neuphonic message type %s", msg.type)
394
- continue
395
-
396
- data = json.loads(msg.data)
397
-
398
- if data.get("data"):
399
- b64data = base64.b64decode(data["data"]["audio"])
400
- for frame in audio_bstream.write(b64data):
401
- emitter.push(frame)
402
-
403
- if data["data"].get("stop"): # A bool flag, is True when audio reaches "<STOP>"
404
- for frame in audio_bstream.flush():
405
- emitter.push(frame)
406
- emitter.flush()
407
- break # we are not going to receive any more audio
408
- else:
409
- logger.error("Unexpected Neuphonic message %s", data)
410
-
411
- async with self._pool.connection() as ws:
412
- tasks = [
413
- asyncio.create_task(_send_task(ws)),
414
- asyncio.create_task(_recv_task(ws)),
415
- ]
416
-
417
- try:
418
- await asyncio.gather(*tasks)
419
- finally:
420
- await utils.aio.gracefully_cancel(*tasks)