mistralai 1.11.1__py3-none-any.whl → 1.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mistralai/_version.py +2 -2
- mistralai/audio.py +20 -0
- mistralai/conversations.py +48 -8
- mistralai/extra/__init__.py +48 -0
- mistralai/extra/exceptions.py +49 -4
- mistralai/extra/realtime/__init__.py +25 -0
- mistralai/extra/realtime/connection.py +207 -0
- mistralai/extra/realtime/transcription.py +271 -0
- mistralai/files.py +6 -0
- mistralai/mistral_agents.py +391 -8
- mistralai/models/__init__.py +103 -0
- mistralai/models/agentaliasresponse.py +23 -0
- mistralai/models/agentconversation.py +14 -4
- mistralai/models/agents_api_v1_agents_create_or_update_aliasop.py +26 -0
- mistralai/models/agents_api_v1_agents_get_versionop.py +2 -2
- mistralai/models/agents_api_v1_agents_getop.py +12 -3
- mistralai/models/agents_api_v1_agents_list_version_aliasesop.py +16 -0
- mistralai/models/audiotranscriptionrequest.py +8 -0
- mistralai/models/audiotranscriptionrequeststream.py +8 -0
- mistralai/models/conversationrequest.py +8 -2
- mistralai/models/conversationrestartrequest.py +18 -4
- mistralai/models/conversationrestartstreamrequest.py +20 -4
- mistralai/models/conversationstreamrequest.py +12 -2
- mistralai/models/files_api_routes_list_filesop.py +8 -1
- mistralai/models/mistralpromptmode.py +4 -0
- mistralai/models/modelcapabilities.py +3 -0
- mistralai/models/realtimetranscriptionerror.py +27 -0
- mistralai/models/realtimetranscriptionerrordetail.py +29 -0
- mistralai/models/realtimetranscriptionsession.py +20 -0
- mistralai/models/realtimetranscriptionsessioncreated.py +30 -0
- mistralai/models/realtimetranscriptionsessionupdated.py +30 -0
- mistralai/models/timestampgranularity.py +4 -1
- mistralai/models/transcriptionsegmentchunk.py +41 -2
- mistralai/models/transcriptionstreamsegmentdelta.py +38 -2
- mistralai/transcriptions.py +24 -0
- {mistralai-1.11.1.dist-info → mistralai-1.12.0.dist-info}/METADATA +6 -2
- {mistralai-1.11.1.dist-info → mistralai-1.12.0.dist-info}/RECORD +39 -28
- {mistralai-1.11.1.dist-info → mistralai-1.12.0.dist-info}/WHEEL +0 -0
- {mistralai-1.11.1.dist-info → mistralai-1.12.0.dist-info}/licenses/LICENSE +0 -0
mistralai/_version.py
CHANGED
|
@@ -3,10 +3,10 @@
|
|
|
3
3
|
import importlib.metadata
|
|
4
4
|
|
|
5
5
|
__title__: str = "mistralai"
|
|
6
|
-
__version__: str = "1.
|
|
6
|
+
__version__: str = "1.12.0"
|
|
7
7
|
__openapi_doc_version__: str = "1.0.0"
|
|
8
8
|
__gen_version__: str = "2.794.1"
|
|
9
|
-
__user_agent__: str = "speakeasy-sdk/python 1.
|
|
9
|
+
__user_agent__: str = "speakeasy-sdk/python 1.12.0 2.794.1 1.0.0 mistralai"
|
|
10
10
|
|
|
11
11
|
try:
|
|
12
12
|
if __package__ is not None:
|
mistralai/audio.py
CHANGED
|
@@ -5,6 +5,13 @@ from .sdkconfiguration import SDKConfiguration
|
|
|
5
5
|
from mistralai.transcriptions import Transcriptions
|
|
6
6
|
from typing import Optional
|
|
7
7
|
|
|
8
|
+
# region imports
|
|
9
|
+
from typing import TYPE_CHECKING
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from mistralai.extra.realtime import RealtimeTranscription
|
|
13
|
+
# endregion imports
|
|
14
|
+
|
|
8
15
|
|
|
9
16
|
class Audio(BaseSDK):
|
|
10
17
|
transcriptions: Transcriptions
|
|
@@ -21,3 +28,16 @@ class Audio(BaseSDK):
|
|
|
21
28
|
self.transcriptions = Transcriptions(
|
|
22
29
|
self.sdk_configuration, parent_ref=self.parent_ref
|
|
23
30
|
)
|
|
31
|
+
|
|
32
|
+
# region sdk-class-body
|
|
33
|
+
@property
|
|
34
|
+
def realtime(self) -> "RealtimeTranscription":
|
|
35
|
+
"""Returns a client for real-time audio transcription via WebSocket."""
|
|
36
|
+
if not hasattr(self, "_realtime"):
|
|
37
|
+
from mistralai.extra.realtime import RealtimeTranscription # pylint: disable=import-outside-toplevel
|
|
38
|
+
|
|
39
|
+
self._realtime = RealtimeTranscription(self.sdk_configuration) # pylint: disable=attribute-defined-outside-init
|
|
40
|
+
|
|
41
|
+
return self._realtime
|
|
42
|
+
|
|
43
|
+
# endregion sdk-class-body
|
mistralai/conversations.py
CHANGED
|
@@ -259,7 +259,12 @@ class Conversations(BaseSDK):
|
|
|
259
259
|
description: OptionalNullable[str] = UNSET,
|
|
260
260
|
metadata: OptionalNullable[Dict[str, Any]] = UNSET,
|
|
261
261
|
agent_id: OptionalNullable[str] = UNSET,
|
|
262
|
-
agent_version: OptionalNullable[
|
|
262
|
+
agent_version: OptionalNullable[
|
|
263
|
+
Union[
|
|
264
|
+
models_conversationrequest.AgentVersion,
|
|
265
|
+
models_conversationrequest.AgentVersionTypedDict,
|
|
266
|
+
]
|
|
267
|
+
] = UNSET,
|
|
263
268
|
model: OptionalNullable[str] = UNSET,
|
|
264
269
|
retries: OptionalNullable[utils.RetryConfig] = UNSET,
|
|
265
270
|
server_url: Optional[str] = None,
|
|
@@ -405,7 +410,12 @@ class Conversations(BaseSDK):
|
|
|
405
410
|
description: OptionalNullable[str] = UNSET,
|
|
406
411
|
metadata: OptionalNullable[Dict[str, Any]] = UNSET,
|
|
407
412
|
agent_id: OptionalNullable[str] = UNSET,
|
|
408
|
-
agent_version: OptionalNullable[
|
|
413
|
+
agent_version: OptionalNullable[
|
|
414
|
+
Union[
|
|
415
|
+
models_conversationrequest.AgentVersion,
|
|
416
|
+
models_conversationrequest.AgentVersionTypedDict,
|
|
417
|
+
]
|
|
418
|
+
] = UNSET,
|
|
409
419
|
model: OptionalNullable[str] = UNSET,
|
|
410
420
|
retries: OptionalNullable[utils.RetryConfig] = UNSET,
|
|
411
421
|
server_url: Optional[str] = None,
|
|
@@ -1711,7 +1721,12 @@ class Conversations(BaseSDK):
|
|
|
1711
1721
|
]
|
|
1712
1722
|
] = None,
|
|
1713
1723
|
metadata: OptionalNullable[Dict[str, Any]] = UNSET,
|
|
1714
|
-
agent_version: OptionalNullable[
|
|
1724
|
+
agent_version: OptionalNullable[
|
|
1725
|
+
Union[
|
|
1726
|
+
models_conversationrestartrequest.ConversationRestartRequestAgentVersion,
|
|
1727
|
+
models_conversationrestartrequest.ConversationRestartRequestAgentVersionTypedDict,
|
|
1728
|
+
]
|
|
1729
|
+
] = UNSET,
|
|
1715
1730
|
retries: OptionalNullable[utils.RetryConfig] = UNSET,
|
|
1716
1731
|
server_url: Optional[str] = None,
|
|
1717
1732
|
timeout_ms: Optional[int] = None,
|
|
@@ -1846,7 +1861,12 @@ class Conversations(BaseSDK):
|
|
|
1846
1861
|
]
|
|
1847
1862
|
] = None,
|
|
1848
1863
|
metadata: OptionalNullable[Dict[str, Any]] = UNSET,
|
|
1849
|
-
agent_version: OptionalNullable[
|
|
1864
|
+
agent_version: OptionalNullable[
|
|
1865
|
+
Union[
|
|
1866
|
+
models_conversationrestartrequest.ConversationRestartRequestAgentVersion,
|
|
1867
|
+
models_conversationrestartrequest.ConversationRestartRequestAgentVersionTypedDict,
|
|
1868
|
+
]
|
|
1869
|
+
] = UNSET,
|
|
1850
1870
|
retries: OptionalNullable[utils.RetryConfig] = UNSET,
|
|
1851
1871
|
server_url: Optional[str] = None,
|
|
1852
1872
|
timeout_ms: Optional[int] = None,
|
|
@@ -1991,7 +2011,12 @@ class Conversations(BaseSDK):
|
|
|
1991
2011
|
description: OptionalNullable[str] = UNSET,
|
|
1992
2012
|
metadata: OptionalNullable[Dict[str, Any]] = UNSET,
|
|
1993
2013
|
agent_id: OptionalNullable[str] = UNSET,
|
|
1994
|
-
agent_version: OptionalNullable[
|
|
2014
|
+
agent_version: OptionalNullable[
|
|
2015
|
+
Union[
|
|
2016
|
+
models_conversationstreamrequest.ConversationStreamRequestAgentVersion,
|
|
2017
|
+
models_conversationstreamrequest.ConversationStreamRequestAgentVersionTypedDict,
|
|
2018
|
+
]
|
|
2019
|
+
] = UNSET,
|
|
1995
2020
|
model: OptionalNullable[str] = UNSET,
|
|
1996
2021
|
retries: OptionalNullable[utils.RetryConfig] = UNSET,
|
|
1997
2022
|
server_url: Optional[str] = None,
|
|
@@ -2148,7 +2173,12 @@ class Conversations(BaseSDK):
|
|
|
2148
2173
|
description: OptionalNullable[str] = UNSET,
|
|
2149
2174
|
metadata: OptionalNullable[Dict[str, Any]] = UNSET,
|
|
2150
2175
|
agent_id: OptionalNullable[str] = UNSET,
|
|
2151
|
-
agent_version: OptionalNullable[
|
|
2176
|
+
agent_version: OptionalNullable[
|
|
2177
|
+
Union[
|
|
2178
|
+
models_conversationstreamrequest.ConversationStreamRequestAgentVersion,
|
|
2179
|
+
models_conversationstreamrequest.ConversationStreamRequestAgentVersionTypedDict,
|
|
2180
|
+
]
|
|
2181
|
+
] = UNSET,
|
|
2152
2182
|
model: OptionalNullable[str] = UNSET,
|
|
2153
2183
|
retries: OptionalNullable[utils.RetryConfig] = UNSET,
|
|
2154
2184
|
server_url: Optional[str] = None,
|
|
@@ -2561,7 +2591,12 @@ class Conversations(BaseSDK):
|
|
|
2561
2591
|
]
|
|
2562
2592
|
] = None,
|
|
2563
2593
|
metadata: OptionalNullable[Dict[str, Any]] = UNSET,
|
|
2564
|
-
agent_version: OptionalNullable[
|
|
2594
|
+
agent_version: OptionalNullable[
|
|
2595
|
+
Union[
|
|
2596
|
+
models_conversationrestartstreamrequest.ConversationRestartStreamRequestAgentVersion,
|
|
2597
|
+
models_conversationrestartstreamrequest.ConversationRestartStreamRequestAgentVersionTypedDict,
|
|
2598
|
+
]
|
|
2599
|
+
] = UNSET,
|
|
2565
2600
|
retries: OptionalNullable[utils.RetryConfig] = UNSET,
|
|
2566
2601
|
server_url: Optional[str] = None,
|
|
2567
2602
|
timeout_ms: Optional[int] = None,
|
|
@@ -2703,7 +2738,12 @@ class Conversations(BaseSDK):
|
|
|
2703
2738
|
]
|
|
2704
2739
|
] = None,
|
|
2705
2740
|
metadata: OptionalNullable[Dict[str, Any]] = UNSET,
|
|
2706
|
-
agent_version: OptionalNullable[
|
|
2741
|
+
agent_version: OptionalNullable[
|
|
2742
|
+
Union[
|
|
2743
|
+
models_conversationrestartstreamrequest.ConversationRestartStreamRequestAgentVersion,
|
|
2744
|
+
models_conversationrestartstreamrequest.ConversationRestartStreamRequestAgentVersionTypedDict,
|
|
2745
|
+
]
|
|
2746
|
+
] = UNSET,
|
|
2707
2747
|
retries: OptionalNullable[utils.RetryConfig] = UNSET,
|
|
2708
2748
|
server_url: Optional[str] = None,
|
|
2709
2749
|
timeout_ms: Optional[int] = None,
|
mistralai/extra/__init__.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING
|
|
2
|
+
|
|
1
3
|
from .struct_chat import (
|
|
2
4
|
ParsedChatCompletionResponse,
|
|
3
5
|
convert_to_parsed_chat_completion_response,
|
|
@@ -5,9 +7,55 @@ from .struct_chat import (
|
|
|
5
7
|
from .utils import response_format_from_pydantic_model
|
|
6
8
|
from .utils.response_format import CustomPydanticModel
|
|
7
9
|
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from .realtime import (
|
|
12
|
+
AudioEncoding,
|
|
13
|
+
AudioFormat,
|
|
14
|
+
RealtimeConnection,
|
|
15
|
+
RealtimeTranscriptionError,
|
|
16
|
+
RealtimeTranscriptionErrorDetail,
|
|
17
|
+
RealtimeTranscriptionSession,
|
|
18
|
+
RealtimeTranscriptionSessionCreated,
|
|
19
|
+
RealtimeTranscriptionSessionUpdated,
|
|
20
|
+
RealtimeTranscription,
|
|
21
|
+
UnknownRealtimeEvent,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
_REALTIME_EXPORTS = {
|
|
25
|
+
"RealtimeTranscription",
|
|
26
|
+
"RealtimeConnection",
|
|
27
|
+
"AudioEncoding",
|
|
28
|
+
"AudioFormat",
|
|
29
|
+
"UnknownRealtimeEvent",
|
|
30
|
+
"RealtimeTranscriptionError",
|
|
31
|
+
"RealtimeTranscriptionErrorDetail",
|
|
32
|
+
"RealtimeTranscriptionSession",
|
|
33
|
+
"RealtimeTranscriptionSessionCreated",
|
|
34
|
+
"RealtimeTranscriptionSessionUpdated",
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def __getattr__(name: str):
|
|
39
|
+
if name in _REALTIME_EXPORTS:
|
|
40
|
+
from . import realtime
|
|
41
|
+
|
|
42
|
+
return getattr(realtime, name)
|
|
43
|
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
|
44
|
+
|
|
45
|
+
|
|
8
46
|
__all__ = [
|
|
9
47
|
"convert_to_parsed_chat_completion_response",
|
|
10
48
|
"response_format_from_pydantic_model",
|
|
11
49
|
"CustomPydanticModel",
|
|
12
50
|
"ParsedChatCompletionResponse",
|
|
51
|
+
"RealtimeTranscription",
|
|
52
|
+
"RealtimeConnection",
|
|
53
|
+
"AudioEncoding",
|
|
54
|
+
"AudioFormat",
|
|
55
|
+
"UnknownRealtimeEvent",
|
|
56
|
+
"RealtimeTranscriptionError",
|
|
57
|
+
"RealtimeTranscriptionErrorDetail",
|
|
58
|
+
"RealtimeTranscriptionSession",
|
|
59
|
+
"RealtimeTranscriptionSessionCreated",
|
|
60
|
+
"RealtimeTranscriptionSessionUpdated",
|
|
13
61
|
]
|
mistralai/extra/exceptions.py
CHANGED
|
@@ -1,14 +1,59 @@
|
|
|
1
|
+
from typing import Optional, TYPE_CHECKING
|
|
2
|
+
|
|
3
|
+
if TYPE_CHECKING:
|
|
4
|
+
from mistralai.models import RealtimeTranscriptionError
|
|
5
|
+
|
|
6
|
+
|
|
1
7
|
class MistralClientException(Exception):
|
|
2
|
-
"""Base exception for
|
|
8
|
+
"""Base exception for client errors."""
|
|
3
9
|
|
|
4
10
|
|
|
5
11
|
class RunException(MistralClientException):
|
|
6
|
-
"""
|
|
12
|
+
"""Conversation run errors."""
|
|
7
13
|
|
|
8
14
|
|
|
9
15
|
class MCPException(MistralClientException):
|
|
10
|
-
"""
|
|
16
|
+
"""MCP operation errors."""
|
|
11
17
|
|
|
12
18
|
|
|
13
19
|
class MCPAuthException(MCPException):
|
|
14
|
-
"""
|
|
20
|
+
"""MCP authentication errors."""
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class RealtimeTranscriptionException(MistralClientException):
|
|
24
|
+
"""Base realtime transcription exception."""
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
message: str,
|
|
29
|
+
*,
|
|
30
|
+
code: Optional[int] = None,
|
|
31
|
+
payload: Optional[object] = None,
|
|
32
|
+
) -> None:
|
|
33
|
+
super().__init__(message)
|
|
34
|
+
self.code = code
|
|
35
|
+
self.payload = payload
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class RealtimeTranscriptionWSError(RealtimeTranscriptionException):
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
message: str,
|
|
42
|
+
*,
|
|
43
|
+
payload: Optional["RealtimeTranscriptionError"] = None,
|
|
44
|
+
raw: Optional[object] = None,
|
|
45
|
+
) -> None:
|
|
46
|
+
code: Optional[int] = None
|
|
47
|
+
if payload is not None:
|
|
48
|
+
try:
|
|
49
|
+
maybe_code = getattr(payload.error, "code", None)
|
|
50
|
+
if isinstance(maybe_code, int):
|
|
51
|
+
code = maybe_code
|
|
52
|
+
except Exception:
|
|
53
|
+
code = None
|
|
54
|
+
|
|
55
|
+
super().__init__(
|
|
56
|
+
message, code=code, payload=payload if payload is not None else raw
|
|
57
|
+
)
|
|
58
|
+
self.payload_typed = payload
|
|
59
|
+
self.payload_raw = raw
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from mistralai.models import (
|
|
2
|
+
AudioEncoding,
|
|
3
|
+
AudioFormat,
|
|
4
|
+
RealtimeTranscriptionError,
|
|
5
|
+
RealtimeTranscriptionErrorDetail,
|
|
6
|
+
RealtimeTranscriptionSession,
|
|
7
|
+
RealtimeTranscriptionSessionCreated,
|
|
8
|
+
RealtimeTranscriptionSessionUpdated,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
from .connection import UnknownRealtimeEvent, RealtimeConnection
|
|
12
|
+
from .transcription import RealtimeTranscription
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"AudioEncoding",
|
|
16
|
+
"AudioFormat",
|
|
17
|
+
"RealtimeTranscriptionError",
|
|
18
|
+
"RealtimeTranscriptionErrorDetail",
|
|
19
|
+
"RealtimeTranscriptionSession",
|
|
20
|
+
"RealtimeTranscriptionSessionCreated",
|
|
21
|
+
"RealtimeTranscriptionSessionUpdated",
|
|
22
|
+
"RealtimeConnection",
|
|
23
|
+
"RealtimeTranscription",
|
|
24
|
+
"UnknownRealtimeEvent",
|
|
25
|
+
]
|
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import base64
|
|
4
|
+
import json
|
|
5
|
+
from asyncio import CancelledError
|
|
6
|
+
from collections import deque
|
|
7
|
+
from typing import Any, AsyncIterator, Deque, Optional, Union
|
|
8
|
+
|
|
9
|
+
from pydantic import ValidationError, BaseModel
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
from websockets.asyncio.client import ClientConnection # websockets >= 13.0
|
|
13
|
+
except ImportError as exc:
|
|
14
|
+
raise ImportError(
|
|
15
|
+
"The `websockets` package (>=13.0) is required for real-time transcription. "
|
|
16
|
+
"Install with: pip install 'mistralai[realtime]'"
|
|
17
|
+
) from exc
|
|
18
|
+
|
|
19
|
+
from mistralai.models import (
|
|
20
|
+
AudioFormat,
|
|
21
|
+
RealtimeTranscriptionError,
|
|
22
|
+
RealtimeTranscriptionSession,
|
|
23
|
+
RealtimeTranscriptionSessionCreated,
|
|
24
|
+
RealtimeTranscriptionSessionUpdated,
|
|
25
|
+
TranscriptionStreamDone,
|
|
26
|
+
TranscriptionStreamLanguage,
|
|
27
|
+
TranscriptionStreamSegmentDelta,
|
|
28
|
+
TranscriptionStreamTextDelta,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class UnknownRealtimeEvent(BaseModel):
|
|
33
|
+
"""
|
|
34
|
+
Forward-compat fallback event:
|
|
35
|
+
- unknown message type
|
|
36
|
+
- invalid JSON payload
|
|
37
|
+
- schema validation failure
|
|
38
|
+
"""
|
|
39
|
+
type: Optional[str]
|
|
40
|
+
content: Any
|
|
41
|
+
error: Optional[str] = None
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
RealtimeEvent = Union[
|
|
45
|
+
# session lifecycle
|
|
46
|
+
RealtimeTranscriptionSessionCreated,
|
|
47
|
+
RealtimeTranscriptionSessionUpdated,
|
|
48
|
+
# server errors
|
|
49
|
+
RealtimeTranscriptionError,
|
|
50
|
+
# transcription events
|
|
51
|
+
TranscriptionStreamLanguage,
|
|
52
|
+
TranscriptionStreamSegmentDelta,
|
|
53
|
+
TranscriptionStreamTextDelta,
|
|
54
|
+
TranscriptionStreamDone,
|
|
55
|
+
# forward-compat fallback
|
|
56
|
+
UnknownRealtimeEvent,
|
|
57
|
+
]
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
_MESSAGE_MODELS: dict[str, Any] = {
|
|
61
|
+
"session.created": RealtimeTranscriptionSessionCreated,
|
|
62
|
+
"session.updated": RealtimeTranscriptionSessionUpdated,
|
|
63
|
+
"error": RealtimeTranscriptionError,
|
|
64
|
+
"transcription.language": TranscriptionStreamLanguage,
|
|
65
|
+
"transcription.segment": TranscriptionStreamSegmentDelta,
|
|
66
|
+
"transcription.text.delta": TranscriptionStreamTextDelta,
|
|
67
|
+
"transcription.done": TranscriptionStreamDone,
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def parse_realtime_event(payload: Any) -> RealtimeEvent:
|
|
72
|
+
"""
|
|
73
|
+
Tolerant parser:
|
|
74
|
+
- unknown event type -> UnknownRealtimeEvent
|
|
75
|
+
- validation failures -> UnknownRealtimeEvent (includes error string)
|
|
76
|
+
- invalid payload -> UnknownRealtimeEvent
|
|
77
|
+
"""
|
|
78
|
+
if not isinstance(payload, dict):
|
|
79
|
+
return UnknownRealtimeEvent(
|
|
80
|
+
type=None, content=payload, error="expected JSON object"
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
msg_type = payload.get("type")
|
|
84
|
+
if not isinstance(msg_type, str):
|
|
85
|
+
return UnknownRealtimeEvent(
|
|
86
|
+
type=None, content=payload, error="missing/invalid 'type'"
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
model_cls = _MESSAGE_MODELS.get(msg_type)
|
|
90
|
+
if model_cls is None:
|
|
91
|
+
return UnknownRealtimeEvent(
|
|
92
|
+
type=msg_type, content=payload, error="unknown event type"
|
|
93
|
+
)
|
|
94
|
+
try:
|
|
95
|
+
parsed = model_cls.model_validate(payload)
|
|
96
|
+
return parsed
|
|
97
|
+
except ValidationError as exc:
|
|
98
|
+
return UnknownRealtimeEvent(type=msg_type, content=payload, error=str(exc))
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class RealtimeConnection:
|
|
102
|
+
def __init__(
|
|
103
|
+
self,
|
|
104
|
+
websocket: ClientConnection,
|
|
105
|
+
session: RealtimeTranscriptionSession,
|
|
106
|
+
*,
|
|
107
|
+
initial_events: Optional[list[RealtimeEvent]] = None,
|
|
108
|
+
) -> None:
|
|
109
|
+
self._websocket = websocket
|
|
110
|
+
self._session = session
|
|
111
|
+
self._audio_format = session.audio_format
|
|
112
|
+
self._closed = False
|
|
113
|
+
self._initial_events: Deque[RealtimeEvent] = deque(initial_events or [])
|
|
114
|
+
|
|
115
|
+
@property
|
|
116
|
+
def request_id(self) -> str:
|
|
117
|
+
return self._session.request_id
|
|
118
|
+
|
|
119
|
+
@property
|
|
120
|
+
def session(self) -> RealtimeTranscriptionSession:
|
|
121
|
+
return self._session
|
|
122
|
+
|
|
123
|
+
@property
|
|
124
|
+
def audio_format(self) -> AudioFormat:
|
|
125
|
+
return self._audio_format
|
|
126
|
+
|
|
127
|
+
@property
|
|
128
|
+
def is_closed(self) -> bool:
|
|
129
|
+
return self._closed
|
|
130
|
+
|
|
131
|
+
async def send_audio(
|
|
132
|
+
self, audio_bytes: Union[bytes, bytearray, memoryview]
|
|
133
|
+
) -> None:
|
|
134
|
+
if self._closed:
|
|
135
|
+
raise RuntimeError("Connection is closed")
|
|
136
|
+
|
|
137
|
+
message = {
|
|
138
|
+
"type": "input_audio.append",
|
|
139
|
+
"audio": base64.b64encode(bytes(audio_bytes)).decode("ascii"),
|
|
140
|
+
}
|
|
141
|
+
await self._websocket.send(json.dumps(message))
|
|
142
|
+
|
|
143
|
+
async def update_session(self, audio_format: AudioFormat) -> None:
|
|
144
|
+
if self._closed:
|
|
145
|
+
raise RuntimeError("Connection is closed")
|
|
146
|
+
|
|
147
|
+
self._audio_format = audio_format
|
|
148
|
+
message = {
|
|
149
|
+
"type": "session.update",
|
|
150
|
+
"session": {"audio_format": audio_format.model_dump(mode="json")},
|
|
151
|
+
}
|
|
152
|
+
await self._websocket.send(json.dumps(message))
|
|
153
|
+
|
|
154
|
+
async def end_audio(self) -> None:
|
|
155
|
+
if self._closed:
|
|
156
|
+
return
|
|
157
|
+
await self._websocket.send(json.dumps({"type": "input_audio.end"}))
|
|
158
|
+
|
|
159
|
+
async def close(self, *, code: int = 1000, reason: str = "") -> None:
|
|
160
|
+
if self._closed:
|
|
161
|
+
return
|
|
162
|
+
self._closed = True
|
|
163
|
+
await self._websocket.close(code=code, reason=reason)
|
|
164
|
+
|
|
165
|
+
async def __aenter__(self) -> "RealtimeConnection":
|
|
166
|
+
return self
|
|
167
|
+
|
|
168
|
+
async def __aexit__(self, exc_type, exc, tb) -> None:
|
|
169
|
+
await self.close()
|
|
170
|
+
|
|
171
|
+
def __aiter__(self) -> AsyncIterator[RealtimeEvent]:
|
|
172
|
+
return self.events()
|
|
173
|
+
|
|
174
|
+
async def events(self) -> AsyncIterator[RealtimeEvent]:
|
|
175
|
+
# replay any handshake/prelude events (including session.created)
|
|
176
|
+
while self._initial_events:
|
|
177
|
+
ev = self._initial_events.popleft()
|
|
178
|
+
self._apply_session_updates(ev)
|
|
179
|
+
yield ev
|
|
180
|
+
|
|
181
|
+
try:
|
|
182
|
+
async for msg in self._websocket:
|
|
183
|
+
text = (
|
|
184
|
+
msg.decode("utf-8", errors="replace")
|
|
185
|
+
if isinstance(msg, (bytes, bytearray))
|
|
186
|
+
else msg
|
|
187
|
+
)
|
|
188
|
+
try:
|
|
189
|
+
data = json.loads(text)
|
|
190
|
+
except Exception as exc:
|
|
191
|
+
yield UnknownRealtimeEvent(
|
|
192
|
+
type=None, content=text, error=f"invalid JSON: {exc}"
|
|
193
|
+
)
|
|
194
|
+
continue
|
|
195
|
+
|
|
196
|
+
ev = parse_realtime_event(data)
|
|
197
|
+
self._apply_session_updates(ev)
|
|
198
|
+
yield ev
|
|
199
|
+
except CancelledError:
|
|
200
|
+
pass
|
|
201
|
+
finally:
|
|
202
|
+
await self.close()
|
|
203
|
+
|
|
204
|
+
def _apply_session_updates(self, ev: RealtimeEvent) -> None:
|
|
205
|
+
if isinstance(ev, RealtimeTranscriptionSessionCreated) or isinstance(ev, RealtimeTranscriptionSessionUpdated):
|
|
206
|
+
self._session = ev.session
|
|
207
|
+
self._audio_format = ev.session.audio_format
|