cartesia 0.0.3__tar.gz → 0.0.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: cartesia
3
- Version: 0.0.3
3
+ Version: 0.0.4
4
4
  Summary: The official Python library for the Cartesia API.
5
5
  Home-page:
6
6
  Author: Cartesia, Inc.
@@ -110,4 +110,6 @@ audio = Audio(audio_data, rate=output["sampling_rate"])
110
110
  display(audio)
111
111
  ```
112
112
 
113
- We recommend using [`python-dotenv`](https://pypi.org/project/python-dotenv/) to add `CARTESIA_API_KEY="my-api-key"` to your .env file so that your API Key is not stored in the source code.
113
+ To avoid storing your API key in the source code, we recommend doing one of the following:
114
+ 1. Use [`python-dotenv`](https://pypi.org/project/python-dotenv/) to add `CARTESIA_API_KEY="my-api-key"` to your .env file.
115
+ 1. Set the `CARTESIA_API_KEY` environment variable, preferably to a secure shell init file (e.g. `~/.zshrc`, `~/.bashrc`)
@@ -76,4 +76,6 @@ audio = Audio(audio_data, rate=output["sampling_rate"])
76
76
  display(audio)
77
77
  ```
78
78
 
79
- We recommend using [`python-dotenv`](https://pypi.org/project/python-dotenv/) to add `CARTESIA_API_KEY="my-api-key"` to your .env file so that your API Key is not stored in the source code.
79
+ To avoid storing your API key in the source code, we recommend doing one of the following:
80
+ 1. Use [`python-dotenv`](https://pypi.org/project/python-dotenv/) to add `CARTESIA_API_KEY="my-api-key"` to your .env file.
81
+ 1. Set the `CARTESIA_API_KEY` environment variable, preferably to a secure shell init file (e.g. `~/.zshrc`, `~/.bashrc`)
@@ -31,7 +31,11 @@ class CartesiaTTS:
31
31
  """The client for Cartesia's text-to-speech library.
32
32
 
33
33
  This client contains methods to interact with the Cartesia text-to-speech API.
34
- The API offers
34
+ The client can be used to retrieve available voices, compute new voice embeddings,
35
+ and generate speech from text.
36
+
37
+ The client also supports generating audio using a websocket for lower latency.
38
+ To enable interrupt handling along the websocket, set `experimental_ws_handle_interrupts=True`.
35
39
 
36
40
  Examples:
37
41
 
@@ -55,18 +59,22 @@ class CartesiaTTS:
55
59
  ... audio, sr = audio_chunk["audio"], audio_chunk["sampling_rate"]
56
60
  """
57
61
 
58
- def __init__(self, *, api_key: str = None):
62
+ def __init__(self, *, api_key: str = None, experimental_ws_handle_interrupts: bool = False):
59
63
  """
60
64
  Args:
61
65
  api_key: The API key to use for authorization.
62
66
  If not specified, the API key will be read from the environment variable
63
67
  `CARTESIA_API_KEY`.
68
+ experimental_ws_handle_interrupts: Whether to handle interrupts when generating
69
+ audio using the websocket. This is an experimental feature and may have bugs
70
+ or be deprecated in the future.
64
71
  """
65
72
  self.base_url = os.environ.get("CARTESIA_BASE_URL", DEFAULT_BASE_URL)
66
73
  self.api_key = api_key or os.environ.get("CARTESIA_API_KEY")
67
74
  self.api_version = os.environ.get("CARTESIA_API_VERSION", DEFAULT_API_VERSION)
68
75
  self.headers = {"X-API-Key": self.api_key, "Content-Type": "application/json"}
69
76
  self.websocket = None
77
+ self.experimental_ws_handle_interrupts = experimental_ws_handle_interrupts
70
78
  self.refresh_websocket()
71
79
 
72
80
  def get_voices(self, skip_embeddings: bool = True) -> Dict[str, VoiceMetadata]:
@@ -167,8 +175,11 @@ class CartesiaTTS:
167
175
  """
168
176
  if self.websocket and not self._is_websocket_closed():
169
177
  self.websocket.close()
178
+ route = "audio/websocket"
179
+ if self.experimental_ws_handle_interrupts:
180
+ route = f"experimental/{route}"
170
181
  self.websocket = connect(
171
- f"{self._ws_url()}/audio/websocket?api_key={self.api_key}",
182
+ f"{self._ws_url()}/{route}?api_key={self.api_key}",
172
183
  close_timeout=None,
173
184
  )
174
185
 
@@ -280,21 +291,50 @@ class CartesiaTTS:
280
291
  except json.JSONDecodeError:
281
292
  pass
282
293
 
283
- def _generate_ws(self, body: Dict[str, Any]):
294
+ def _generate_ws(self, body: Dict[str, Any], *, context_id: str = None):
295
+ """Generate audio using the websocket connection.
296
+
297
+ Args:
298
+ body: The request body.
299
+ context_id: The context id for the request.
300
+ The context id must be globally unique for the duration this client exists.
301
+ If this is provided, the context id that is in the response will
302
+ also be returned as part of the dict. This is helpful for testing.
303
+ """
284
304
  if not self.websocket or self._is_websocket_closed():
285
305
  self.refresh_websocket()
286
306
 
287
- self.websocket.send(json.dumps({"data": body, "context_id": uuid.uuid4().hex}))
307
+ include_context_id = bool(context_id)
308
+ if context_id is None:
309
+ context_id = uuid.uuid4().hex
310
+ self.websocket.send(json.dumps({"data": body, "context_id": context_id}))
288
311
  try:
289
- response = json.loads(self.websocket.recv())
290
- while not response["done"]:
312
+ while True:
313
+ response = json.loads(self.websocket.recv())
314
+ if response["done"]:
315
+ break
291
316
  audio = base64.b64decode(response["data"])
292
- # print("timing", time.perf_counter() - start)
293
- yield {"audio": audio, "sampling_rate": response["sampling_rate"]}
294
317
 
295
- response = json.loads(self.websocket.recv())
296
- except Exception:
297
- raise RuntimeError(f"Failed to generate audio. {response}")
318
+ optional_kwargs = {}
319
+ if include_context_id:
320
+ optional_kwargs["context_id"] = response["context_id"]
321
+
322
+ yield {
323
+ "audio": audio,
324
+ "sampling_rate": response["sampling_rate"],
325
+ **optional_kwargs,
326
+ }
327
+
328
+ if self.experimental_ws_handle_interrupts:
329
+ self.websocket.send(json.dumps({"context_id": context_id}))
330
+ except GeneratorExit:
331
+ # The exit is only called when the generator is garbage collected.
332
+ # It may not be called directly after a break statement.
333
+ # However, the generator will be automatically cancelled on the next request.
334
+ if self.experimental_ws_handle_interrupts:
335
+ self.websocket.send(json.dumps({"context_id": context_id, "action": "cancel"}))
336
+ except Exception as e:
337
+ raise RuntimeError(f"Failed to generate audio. {response}") from e
298
338
 
299
339
  def _http_url(self):
300
340
  prefix = "http" if "localhost" in self.base_url else "https"
@@ -0,0 +1 @@
1
+ __version__ = "0.0.4"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: cartesia
3
- Version: 0.0.3
3
+ Version: 0.0.4
4
4
  Summary: The official Python library for the Cartesia API.
5
5
  Home-page:
6
6
  Author: Cartesia, Inc.
@@ -110,4 +110,6 @@ audio = Audio(audio_data, rate=output["sampling_rate"])
110
110
  display(audio)
111
111
  ```
112
112
 
113
- We recommend using [`python-dotenv`](https://pypi.org/project/python-dotenv/) to add `CARTESIA_API_KEY="my-api-key"` to your .env file so that your API Key is not stored in the source code.
113
+ To avoid storing your API key in the source code, we recommend doing one of the following:
114
+ 1. Use [`python-dotenv`](https://pypi.org/project/python-dotenv/) to add `CARTESIA_API_KEY="my-api-key"` to your .env file.
115
+ 1. Set the `CARTESIA_API_KEY` environment variable, preferably to a secure shell init file (e.g. `~/.zshrc`, `~/.bashrc`)
@@ -78,7 +78,8 @@ class UploadCommand(Command):
78
78
  """Support setup.py upload."""
79
79
 
80
80
  description = "Build and publish the package."
81
- user_options = []
81
+ user_options = [("skip-upload", "u", "skip git tagging and pypi upload")]
82
+ boolean_options = ["skip-upload"]
82
83
 
83
84
  @staticmethod
84
85
  def status(s):
@@ -86,21 +87,26 @@ class UploadCommand(Command):
86
87
  print("\033[1m{0}\033[0m".format(s))
87
88
 
88
89
  def initialize_options(self):
89
- pass
90
+ self.skip_upload = False
90
91
 
91
92
  def finalize_options(self):
92
- pass
93
+ self.skip_upload = bool(self.skip_upload)
93
94
 
94
95
  def run(self):
95
96
  try:
96
97
  self.status("Removing previous builds…")
97
98
  rmtree(os.path.join(here, "dist"))
99
+ rmtree(os.path.join(here, "build"))
98
100
  except OSError:
99
101
  pass
100
102
 
101
103
  self.status("Building Source and Wheel (universal) distribution…")
102
104
  os.system("{0} setup.py sdist bdist_wheel --universal".format(sys.executable))
103
105
 
106
+ if self.skip_upload:
107
+ self.status("Skipping git tagging and pypi upload")
108
+ sys.exit()
109
+
104
110
  self.status("Uploading the package to PyPI via Twine…")
105
111
  os.system("twine upload dist/*")
106
112
 
@@ -116,6 +122,9 @@ class BumpVersionCommand(Command):
116
122
  To use: python setup.py bumpversion -v <version>
117
123
 
118
124
  This command will push the new version directly and tag it.
125
+
126
+ Usage:
127
+ python setup.py bumpversion --version=1.0.1
119
128
  """
120
129
 
121
130
  description = "Installs the foo."
@@ -130,6 +139,11 @@ class BumpVersionCommand(Command):
130
139
 
131
140
  def initialize_options(self):
132
141
  self.version = None
142
+ self.base_branch = None
143
+ self.version_branch = None
144
+ self.updated_files = [
145
+ "cartesia/version.py",
146
+ ]
133
147
 
134
148
  def finalize_options(self):
135
149
  # This package cannot be imported at top level because it
@@ -147,14 +161,18 @@ class BumpVersionCommand(Command):
147
161
  )
148
162
 
149
163
  def _undo(self):
150
- os.system(f"git restore --staged {PACKAGE_DIR}/__init__.py")
151
- os.system(f"git checkout -- {PACKAGE_DIR}/__init__.py")
164
+ os.system(f"git restore --staged {' '.join(self.updated_files)}")
165
+ os.system(f"git checkout -- {' '.join(self.updated_files)}")
166
+
167
+ # Return to the original branch
168
+ os.system(f"git checkout {self.base_branch}")
169
+ os.system(f"git branch -D {self.version_branch}")
152
170
 
153
171
  def run(self):
154
172
  current_version = about["__version__"]
155
173
 
156
174
  self.status("Checking current branch is 'main'")
157
- current_branch = get_git_branch()
175
+ self.base_branch = current_branch = get_git_branch()
158
176
  if current_branch != "main":
159
177
  raise RuntimeError(
160
178
  "You can only bump the version from the 'main' branch. "
@@ -174,18 +192,25 @@ class BumpVersionCommand(Command):
174
192
 
175
193
  # TODO: Add check to see if all tests are passing on main.
176
194
 
195
+ # Checkout new branch
196
+ self.version_branch = f"bumpversion/v{self.version}"
197
+ self.status(f"Create branch '{self.version_branch}'")
198
+ err_code = os.system(f"git checkout -b {self.version_branch}")
199
+ if err_code != 0:
200
+ raise RuntimeError("Failed to create branch.")
201
+
177
202
  # Change the version in __init__.py
178
203
  self.status(f"Updating version {current_version} -> {self.version}")
179
204
  update_version(self.version)
180
- if current_version != self.version:
181
- self._undo()
182
- raise RuntimeError("Failed to update version.")
205
+ # if current_version != self.version:
206
+ # self._undo()
207
+ # raise RuntimeError("Failed to update version.")
183
208
 
184
- self.status(f"Adding {PACKAGE_DIR}/__init__.py to git")
185
- err_code = os.system(f"git add {PACKAGE_DIR}/__init__.py")
209
+ self.status(f"Adding {', '.join(self.updated_files)} to git")
210
+ err_code = os.system(f"git add {' '.join(self.updated_files)}")
186
211
  if err_code != 0:
187
212
  self._undo()
188
- raise RuntimeError("Failed to add file to git.")
213
+ raise RuntimeError("Failed to add files to git.")
189
214
 
190
215
  # Commit the file with a message '[bumpversion] v<version>'.
191
216
  self.status(f"Commit with message '[bumpversion] v{self.version}'")
@@ -195,12 +220,15 @@ class BumpVersionCommand(Command):
195
220
  raise RuntimeError("Failed to commit file to git.")
196
221
 
197
222
  # Push the commit to origin.
198
- # self.status("Pushing commit to origin")
199
- # err_code = os.system("git push")
200
- # if err_code != 0:
201
- # # TODO: undo the commit automatically.
202
- # raise RuntimeError("Failed to push commit to origin.")
223
+ self.status(f"Pushing commit to origin/{self.version_branch}")
224
+ err_code = os.system(f"git push --force --set-upstream origin {self.version_branch}")
225
+ if err_code != 0:
226
+ # TODO: undo the commit automatically.
227
+ self._undo()
228
+ raise RuntimeError("Failed to push commit to origin.")
203
229
 
230
+ os.system(f"git checkout {self.base_branch}")
231
+ os.system(f"git branch -D {self.version_branch}")
204
232
  sys.exit()
205
233
 
206
234
 
@@ -6,11 +6,12 @@ but rather for general correctness.
6
6
  """
7
7
 
8
8
  import os
9
- from typing import Dict, Generator
9
+ import uuid
10
+ from typing import Dict, Generator, List
10
11
 
11
12
  import pytest
12
13
 
13
- from cartesia.tts import CartesiaTTS, VoiceMetadata
14
+ from cartesia.tts import DEFAULT_MODEL_ID, CartesiaTTS, VoiceMetadata
14
15
 
15
16
  SAMPLE_VOICE = "Milo"
16
17
 
@@ -26,6 +27,13 @@ def client():
26
27
  return CartesiaTTS(api_key=os.environ.get("CARTESIA_API_KEY"))
27
28
 
28
29
 
30
+ @pytest.fixture(scope="session")
31
+ def client_with_ws_interrupt():
32
+ return CartesiaTTS(
33
+ api_key=os.environ.get("CARTESIA_API_KEY"), experimental_ws_handle_interrupts=True
34
+ )
35
+
36
+
29
37
  @pytest.fixture(scope="session")
30
38
  def resources(client: CartesiaTTS):
31
39
  voices = client.get_voices()
@@ -93,6 +101,45 @@ def test_generate_stream(resources: _Resources, websocket: bool):
93
101
  assert isinstance(output["sampling_rate"], int)
94
102
 
95
103
 
104
+ @pytest.mark.parametrize(
105
+ "actions",
106
+ [
107
+ ["cancel-5", None],
108
+ ["cancel-5", "cancel-1", None],
109
+ [None, "cancel-3", None],
110
+ [None, "cancel-1", "cancel-2"],
111
+ ],
112
+ )
113
+ def test_generate_stream_interrupt(
114
+ client_with_ws_interrupt: CartesiaTTS, resources: _Resources, actions: List[str]
115
+ ):
116
+ client = client_with_ws_interrupt
117
+ voices = resources.voices
118
+ embedding = voices[SAMPLE_VOICE]["embedding"]
119
+ transcript = "Hello, world!"
120
+
121
+ context_ids = [f"test-{uuid.uuid4().hex[:6]}" for _ in range(len(actions))]
122
+
123
+ for context_id, action in zip(context_ids, actions):
124
+ body = dict(transcript=transcript, model_id=DEFAULT_MODEL_ID, voice=embedding)
125
+
126
+ # Parse actions to see what we should expect.
127
+ if action is None:
128
+ num_turns = None
129
+ elif "cancel" in action:
130
+ num_turns = int(action.split("-")[1])
131
+
132
+ generator = client._generate_ws(body, context_id=context_id)
133
+ for idx, response in enumerate(generator):
134
+ assert response.keys() == {"audio", "sampling_rate", "context_id"}
135
+ assert response["context_id"] == context_id, (
136
+ f"Context ID from response ({response['context_id']}) does not match "
137
+ f"the expected context ID ({context_id})"
138
+ )
139
+ if idx + 1 == num_turns:
140
+ break
141
+
142
+
96
143
  @pytest.mark.parametrize("chunk_time", [0.05, 0.6])
97
144
  def test_check_inputs_invalid_chunk_time(client: CartesiaTTS, chunk_time):
98
145
  with pytest.raises(ValueError, match="`chunk_time` must be between 0.1 and 0.5"):
@@ -1 +0,0 @@
1
- __version__ = "0.0.3"
File without changes
File without changes
File without changes