videosdk-plugins-openai 0.0.26__tar.gz → 0.0.28__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of videosdk-plugins-openai might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: videosdk-plugins-openai
3
- Version: 0.0.26
3
+ Version: 0.0.28
4
4
  Summary: VideoSDK Agent Framework plugin for OpenAI services
5
5
  Author: videosdk
6
6
  License-Expression: Apache-2.0
@@ -13,7 +13,7 @@ Classifier: Topic :: Multimedia :: Video
13
13
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
14
14
  Requires-Python: >=3.11
15
15
  Requires-Dist: openai[realtime]>=1.68.2
16
- Requires-Dist: videosdk-agents>=0.0.26
16
+ Requires-Dist: videosdk-agents>=0.0.28
17
17
  Description-Content-Type: text/markdown
18
18
 
19
19
  # VideoSDK OpenAI Plugin
@@ -21,7 +21,7 @@ classifiers = [
21
21
  "Topic :: Scientific/Engineering :: Artificial Intelligence",
22
22
  ]
23
23
  dependencies = [
24
- "videosdk-agents>=0.0.26",
24
+ "videosdk-agents>=0.0.28",
25
25
  "openai[realtime]>=1.68.2",
26
26
  ]
27
27
 
@@ -42,6 +42,7 @@ class OpenAILLM(LLM):
42
42
  self.temperature = temperature
43
43
  self.tool_choice = tool_choice
44
44
  self.max_completion_tokens = max_completion_tokens
45
+ self._cancelled = False
45
46
 
46
47
  self._client = openai.AsyncOpenAI(
47
48
  api_key=self.api_key,
@@ -75,6 +76,8 @@ class OpenAILLM(LLM):
75
76
  Yields:
76
77
  LLMResponse objects containing the model's responses
77
78
  """
79
+ self._cancelled = False
80
+
78
81
  def _format_content(content: Union[str, List[ChatContent]]):
79
82
  if isinstance(content, str):
80
83
  return content
@@ -139,14 +142,17 @@ class OpenAILLM(LLM):
139
142
  if formatted_tools:
140
143
  completion_params["functions"] = formatted_tools
141
144
  completion_params["function_call"] = self.tool_choice
142
-
143
145
  completion_params.update(kwargs)
144
146
  try:
145
147
  response_stream = await self._client.chat.completions.create(**completion_params)
148
+
146
149
  current_content = ""
147
150
  current_function_call = None
148
151
 
149
152
  async for chunk in response_stream:
153
+ if self._cancelled:
154
+ break
155
+
150
156
  if not chunk.choices:
151
157
  continue
152
158
 
@@ -178,17 +184,22 @@ class OpenAILLM(LLM):
178
184
  current_function_call = None
179
185
 
180
186
  elif delta.content is not None:
181
- current_content += delta.content
187
+ current_content = delta.content
182
188
  yield LLMResponse(
183
189
  content=current_content,
184
190
  role=ChatRole.ASSISTANT
185
191
  )
186
192
 
187
193
  except Exception as e:
188
- self.emit("error", e)
194
+ if not self._cancelled:
195
+ self.emit("error", e)
189
196
  raise
190
197
 
198
+ async def cancel_current_generation(self) -> None:
199
+ self._cancelled = True
200
+
191
201
  async def aclose(self) -> None:
192
202
  """Cleanup resources by closing the HTTP client"""
203
+ await self.cancel_current_generation()
193
204
  if self._client:
194
- await self._client.close()
205
+ await self._client.close()
@@ -22,7 +22,7 @@ from videosdk.agents import (
22
22
  ToolChoice,
23
23
  RealtimeBaseModel,
24
24
  global_event_emitter,
25
- Agent
25
+ Agent,
26
26
  )
27
27
  from videosdk.agents import realtime_metrics_collector
28
28
 
@@ -46,19 +46,16 @@ DEFAULT_INPUT_AUDIO_TRANSCRIPTION = InputAudioTranscription(
46
46
  )
47
47
  DEFAULT_TOOL_CHOICE = "auto"
48
48
 
49
- OpenAIEventTypes = Literal[
50
- "user_speech_started",
51
- "text_response",
52
- "error"
53
- ]
49
+ OpenAIEventTypes = Literal["user_speech_started", "text_response", "error"]
54
50
  DEFAULT_VOICE = "alloy"
55
51
  DEFAULT_INPUT_AUDIO_FORMAT = "pcm16"
56
52
  DEFAULT_OUTPUT_AUDIO_FORMAT = "pcm16"
57
53
 
54
+
58
55
  @dataclass
59
56
  class OpenAIRealtimeConfig:
60
57
  """Configuration for the OpenAI realtime API
61
-
58
+
62
59
  Args:
63
60
  voice: Voice ID for audio output. Default is 'alloy'
64
61
  temperature: Controls randomness in response generation. Higher values (e.g. 0.8) make output more random,
@@ -75,23 +72,31 @@ class OpenAIRealtimeConfig:
75
72
  tool_choice: How tools should be selected ('auto' or 'none'). Default is 'auto'
76
73
  modalities: List of enabled response types ["text", "audio"]. Default includes both
77
74
  """
75
+
78
76
  voice: str = DEFAULT_VOICE
79
77
  temperature: float = DEFAULT_TEMPERATURE
80
- turn_detection: TurnDetection | None = field(default_factory=lambda: DEFAULT_TURN_DETECTION)
81
- input_audio_transcription: InputAudioTranscription | None = field(default_factory=lambda: DEFAULT_INPUT_AUDIO_TRANSCRIPTION)
78
+ turn_detection: TurnDetection | None = field(
79
+ default_factory=lambda: DEFAULT_TURN_DETECTION
80
+ )
81
+ input_audio_transcription: InputAudioTranscription | None = field(
82
+ default_factory=lambda: DEFAULT_INPUT_AUDIO_TRANSCRIPTION
83
+ )
82
84
  tool_choice: ToolChoice | None = DEFAULT_TOOL_CHOICE
83
85
  modalities: list[str] = field(default_factory=lambda: ["text", "audio"])
84
86
 
87
+
85
88
  @dataclass
86
89
  class OpenAISession:
87
90
  """Represents an OpenAI WebSocket session"""
91
+
88
92
  ws: aiohttp.ClientWebSocketResponse
89
93
  msg_queue: asyncio.Queue[Dict[str, Any]]
90
94
  tasks: list[asyncio.Task]
91
95
 
96
+
92
97
  class OpenAIRealtime(RealtimeBaseModel[OpenAIEventTypes]):
93
98
  """OpenAI's realtime model implementation."""
94
-
99
+
95
100
  def __init__(
96
101
  self,
97
102
  *,
@@ -102,7 +107,7 @@ class OpenAIRealtime(RealtimeBaseModel[OpenAIEventTypes]):
102
107
  ) -> None:
103
108
  """
104
109
  Initialize OpenAI realtime model.
105
-
110
+
106
111
  Args:
107
112
  model: The OpenAI model identifier to use (e.g. 'gpt-4', 'gpt-3.5-turbo')
108
113
  config: Optional configuration object for customizing model behavior. Contains settings for:
@@ -114,7 +119,7 @@ class OpenAIRealtime(RealtimeBaseModel[OpenAIEventTypes]):
114
119
  - modalities: List of enabled modalities ('text', 'audio')
115
120
  api_key: OpenAI API key. If not provided, will attempt to read from OPENAI_API_KEY env var
116
121
  base_url: Base URL for OpenAI API. Defaults to 'https://api.openai.com/v1'
117
-
122
+
118
123
  Raises:
119
124
  ValueError: If no API key is provided and none found in environment variables
120
125
  """
@@ -123,8 +128,13 @@ class OpenAIRealtime(RealtimeBaseModel[OpenAIEventTypes]):
123
128
  self.api_key = api_key or os.getenv("OPENAI_API_KEY")
124
129
  self.base_url = base_url or OPENAI_BASE_URL
125
130
  if not self.api_key:
126
- self.emit("error", "OpenAI API key must be provided or set in OPENAI_API_KEY environment variable")
127
- raise ValueError("OpenAI API key must be provided or set in OPENAI_API_KEY environment variable")
131
+ self.emit(
132
+ "error",
133
+ "OpenAI API key must be provided or set in OPENAI_API_KEY environment variable",
134
+ )
135
+ raise ValueError(
136
+ "OpenAI API key must be provided or set in OPENAI_API_KEY environment variable"
137
+ )
128
138
  self._http_session: Optional[aiohttp.ClientSession] = None
129
139
  self._session: Optional[OpenAISession] = None
130
140
  self._closing = False
@@ -137,34 +147,37 @@ class OpenAIRealtime(RealtimeBaseModel[OpenAIEventTypes]):
137
147
  self.input_sample_rate = 48000
138
148
  self.target_sample_rate = 16000
139
149
  self._agent_speaking = False
140
-
150
+
141
151
  def set_agent(self, agent: Agent) -> None:
142
152
  self._instructions = agent.instructions
143
153
  self._tools = agent.tools
144
154
  self.tools_formatted = self._format_tools_for_session(self._tools)
145
155
  self._formatted_tools = self.tools_formatted
146
-
156
+
147
157
  async def connect(self) -> None:
148
158
  headers = {"Agent": "VideoSDK Agents"}
149
159
  headers["Authorization"] = f"Bearer {self.api_key}"
150
160
  headers["OpenAI-Beta"] = "realtime=v1"
151
-
161
+
152
162
  url = self.process_base_url(self.base_url, self.model)
153
-
163
+
154
164
  self._session = await self._create_session(url, headers)
155
165
  await self._handle_websocket(self._session)
156
166
  await self.send_first_session_update()
157
-
167
+
158
168
  async def handle_audio_input(self, audio_data: bytes) -> None:
159
169
  """Handle incoming audio data from the user"""
160
170
  if self._session and not self._closing and "audio" in self.config.modalities:
161
171
  audio_data = np.frombuffer(audio_data, dtype=np.int16)
162
- audio_data = signal.resample(audio_data, int(len(audio_data) * self.target_sample_rate / self.input_sample_rate))
172
+ audio_data = signal.resample(
173
+ audio_data,
174
+ int(len(audio_data) * self.target_sample_rate / self.input_sample_rate),
175
+ )
163
176
  audio_data = audio_data.astype(np.int16).tobytes()
164
177
  base64_audio_data = base64.b64encode(audio_data).decode("utf-8")
165
178
  audio_event = {
166
179
  "type": "input_audio_buffer.append",
167
- "audio": base64_audio_data
180
+ "audio": base64_audio_data,
168
181
  }
169
182
  await self.send_event(audio_event)
170
183
 
@@ -176,58 +189,69 @@ class OpenAIRealtime(RealtimeBaseModel[OpenAIEventTypes]):
176
189
 
177
190
  async def _create_session(self, url: str, headers: dict) -> OpenAISession:
178
191
  """Create a new WebSocket session"""
179
-
192
+
180
193
  http_session = await self._ensure_http_session()
181
- ws = await http_session.ws_connect(url, headers=headers, autoping=True, heartbeat=10, autoclose=False, timeout=30)
194
+ ws = await http_session.ws_connect(
195
+ url,
196
+ headers=headers,
197
+ autoping=True,
198
+ heartbeat=10,
199
+ autoclose=False,
200
+ timeout=30,
201
+ )
182
202
  msg_queue: asyncio.Queue = asyncio.Queue()
183
203
  tasks: list[asyncio.Task] = []
184
-
204
+
185
205
  self._closing = False
186
-
206
+
187
207
  return OpenAISession(ws=ws, msg_queue=msg_queue, tasks=tasks)
188
-
208
+
189
209
  async def send_message(self, message: str) -> None:
190
210
  """Send a message to the OpenAI realtime API"""
191
- await self.send_event({
192
- "type": "conversation.item.create",
193
- "item": {
194
- "type": "message",
195
- "role": "assistant",
196
- "content": [
197
- {
198
- "type": "text",
199
- "text": "Repeat the user's exact message back to them:" + message + "DO NOT ADD ANYTHING ELSE",
200
- }
201
- ]
211
+ await self.send_event(
212
+ {
213
+ "type": "conversation.item.create",
214
+ "item": {
215
+ "type": "message",
216
+ "role": "assistant",
217
+ "content": [
218
+ {
219
+ "type": "text",
220
+ "text": "Repeat the user's exact message back to them:"
221
+ + message
222
+ + "DO NOT ADD ANYTHING ELSE",
223
+ }
224
+ ],
225
+ },
202
226
  }
203
- })
227
+ )
204
228
  await self.create_response()
205
-
229
+
206
230
  async def create_response(self) -> None:
207
231
  """Create a response to the OpenAI realtime API"""
208
232
  if not self._session:
209
233
  self.emit("error", "No active WebSocket session")
210
234
  raise RuntimeError("No active WebSocket session")
211
-
235
+
212
236
  response_event = {
213
237
  "type": "response.create",
214
238
  "event_id": str(uuid.uuid4()),
215
239
  "response": {
216
- "instructions": self._instructions,
217
- "metadata": {
218
- "client_event_id": str(uuid.uuid4())
219
- }
220
- }
240
+ "instructions": self._instructions,
241
+ "metadata": {"client_event_id": str(uuid.uuid4())},
242
+ },
221
243
  }
222
-
244
+
223
245
  await self.send_event(response_event)
224
-
246
+
225
247
  async def _handle_websocket(self, session: OpenAISession) -> None:
226
248
  """Start WebSocket send/receive tasks"""
227
- session.tasks.extend([
228
- asyncio.create_task(self._send_loop(session), name="send_loop"),
229
- asyncio.create_task(self._receive_loop(session), name="receive_loop")
230
- ])
249
+ session.tasks.extend(
250
+ [
251
+ asyncio.create_task(self._send_loop(session), name="send_loop"),
252
+ asyncio.create_task(self._receive_loop(session), name="receive_loop"),
253
+ ]
254
+ )
231
255
 
232
256
  async def _send_loop(self, session: OpenAISession) -> None:
233
257
  """Send messages from queue to WebSocket"""
@@ -248,7 +272,7 @@ class OpenAIRealtime(RealtimeBaseModel[OpenAIEventTypes]):
248
272
  try:
249
273
  while not self._closing:
250
274
  msg = await session.ws.receive()
251
-
275
+
252
276
  if msg.type == aiohttp.WSMsgType.CLOSED:
253
277
  self.emit("error", f"WebSocket closed with reason: {msg.extra}")
254
278
  break
@@ -265,50 +289,50 @@ class OpenAIRealtime(RealtimeBaseModel[OpenAIEventTypes]):
265
289
  async def _handle_message(self, data: dict) -> None:
266
290
  """Handle incoming WebSocket messages"""
267
291
  try:
268
- event_type = data.get('type')
292
+ event_type = data.get("type")
269
293
 
270
294
  if event_type == "input_audio_buffer.speech_started":
271
295
  await self._handle_speech_started(data)
272
-
296
+
273
297
  elif event_type == "input_audio_buffer.speech_stopped":
274
298
  await self._handle_speech_stopped(data)
275
299
 
276
300
  elif event_type == "response.created":
277
301
  await self._handle_response_created(data)
278
-
302
+
279
303
  elif event_type == "response.output_item.added":
280
304
  await self._handle_output_item_added(data)
281
-
305
+
282
306
  elif event_type == "response.content_part.added":
283
307
  await self._handle_content_part_added(data)
284
-
308
+
285
309
  elif event_type == "response.text.delta":
286
310
  await self._handle_text_delta(data)
287
311
 
288
312
  elif event_type == "response.audio.delta":
289
313
  await self._handle_audio_delta(data)
290
-
314
+
291
315
  elif event_type == "response.audio_transcript.delta":
292
316
  await self._handle_audio_transcript_delta(data)
293
-
317
+
294
318
  elif event_type == "response.done":
295
319
  await self._handle_response_done(data)
296
320
 
297
321
  elif event_type == "error":
298
322
  await self._handle_error(data)
299
-
323
+
300
324
  elif event_type == "response.function_call_arguments.delta":
301
325
  await self._handle_function_call_arguments_delta(data)
302
-
326
+
303
327
  elif event_type == "response.function_call_arguments.done":
304
328
  await self._handle_function_call_arguments_done(data)
305
-
329
+
306
330
  elif event_type == "response.output_item.done":
307
331
  await self._handle_output_item_done(data)
308
-
332
+
309
333
  elif event_type == "conversation.item.input_audio_transcription.completed":
310
334
  await self._handle_input_audio_transcription_completed(data)
311
-
335
+
312
336
  elif event_type == "response.text.done":
313
337
  await self._handle_text_done(data)
314
338
 
@@ -334,15 +358,18 @@ class OpenAIRealtime(RealtimeBaseModel[OpenAIEventTypes]):
334
358
 
335
359
  async def _handle_output_item_added(self, data: dict) -> None:
336
360
  """Handle new output item addition"""
337
-
361
+
338
362
  async def _handle_output_item_done(self, data: dict) -> None:
339
363
  """Handle output item done"""
340
364
  try:
341
365
  item = data.get("item", {})
342
- if item.get("type") == "function_call" and item.get("status") == "completed":
366
+ if (
367
+ item.get("type") == "function_call"
368
+ and item.get("status") == "completed"
369
+ ):
343
370
  name = item.get("name")
344
371
  arguments = json.loads(item.get("arguments", "{}"))
345
-
372
+
346
373
  if name and self._tools:
347
374
  for tool in self._tools:
348
375
  tool_info = get_tool_info(tool)
@@ -350,28 +377,34 @@ class OpenAIRealtime(RealtimeBaseModel[OpenAIEventTypes]):
350
377
  try:
351
378
  await realtime_metrics_collector.add_tool_call(name)
352
379
  result = await tool(**arguments)
353
- await self.send_event({
354
- "type": "conversation.item.create",
355
- "item": {
356
- "type": "function_call_output",
357
- "call_id": item.get("call_id"),
358
- "output": json.dumps(result)
380
+ await self.send_event(
381
+ {
382
+ "type": "conversation.item.create",
383
+ "item": {
384
+ "type": "function_call_output",
385
+ "call_id": item.get("call_id"),
386
+ "output": json.dumps(result),
387
+ },
359
388
  }
360
- })
361
-
362
- await self.send_event({
363
- "type": "response.create",
364
- "event_id": str(uuid.uuid4()),
365
- "response": {
366
- "instructions": self._instructions,
367
- "metadata": {
368
- "client_event_id": str(uuid.uuid4())
369
- }
389
+ )
390
+
391
+ await self.send_event(
392
+ {
393
+ "type": "response.create",
394
+ "event_id": str(uuid.uuid4()),
395
+ "response": {
396
+ "instructions": self._instructions,
397
+ "metadata": {
398
+ "client_event_id": str(uuid.uuid4())
399
+ },
400
+ },
370
401
  }
371
- })
372
-
402
+ )
403
+
373
404
  except Exception as e:
374
- self.emit("error", f"Error executing function {name}: {e}")
405
+ self.emit(
406
+ "error", f"Error executing function {name}: {e}"
407
+ )
375
408
  break
376
409
  except Exception as e:
377
410
  self.emit("error", f"Error handling output item done: {e}")
@@ -387,7 +420,7 @@ class OpenAIRealtime(RealtimeBaseModel[OpenAIEventTypes]):
387
420
  """Handle audio chunk"""
388
421
  if "audio" not in self.config.modalities:
389
422
  return
390
-
423
+
391
424
  try:
392
425
  if not self._agent_speaking:
393
426
  await realtime_metrics_collector.set_agent_speech_start()
@@ -395,18 +428,17 @@ class OpenAIRealtime(RealtimeBaseModel[OpenAIEventTypes]):
395
428
  base64_audio_data = base64.b64decode(data.get("delta"))
396
429
  if base64_audio_data:
397
430
  if self.audio_track and self.loop:
398
- self.loop.create_task(self.audio_track.add_new_bytes(base64_audio_data))
431
+ asyncio.create_task(
432
+ self.audio_track.add_new_bytes(base64_audio_data)
433
+ )
399
434
  except Exception as e:
400
435
  self.emit("error", f"Error handling audio delta: {e}")
401
436
  traceback.print_exc()
402
-
437
+
403
438
  async def interrupt(self) -> None:
404
439
  """Interrupt the current response and flush audio"""
405
440
  if self._session and not self._closing:
406
- cancel_event = {
407
- "type": "response.cancel",
408
- "event_id": str(uuid.uuid4())
409
- }
441
+ cancel_event = {"type": "response.cancel", "event_id": str(uuid.uuid4())}
410
442
  await self.send_event(cancel_event)
411
443
  await realtime_metrics_collector.set_interrupted()
412
444
  if self.audio_track:
@@ -414,11 +446,11 @@ class OpenAIRealtime(RealtimeBaseModel[OpenAIEventTypes]):
414
446
  if self._agent_speaking:
415
447
  await realtime_metrics_collector.set_agent_speech_end(timeout=1.0)
416
448
  self._agent_speaking = False
417
-
449
+
418
450
  async def _handle_audio_transcript_delta(self, data: dict) -> None:
419
451
  """Handle transcript chunk"""
420
452
  delta_content = data.get("delta", "")
421
- if not hasattr(self, '_current_audio_transcript'):
453
+ if not hasattr(self, "_current_audio_transcript"):
422
454
  self._current_audio_transcript = ""
423
455
  self._current_audio_transcript += delta_content
424
456
 
@@ -428,25 +460,35 @@ class OpenAIRealtime(RealtimeBaseModel[OpenAIEventTypes]):
428
460
  if transcript:
429
461
  await realtime_metrics_collector.set_user_transcript(transcript)
430
462
  try:
431
- self.emit("realtime_model_transcription", {
432
- "role": "user",
433
- "text": transcript,
434
- "is_final": True
435
- })
463
+ self.emit(
464
+ "realtime_model_transcription",
465
+ {"role": "user", "text": transcript, "is_final": True},
466
+ )
436
467
  except Exception:
437
468
  pass
438
469
 
439
470
  async def _handle_response_done(self, data: dict) -> None:
440
471
  """Handle response completion for agent transcript"""
441
- if hasattr(self, '_current_audio_transcript') and self._current_audio_transcript:
442
- await realtime_metrics_collector.set_agent_response(self._current_audio_transcript)
443
- global_event_emitter.emit("text_response", {"text": self._current_audio_transcript, "type": "done"})
472
+ if (
473
+ hasattr(self, "_current_audio_transcript")
474
+ and self._current_audio_transcript
475
+ ):
476
+ await realtime_metrics_collector.set_agent_response(
477
+ self._current_audio_transcript
478
+ )
479
+ global_event_emitter.emit(
480
+ "text_response",
481
+ {"text": self._current_audio_transcript, "type": "done"},
482
+ )
444
483
  try:
445
- self.emit("realtime_model_transcription", {
446
- "role": "agent",
447
- "text": self._current_audio_transcript,
448
- "is_final": True
449
- })
484
+ self.emit(
485
+ "realtime_model_transcription",
486
+ {
487
+ "role": "agent",
488
+ "text": self._current_audio_transcript,
489
+ "is_final": True,
490
+ },
491
+ )
450
492
  except Exception:
451
493
  pass
452
494
  self._current_audio_transcript = ""
@@ -465,11 +507,11 @@ class OpenAIRealtime(RealtimeBaseModel[OpenAIEventTypes]):
465
507
 
466
508
  async def _cleanup_session(self, session: OpenAISession) -> None:
467
509
  """Clean up session resources"""
468
- if self._closing:
510
+ if self._closing:
469
511
  return
470
-
512
+
471
513
  self._closing = True
472
-
514
+
473
515
  for task in session.tasks:
474
516
  if not task.done():
475
517
  task.cancel()
@@ -483,7 +525,7 @@ class OpenAIRealtime(RealtimeBaseModel[OpenAIEventTypes]):
483
525
  await session.ws.close()
484
526
  except Exception:
485
527
  pass
486
-
528
+
487
529
  async def send_event(self, event: Dict[str, Any]) -> None:
488
530
  """Send an event to the WebSocket"""
489
531
  if self._session and not self._closing:
@@ -493,15 +535,15 @@ class OpenAIRealtime(RealtimeBaseModel[OpenAIEventTypes]):
493
535
  """Cleanup all resources"""
494
536
  if self._closing:
495
537
  return
496
-
538
+
497
539
  self._closing = True
498
-
540
+
499
541
  if self._session:
500
542
  await self._cleanup_session(self._session)
501
-
543
+
502
544
  if self._http_session and not self._http_session.closed:
503
545
  await self._http_session.close()
504
-
546
+
505
547
  async def send_first_session_update(self) -> None:
506
548
  """Send initial session update with default values after connection"""
507
549
  if not self._session:
@@ -509,41 +551,54 @@ class OpenAIRealtime(RealtimeBaseModel[OpenAIEventTypes]):
509
551
 
510
552
  turn_detection = None
511
553
  input_audio_transcription = None
512
-
554
+
513
555
  if "audio" in self.config.modalities:
514
- turn_detection = self.config.turn_detection.model_dump(
515
- by_alias=True,
516
- exclude_unset=True,
517
- exclude_defaults=True,
518
- ) if self.config.turn_detection else None
519
- input_audio_transcription = self.config.input_audio_transcription.model_dump(
520
- by_alias=True,
521
- exclude_unset=True,
522
- exclude_defaults=True,
523
- ) if self.config.input_audio_transcription else None
556
+ turn_detection = (
557
+ self.config.turn_detection.model_dump(
558
+ by_alias=True,
559
+ exclude_unset=True,
560
+ exclude_defaults=True,
561
+ )
562
+ if self.config.turn_detection
563
+ else None
564
+ )
565
+ input_audio_transcription = (
566
+ self.config.input_audio_transcription.model_dump(
567
+ by_alias=True,
568
+ exclude_unset=True,
569
+ exclude_defaults=True,
570
+ )
571
+ if self.config.input_audio_transcription
572
+ else None
573
+ )
524
574
 
525
575
  session_update = {
526
576
  "type": "session.update",
527
577
  "session": {
528
578
  "model": self.model,
529
- "instructions": self._instructions or "You are a helpful assistant that can answer questions and help with tasks.",
579
+ "instructions": self._instructions
580
+ or "You are a helpful assistant that can answer questions and help with tasks.",
530
581
  "temperature": self.config.temperature,
531
582
  "tool_choice": self.config.tool_choice,
532
583
  "tools": self._formatted_tools or [],
533
584
  "modalities": self.config.modalities,
534
- "max_response_output_tokens": "inf"
535
- }
585
+ "max_response_output_tokens": "inf",
586
+ },
536
587
  }
537
-
588
+
538
589
  if "audio" in self.config.modalities:
539
590
  session_update["session"]["voice"] = self.config.voice
540
591
  session_update["session"]["input_audio_format"] = DEFAULT_INPUT_AUDIO_FORMAT
541
- session_update["session"]["output_audio_format"] = DEFAULT_OUTPUT_AUDIO_FORMAT
592
+ session_update["session"][
593
+ "output_audio_format"
594
+ ] = DEFAULT_OUTPUT_AUDIO_FORMAT
542
595
  if turn_detection:
543
596
  session_update["session"]["turn_detection"] = turn_detection
544
597
  if input_audio_transcription:
545
- session_update["session"]["input_audio_transcription"] = input_audio_transcription
546
-
598
+ session_update["session"][
599
+ "input_audio_transcription"
600
+ ] = input_audio_transcription
601
+
547
602
  # Send the event
548
603
  await self.send_event(session_update)
549
604
 
@@ -560,27 +615,31 @@ class OpenAIRealtime(RealtimeBaseModel[OpenAIEventTypes]):
560
615
  path = parsed_url.path
561
616
 
562
617
  if "model" not in query_params:
563
- query_params["model"] = [model]
618
+ query_params["model"] = [model]
564
619
 
565
620
  new_query = urlencode(query_params, doseq=True)
566
- new_url = urlunparse((parsed_url.scheme, parsed_url.netloc, path, "", new_query, ""))
621
+ new_url = urlunparse(
622
+ (parsed_url.scheme, parsed_url.netloc, path, "", new_query, "")
623
+ )
567
624
 
568
625
  return new_url
569
626
 
570
- def _format_tools_for_session(self, tools: List[FunctionTool]) -> List[Dict[str, Any]]:
627
+ def _format_tools_for_session(
628
+ self, tools: List[FunctionTool]
629
+ ) -> List[Dict[str, Any]]:
571
630
  """Format tools for OpenAI session update"""
572
631
  oai_tools = []
573
632
  for tool in tools:
574
633
  if not is_function_tool(tool):
575
634
  continue
576
-
635
+
577
636
  try:
578
637
  tool_schema = build_openai_schema(tool)
579
638
  oai_tools.append(tool_schema)
580
639
  except Exception as e:
581
640
  self.emit("error", f"Failed to format tool {tool}: {e}")
582
641
  continue
583
-
642
+
584
643
  return oai_tools
585
644
 
586
645
  async def send_text_message(self, message: str) -> None:
@@ -588,19 +647,15 @@ class OpenAIRealtime(RealtimeBaseModel[OpenAIEventTypes]):
588
647
  if not self._session:
589
648
  self.emit("error", "No active WebSocket session")
590
649
  raise RuntimeError("No active WebSocket session")
591
-
592
- await self.send_event({
593
- "type": "conversation.item.create",
594
- "item": {
595
- "type": "message",
596
- "role": "user",
597
- "content": [
598
- {
599
- "type": "input_text",
600
- "text": message
601
- }
602
- ]
650
+
651
+ await self.send_event(
652
+ {
653
+ "type": "conversation.item.create",
654
+ "item": {
655
+ "type": "message",
656
+ "role": "user",
657
+ "content": [{"type": "input_text", "text": message}],
658
+ },
603
659
  }
604
- })
660
+ )
605
661
  await self.create_response()
606
-
@@ -6,14 +6,17 @@ import os
6
6
  import openai
7
7
  import asyncio
8
8
 
9
- from videosdk.agents import TTS
9
+ from videosdk.agents import TTS, segment_text
10
10
 
11
11
  OPENAI_TTS_SAMPLE_RATE = 24000
12
12
  OPENAI_TTS_CHANNELS = 1
13
13
 
14
14
  DEFAULT_MODEL = "gpt-4o-mini-tts"
15
15
  DEFAULT_VOICE = "ash"
16
- _RESPONSE_FORMATS = Union[Literal["mp3", "opus", "aac", "flac", "wav", "pcm"], str]
16
+ _RESPONSE_FORMATS = Union[Literal["mp3",
17
+ "opus", "aac", "flac", "wav", "pcm"], str]
18
+
19
+
17
20
 
18
21
  class OpenAITTS(TTS):
19
22
  def __init__(
@@ -25,10 +28,10 @@ class OpenAITTS(TTS):
25
28
  instructions: str | None = None,
26
29
  api_key: str | None = None,
27
30
  base_url: str | None = None,
28
- response_format: str = "pcm"
31
+ response_format: str = "pcm",
29
32
  ) -> None:
30
33
  super().__init__(sample_rate=OPENAI_TTS_SAMPLE_RATE, num_channels=OPENAI_TTS_CHANNELS)
31
-
34
+
32
35
  self.model = model
33
36
  self.voice = voice
34
37
  self.speed = speed
@@ -37,17 +40,21 @@ class OpenAITTS(TTS):
37
40
  self.loop = None
38
41
  self.response_format = response_format
39
42
  self._first_chunk_sent = False
40
-
43
+ self._current_synthesis_task: asyncio.Task | None = None
44
+ self._interrupted = False
45
+
41
46
  self.api_key = api_key or os.getenv("OPENAI_API_KEY")
42
47
  if not self.api_key:
43
- raise ValueError("OpenAI API key must be provided either through api_key parameter or OPENAI_API_KEY environment variable")
44
-
48
+ raise ValueError(
49
+ "OpenAI API key must be provided either through api_key parameter or OPENAI_API_KEY environment variable")
50
+
45
51
  self._client = openai.AsyncClient(
46
52
  max_retries=0,
47
53
  api_key=self.api_key,
48
54
  base_url=base_url or None,
49
55
  http_client=httpx.AsyncClient(
50
- timeout=httpx.Timeout(connect=15.0, read=5.0, write=5.0, pool=5.0),
56
+ timeout=httpx.Timeout(
57
+ connect=15.0, read=5.0, write=5.0, pool=5.0),
51
58
  follow_redirects=True,
52
59
  limits=httpx.Limits(
53
60
  max_connections=50,
@@ -60,71 +67,86 @@ class OpenAITTS(TTS):
60
67
  def reset_first_audio_tracking(self) -> None:
61
68
  """Reset the first audio tracking state for next TTS task"""
62
69
  self._first_chunk_sent = False
63
-
70
+
64
71
  async def synthesize(
65
72
  self,
66
73
  text: AsyncIterator[str] | str,
67
74
  voice_id: Optional[str] = None,
68
- **kwargs: Any
75
+ **kwargs: Any,
69
76
  ) -> None:
70
77
  """
71
78
  Convert text to speech using OpenAI's TTS API and stream to audio track
72
-
79
+
73
80
  Args:
74
81
  text: Text to convert to speech
75
82
  voice_id: Optional voice override
76
83
  **kwargs: Additional provider-specific arguments
77
84
  """
78
85
  try:
79
- if isinstance(text, AsyncIterator):
80
- full_text = ""
81
- async for chunk in text:
82
- full_text += chunk
83
- else:
84
- full_text = text
85
-
86
86
  if not self.audio_track or not self.loop:
87
87
  self.emit("error", "Audio track or event loop not set")
88
88
  return
89
89
 
90
+ self._interrupted = False
91
+
92
+ if isinstance(text, AsyncIterator):
93
+ async for segment in segment_text(text):
94
+ if self._interrupted:
95
+ break
96
+ await self._synthesize_segment(segment, voice_id, **kwargs)
97
+ else:
98
+ if not self._interrupted:
99
+ await self._synthesize_segment(text, voice_id, **kwargs)
100
+
101
+ except Exception as e:
102
+ self.emit("error", f"TTS synthesis failed: {str(e)}")
103
+
104
+ async def _synthesize_segment(self, text: str, voice_id: Optional[str] = None, **kwargs: Any) -> None:
105
+ """Synthesize a single text segment"""
106
+ if not text.strip() or self._interrupted:
107
+ return
108
+
109
+ try:
90
110
  audio_data = b""
91
111
  async with self._client.audio.speech.with_streaming_response.create(
92
112
  model=self.model,
93
113
  voice=voice_id or self.voice,
94
- input=full_text,
114
+ input=text,
95
115
  speed=self.speed,
96
116
  response_format=self.response_format,
97
- **({"instructions": self.instructions} if self.instructions else {})
117
+ **({"instructions": self.instructions} if self.instructions else {}),
98
118
  ) as response:
99
119
  async for chunk in response.iter_bytes():
120
+ if self._interrupted:
121
+ break
100
122
  if chunk:
101
123
  audio_data += chunk
102
124
 
103
- if audio_data:
125
+ if audio_data and not self._interrupted:
104
126
  await self._stream_audio_chunks(audio_data)
105
127
 
106
- except openai.APIError as e:
107
- self.emit("error", str(e))
108
128
  except Exception as e:
109
- self.emit("error", f"TTS synthesis failed: {str(e)}")
129
+ if not self._interrupted:
130
+ self.emit("error", f"Segment synthesis failed: {str(e)}")
110
131
 
111
132
  async def _stream_audio_chunks(self, audio_bytes: bytes) -> None:
112
133
  """Stream audio data in chunks for smooth playback"""
113
- chunk_size = int(OPENAI_TTS_SAMPLE_RATE * OPENAI_TTS_CHANNELS * 2 * 20 / 1000)
114
-
134
+ chunk_size = int(OPENAI_TTS_SAMPLE_RATE *
135
+ OPENAI_TTS_CHANNELS * 2 * 20 / 1000)
136
+
115
137
  for i in range(0, len(audio_bytes), chunk_size):
116
138
  chunk = audio_bytes[i:i + chunk_size]
117
-
139
+
118
140
  if len(chunk) < chunk_size and len(chunk) > 0:
119
141
  padding_needed = chunk_size - len(chunk)
120
142
  chunk += b'\x00' * padding_needed
121
-
143
+
122
144
  if len(chunk) == chunk_size:
123
145
  if not self._first_chunk_sent and self._first_audio_callback:
124
146
  self._first_chunk_sent = True
125
147
  await self._first_audio_callback()
126
-
127
- self.loop.create_task(self.audio_track.add_new_bytes(chunk))
148
+
149
+ asyncio.create_task(self.audio_track.add_new_bytes(chunk))
128
150
  await asyncio.sleep(0.001)
129
151
 
130
152
  async def aclose(self) -> None:
@@ -133,6 +155,9 @@ class OpenAITTS(TTS):
133
155
  await super().aclose()
134
156
 
135
157
  async def interrupt(self) -> None:
136
- """Interrupt the TTS process"""
158
+ """Interrupt TTS synthesis"""
159
+ self._interrupted = True
160
+ if self._current_synthesis_task:
161
+ self._current_synthesis_task.cancel()
137
162
  if self.audio_track:
138
- self.audio_track.interrupt()
163
+ self.audio_track.interrupt()
@@ -0,0 +1 @@
1
+ __version__ = "0.0.28"
@@ -1 +0,0 @@
1
- __version__ = "0.0.26"