cartesia 0.0.4__tar.gz → 0.0.5rc1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: cartesia
3
- Version: 0.0.4
3
+ Version: 0.0.5rc1
4
4
  Summary: The official Python library for the Cartesia API.
5
5
  Home-page:
6
6
  Author: Cartesia, Inc.
@@ -10,8 +10,11 @@ Classifier: Programming Language :: Python :: 3
10
10
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
11
11
  Requires-Python: >=3.8.0
12
12
  Description-Content-Type: text/markdown
13
- Requires-Dist: websockets
13
+ Requires-Dist: aiohttp
14
+ Requires-Dist: httpx
15
+ Requires-Dist: pytest-asyncio
14
16
  Requires-Dist: requests
17
+ Requires-Dist: websockets
15
18
  Provides-Extra: dev
16
19
  Requires-Dist: pre-commit; extra == "dev"
17
20
  Requires-Dist: docformatter; extra == "dev"
@@ -21,6 +24,7 @@ Requires-Dist: flake8==7.0.0; extra == "dev"
21
24
  Requires-Dist: flake8-bugbear==24.2.6; extra == "dev"
22
25
  Requires-Dist: pytest>=8.0.2; extra == "dev"
23
26
  Requires-Dist: pytest-cov>=4.1.0; extra == "dev"
27
+ Requires-Dist: twine; extra == "dev"
24
28
  Provides-Extra: all
25
29
  Requires-Dist: pre-commit; extra == "all"
26
30
  Requires-Dist: docformatter; extra == "all"
@@ -30,6 +34,7 @@ Requires-Dist: flake8==7.0.0; extra == "all"
30
34
  Requires-Dist: flake8-bugbear==24.2.6; extra == "all"
31
35
  Requires-Dist: pytest>=8.0.2; extra == "all"
32
36
  Requires-Dist: pytest-cov>=4.1.0; extra == "all"
37
+ Requires-Dist: twine; extra == "all"
33
38
 
34
39
 
35
40
  # Cartesia Python API Library
@@ -104,7 +109,37 @@ for output in client.generate(transcript=transcript, voice=voice, stream=True):
104
109
  audio_data.seek(0)
105
110
 
106
111
  # Create an Audio object from the BytesIO data
107
- audio = Audio(audio_data, rate=output["sampling_rate"])
112
+ audio = Audio(np.frombuffer(audio_data.read(), dtype=np.float32), rate=output["sampling_rate"])
113
+
114
+ # Display the Audio object
115
+ display(audio)
116
+ ```
117
+
118
+ You can also use the async client if you want to make asynchronous API calls. The usage is very similar:
119
+ ```python
120
+ from cartesia.tts import AsyncCartesiaTTS
121
+ from IPython.display import Audio
122
+ import io
123
+ import os
124
+
125
+ client = AsyncCartesiaTTS(api_key=os.environ.get("CARTESIA_API_KEY"))
126
+ voices = client.get_voices()
127
+ voice = client.get_voice_embedding(voice_id=voices["Graham"]["id"])
128
+ transcript = "Hello! Welcome to Cartesia"
129
+
130
+ # Create a BytesIO object to store the audio data
131
+ audio_data = io.BytesIO()
132
+
133
+ # Generate and stream audio
134
+ async for output in client.generate(transcript=transcript, voice=voice, stream=True):
135
+ buffer = output["audio"]
136
+ audio_data.write(buffer)
137
+
138
+ # Set the cursor position to the beginning of the BytesIO object
139
+ audio_data.seek(0)
140
+
141
+ # Create an Audio object from the BytesIO data
142
+ audio = Audio(np.frombuffer(audio_data.read(), dtype=np.float32), rate=output["sampling_rate"])
108
143
 
109
144
  # Display the Audio object
110
145
  display(audio)
@@ -70,7 +70,37 @@ for output in client.generate(transcript=transcript, voice=voice, stream=True):
70
70
  audio_data.seek(0)
71
71
 
72
72
  # Create an Audio object from the BytesIO data
73
- audio = Audio(audio_data, rate=output["sampling_rate"])
73
+ audio = Audio(np.frombuffer(audio_data.read(), dtype=np.float32), rate=output["sampling_rate"])
74
+
75
+ # Display the Audio object
76
+ display(audio)
77
+ ```
78
+
79
+ You can also use the async client if you want to make asynchronous API calls. The usage is very similar:
80
+ ```python
81
+ from cartesia.tts import AsyncCartesiaTTS
82
+ from IPython.display import Audio
83
+ import io
84
+ import os
85
+
86
+ client = AsyncCartesiaTTS(api_key=os.environ.get("CARTESIA_API_KEY"))
87
+ voices = client.get_voices()
88
+ voice = client.get_voice_embedding(voice_id=voices["Graham"]["id"])
89
+ transcript = "Hello! Welcome to Cartesia"
90
+
91
+ # Create a BytesIO object to store the audio data
92
+ audio_data = io.BytesIO()
93
+
94
+ # Generate and stream audio
95
+ async for output in client.generate(transcript=transcript, voice=voice, stream=True):
96
+ buffer = output["audio"]
97
+ audio_data.write(buffer)
98
+
99
+ # Set the cursor position to the beginning of the BytesIO object
100
+ audio_data.seek(0)
101
+
102
+ # Create an Audio object from the BytesIO data
103
+ audio = Audio(np.frombuffer(audio_data.read(), dtype=np.float32), rate=output["sampling_rate"])
74
104
 
75
105
  # Display the Audio object
76
106
  display(audio)
@@ -1,15 +1,20 @@
1
+ import asyncio
1
2
  import base64
2
3
  import json
3
4
  import os
4
5
  import uuid
5
- from typing import Any, Dict, Generator, List, Optional, TypedDict, Union
6
+ from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Tuple, TypedDict, Union
6
7
 
8
+ import aiohttp
9
+ import httpx
7
10
  import requests
8
11
  from websockets.sync.client import connect
9
12
 
10
13
  DEFAULT_MODEL_ID = "genial-planet-1346"
11
14
  DEFAULT_BASE_URL = "api.cartesia.ai"
12
15
  DEFAULT_API_VERSION = "v0"
16
+ DEFAULT_TIMEOUT = 60 # seconds
17
+ DEFAULT_NUM_CONNECTIONS = 10 # connections per client
13
18
 
14
19
 
15
20
  class AudioOutput(TypedDict):
@@ -27,6 +32,37 @@ class VoiceMetadata(TypedDict):
27
32
  embedding: Optional[Embedding]
28
33
 
29
34
 
35
+ def update_buffer(buffer: str, chunk_bytes: bytes) -> Tuple[str, List[Dict[str, Any]]]:
36
+ buffer += chunk_bytes.decode("utf-8")
37
+ outputs = []
38
+ while "{" in buffer and "}" in buffer:
39
+ start_index = buffer.find("{")
40
+ end_index = buffer.find("}", start_index)
41
+ if start_index != -1 and end_index != -1:
42
+ try:
43
+ chunk_json = json.loads(buffer[start_index : end_index + 1])
44
+ audio = base64.b64decode(chunk_json["data"])
45
+ outputs.append({"audio": audio, "sampling_rate": chunk_json["sampling_rate"]})
46
+ buffer = buffer[end_index + 1 :]
47
+ except json.JSONDecodeError:
48
+ break
49
+ return buffer, outputs
50
+
51
+
52
+ def convert_response(response: Dict[str, any], include_context_id: bool) -> Dict[str, Any]:
53
+ audio = base64.b64decode(response["data"])
54
+
55
+ optional_kwargs = {}
56
+ if include_context_id:
57
+ optional_kwargs["context_id"] = response["context_id"]
58
+
59
+ return {
60
+ "audio": audio,
61
+ "sampling_rate": response["sampling_rate"],
62
+ **optional_kwargs,
63
+ }
64
+
65
+
30
66
  class CartesiaTTS:
31
67
  """The client for Cartesia's text-to-speech library.
32
68
 
@@ -108,9 +144,9 @@ class CartesiaTTS:
108
144
  >>> audio = client.generate(transcript="Hello world!", voice=embedding)
109
145
  """
110
146
  params = {"select": "id, name, description"} if skip_embeddings else None
111
- response = requests.get(f"{self._http_url()}/voices", headers=self.headers, params=params)
147
+ response = httpx.get(f"{self._http_url()}/voices", headers=self.headers, params=params)
112
148
 
113
- if response.status_code != 200:
149
+ if not response.is_success:
114
150
  raise ValueError(f"Failed to get voices. Error: {response.text}")
115
151
 
116
152
  voices = response.json()
@@ -142,20 +178,20 @@ class CartesiaTTS:
142
178
 
143
179
  if voice_id:
144
180
  url = f"{self._http_url()}/voices/embedding/{voice_id}"
145
- response = requests.get(url, headers=self.headers)
181
+ response = httpx.get(url, headers=self.headers)
146
182
  elif filepath:
147
183
  url = f"{self._http_url()}/voices/clone/clip"
148
184
  files = {"clip": open(filepath, "rb")}
149
185
  headers = self.headers.copy()
150
186
  # The default content type of JSON is incorrect for file uploads
151
187
  headers.pop("Content-Type")
152
- response = requests.post(url, headers=headers, files=files)
188
+ response = httpx.post(url, headers=headers, files=files)
153
189
  elif link:
154
190
  url = f"{self._http_url()}/voices/clone/url"
155
191
  params = {"link": link}
156
- response = requests.post(url, headers=self.headers, params=params)
192
+ response = httpx.post(url, headers=self.headers, params=params)
157
193
 
158
- if response.status_code != 200:
194
+ if not response.is_success:
159
195
  raise ValueError(
160
196
  f"Failed to clone voice. Status Code: {response.status_code}\n"
161
197
  f"Error: {response.text}"
@@ -200,6 +236,29 @@ class CartesiaTTS:
200
236
  if transcript.strip() == "":
201
237
  raise ValueError("`transcript` must be non empty")
202
238
 
239
+ def _generate_request_body(
240
+ self,
241
+ *,
242
+ transcript: str,
243
+ duration: int = None,
244
+ chunk_time: float = None,
245
+ voice: Embedding = None,
246
+ ) -> Dict[str, Any]:
247
+ """
248
+ Create the request body for a stream request.
249
+ Note that anything that's not provided will use a default if available or be filtered out otherwise.
250
+ """
251
+ body = dict(transcript=transcript, model_id=DEFAULT_MODEL_ID, voice=voice)
252
+
253
+ optional_body = dict(
254
+ duration=duration,
255
+ chunk_time=chunk_time,
256
+ voice=voice,
257
+ )
258
+ body.update({k: v for k, v in optional_body.items() if v is not None})
259
+
260
+ return body
261
+
203
262
  def generate(
204
263
  self,
205
264
  *,
@@ -232,14 +291,9 @@ class CartesiaTTS:
232
291
  """
233
292
  self._check_inputs(transcript, duration, chunk_time)
234
293
 
235
- body = dict(transcript=transcript, model_id=DEFAULT_MODEL_ID)
236
-
237
- optional_body = dict(
238
- duration=duration,
239
- chunk_time=chunk_time,
240
- voice=voice,
294
+ body = self._generate_request_body(
295
+ transcript=transcript, duration=duration, chunk_time=chunk_time, voice=voice
241
296
  )
242
- body.update({k: v for k, v in optional_body.items() if v is not None})
243
297
 
244
298
  if websocket:
245
299
  generator = self._generate_ws(body)
@@ -265,23 +319,14 @@ class CartesiaTTS:
265
319
  data=json.dumps(body),
266
320
  headers=self.headers,
267
321
  )
268
- if response.status_code != 200:
322
+ if not response.ok:
269
323
  raise ValueError(f"Failed to generate audio. {response.text}")
270
324
 
271
325
  buffer = ""
272
326
  for chunk_bytes in response.iter_content(chunk_size=None):
273
- buffer += chunk_bytes.decode("utf-8")
274
- while "{" in buffer and "}" in buffer:
275
- start_index = buffer.find("{")
276
- end_index = buffer.find("}", start_index)
277
- if start_index != -1 and end_index != -1:
278
- try:
279
- chunk_json = json.loads(buffer[start_index : end_index + 1])
280
- audio = base64.b64decode(chunk_json["data"])
281
- yield {"audio": audio, "sampling_rate": chunk_json["sampling_rate"]}
282
- buffer = buffer[end_index + 1 :]
283
- except json.JSONDecodeError:
284
- break
327
+ buffer, outputs = update_buffer(buffer, chunk_bytes)
328
+ for output in outputs:
329
+ yield output
285
330
 
286
331
  if buffer:
287
332
  try:
@@ -313,17 +358,8 @@ class CartesiaTTS:
313
358
  response = json.loads(self.websocket.recv())
314
359
  if response["done"]:
315
360
  break
316
- audio = base64.b64decode(response["data"])
317
361
 
318
- optional_kwargs = {}
319
- if include_context_id:
320
- optional_kwargs["context_id"] = response["context_id"]
321
-
322
- yield {
323
- "audio": audio,
324
- "sampling_rate": response["sampling_rate"],
325
- **optional_kwargs,
326
- }
362
+ yield convert_response(response, include_context_id)
327
363
 
328
364
  if self.experimental_ws_handle_interrupts:
329
365
  self.websocket.send(json.dumps({"context_id": context_id}))
@@ -347,3 +383,143 @@ class CartesiaTTS:
347
383
  def __del__(self):
348
384
  if self.websocket.socket.fileno() > -1:
349
385
  self.websocket.close()
386
+
387
+
388
+ class AsyncCartesiaTTS(CartesiaTTS):
389
+ def __init__(self, *, api_key: str = None, experimental_ws_handle_interrupts: bool = False):
390
+ self.timeout = aiohttp.ClientTimeout(total=DEFAULT_TIMEOUT)
391
+ self.connector = aiohttp.TCPConnector(limit=DEFAULT_NUM_CONNECTIONS)
392
+ self._session = aiohttp.ClientSession(timeout=self.timeout, connector=self.connector)
393
+ super().__init__(
394
+ api_key=api_key, experimental_ws_handle_interrupts=experimental_ws_handle_interrupts
395
+ )
396
+
397
+ def refresh_websocket(self):
398
+ pass # do not load the websocket for the client until asynchronously when it is needed
399
+
400
+ async def _async_refresh_websocket(self):
401
+ """Refresh the websocket connection."""
402
+ if self.websocket and not self._is_websocket_closed():
403
+ self.websocket.close()
404
+ route = "audio/websocket"
405
+ if self.experimental_ws_handle_interrupts:
406
+ route = f"experimental/{route}"
407
+ self.websocket = await self._session.ws_connect(
408
+ f"{self._ws_url()}/{route}?api_key={self.api_key}"
409
+ )
410
+
411
+ async def generate(
412
+ self,
413
+ *,
414
+ transcript: str,
415
+ duration: int = None,
416
+ chunk_time: float = None,
417
+ voice: Embedding = None,
418
+ stream: bool = False,
419
+ websocket: bool = True,
420
+ ) -> Union[AudioOutput, AsyncGenerator[AudioOutput, None]]:
421
+ """Asynchronously generate audio from a transcript.
422
+ NOTE: This overrides the non-asynchronous generate method from the base class.
423
+ Args:
424
+ transcript: The text to generate audio for.
425
+ voice: The embedding to use for generating audio.
426
+ options: The options to use for generating audio. See :class:`GenerateOptions`.
427
+ Returns:
428
+ A dictionary containing the following:
429
+ * "audio": The audio as a 1D numpy array.
430
+ * "sampling_rate": The sampling rate of the audio.
431
+ """
432
+ body = self._generate_request_body(
433
+ transcript=transcript, duration=duration, chunk_time=chunk_time, voice=voice
434
+ )
435
+
436
+ if websocket:
437
+ generator = self._generate_ws(body)
438
+ else:
439
+ generator = self._generate_http(body)
440
+
441
+ if stream:
442
+ return generator
443
+
444
+ chunks = []
445
+ sampling_rate = None
446
+ async for chunk in generator:
447
+ if sampling_rate is None:
448
+ sampling_rate = chunk["sampling_rate"]
449
+ chunks.append(chunk["audio"])
450
+
451
+ return {"audio": b"".join(chunks), "sampling_rate": sampling_rate}
452
+
453
+ async def _generate_http(self, body: Dict[str, Any]):
454
+ async with self._session.post(
455
+ f"{self._http_url()}/audio/stream", data=json.dumps(body), headers=self.headers
456
+ ) as response:
457
+ if response.status < 200 or response.status >= 300:
458
+ raise ValueError(f"Failed to generate audio. {response.text}")
459
+
460
+ buffer = ""
461
+ async for chunk_bytes in response.content.iter_any():
462
+ buffer, outputs = update_buffer(buffer, chunk_bytes)
463
+ for output in outputs:
464
+ yield output
465
+
466
+ if buffer:
467
+ try:
468
+ chunk_json = json.loads(buffer)
469
+ audio = base64.b64decode(chunk_json["data"])
470
+ yield {"audio": audio, "sampling_rate": chunk_json["sampling_rate"]}
471
+ except json.JSONDecodeError:
472
+ pass
473
+
474
+ async def _generate_ws(self, body: Dict[str, Any], *, context_id: str = None):
475
+ include_context_id = bool(context_id)
476
+ route = "audio/websocket"
477
+ if self.experimental_ws_handle_interrupts:
478
+ route = f"experimental/{route}"
479
+
480
+ if not self.websocket or self._is_websocket_closed():
481
+ await self._async_refresh_websocket()
482
+
483
+ ws = self.websocket
484
+ if context_id is None:
485
+ context_id = uuid.uuid4().hex
486
+ await ws.send_json({"data": body, "context_id": context_id})
487
+ try:
488
+ response = None
489
+ while True:
490
+ response = await ws.receive_json()
491
+ if response["done"]:
492
+ break
493
+
494
+ yield convert_response(response, include_context_id)
495
+
496
+ if self.experimental_ws_handle_interrupts:
497
+ await ws.send_json({"context_id": context_id})
498
+ except GeneratorExit:
499
+ # The exit is only called when the generator is garbage collected.
500
+ # It may not be called directly after a break statement.
501
+ # However, the generator will be automatically cancelled on the next request.
502
+ if self.experimental_ws_handle_interrupts:
503
+ await ws.send_json({"context_id": context_id, "action": "cancel"})
504
+ except Exception as e:
505
+ raise RuntimeError(f"Failed to generate audio. {response}") from e
506
+
507
+ def _is_websocket_closed(self):
508
+ return self.websocket.closed
509
+
510
+ async def cleanup(self):
511
+ if self.websocket is not None and not self._is_websocket_closed():
512
+ await self.websocket.close()
513
+ if not self._session.closed:
514
+ await self._session.close()
515
+
516
+ def __del__(self):
517
+ try:
518
+ loop = asyncio.get_running_loop()
519
+ except RuntimeError:
520
+ loop = None
521
+
522
+ if loop is None:
523
+ asyncio.run(self.cleanup())
524
+ else:
525
+ loop.create_task(self.cleanup())
@@ -0,0 +1 @@
1
+ __version__ = "0.0.5rc1"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: cartesia
3
- Version: 0.0.4
3
+ Version: 0.0.5rc1
4
4
  Summary: The official Python library for the Cartesia API.
5
5
  Home-page:
6
6
  Author: Cartesia, Inc.
@@ -10,8 +10,11 @@ Classifier: Programming Language :: Python :: 3
10
10
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
11
11
  Requires-Python: >=3.8.0
12
12
  Description-Content-Type: text/markdown
13
- Requires-Dist: websockets
13
+ Requires-Dist: aiohttp
14
+ Requires-Dist: httpx
15
+ Requires-Dist: pytest-asyncio
14
16
  Requires-Dist: requests
17
+ Requires-Dist: websockets
15
18
  Provides-Extra: dev
16
19
  Requires-Dist: pre-commit; extra == "dev"
17
20
  Requires-Dist: docformatter; extra == "dev"
@@ -21,6 +24,7 @@ Requires-Dist: flake8==7.0.0; extra == "dev"
21
24
  Requires-Dist: flake8-bugbear==24.2.6; extra == "dev"
22
25
  Requires-Dist: pytest>=8.0.2; extra == "dev"
23
26
  Requires-Dist: pytest-cov>=4.1.0; extra == "dev"
27
+ Requires-Dist: twine; extra == "dev"
24
28
  Provides-Extra: all
25
29
  Requires-Dist: pre-commit; extra == "all"
26
30
  Requires-Dist: docformatter; extra == "all"
@@ -30,6 +34,7 @@ Requires-Dist: flake8==7.0.0; extra == "all"
30
34
  Requires-Dist: flake8-bugbear==24.2.6; extra == "all"
31
35
  Requires-Dist: pytest>=8.0.2; extra == "all"
32
36
  Requires-Dist: pytest-cov>=4.1.0; extra == "all"
37
+ Requires-Dist: twine; extra == "all"
33
38
 
34
39
 
35
40
  # Cartesia Python API Library
@@ -104,7 +109,37 @@ for output in client.generate(transcript=transcript, voice=voice, stream=True):
104
109
  audio_data.seek(0)
105
110
 
106
111
  # Create an Audio object from the BytesIO data
107
- audio = Audio(audio_data, rate=output["sampling_rate"])
112
+ audio = Audio(np.frombuffer(audio_data.read(), dtype=np.float32), rate=output["sampling_rate"])
113
+
114
+ # Display the Audio object
115
+ display(audio)
116
+ ```
117
+
118
+ You can also use the async client if you want to make asynchronous API calls. The usage is very similar:
119
+ ```python
120
+ from cartesia.tts import AsyncCartesiaTTS
121
+ from IPython.display import Audio
122
+ import io
123
+ import os
124
+
125
+ client = AsyncCartesiaTTS(api_key=os.environ.get("CARTESIA_API_KEY"))
126
+ voices = client.get_voices()
127
+ voice = client.get_voice_embedding(voice_id=voices["Graham"]["id"])
128
+ transcript = "Hello! Welcome to Cartesia"
129
+
130
+ # Create a BytesIO object to store the audio data
131
+ audio_data = io.BytesIO()
132
+
133
+ # Generate and stream audio
134
+ async for output in client.generate(transcript=transcript, voice=voice, stream=True):
135
+ buffer = output["audio"]
136
+ audio_data.write(buffer)
137
+
138
+ # Set the cursor position to the beginning of the BytesIO object
139
+ audio_data.seek(0)
140
+
141
+ # Create an Audio object from the BytesIO data
142
+ audio = Audio(np.frombuffer(audio_data.read(), dtype=np.float32), rate=output["sampling_rate"])
108
143
 
109
144
  # Display the Audio object
110
145
  display(audio)
@@ -1,5 +1,8 @@
1
- websockets
1
+ aiohttp
2
+ httpx
3
+ pytest-asyncio
2
4
  requests
5
+ websockets
3
6
 
4
7
  [all]
5
8
  pre-commit
@@ -10,6 +13,7 @@ flake8==7.0.0
10
13
  flake8-bugbear==24.2.6
11
14
  pytest>=8.0.2
12
15
  pytest-cov>=4.1.0
16
+ twine
13
17
 
14
18
  [dev]
15
19
  pre-commit
@@ -20,3 +24,4 @@ flake8==7.0.0
20
24
  flake8-bugbear==24.2.6
21
25
  pytest>=8.0.2
22
26
  pytest-cov>=4.1.0
27
+ twine
@@ -214,7 +214,7 @@ class BumpVersionCommand(Command):
214
214
 
215
215
  # Commit the file with a message '[bumpversion] v<version>'.
216
216
  self.status(f"Commit with message '[bumpversion] v{self.version}'")
217
- err_code = os.system("git commit -m '[bumpversion] v{}'".format(current_version))
217
+ err_code = os.system("git commit -m '[bumpversion] v{}'".format(self.version))
218
218
  if err_code != 0:
219
219
  self._undo()
220
220
  raise RuntimeError("Failed to commit file to git.")
@@ -6,12 +6,14 @@ but rather for general correctness.
6
6
  """
7
7
 
8
8
  import os
9
+ import sys
9
10
  import uuid
10
- from typing import Dict, Generator, List
11
+ from typing import AsyncGenerator, Dict, Generator, List
11
12
 
12
13
  import pytest
13
14
 
14
- from cartesia.tts import DEFAULT_MODEL_ID, CartesiaTTS, VoiceMetadata
15
+ sys.path.append(os.path.dirname(os.path.dirname(__file__)))
16
+ from cartesia.tts import DEFAULT_MODEL_ID, AsyncCartesiaTTS, CartesiaTTS, VoiceMetadata
15
17
 
16
18
  SAMPLE_VOICE = "Milo"
17
19
 
@@ -22,9 +24,17 @@ class _Resources:
22
24
  self.voices = voices
23
25
 
24
26
 
27
+ def create_client():
28
+ return CartesiaTTS(api_key=os.environ.get("CARTESIA_API_KEY"))
29
+
30
+
31
+ def create_async_client():
32
+ return AsyncCartesiaTTS(api_key=os.environ.get("CARTESIA_API_KEY"))
33
+
34
+
25
35
  @pytest.fixture(scope="session")
26
36
  def client():
27
- return CartesiaTTS(api_key=os.environ.get("CARTESIA_API_KEY"))
37
+ return create_client()
28
38
 
29
39
 
30
40
  @pytest.fixture(scope="session")
@@ -101,6 +111,41 @@ def test_generate_stream(resources: _Resources, websocket: bool):
101
111
  assert isinstance(output["sampling_rate"], int)
102
112
 
103
113
 
114
+ @pytest.mark.parametrize("websocket", [True, False])
115
+ @pytest.mark.asyncio
116
+ async def test_async_generate(resources: _Resources, websocket: bool):
117
+ voices = resources.voices
118
+ embedding = voices[SAMPLE_VOICE]["embedding"]
119
+ transcript = "Hello, world!"
120
+
121
+ async_client = create_async_client()
122
+ output = await async_client.generate(
123
+ transcript=transcript, voice=embedding, websocket=websocket
124
+ )
125
+
126
+ assert output.keys() == {"audio", "sampling_rate"}
127
+ assert isinstance(output["audio"], bytes)
128
+ assert isinstance(output["sampling_rate"], int)
129
+
130
+
131
+ @pytest.mark.parametrize("websocket", [True, False])
132
+ @pytest.mark.asyncio
133
+ async def test_async_generate_stream(resources: _Resources, websocket: bool):
134
+ voices = resources.voices
135
+ embedding = voices[SAMPLE_VOICE]["embedding"]
136
+ transcript = "Hello, world!"
137
+
138
+ async_client = create_async_client()
139
+
140
+ generator = await async_client.generate(transcript=transcript, voice=embedding, stream=True)
141
+ assert isinstance(generator, AsyncGenerator)
142
+
143
+ async for output in generator:
144
+ assert output.keys() == {"audio", "sampling_rate"}
145
+ assert isinstance(output["audio"], bytes)
146
+ assert isinstance(output["sampling_rate"], int)
147
+
148
+
104
149
  @pytest.mark.parametrize(
105
150
  "actions",
106
151
  [
@@ -1 +0,0 @@
1
- __version__ = "0.0.4"
File without changes
File without changes