dv-pipecat-ai 0.0.85.dev824__py3-none-any.whl → 0.0.85.dev858__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dv-pipecat-ai might be problematic. Click here for more details.

Files changed (31) hide show
  1. {dv_pipecat_ai-0.0.85.dev824.dist-info → dv_pipecat_ai-0.0.85.dev858.dist-info}/METADATA +2 -1
  2. {dv_pipecat_ai-0.0.85.dev824.dist-info → dv_pipecat_ai-0.0.85.dev858.dist-info}/RECORD +31 -29
  3. pipecat/audio/turn/smart_turn/local_smart_turn_v3.py +5 -1
  4. pipecat/frames/frames.py +22 -0
  5. pipecat/metrics/connection_metrics.py +45 -0
  6. pipecat/processors/aggregators/llm_response.py +15 -9
  7. pipecat/processors/dtmf_aggregator.py +17 -21
  8. pipecat/processors/frame_processor.py +44 -1
  9. pipecat/processors/metrics/frame_processor_metrics.py +108 -0
  10. pipecat/processors/transcript_processor.py +2 -1
  11. pipecat/serializers/__init__.py +2 -0
  12. pipecat/serializers/asterisk.py +16 -2
  13. pipecat/serializers/convox.py +2 -2
  14. pipecat/serializers/custom.py +2 -2
  15. pipecat/serializers/vi.py +326 -0
  16. pipecat/services/cartesia/tts.py +75 -10
  17. pipecat/services/deepgram/stt.py +317 -17
  18. pipecat/services/elevenlabs/stt.py +487 -19
  19. pipecat/services/elevenlabs/tts.py +28 -4
  20. pipecat/services/google/llm.py +26 -11
  21. pipecat/services/openai/base_llm.py +79 -14
  22. pipecat/services/salesforce/llm.py +64 -59
  23. pipecat/services/sarvam/tts.py +0 -1
  24. pipecat/services/soniox/stt.py +45 -10
  25. pipecat/services/vistaar/llm.py +97 -6
  26. pipecat/transcriptions/language.py +50 -0
  27. pipecat/transports/base_input.py +15 -11
  28. pipecat/transports/base_output.py +26 -3
  29. {dv_pipecat_ai-0.0.85.dev824.dist-info → dv_pipecat_ai-0.0.85.dev858.dist-info}/WHEEL +0 -0
  30. {dv_pipecat_ai-0.0.85.dev824.dist-info → dv_pipecat_ai-0.0.85.dev858.dist-info}/licenses/LICENSE +0 -0
  31. {dv_pipecat_ai-0.0.85.dev824.dist-info → dv_pipecat_ai-0.0.85.dev858.dist-info}/top_level.txt +0 -0
@@ -18,6 +18,7 @@ from openai import (
18
18
  APITimeoutError,
19
19
  AsyncOpenAI,
20
20
  AsyncStream,
21
+ BadRequestError,
21
22
  DefaultAsyncHttpxClient,
22
23
  )
23
24
  from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam
@@ -32,6 +33,7 @@ from pipecat.frames.frames import (
32
33
  LLMMessagesFrame,
33
34
  LLMTextFrame,
34
35
  LLMUpdateSettingsFrame,
36
+ WarmupLLMFrame,
35
37
  )
36
38
  from pipecat.metrics.metrics import LLMTokenUsage
37
39
  from pipecat.processors.aggregators.llm_context import LLMContext
@@ -99,6 +101,7 @@ class BaseOpenAILLMService(LLMService):
99
101
  params: Optional[InputParams] = None,
100
102
  retry_timeout_secs: Optional[float] = 5.0,
101
103
  retry_on_timeout: Optional[bool] = False,
104
+ enable_warmup: bool = False,
102
105
  **kwargs,
103
106
  ):
104
107
  """Initialize the BaseOpenAILLMService.
@@ -113,6 +116,7 @@ class BaseOpenAILLMService(LLMService):
113
116
  params: Input parameters for model configuration and behavior.
114
117
  retry_timeout_secs: Request timeout in seconds. Defaults to 5.0 seconds.
115
118
  retry_on_timeout: Whether to retry the request once if it times out.
119
+ enable_warmup: Whether to enable LLM cache warmup. Defaults to False.
116
120
  **kwargs: Additional arguments passed to the parent LLMService.
117
121
  """
118
122
  super().__init__(**kwargs)
@@ -132,6 +136,7 @@ class BaseOpenAILLMService(LLMService):
132
136
  }
133
137
  self._retry_timeout_secs = retry_timeout_secs
134
138
  self._retry_on_timeout = retry_on_timeout
139
+ self._enable_warmup = enable_warmup
135
140
  self.set_model_name(model)
136
141
  self._client = self.create_client(
137
142
  api_key=api_key,
@@ -200,20 +205,29 @@ class BaseOpenAILLMService(LLMService):
200
205
  """
201
206
  params = self.build_chat_completion_params(params_from_context)
202
207
 
203
- if self._retry_on_timeout:
204
- try:
205
- chunks = await asyncio.wait_for(
206
- self._client.chat.completions.create(**params), timeout=self._retry_timeout_secs
207
- )
208
- return chunks
209
- except (APITimeoutError, asyncio.TimeoutError):
210
- # Retry, this time without a timeout so we get a response
211
- logger.debug(f"{self}: Retrying chat completion due to timeout")
208
+ await self.start_connection_metrics()
209
+
210
+ try:
211
+ if self._retry_on_timeout:
212
+ try:
213
+ chunks = await asyncio.wait_for(
214
+ self._client.chat.completions.create(**params), timeout=self._retry_timeout_secs
215
+ )
216
+ await self.stop_connection_metrics(success=True, connection_type="http")
217
+ return chunks
218
+ except (APITimeoutError, asyncio.TimeoutError):
219
+ # Retry, this time without a timeout so we get a response
220
+ logger.debug(f"{self}: Retrying chat completion due to timeout")
221
+ chunks = await self._client.chat.completions.create(**params)
222
+ await self.stop_connection_metrics(success=True, connection_type="http")
223
+ return chunks
224
+ else:
212
225
  chunks = await self._client.chat.completions.create(**params)
226
+ await self.stop_connection_metrics(success=True, connection_type="http")
213
227
  return chunks
214
- else:
215
- chunks = await self._client.chat.completions.create(**params)
216
- return chunks
228
+ except Exception as e:
229
+ await self.stop_connection_metrics(success=False, error=str(e), connection_type="http")
230
+ raise
217
231
 
218
232
  def build_chat_completion_params(self, params_from_context: OpenAILLMInvocationParams) -> dict:
219
233
  """Build parameters for chat completion request.
@@ -438,14 +452,19 @@ class BaseOpenAILLMService(LLMService):
438
452
  completions and manage settings.
439
453
  >>>>>>> dv-stage
440
454
 
441
- Args:
455
+ Args:
442
456
  frame: The frame to process.
443
457
  direction: The direction of frame processing.
444
458
  """
445
459
  await super().process_frame(frame, direction)
446
460
 
447
461
  context = None
448
- if isinstance(frame, OpenAILLMContextFrame):
462
+ if isinstance(frame, WarmupLLMFrame):
463
+ # Handle warmup frame - prime cache without emitting response
464
+ # Run in background to avoid blocking the pipeline
465
+ asyncio.create_task(self._handle_warmup_frame(frame))
466
+ return # Don't process further, warmup is silent
467
+ elif isinstance(frame, OpenAILLMContextFrame):
449
468
  # Handle OpenAI-specific context frames
450
469
  context = frame.context
451
470
  elif isinstance(frame, LLMContextFrame):
@@ -470,3 +489,49 @@ class BaseOpenAILLMService(LLMService):
470
489
  finally:
471
490
  await self.stop_processing_metrics()
472
491
  await self.push_frame(LLMFullResponseEndFrame())
492
+
493
+ def _is_gpt5_model(self) -> bool:
494
+ """Check if the current model is a GPT-5 series model that requires max_completion_tokens."""
495
+ model = (self.model_name or "").lower()
496
+ return model.startswith("gpt-5")
497
+
498
+ async def _handle_warmup_frame(self, frame: WarmupLLMFrame):
499
+ """Handle WarmupLLMFrame to prime the LLM cache without emitting responses.
500
+
501
+ This method sends a minimal request to the LLM to warm up any provider-side
502
+ caches (like prompt caching). The response is discarded and no frames are emitted.
503
+
504
+ Args:
505
+ frame: WarmupLLMFrame containing the messages to cache.
506
+ """
507
+ # Skip warmup if disabled
508
+ if not self._enable_warmup:
509
+ self.logger.debug("LLM warmup is disabled, skipping")
510
+ return
511
+
512
+ try:
513
+ # Use the provided messages for warmup
514
+ messages: List[ChatCompletionMessageParam] = frame.messages # type: ignore
515
+
516
+ # Make a non-streaming call to warm the cache
517
+ # We use a minimal token limit to reduce latency and cost
518
+ # GPT-5 series models require max_completion_tokens instead of max_tokens
519
+ warmup_params = {
520
+ "model": self.model_name,
521
+ "messages": messages,
522
+ "stream": False,
523
+ }
524
+
525
+ if self._is_gpt5_model():
526
+ warmup_params["max_completion_tokens"] = 10
527
+ else:
528
+ warmup_params["max_tokens"] = 10
529
+
530
+ await self._client.chat.completions.create(**warmup_params)
531
+
532
+ self.logger.info("LLM cache warmed successfully")
533
+ # Intentionally don't emit any frames - this is a silent warmup
534
+
535
+ except Exception as e:
536
+ self.logger.error(f"Failed to warm LLM cache: {e}")
537
+ # Don't propagate error - warmup failure shouldn't break the bot
@@ -13,6 +13,7 @@ from dataclasses import dataclass
13
13
  from typing import AsyncGenerator, Dict, Optional
14
14
 
15
15
  import httpx
16
+ from env_config import api_config
16
17
  from loguru import logger
17
18
 
18
19
  from pipecat.frames.frames import (
@@ -23,6 +24,10 @@ from pipecat.frames.frames import (
23
24
  LLMTextFrame,
24
25
  LLMUpdateSettingsFrame,
25
26
  )
27
+ from pipecat.processors.aggregators.llm_response import (
28
+ LLMAssistantAggregatorParams,
29
+ LLMUserAggregatorParams,
30
+ )
26
31
  from pipecat.processors.aggregators.openai_llm_context import (
27
32
  OpenAILLMContext,
28
33
  OpenAILLMContextFrame,
@@ -34,11 +39,6 @@ from pipecat.services.openai.llm import (
34
39
  OpenAIContextAggregatorPair,
35
40
  OpenAIUserContextAggregator,
36
41
  )
37
- from pipecat.processors.aggregators.llm_response import (
38
- LLMAssistantAggregatorParams,
39
- LLMUserAggregatorParams,
40
- )
41
- from env_config import api_config
42
42
  from pipecat.utils.redis import create_async_redis_client
43
43
 
44
44
 
@@ -96,12 +96,11 @@ class SalesforceAgentLLMService(LLMService):
96
96
  # Initialize parent LLM service
97
97
  super().__init__(**kwargs)
98
98
  self._agent_id = agent_id
99
- self._org_domain = org_domain
99
+ self._org_domain = org_domain
100
100
  self._client_id = client_id
101
101
  self._client_secret = client_secret
102
102
  self._api_host = api_host
103
103
 
104
-
105
104
  # Validate required environment variables
106
105
  required_vars = {
107
106
  "SALESFORCE_AGENT_ID": self._agent_id,
@@ -145,7 +144,6 @@ class SalesforceAgentLLMService(LLMService):
145
144
  )
146
145
 
147
146
  self._schedule_session_warmup()
148
-
149
147
 
150
148
  async def __aenter__(self):
151
149
  """Async context manager entry."""
@@ -237,7 +235,7 @@ class SalesforceAgentLLMService(LLMService):
237
235
  return
238
236
 
239
237
  ttl_seconds = 3600 # Default fallback
240
-
238
+
241
239
  # Try to get expiration from expires_in parameter first
242
240
  if expires_in is not None:
243
241
  try:
@@ -246,7 +244,7 @@ class SalesforceAgentLLMService(LLMService):
246
244
  except (TypeError, ValueError):
247
245
  logger.debug("Unable to parse expires_in parameter")
248
246
  expires_in = None
249
-
247
+
250
248
  # If no expires_in available, use default TTL
251
249
  if expires_in is None:
252
250
  logger.debug("No expiration info found, using default TTL")
@@ -271,7 +269,7 @@ class SalesforceAgentLLMService(LLMService):
271
269
 
272
270
  async def _get_access_token(self, *, force_refresh: bool = False) -> str:
273
271
  """Get OAuth access token using client credentials.
274
-
272
+
275
273
  Args:
276
274
  force_refresh: If True, skip cache and fetch fresh token from Salesforce.
277
275
  """
@@ -301,15 +299,15 @@ class SalesforceAgentLLMService(LLMService):
301
299
 
302
300
  async def _make_authenticated_request(self, method: str, url: str, **kwargs):
303
301
  """Make an authenticated HTTP request with automatic token refresh on auth errors.
304
-
302
+
305
303
  Args:
306
304
  method: HTTP method (GET, POST, DELETE, etc.)
307
305
  url: Request URL
308
306
  **kwargs: Additional arguments passed to httpx request
309
-
307
+
310
308
  Returns:
311
309
  httpx.Response: The HTTP response
312
-
310
+
313
311
  Raises:
314
312
  Exception: If request fails after token refresh attempt
315
313
  """
@@ -318,7 +316,7 @@ class SalesforceAgentLLMService(LLMService):
318
316
  headers = kwargs.get("headers", {})
319
317
  headers["Authorization"] = f"Bearer {access_token}"
320
318
  kwargs["headers"] = headers
321
-
319
+
322
320
  try:
323
321
  response = await self._http_client.request(method, url, **kwargs)
324
322
  response.raise_for_status()
@@ -326,14 +324,16 @@ class SalesforceAgentLLMService(LLMService):
326
324
  except httpx.HTTPStatusError as e:
327
325
  # If authentication error, clear cache and retry with fresh token
328
326
  if e.response.status_code in (401, 403):
329
- logger.warning(f"Salesforce authentication error ({e.response.status_code}), refreshing token")
327
+ logger.warning(
328
+ f"Salesforce authentication error ({e.response.status_code}), refreshing token"
329
+ )
330
330
  await self._clear_cached_access_token()
331
-
331
+
332
332
  # Retry with fresh token
333
333
  fresh_token = await self._get_access_token(force_refresh=True)
334
334
  headers["Authorization"] = f"Bearer {fresh_token}"
335
335
  kwargs["headers"] = headers
336
-
336
+
337
337
  response = await self._http_client.request(method, url, **kwargs)
338
338
  response.raise_for_status()
339
339
  return response
@@ -359,9 +359,7 @@ class SalesforceAgentLLMService(LLMService):
359
359
 
360
360
  try:
361
361
  response = await self._make_authenticated_request(
362
- "POST", session_url,
363
- headers={"Content-Type": "application/json"},
364
- json=payload
362
+ "POST", session_url, headers={"Content-Type": "application/json"}, json=payload
365
363
  )
366
364
  session_data = response.json()
367
365
  session_id = session_data["sessionId"]
@@ -419,8 +417,7 @@ class SalesforceAgentLLMService(LLMService):
419
417
  # End the session via API
420
418
  url = f"{self._api_host}/einstein/ai-agent/v1/sessions/{session_id}"
421
419
  await self._make_authenticated_request(
422
- "DELETE", url,
423
- headers={"x-session-end-reason": "UserRequest"}
420
+ "DELETE", url, headers={"x-session-end-reason": "UserRequest"}
424
421
  )
425
422
  except Exception as e:
426
423
  logger.warning(f"Failed to end session {session_id}: {e}")
@@ -431,32 +428,32 @@ class SalesforceAgentLLMService(LLMService):
431
428
 
432
429
  def _extract_user_message(self, context: OpenAILLMContext) -> str:
433
430
  """Extract the last user message from context.
434
-
431
+
435
432
  Similar to Vistaar pattern - extract only the most recent user message.
436
-
433
+
437
434
  Args:
438
435
  context: The OpenAI LLM context containing messages.
439
-
436
+
440
437
  Returns:
441
438
  The last user message as a string.
442
439
  """
443
440
  messages = context.get_messages()
444
-
441
+
445
442
  # Find the last user message (iterate in reverse for efficiency)
446
443
  for message in reversed(messages):
447
444
  if message.get("role") == "user":
448
445
  content = message.get("content", "")
449
-
446
+
450
447
  # Handle content that might be a list (for multimodal messages)
451
448
  if isinstance(content, list):
452
449
  text_parts = [
453
450
  item.get("text", "") for item in content if item.get("type") == "text"
454
451
  ]
455
452
  content = " ".join(text_parts)
456
-
453
+
457
454
  if isinstance(content, str):
458
455
  return content.strip()
459
-
456
+
460
457
  return ""
461
458
 
462
459
  def _generate_sequence_id(self) -> int:
@@ -464,7 +461,9 @@ class SalesforceAgentLLMService(LLMService):
464
461
  self._sequence_counter += 1
465
462
  return self._sequence_counter
466
463
 
467
- async def _stream_salesforce_response(self, session_id: str, user_message: str) -> AsyncGenerator[str, None]:
464
+ async def _stream_salesforce_response(
465
+ self, session_id: str, user_message: str
466
+ ) -> AsyncGenerator[str, None]:
468
467
  """Stream response from Salesforce Agent API."""
469
468
  url = f"{self._api_host}/einstein/ai-agent/v1/sessions/{session_id}/messages/stream"
470
469
 
@@ -472,15 +471,9 @@ class SalesforceAgentLLMService(LLMService):
472
471
  "message": {
473
472
  "sequenceId": self._generate_sequence_id(),
474
473
  "type": "Text",
475
- "text": user_message
474
+ "text": user_message,
476
475
  },
477
- "variables": [
478
- {
479
- "name": "$Context.EndUserLanguage",
480
- "type": "Text",
481
- "value": "en_US"
482
- }
483
- ]
476
+ "variables": [{"name": "$Context.EndUserLanguage", "type": "Text", "value": "en_US"}],
484
477
  }
485
478
 
486
479
  # First attempt with current token
@@ -493,9 +486,11 @@ class SalesforceAgentLLMService(LLMService):
493
486
 
494
487
  try:
495
488
  logger.info(f"🌐 Salesforce API request: {user_message[:50]}...")
496
- async with self._http_client.stream("POST", url, headers=headers, json=message_data) as response:
489
+ async with self._http_client.stream(
490
+ "POST", url, headers=headers, json=message_data
491
+ ) as response:
497
492
  response.raise_for_status()
498
-
493
+
499
494
  async for line in response.aiter_lines():
500
495
  if not line:
501
496
  continue
@@ -525,17 +520,23 @@ class SalesforceAgentLLMService(LLMService):
525
520
  except httpx.HTTPStatusError as e:
526
521
  # If authentication error, retry with fresh token
527
522
  if e.response.status_code in (401, 403):
528
- logger.warning(f"Salesforce streaming authentication error ({e.response.status_code}), refreshing token")
523
+ logger.warning(
524
+ f"Salesforce streaming authentication error ({e.response.status_code}), refreshing token"
525
+ )
529
526
  await self._clear_cached_access_token()
530
-
527
+
531
528
  # Retry with fresh token
532
529
  fresh_token = await self._get_access_token(force_refresh=True)
533
530
  headers["Authorization"] = f"Bearer {fresh_token}"
534
-
535
- logger.info(f"🔄 Retrying Salesforce stream with fresh token: {user_message[:50]}...")
536
- async with self._http_client.stream("POST", url, headers=headers, json=message_data) as response:
531
+
532
+ logger.info(
533
+ f"🔄 Retrying Salesforce stream with fresh token: {user_message[:50]}..."
534
+ )
535
+ async with self._http_client.stream(
536
+ "POST", url, headers=headers, json=message_data
537
+ ) as response:
537
538
  response.raise_for_status()
538
-
539
+
539
540
  async for line in response.aiter_lines():
540
541
  if not line:
541
542
  continue
@@ -576,40 +577,41 @@ class SalesforceAgentLLMService(LLMService):
576
577
  context: The OpenAI LLM context containing messages to process.
577
578
  """
578
579
  logger.info(f"🔄 Salesforce processing context with {len(context.get_messages())} messages")
579
-
580
+
580
581
  # Extract user message from context first
581
582
  user_message = self._extract_user_message(context)
582
-
583
+
583
584
  if not user_message:
584
585
  logger.warning("Salesforce: No user message found in context")
585
586
  return
586
-
587
+
587
588
  try:
588
589
  logger.info(f"🎯 Salesforce extracted query: {user_message}")
589
-
590
- # Start response
590
+
591
+ # Start response
591
592
  await self.push_frame(LLMFullResponseStartFrame())
592
- await self.push_frame(LLMFullResponseStartFrame(),FrameDirection.UPSTREAM)
593
+ await self.push_frame(LLMFullResponseStartFrame(), FrameDirection.UPSTREAM)
593
594
  await self.start_processing_metrics()
594
595
  await self.start_ttfb_metrics()
595
-
596
+
596
597
  # Get or create session
597
598
  session_id = await self._get_or_create_session()
598
-
599
+
599
600
  first_chunk = True
600
-
601
+
601
602
  # Stream the response
602
603
  async for text_chunk in self._stream_salesforce_response(session_id, user_message):
603
604
  if first_chunk:
604
605
  await self.stop_ttfb_metrics()
605
606
  first_chunk = False
606
-
607
+
607
608
  # Push each text chunk as it arrives
608
609
  await self.push_frame(LLMTextFrame(text=text_chunk))
609
-
610
+
610
611
  except Exception as e:
611
612
  logger.error(f"Salesforce context processing error: {type(e).__name__}: {str(e)}")
612
613
  import traceback
614
+
613
615
  logger.error(f"Salesforce traceback: {traceback.format_exc()}")
614
616
  raise
615
617
  finally:
@@ -627,7 +629,9 @@ class SalesforceAgentLLMService(LLMService):
627
629
  context = None
628
630
  if isinstance(frame, OpenAILLMContextFrame):
629
631
  context = frame.context
630
- logger.info(f"🔍 Received OpenAILLMContextFrame with {len(context.get_messages())} messages")
632
+ logger.info(
633
+ f"🔍 Received OpenAILLMContextFrame with {len(context.get_messages())} messages"
634
+ )
631
635
  elif isinstance(frame, LLMMessagesFrame):
632
636
  context = OpenAILLMContext.from_messages(frame.messages)
633
637
  logger.info(f"🔍 Received LLMMessagesFrame with {len(frame.messages)} messages")
@@ -680,6 +684,7 @@ class SalesforceAgentLLMService(LLMService):
680
684
  def get_llm_adapter(self):
681
685
  """Get the LLM adapter for this service."""
682
686
  from pipecat.adapters.services.open_ai_adapter import OpenAILLMAdapter
687
+
683
688
  return OpenAILLMAdapter()
684
689
 
685
690
  async def close(self):
@@ -23,7 +23,6 @@ from pipecat.frames.frames import (
23
23
  InterruptionFrame,
24
24
  LLMFullResponseEndFrame,
25
25
  StartFrame,
26
- StartInterruptionFrame,
27
26
  TTSAudioRawFrame,
28
27
  TTSStartedFrame,
29
28
  TTSStoppedFrame,
@@ -49,6 +49,33 @@ END_TOKEN = "<end>"
49
49
  FINALIZED_TOKEN = "<fin>"
50
50
 
51
51
 
52
+ class SonioxContextGeneralItem(BaseModel):
53
+ """Represents a key-value pair for structured general context information."""
54
+
55
+ key: str
56
+ value: str
57
+
58
+
59
+ class SonioxContextTranslationTerm(BaseModel):
60
+ """Represents a custom translation mapping for ambiguous or domain-specific terms."""
61
+
62
+ source: str
63
+ target: str
64
+
65
+
66
+ class SonioxContextObject(BaseModel):
67
+ """Context object for models with context_version 2, for Soniox stt-rt-v3-preview and higher.
68
+
69
+ Learn more about context in the documentation:
70
+ https://soniox.com/docs/stt/concepts/context
71
+ """
72
+
73
+ general: Optional[List[SonioxContextGeneralItem]] = None
74
+ text: Optional[str] = None
75
+ terms: Optional[List[str]] = None
76
+ translation_terms: Optional[List[SonioxContextTranslationTerm]] = None
77
+
78
+
52
79
  class SonioxInputParams(BaseModel):
53
80
  """Real-time transcription settings.
54
81
 
@@ -60,9 +87,9 @@ class SonioxInputParams(BaseModel):
60
87
  audio_format: Audio format to use for transcription.
61
88
  num_channels: Number of channels to use for transcription.
62
89
  language_hints: List of language hints to use for transcription.
63
- context: Customization for transcription.
64
- enable_non_final_tokens: Whether to enable non-final tokens. If false, only final tokens will be returned.
65
- max_non_final_tokens_duration_ms: Maximum duration of non-final tokens.
90
+ context: Customization for transcription. String for models with context_version 1 and ContextObject for models with context_version 2.
91
+ enable_speaker_diarization: Whether to enable speaker diarization. Tokens are annotated with speaker IDs.
92
+ enable_language_identification: Whether to enable language identification. Tokens are annotated with language IDs.
66
93
  client_reference_id: Client reference ID to use for transcription.
67
94
  """
68
95
 
@@ -72,10 +99,10 @@ class SonioxInputParams(BaseModel):
72
99
  num_channels: Optional[int] = 1
73
100
 
74
101
  language_hints: Optional[List[Language]] = None
75
- context: Optional[str] = None
102
+ context: Optional[SonioxContextObject | str] = None
76
103
 
77
- enable_non_final_tokens: Optional[bool] = True
78
- max_non_final_tokens_duration_ms: Optional[int] = None
104
+ enable_speaker_diarization: Optional[bool] = False
105
+ enable_language_identification: Optional[bool] = False
79
106
 
80
107
  client_reference_id: Optional[str] = None
81
108
 
@@ -173,6 +200,10 @@ class SonioxSTTService(STTService):
173
200
  # Either one or the other is required.
174
201
  enable_endpoint_detection = not self._vad_force_turn_endpoint
175
202
 
203
+ context = self._params.context
204
+ if isinstance(context, SonioxContextObject):
205
+ context = context.model_dump()
206
+
176
207
  # Send the initial configuration message.
177
208
  config = {
178
209
  "api_key": self._api_key,
@@ -182,9 +213,9 @@ class SonioxSTTService(STTService):
182
213
  "enable_endpoint_detection": enable_endpoint_detection,
183
214
  "sample_rate": self.sample_rate,
184
215
  "language_hints": _prepare_language_hints(self._params.language_hints),
185
- "context": self._params.context,
186
- "enable_non_final_tokens": self._params.enable_non_final_tokens,
187
- "max_non_final_tokens_duration_ms": self._params.max_non_final_tokens_duration_ms,
216
+ "context": context,
217
+ "enable_speaker_diarization": self._params.enable_speaker_diarization,
218
+ "enable_language_identification": self._params.enable_language_identification,
188
219
  "client_reference_id": self._params.client_reference_id,
189
220
  }
190
221
 
@@ -210,6 +241,7 @@ class SonioxSTTService(STTService):
210
241
  if self._receive_task != asyncio.current_task():
211
242
  await self._receive_task
212
243
  self._receive_task = None
244
+ self.logger.debug("Disconnected from Soniox STT")
213
245
 
214
246
  async def stop(self, frame: EndFrame):
215
247
  """Stop the Soniox STT websocket connection.
@@ -351,7 +383,10 @@ class SonioxSTTService(STTService):
351
383
 
352
384
  if self._final_transcription_buffer or non_final_transcription:
353
385
  final_text = "".join(
354
- map(lambda token: token["text"], self._final_transcription_buffer)
386
+ map(
387
+ lambda token: token["text"],
388
+ self._final_transcription_buffer,
389
+ )
355
390
  )
356
391
  non_final_text = "".join(
357
392
  map(lambda token: token["text"], non_final_transcription)