livekit-plugins-azure 1.0.22__tar.gz → 1.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -146,6 +146,9 @@ venv.bak/
146
146
  .dmypy.json
147
147
  dmypy.json
148
148
 
149
+ # trunk
150
+ .trunk/
151
+
149
152
  # Pyre type checker
150
153
  .pyre/
151
154
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: livekit-plugins-azure
3
- Version: 1.0.22
3
+ Version: 1.1.0
4
4
  Summary: Agent Framework plugin for services from Azure
5
5
  Project-URL: Documentation, https://docs.livekit.io
6
6
  Project-URL: Website, https://livekit.io/
@@ -19,7 +19,7 @@ Classifier: Topic :: Multimedia :: Video
19
19
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
20
20
  Requires-Python: >=3.9.0
21
21
  Requires-Dist: azure-cognitiveservices-speech>=1.43.0
22
- Requires-Dist: livekit-agents>=1.0.22
22
+ Requires-Dist: livekit-agents>=1.1.0
23
23
  Description-Content-Type: text/markdown
24
24
 
25
25
  # Azure plugin for LiveKit Agents
@@ -29,7 +29,7 @@ from .log import logger
29
29
 
30
30
 
31
31
  class AzurePlugin(Plugin):
32
- def __init__(self):
32
+ def __init__(self) -> None:
33
33
  super().__init__(__name__, __version__, __package__, logger)
34
34
 
35
35
 
@@ -18,6 +18,7 @@ import os
18
18
  import weakref
19
19
  from copy import deepcopy
20
20
  from dataclasses import dataclass
21
+ from typing import cast
21
22
 
22
23
  import azure.cognitiveservices.speech as speechsdk # type: ignore
23
24
  from livekit import rtc
@@ -95,13 +96,13 @@ class STT(stt.STT):
95
96
  language = [language]
96
97
 
97
98
  if not is_given(speech_host):
98
- speech_host = os.environ.get("AZURE_SPEECH_HOST")
99
+ speech_host = os.environ.get("AZURE_SPEECH_HOST") or NOT_GIVEN
99
100
 
100
101
  if not is_given(speech_key):
101
- speech_key = os.environ.get("AZURE_SPEECH_KEY")
102
+ speech_key = os.environ.get("AZURE_SPEECH_KEY") or NOT_GIVEN
102
103
 
103
104
  if not is_given(speech_region):
104
- speech_region = os.environ.get("AZURE_SPEECH_REGION")
105
+ speech_region = os.environ.get("AZURE_SPEECH_REGION") or NOT_GIVEN
105
106
 
106
107
  if not (
107
108
  is_given(speech_host)
@@ -155,10 +156,11 @@ class STT(stt.STT):
155
156
  self._streams.add(stream)
156
157
  return stream
157
158
 
158
- def update_options(self, *, language: NotGivenOr[list[str] | str] = NOT_GIVEN):
159
+ def update_options(self, *, language: NotGivenOr[list[str] | str] = NOT_GIVEN) -> None:
159
160
  if is_given(language):
160
161
  if isinstance(language, str):
161
162
  language = [language]
163
+ language = cast(list[str], language)
162
164
  self._config.language = language
163
165
  for stream in self._streams:
164
166
  stream.update_options(language=language)
@@ -176,7 +178,7 @@ class SpeechStream(stt.SpeechStream):
176
178
  self._loop = asyncio.get_running_loop()
177
179
  self._reconnect_event = asyncio.Event()
178
180
 
179
- def update_options(self, *, language: list[str]):
181
+ def update_options(self, *, language: list[str]) -> None:
180
182
  self._opts.language = language
181
183
  self._reconnect_event.set()
182
184
 
@@ -203,7 +205,7 @@ class SpeechStream(stt.SpeechStream):
203
205
  self._session_started_event.wait(), self._conn_options.timeout
204
206
  )
205
207
 
206
- async def process_input():
208
+ async def process_input() -> None:
207
209
  async for input in self._input_ch:
208
210
  if isinstance(input, rtc.AudioFrame):
209
211
  self._stream.write(input.data.tobytes())
@@ -234,13 +236,13 @@ class SpeechStream(stt.SpeechStream):
234
236
  await self._session_stopped_event.wait()
235
237
  finally:
236
238
 
237
- def _cleanup():
239
+ def _cleanup() -> None:
238
240
  self._recognizer.stop_continuous_recognition()
239
241
  del self._recognizer
240
242
 
241
243
  await asyncio.to_thread(_cleanup)
242
244
 
243
- def _on_recognized(self, evt: speechsdk.SpeechRecognitionEventArgs):
245
+ def _on_recognized(self, evt: speechsdk.SpeechRecognitionEventArgs) -> None:
244
246
  detected_lg = speechsdk.AutoDetectSourceLanguageResult(evt.result).language
245
247
  text = evt.result.text.strip()
246
248
  if not text:
@@ -259,7 +261,7 @@ class SpeechStream(stt.SpeechStream):
259
261
  ),
260
262
  )
261
263
 
262
- def _on_recognizing(self, evt: speechsdk.SpeechRecognitionEventArgs):
264
+ def _on_recognizing(self, evt: speechsdk.SpeechRecognitionEventArgs) -> None:
263
265
  detected_lg = speechsdk.AutoDetectSourceLanguageResult(evt.result).language
264
266
  text = evt.result.text.strip()
265
267
  if not text:
@@ -279,7 +281,7 @@ class SpeechStream(stt.SpeechStream):
279
281
  ),
280
282
  )
281
283
 
282
- def _on_speech_start(self, evt: speechsdk.SpeechRecognitionEventArgs):
284
+ def _on_speech_start(self, evt: speechsdk.SpeechRecognitionEventArgs) -> None:
283
285
  if self._speaking:
284
286
  return
285
287
 
@@ -291,7 +293,7 @@ class SpeechStream(stt.SpeechStream):
291
293
  stt.SpeechEvent(type=stt.SpeechEventType.START_OF_SPEECH),
292
294
  )
293
295
 
294
- def _on_speech_end(self, evt: speechsdk.SpeechRecognitionEventArgs):
296
+ def _on_speech_end(self, evt: speechsdk.SpeechRecognitionEventArgs) -> None:
295
297
  if not self._speaking:
296
298
  return
297
299
 
@@ -303,13 +305,13 @@ class SpeechStream(stt.SpeechStream):
303
305
  stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH),
304
306
  )
305
307
 
306
- def _on_session_started(self, evt: speechsdk.SpeechRecognitionEventArgs):
308
+ def _on_session_started(self, evt: speechsdk.SpeechRecognitionEventArgs) -> None:
307
309
  self._session_started_event.set()
308
310
 
309
311
  with contextlib.suppress(RuntimeError):
310
312
  self._loop.call_soon_threadsafe(self._session_started_event.set)
311
313
 
312
- def _on_session_stopped(self, evt: speechsdk.SpeechRecognitionEventArgs):
314
+ def _on_session_stopped(self, evt: speechsdk.SpeechRecognitionEventArgs) -> None:
313
315
  with contextlib.suppress(RuntimeError):
314
316
  self._loop.call_soon_threadsafe(self._session_stopped_event.set)
315
317
 
@@ -354,7 +356,7 @@ def _create_speech_recognizer(
354
356
  speech_recognizer = speechsdk.SpeechRecognizer(
355
357
  speech_config=speech_config,
356
358
  audio_config=audio_config,
357
- auto_detect_source_language_config=auto_detect_source_language_config, # type: ignore
359
+ auto_detect_source_language_config=auto_detect_source_language_config,
358
360
  )
359
361
 
360
362
  return speech_recognizer
@@ -0,0 +1,298 @@
1
+ # Licensed under the Apache License, Version 2.0 (the "License");
2
+ # you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at
4
+ #
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ #
7
+ # Unless required by applicable law or agreed to in writing, software
8
+ # distributed under the License is distributed on an "AS IS" BASIS,
9
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ # See the License for the specific language governing permissions and
11
+ # limitations under the License.
12
+
13
+ from __future__ import annotations
14
+
15
+ import asyncio
16
+ import os
17
+ from dataclasses import dataclass, replace
18
+ from typing import Literal
19
+
20
+ import aiohttp
21
+
22
+ from livekit.agents import APIConnectionError, APIStatusError, APITimeoutError, tts, utils
23
+ from livekit.agents.types import (
24
+ DEFAULT_API_CONNECT_OPTIONS,
25
+ NOT_GIVEN,
26
+ APIConnectOptions,
27
+ NotGivenOr,
28
+ )
29
+ from livekit.agents.utils import is_given
30
+
31
+ SUPPORTED_OUTPUT_FORMATS = {
32
+ 8000: "raw-8khz-16bit-mono-pcm",
33
+ 16000: "raw-16khz-16bit-mono-pcm",
34
+ 22050: "raw-22050hz-16bit-mono-pcm",
35
+ 24000: "raw-24khz-16bit-mono-pcm",
36
+ 44100: "raw-44100hz-16bit-mono-pcm",
37
+ 48000: "raw-48khz-16bit-mono-pcm",
38
+ }
39
+
40
+
41
+ @dataclass
42
+ class ProsodyConfig:
43
+ rate: Literal["x-slow", "slow", "medium", "fast", "x-fast"] | float | None = None
44
+ volume: Literal["silent", "x-soft", "soft", "medium", "loud", "x-loud"] | float | None = None
45
+ pitch: Literal["x-low", "low", "medium", "high", "x-high"] | None = None
46
+
47
+ def validate(self) -> None:
48
+ if self.rate:
49
+ if isinstance(self.rate, float) and not 0.5 <= self.rate <= 2:
50
+ raise ValueError("Prosody rate must be between 0.5 and 2")
51
+ if isinstance(self.rate, str) and self.rate not in [
52
+ "x-slow",
53
+ "slow",
54
+ "medium",
55
+ "fast",
56
+ "x-fast",
57
+ ]:
58
+ raise ValueError(
59
+ "Prosody rate must be one of 'x-slow', 'slow', 'medium', 'fast', 'x-fast'"
60
+ )
61
+ if self.volume:
62
+ if isinstance(self.volume, float) and not 0 <= self.volume <= 100:
63
+ raise ValueError("Prosody volume must be between 0 and 100")
64
+ if isinstance(self.volume, str) and self.volume not in [
65
+ "silent",
66
+ "x-soft",
67
+ "soft",
68
+ "medium",
69
+ "loud",
70
+ "x-loud",
71
+ ]:
72
+ raise ValueError(
73
+ "Prosody volume must be one of 'silent', 'x-soft', 'soft', 'medium', 'loud', 'x-loud'" # noqa: E501
74
+ )
75
+ if self.pitch and self.pitch not in [
76
+ "x-low",
77
+ "low",
78
+ "medium",
79
+ "high",
80
+ "x-high",
81
+ ]:
82
+ raise ValueError(
83
+ "Prosody pitch must be one of 'x-low', 'low', 'medium', 'high', 'x-high'"
84
+ )
85
+
86
+ def __post_init__(self) -> None:
87
+ self.validate()
88
+
89
+
90
+ @dataclass
91
+ class StyleConfig:
92
+ style: str
93
+ degree: float | None = None
94
+
95
+ def validate(self) -> None:
96
+ if self.degree is not None and not 0.1 <= self.degree <= 2.0:
97
+ raise ValueError("Style degree must be between 0.1 and 2.0")
98
+
99
+ def __post_init__(self) -> None:
100
+ self.validate()
101
+
102
+
103
+ @dataclass
104
+ class _TTSOptions:
105
+ sample_rate: int
106
+ subscription_key: str | None
107
+ region: str | None
108
+ voice: str
109
+ language: str | None
110
+ speech_endpoint: str | None
111
+ deployment_id: str | None
112
+ prosody: NotGivenOr[ProsodyConfig]
113
+ style: NotGivenOr[StyleConfig]
114
+ auth_token: str | None = None
115
+
116
+ def get_endpoint_url(self) -> str:
117
+ base = (
118
+ self.speech_endpoint
119
+ or f"https://{self.region}.tts.speech.microsoft.com/cognitiveservices/v1"
120
+ )
121
+ if self.deployment_id:
122
+ return f"{base}?deploymentId={self.deployment_id}"
123
+ return base
124
+
125
+
126
+ class TTS(tts.TTS):
127
+ def __init__(
128
+ self,
129
+ *,
130
+ voice: str = "en-US-JennyNeural",
131
+ language: str | None = None,
132
+ sample_rate: int = 24000,
133
+ prosody: NotGivenOr[ProsodyConfig] = NOT_GIVEN,
134
+ style: NotGivenOr[StyleConfig] = NOT_GIVEN,
135
+ speech_key: str | None = None,
136
+ speech_region: str | None = None,
137
+ speech_endpoint: str | None = None,
138
+ deployment_id: str | None = None,
139
+ speech_auth_token: str | None = None,
140
+ http_session: aiohttp.ClientSession | None = None,
141
+ ) -> None:
142
+ super().__init__(
143
+ capabilities=tts.TTSCapabilities(streaming=False),
144
+ sample_rate=sample_rate,
145
+ num_channels=1,
146
+ )
147
+ if sample_rate not in SUPPORTED_OUTPUT_FORMATS:
148
+ raise ValueError(
149
+ f"Unsupported sample rate {sample_rate}. Supported: {list(SUPPORTED_OUTPUT_FORMATS)}" # noqa: E501
150
+ )
151
+
152
+ if not speech_key:
153
+ speech_key = os.environ.get("AZURE_SPEECH_KEY")
154
+
155
+ if not speech_region:
156
+ speech_region = os.environ.get("AZURE_SPEECH_REGION")
157
+
158
+ if not speech_endpoint:
159
+ speech_endpoint = os.environ.get("AZURE_SPEECH_ENDPOINT")
160
+
161
+ has_endpoint = bool(speech_endpoint)
162
+ has_key_and_region = bool(speech_key and speech_region)
163
+ has_token_and_region = bool(speech_auth_token and speech_region)
164
+ if not (has_endpoint or has_key_and_region or has_token_and_region):
165
+ raise ValueError(
166
+ "Authentication requires one of: speech_endpoint (AZURE_SPEECH_ENDPOINT), "
167
+ "speech_key & speech_region (AZURE_SPEECH_KEY & AZURE_SPEECH_REGION), "
168
+ "or speech_auth_token & speech_region."
169
+ )
170
+
171
+ if is_given(prosody):
172
+ prosody.validate()
173
+ if is_given(style):
174
+ style.validate()
175
+
176
+ self._session = http_session
177
+ self._opts = _TTSOptions(
178
+ sample_rate=sample_rate,
179
+ subscription_key=speech_key,
180
+ region=speech_region,
181
+ speech_endpoint=speech_endpoint,
182
+ voice=voice,
183
+ deployment_id=deployment_id,
184
+ language=language,
185
+ prosody=prosody,
186
+ style=style,
187
+ auth_token=speech_auth_token,
188
+ )
189
+
190
+ def update_options(
191
+ self,
192
+ *,
193
+ voice: NotGivenOr[str] = NOT_GIVEN,
194
+ language: NotGivenOr[str] = NOT_GIVEN,
195
+ prosody: NotGivenOr[ProsodyConfig] = NOT_GIVEN,
196
+ style: NotGivenOr[StyleConfig] = NOT_GIVEN,
197
+ ) -> None:
198
+ if is_given(voice):
199
+ self._opts.voice = voice
200
+ if is_given(language):
201
+ self._opts.language = language
202
+ if is_given(prosody):
203
+ prosody.validate()
204
+ self._opts.prosody = prosody
205
+ if is_given(style):
206
+ style.validate()
207
+ self._opts.style = style
208
+
209
+ def _ensure_session(self) -> aiohttp.ClientSession:
210
+ if not self._session:
211
+ self._session = utils.http_context.http_session()
212
+ return self._session
213
+
214
+ def synthesize(
215
+ self,
216
+ text: str,
217
+ *,
218
+ conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
219
+ ) -> tts.ChunkedStream:
220
+ return ChunkedStream(tts=self, input_text=text, conn_options=conn_options)
221
+
222
+
223
+ class ChunkedStream(tts.ChunkedStream):
224
+ def __init__(self, *, tts: TTS, input_text: str, conn_options: APIConnectOptions) -> None:
225
+ super().__init__(tts=tts, input_text=input_text, conn_options=conn_options)
226
+ self._tts: TTS = tts
227
+ self._opts = replace(tts._opts)
228
+
229
+ def _build_ssml(self) -> str:
230
+ lang = self._opts.language or "en-US"
231
+ ssml = (
232
+ f'<speak version="1.0" '
233
+ f'xmlns="http://www.w3.org/2001/10/synthesis" '
234
+ f'xmlns:mstts="http://www.w3.org/2001/mstts" '
235
+ f'xml:lang="{lang}">'
236
+ )
237
+ ssml += f'<voice name="{self._opts.voice}">'
238
+ if is_given(self._opts.style):
239
+ degree = f' styledegree="{self._opts.style.degree}"' if self._opts.style.degree else ""
240
+ ssml += f'<mstts:express-as style="{self._opts.style.style}"{degree}>'
241
+
242
+ if is_given(self._opts.prosody):
243
+ p = self._opts.prosody
244
+
245
+ rate_attr = f' rate="{p.rate}"' if p.rate is not None else ""
246
+ vol_attr = f' volume="{p.volume}"' if p.volume is not None else ""
247
+ pitch_attr = f' pitch="{p.pitch}"' if p.pitch is not None else ""
248
+ ssml += f"<prosody{rate_attr}{vol_attr}{pitch_attr}>{self.input_text}</prosody>"
249
+ else:
250
+ ssml += self.input_text
251
+
252
+ if is_given(self._opts.style):
253
+ ssml += "</mstts:express-as>"
254
+
255
+ ssml += "</voice></speak>"
256
+ return ssml
257
+
258
+ async def _run(self, output_emitter: tts.AudioEmitter) -> None:
259
+ headers = {
260
+ "Content-Type": "application/ssml+xml",
261
+ "X-Microsoft-OutputFormat": SUPPORTED_OUTPUT_FORMATS[self._opts.sample_rate],
262
+ "User-Agent": "LiveKit Agents",
263
+ }
264
+ if self._opts.auth_token:
265
+ headers["Authorization"] = f"Bearer {self._opts.auth_token}"
266
+
267
+ elif self._opts.subscription_key:
268
+ headers["Ocp-Apim-Subscription-Key"] = self._opts.subscription_key
269
+
270
+ output_emitter.initialize(
271
+ request_id=utils.shortuuid(),
272
+ sample_rate=self._opts.sample_rate,
273
+ num_channels=1,
274
+ mime_type="audio/pcm",
275
+ )
276
+
277
+ try:
278
+ async with self._tts._ensure_session().post(
279
+ url=self._opts.get_endpoint_url(),
280
+ headers=headers,
281
+ data=self._build_ssml(),
282
+ timeout=aiohttp.ClientTimeout(total=30, sock_connect=self._conn_options.timeout),
283
+ ) as resp:
284
+ resp.raise_for_status()
285
+ async for data, _ in resp.content.iter_chunks():
286
+ output_emitter.push(data)
287
+
288
+ except asyncio.TimeoutError:
289
+ raise APITimeoutError() from None
290
+ except aiohttp.ClientResponseError as e:
291
+ raise APIStatusError(
292
+ message=e.message,
293
+ status_code=e.status,
294
+ request_id=None,
295
+ body=None,
296
+ ) from None
297
+ except Exception as e:
298
+ raise APIConnectionError(str(e)) from e
@@ -12,4 +12,4 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- __version__ = "1.0.22"
15
+ __version__ = "1.1.0"
@@ -23,7 +23,7 @@ classifiers = [
23
23
  "Programming Language :: Python :: 3 :: Only",
24
24
  ]
25
25
  dependencies = [
26
- "livekit-agents>=1.0.22",
26
+ "livekit-agents>=1.1.0",
27
27
  "azure-cognitiveservices-speech>=1.43.0",
28
28
  ]
29
29
 
@@ -1,464 +0,0 @@
1
- # Licensed under the Apache License, Version 2.0 (the "License");
2
- # you may not use this file except in compliance with the License.
3
- # You may obtain a copy of the License at
4
- #
5
- # http://www.apache.org/licenses/LICENSE-2.0
6
- #
7
- # Unless required by applicable law or agreed to in writing, software
8
- # distributed under the License is distributed on an "AS IS" BASIS,
9
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
- # See the License for the specific language governing permissions and
11
- # limitations under the License.
12
-
13
- from __future__ import annotations
14
-
15
- import asyncio
16
- import contextlib
17
- import os
18
- from dataclasses import dataclass
19
- from typing import Callable, Literal
20
-
21
- import azure.cognitiveservices.speech as speechsdk # type: ignore
22
- from livekit.agents import (
23
- APIConnectionError,
24
- APIConnectOptions,
25
- APITimeoutError,
26
- tts,
27
- utils,
28
- )
29
- from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS, NOT_GIVEN, NotGivenOr
30
- from livekit.agents.utils import is_given
31
-
32
- from .log import logger
33
-
34
- # only raw & pcm
35
- SUPPORTED_SAMPLE_RATE = {
36
- 8000: speechsdk.SpeechSynthesisOutputFormat.Raw8Khz16BitMonoPcm,
37
- 16000: speechsdk.SpeechSynthesisOutputFormat.Raw16Khz16BitMonoPcm,
38
- 22050: speechsdk.SpeechSynthesisOutputFormat.Raw22050Hz16BitMonoPcm,
39
- 24000: speechsdk.SpeechSynthesisOutputFormat.Raw24Khz16BitMonoPcm,
40
- 44100: speechsdk.SpeechSynthesisOutputFormat.Raw44100Hz16BitMonoPcm,
41
- 48000: speechsdk.SpeechSynthesisOutputFormat.Raw48Khz16BitMonoPcm,
42
- }
43
-
44
-
45
- @dataclass
46
- class ProsodyConfig:
47
- """
48
- Prosody configuration for Azure TTS.
49
-
50
- Args:
51
- rate: Speaking rate. Can be one of "x-slow", "slow", "medium", "fast", "x-fast", or a float. A float value of 1.0 represents normal speed.
52
- volume: Speaking volume. Can be one of "silent", "x-soft", "soft", "medium", "loud", "x-loud", or a float. A float value of 100 (x-loud) represents the highest volume and it's the default pitch.
53
- pitch: Speaking pitch. Can be one of "x-low", "low", "medium", "high", "x-high". The default pitch is "medium".
54
- """ # noqa: E501
55
-
56
- rate: Literal["x-slow", "slow", "medium", "fast", "x-fast"] | float | None = None
57
- volume: Literal["silent", "x-soft", "soft", "medium", "loud", "x-loud"] | float | None = None
58
- pitch: Literal["x-low", "low", "medium", "high", "x-high"] | None = None
59
-
60
- def validate(self) -> None:
61
- if self.rate:
62
- if isinstance(self.rate, float) and not 0.5 <= self.rate <= 2:
63
- raise ValueError("Prosody rate must be between 0.5 and 2")
64
- if isinstance(self.rate, str) and self.rate not in [
65
- "x-slow",
66
- "slow",
67
- "medium",
68
- "fast",
69
- "x-fast",
70
- ]:
71
- raise ValueError(
72
- "Prosody rate must be one of 'x-slow', 'slow', 'medium', 'fast', 'x-fast'"
73
- )
74
- if self.volume:
75
- if isinstance(self.volume, float) and not 0 <= self.volume <= 100:
76
- raise ValueError("Prosody volume must be between 0 and 100")
77
- if isinstance(self.volume, str) and self.volume not in [
78
- "silent",
79
- "x-soft",
80
- "soft",
81
- "medium",
82
- "loud",
83
- "x-loud",
84
- ]:
85
- raise ValueError(
86
- "Prosody volume must be one of 'silent', 'x-soft', 'soft', 'medium', 'loud', 'x-loud'" # noqa: E501
87
- )
88
-
89
- if self.pitch and self.pitch not in [
90
- "x-low",
91
- "low",
92
- "medium",
93
- "high",
94
- "x-high",
95
- ]:
96
- raise ValueError(
97
- "Prosody pitch must be one of 'x-low', 'low', 'medium', 'high', 'x-high'"
98
- )
99
-
100
- def __post_init__(self):
101
- self.validate()
102
-
103
-
104
- @dataclass
105
- class StyleConfig:
106
- """
107
- Style configuration for Azure TTS neural voices.
108
-
109
- Args:
110
- style: Speaking style for neural voices. Examples: "cheerful", "sad", "angry", etc.
111
- degree: Intensity of the style, from 0.1 to 2.0.
112
- """
113
-
114
- style: str
115
- degree: float | None = None
116
-
117
- def validate(self) -> None:
118
- if self.degree is not None and not 0.1 <= self.degree <= 2.0:
119
- raise ValueError("Style degree must be between 0.1 and 2.0")
120
-
121
- def __post_init__(self):
122
- self.validate()
123
-
124
-
125
- @dataclass
126
- class _TTSOptions:
127
- sample_rate: int
128
- speech_key: NotGivenOr[str] = NOT_GIVEN
129
- speech_region: NotGivenOr[str] = NOT_GIVEN
130
- # see https://learn.microsoft.com/en-us/azure/ai-services/speech-service/speech-container-ntts?tabs=container#use-the-container
131
- speech_host: NotGivenOr[str] = NOT_GIVEN
132
- # see https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts
133
- voice: NotGivenOr[str] = NOT_GIVEN
134
- # for using custom voices (see https://learn.microsoft.com/en-us/azure/ai-services/speech-service/how-to-speech-synthesis?tabs=browserjs%2Cterminal&pivots=programming-language-python#use-a-custom-endpoint)
135
- endpoint_id: NotGivenOr[str] = NOT_GIVEN
136
- # for using Microsoft Entra auth (see https://learn.microsoft.com/en-us/azure/ai-services/speech-service/how-to-configure-azure-ad-auth?tabs=portal&pivots=programming-language-python)
137
- speech_auth_token: NotGivenOr[str] = NOT_GIVEN
138
- # Useful to specify the language with multi-language voices
139
- language: NotGivenOr[str] = NOT_GIVEN
140
- # See https://learn.microsoft.com/en-us/azure/ai-services/speech-service/speech-synthesis-markup-voice#adjust-prosody
141
- prosody: NotGivenOr[ProsodyConfig] = NOT_GIVEN
142
- speech_endpoint: NotGivenOr[str] = NOT_GIVEN
143
- style: NotGivenOr[StyleConfig] = NOT_GIVEN
144
- # See https://learn.microsoft.com/en-us/azure/ai-services/speech-service/how-to-speech-synthesis?tabs=browserjs%2Cterminal&pivots=programming-language-python
145
- on_bookmark_reached_event: NotGivenOr[Callable] = NOT_GIVEN
146
- on_synthesis_canceled_event: NotGivenOr[Callable] = NOT_GIVEN
147
- on_synthesis_completed_event: NotGivenOr[Callable] = NOT_GIVEN
148
- on_synthesis_started_event: NotGivenOr[Callable] = NOT_GIVEN
149
- on_synthesizing_event: NotGivenOr[Callable] = NOT_GIVEN
150
- on_viseme_event: NotGivenOr[Callable] = NOT_GIVEN
151
- on_word_boundary_event: NotGivenOr[Callable] = NOT_GIVEN
152
-
153
-
154
- class TTS(tts.TTS):
155
- def __init__(
156
- self,
157
- *,
158
- sample_rate: int = 24000,
159
- voice: NotGivenOr[str] = NOT_GIVEN,
160
- language: NotGivenOr[str] = NOT_GIVEN,
161
- prosody: NotGivenOr[ProsodyConfig] = NOT_GIVEN,
162
- speech_key: NotGivenOr[str] = NOT_GIVEN,
163
- speech_region: NotGivenOr[str] = NOT_GIVEN,
164
- speech_host: NotGivenOr[str] = NOT_GIVEN,
165
- speech_auth_token: NotGivenOr[str] = NOT_GIVEN,
166
- endpoint_id: NotGivenOr[str] = NOT_GIVEN,
167
- style: NotGivenOr[StyleConfig] = NOT_GIVEN,
168
- on_bookmark_reached_event: NotGivenOr[Callable] = NOT_GIVEN,
169
- on_synthesis_canceled_event: NotGivenOr[Callable] = NOT_GIVEN,
170
- on_synthesis_completed_event: NotGivenOr[Callable] = NOT_GIVEN,
171
- on_synthesis_started_event: NotGivenOr[Callable] = NOT_GIVEN,
172
- on_synthesizing_event: NotGivenOr[Callable] = NOT_GIVEN,
173
- on_viseme_event: NotGivenOr[Callable] = NOT_GIVEN,
174
- on_word_boundary_event: NotGivenOr[Callable] = NOT_GIVEN,
175
- speech_endpoint: NotGivenOr[str] = NOT_GIVEN,
176
- ) -> None:
177
- """
178
- Create a new instance of Azure TTS.
179
-
180
- Either ``speech_host`` or ``speech_key`` and ``speech_region`` or
181
- ``speech_auth_token`` and ``speech_region`` must be set using arguments.
182
- Alternatively, set the ``AZURE_SPEECH_HOST``, ``AZURE_SPEECH_KEY``
183
- and ``AZURE_SPEECH_REGION`` environmental variables, respectively.
184
- ``speech_auth_token`` must be set using the arguments as it's an ephemeral token.
185
- """
186
-
187
- if sample_rate not in SUPPORTED_SAMPLE_RATE:
188
- raise ValueError(
189
- f"Unsupported sample rate {sample_rate}. Supported sample rates: {list(SUPPORTED_SAMPLE_RATE.keys())}" # noqa: E501
190
- )
191
-
192
- super().__init__(
193
- capabilities=tts.TTSCapabilities(
194
- streaming=False,
195
- ),
196
- sample_rate=sample_rate,
197
- num_channels=1,
198
- )
199
-
200
- if not is_given(speech_host):
201
- speech_host = os.environ.get("AZURE_SPEECH_HOST")
202
-
203
- if not is_given(speech_key):
204
- speech_key = os.environ.get("AZURE_SPEECH_KEY")
205
-
206
- if not is_given(speech_region):
207
- speech_region = os.environ.get("AZURE_SPEECH_REGION")
208
-
209
- if not (
210
- is_given(speech_host)
211
- or (is_given(speech_key) and is_given(speech_region))
212
- or (is_given(speech_auth_token) and is_given(speech_region))
213
- or (is_given(speech_key) and is_given(speech_endpoint))
214
- ):
215
- raise ValueError(
216
- "AZURE_SPEECH_HOST or AZURE_SPEECH_KEY and AZURE_SPEECH_REGION or speech_auth_token and AZURE_SPEECH_REGION or AZURE_SPEECH_KEY and speech_endpoint must be set" # noqa: E501
217
- )
218
-
219
- if is_given(prosody):
220
- prosody.validate()
221
-
222
- if is_given(style):
223
- style.validate()
224
-
225
- self._opts = _TTSOptions(
226
- sample_rate=sample_rate,
227
- speech_key=speech_key,
228
- speech_region=speech_region,
229
- speech_host=speech_host,
230
- speech_auth_token=speech_auth_token,
231
- voice=voice,
232
- endpoint_id=endpoint_id,
233
- language=language,
234
- prosody=prosody,
235
- style=style,
236
- on_bookmark_reached_event=on_bookmark_reached_event,
237
- on_synthesis_canceled_event=on_synthesis_canceled_event,
238
- on_synthesis_completed_event=on_synthesis_completed_event,
239
- on_synthesis_started_event=on_synthesis_started_event,
240
- on_synthesizing_event=on_synthesizing_event,
241
- on_viseme_event=on_viseme_event,
242
- on_word_boundary_event=on_word_boundary_event,
243
- speech_endpoint=speech_endpoint,
244
- )
245
-
246
- def update_options(
247
- self,
248
- *,
249
- voice: NotGivenOr[str] = NOT_GIVEN,
250
- language: NotGivenOr[str] = NOT_GIVEN,
251
- prosody: NotGivenOr[ProsodyConfig] = NOT_GIVEN,
252
- style: NotGivenOr[StyleConfig] = NOT_GIVEN,
253
- on_bookmark_reached_event: NotGivenOr[Callable] = NOT_GIVEN,
254
- on_synthesis_canceled_event: NotGivenOr[Callable] = NOT_GIVEN,
255
- on_synthesis_completed_event: NotGivenOr[Callable] = NOT_GIVEN,
256
- on_synthesis_started_event: NotGivenOr[Callable] = NOT_GIVEN,
257
- on_synthesizing_event: NotGivenOr[Callable] = NOT_GIVEN,
258
- on_viseme_event: NotGivenOr[Callable] = NOT_GIVEN,
259
- on_word_boundary_event: NotGivenOr[Callable] = NOT_GIVEN,
260
- ) -> None:
261
- if is_given(voice):
262
- self._opts.voice = voice
263
- if is_given(language):
264
- self._opts.language = language
265
- if is_given(prosody):
266
- self._opts.prosody = prosody
267
- if is_given(style):
268
- self._opts.style = style
269
-
270
- if is_given(on_bookmark_reached_event):
271
- self._opts.on_bookmark_reached_event = on_bookmark_reached_event
272
- if is_given(on_synthesis_canceled_event):
273
- self._opts.on_synthesis_canceled_event = on_synthesis_canceled_event
274
- if is_given(on_synthesis_completed_event):
275
- self._opts.on_synthesis_completed_event = on_synthesis_completed_event
276
- if is_given(on_synthesis_started_event):
277
- self._opts.on_synthesis_started_event = on_synthesis_started_event
278
- if is_given(on_synthesizing_event):
279
- self._opts.on_synthesizing_event = on_synthesizing_event
280
- if is_given(on_viseme_event):
281
- self._opts.on_viseme_event = on_viseme_event
282
- if is_given(on_word_boundary_event):
283
- self._opts.on_word_boundary_event = on_word_boundary_event
284
-
285
- def synthesize(
286
- self,
287
- text: str,
288
- *,
289
- conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
290
- ) -> ChunkedStream:
291
- return ChunkedStream(tts=self, input_text=text, conn_options=conn_options, opts=self._opts)
292
-
293
-
294
- class ChunkedStream(tts.ChunkedStream):
295
- def __init__(
296
- self,
297
- *,
298
- tts: TTS,
299
- input_text: str,
300
- opts: _TTSOptions,
301
- conn_options: APIConnectOptions,
302
- ) -> None:
303
- super().__init__(tts=tts, input_text=input_text, conn_options=conn_options)
304
- self._opts = opts
305
-
306
- async def _run(self):
307
- stream_callback = speechsdk.audio.PushAudioOutputStream(
308
- _PushAudioOutputStreamCallback(
309
- self._opts.sample_rate, asyncio.get_running_loop(), self._event_ch
310
- )
311
- )
312
- synthesizer = _create_speech_synthesizer(
313
- config=self._opts,
314
- stream=stream_callback,
315
- )
316
-
317
- def _synthesize() -> speechsdk.SpeechSynthesisResult:
318
- if self._opts.prosody or self._opts.style:
319
- ssml = (
320
- '<speak version="1.0" '
321
- 'xmlns="http://www.w3.org/2001/10/synthesis" '
322
- 'xmlns:mstts="http://www.w3.org/2001/mstts" '
323
- f'xml:lang="{self._opts.language or "en-US"}">'
324
- )
325
- ssml += f'<voice name="{self._opts.voice}">'
326
-
327
- # Add style if specified
328
- if self._opts.style:
329
- style_degree = (
330
- f' styledegree="{self._opts.style.degree}"'
331
- if self._opts.style.degree
332
- else ""
333
- )
334
- ssml += f'<mstts:express-as style="{self._opts.style.style}"{style_degree}>'
335
-
336
- # Add prosody if specified
337
- if self._opts.prosody:
338
- ssml += "<prosody"
339
- if self._opts.prosody.rate:
340
- ssml += f' rate="{self._opts.prosody.rate}"'
341
- if self._opts.prosody.volume:
342
- ssml += f' volume="{self._opts.prosody.volume}"'
343
- if self._opts.prosody.pitch:
344
- ssml += f' pitch="{self._opts.prosody.pitch}"'
345
- ssml += ">"
346
- ssml += self._input_text
347
- ssml += "</prosody>"
348
- else:
349
- ssml += self._input_text
350
-
351
- # Close style tag if it was opened
352
- if self._opts.style:
353
- ssml += "</mstts:express-as>"
354
-
355
- ssml += "</voice></speak>"
356
- return synthesizer.speak_ssml_async(ssml).get() # type: ignore
357
-
358
- return synthesizer.speak_text_async(self.input_text).get() # type: ignore
359
-
360
- result = None
361
- try:
362
- result = await asyncio.to_thread(_synthesize)
363
- if result.reason != speechsdk.ResultReason.SynthesizingAudioCompleted:
364
- if (
365
- result.cancellation_details.error_code
366
- == speechsdk.CancellationErrorCode.ServiceTimeout
367
- ):
368
- raise APITimeoutError()
369
- else:
370
- cancel_details = result.cancellation_details
371
- raise APIConnectionError(cancel_details.error_details)
372
- finally:
373
-
374
- def _cleanup() -> None:
375
- # cleanup resources inside an Executor
376
- # to avoid blocking the event loop
377
- nonlocal synthesizer, stream_callback, result
378
- del synthesizer
379
- del stream_callback
380
-
381
- if result is not None:
382
- del result
383
-
384
- try:
385
- await asyncio.to_thread(_cleanup)
386
- except Exception:
387
- logger.exception("failed to cleanup Azure TTS resources")
388
-
389
-
390
- class _PushAudioOutputStreamCallback(speechsdk.audio.PushAudioOutputStreamCallback):
391
- def __init__(
392
- self,
393
- sample_rate: int,
394
- loop: asyncio.AbstractEventLoop,
395
- event_ch: utils.aio.ChanSender[tts.SynthesizedAudio],
396
- ):
397
- super().__init__()
398
- self._event_ch = event_ch
399
- self._loop = loop
400
- self._request_id = utils.shortuuid()
401
-
402
- self._bstream = utils.audio.AudioByteStream(sample_rate=sample_rate, num_channels=1)
403
-
404
- def write(self, audio_buffer: memoryview) -> int:
405
- for frame in self._bstream.write(audio_buffer.tobytes()):
406
- audio = tts.SynthesizedAudio(
407
- request_id=self._request_id,
408
- frame=frame,
409
- )
410
- with contextlib.suppress(RuntimeError):
411
- self._loop.call_soon_threadsafe(self._event_ch.send_nowait, audio)
412
-
413
- return audio_buffer.nbytes
414
-
415
- def close(self) -> None:
416
- for frame in self._bstream.flush():
417
- audio = tts.SynthesizedAudio(
418
- request_id=self._request_id,
419
- frame=frame,
420
- )
421
- with contextlib.suppress(RuntimeError):
422
- self._loop.call_soon_threadsafe(self._event_ch.send_nowait, audio)
423
-
424
-
425
- def _create_speech_synthesizer(
426
- *, config: _TTSOptions, stream: speechsdk.audio.AudioOutputStream
427
- ) -> speechsdk.SpeechSynthesizer:
428
- # let the SpeechConfig constructor to validate the arguments
429
- speech_config = speechsdk.SpeechConfig(
430
- subscription=config.speech_key if is_given(config.speech_key) else None,
431
- region=config.speech_region if is_given(config.speech_region) else None,
432
- endpoint=config.speech_endpoint if is_given(config.speech_endpoint) else None,
433
- host=config.speech_host if is_given(config.speech_host) else None,
434
- auth_token=config.speech_auth_token if is_given(config.speech_auth_token) else None,
435
- speech_recognition_language=config.language if is_given(config.language) else "en-US",
436
- )
437
-
438
- speech_config.set_speech_synthesis_output_format(SUPPORTED_SAMPLE_RATE[config.sample_rate])
439
- stream_config = speechsdk.audio.AudioOutputConfig(stream=stream)
440
- if is_given(config.voice):
441
- speech_config.speech_synthesis_voice_name = config.voice
442
- if is_given(config.endpoint_id):
443
- speech_config.endpoint_id = config.endpoint_id
444
-
445
- synthesizer = speechsdk.SpeechSynthesizer(
446
- speech_config=speech_config, audio_config=stream_config
447
- )
448
-
449
- if is_given(config.on_bookmark_reached_event):
450
- synthesizer.bookmark_reached.connect(config.on_bookmark_reached_event)
451
- if is_given(config.on_synthesis_canceled_event):
452
- synthesizer.synthesis_canceled.connect(config.on_synthesis_canceled_event)
453
- if is_given(config.on_synthesis_completed_event):
454
- synthesizer.synthesis_completed.connect(config.on_synthesis_completed_event)
455
- if is_given(config.on_synthesis_started_event):
456
- synthesizer.synthesis_started.connect(config.on_synthesis_started_event)
457
- if is_given(config.on_synthesizing_event):
458
- synthesizer.synthesizing.connect(config.on_synthesizing_event)
459
- if is_given(config.on_viseme_event):
460
- synthesizer.viseme_received.connect(config.on_viseme_event)
461
- if is_given(config.on_word_boundary_event):
462
- synthesizer.synthesis_word_boundary.connect(config.on_word_boundary_event)
463
-
464
- return synthesizer