mistralai 1.10.1__py3-none-any.whl → 1.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mistralai/_version.py +3 -3
- mistralai/accesses.py +22 -12
- mistralai/agents.py +88 -44
- mistralai/audio.py +20 -0
- mistralai/basesdk.py +6 -0
- mistralai/chat.py +96 -40
- mistralai/classifiers.py +35 -22
- mistralai/conversations.py +234 -72
- mistralai/documents.py +72 -26
- mistralai/embeddings.py +17 -8
- mistralai/extra/__init__.py +48 -0
- mistralai/extra/exceptions.py +49 -4
- mistralai/extra/realtime/__init__.py +25 -0
- mistralai/extra/realtime/connection.py +207 -0
- mistralai/extra/realtime/transcription.py +271 -0
- mistralai/files.py +64 -24
- mistralai/fim.py +20 -12
- mistralai/httpclient.py +0 -1
- mistralai/jobs.py +65 -26
- mistralai/libraries.py +20 -10
- mistralai/mistral_agents.py +825 -34
- mistralai/mistral_jobs.py +33 -14
- mistralai/models/__init__.py +119 -0
- mistralai/models/agent.py +1 -1
- mistralai/models/agentaliasresponse.py +23 -0
- mistralai/models/agentconversation.py +15 -5
- mistralai/models/agenthandoffdoneevent.py +1 -1
- mistralai/models/agenthandoffentry.py +3 -2
- mistralai/models/agenthandoffstartedevent.py +1 -1
- mistralai/models/agents_api_v1_agents_create_or_update_aliasop.py +26 -0
- mistralai/models/agents_api_v1_agents_get_versionop.py +21 -0
- mistralai/models/agents_api_v1_agents_getop.py +12 -3
- mistralai/models/agents_api_v1_agents_list_version_aliasesop.py +16 -0
- mistralai/models/agents_api_v1_agents_list_versionsop.py +33 -0
- mistralai/models/agents_api_v1_agents_listop.py +4 -0
- mistralai/models/agentscompletionrequest.py +2 -5
- mistralai/models/agentscompletionstreamrequest.py +2 -5
- mistralai/models/archiveftmodelout.py +1 -1
- mistralai/models/assistantmessage.py +1 -1
- mistralai/models/audiochunk.py +1 -1
- mistralai/models/audioencoding.py +6 -1
- mistralai/models/audioformat.py +2 -4
- mistralai/models/audiotranscriptionrequest.py +8 -0
- mistralai/models/audiotranscriptionrequeststream.py +8 -0
- mistralai/models/basemodelcard.py +1 -1
- mistralai/models/batchjobin.py +2 -4
- mistralai/models/batchjobout.py +1 -1
- mistralai/models/batchjobsout.py +1 -1
- mistralai/models/chatcompletionchoice.py +10 -5
- mistralai/models/chatcompletionrequest.py +2 -5
- mistralai/models/chatcompletionstreamrequest.py +2 -5
- mistralai/models/classifierdetailedjobout.py +4 -2
- mistralai/models/classifierftmodelout.py +3 -2
- mistralai/models/classifierjobout.py +4 -2
- mistralai/models/codeinterpretertool.py +1 -1
- mistralai/models/completiondetailedjobout.py +5 -2
- mistralai/models/completionftmodelout.py +3 -2
- mistralai/models/completionjobout.py +5 -2
- mistralai/models/completionresponsestreamchoice.py +9 -8
- mistralai/models/conversationappendrequest.py +4 -1
- mistralai/models/conversationappendstreamrequest.py +4 -1
- mistralai/models/conversationhistory.py +2 -1
- mistralai/models/conversationmessages.py +1 -1
- mistralai/models/conversationrequest.py +13 -3
- mistralai/models/conversationresponse.py +2 -1
- mistralai/models/conversationrestartrequest.py +22 -5
- mistralai/models/conversationrestartstreamrequest.py +24 -5
- mistralai/models/conversationstreamrequest.py +17 -3
- mistralai/models/documentlibrarytool.py +1 -1
- mistralai/models/documenturlchunk.py +1 -1
- mistralai/models/embeddingdtype.py +7 -1
- mistralai/models/encodingformat.py +4 -1
- mistralai/models/entitytype.py +8 -1
- mistralai/models/filepurpose.py +8 -1
- mistralai/models/files_api_routes_list_filesop.py +12 -12
- mistralai/models/files_api_routes_upload_fileop.py +2 -6
- mistralai/models/fileschema.py +3 -5
- mistralai/models/finetuneablemodeltype.py +4 -1
- mistralai/models/ftclassifierlossfunction.py +4 -1
- mistralai/models/ftmodelcard.py +1 -1
- mistralai/models/functioncallentry.py +3 -2
- mistralai/models/functioncallevent.py +1 -1
- mistralai/models/functionresultentry.py +3 -2
- mistralai/models/functiontool.py +1 -1
- mistralai/models/githubrepositoryin.py +1 -1
- mistralai/models/githubrepositoryout.py +1 -1
- mistralai/models/httpvalidationerror.py +4 -2
- mistralai/models/imagegenerationtool.py +1 -1
- mistralai/models/imageurlchunk.py +1 -1
- mistralai/models/jobsout.py +1 -1
- mistralai/models/legacyjobmetadataout.py +1 -1
- mistralai/models/messageinputentry.py +9 -3
- mistralai/models/messageoutputentry.py +6 -3
- mistralai/models/messageoutputevent.py +4 -2
- mistralai/models/mistralerror.py +11 -7
- mistralai/models/mistralpromptmode.py +5 -1
- mistralai/models/modelcapabilities.py +3 -0
- mistralai/models/modelconversation.py +1 -1
- mistralai/models/no_response_error.py +5 -1
- mistralai/models/ocrrequest.py +11 -1
- mistralai/models/ocrtableobject.py +4 -1
- mistralai/models/realtimetranscriptionerror.py +27 -0
- mistralai/models/realtimetranscriptionerrordetail.py +29 -0
- mistralai/models/realtimetranscriptionsession.py +20 -0
- mistralai/models/realtimetranscriptionsessioncreated.py +30 -0
- mistralai/models/realtimetranscriptionsessionupdated.py +30 -0
- mistralai/models/referencechunk.py +1 -1
- mistralai/models/requestsource.py +5 -1
- mistralai/models/responsedoneevent.py +1 -1
- mistralai/models/responseerrorevent.py +1 -1
- mistralai/models/responseformats.py +5 -1
- mistralai/models/responsestartedevent.py +1 -1
- mistralai/models/responsevalidationerror.py +2 -0
- mistralai/models/retrievefileout.py +3 -5
- mistralai/models/sampletype.py +7 -1
- mistralai/models/sdkerror.py +2 -0
- mistralai/models/shareenum.py +7 -1
- mistralai/models/sharingdelete.py +2 -4
- mistralai/models/sharingin.py +3 -5
- mistralai/models/source.py +8 -1
- mistralai/models/systemmessage.py +1 -1
- mistralai/models/textchunk.py +1 -1
- mistralai/models/thinkchunk.py +1 -1
- mistralai/models/timestampgranularity.py +4 -1
- mistralai/models/tool.py +2 -6
- mistralai/models/toolcall.py +2 -6
- mistralai/models/toolchoice.py +2 -6
- mistralai/models/toolchoiceenum.py +6 -1
- mistralai/models/toolexecutiondeltaevent.py +2 -1
- mistralai/models/toolexecutiondoneevent.py +2 -1
- mistralai/models/toolexecutionentry.py +4 -2
- mistralai/models/toolexecutionstartedevent.py +2 -1
- mistralai/models/toolfilechunk.py +2 -1
- mistralai/models/toolmessage.py +1 -1
- mistralai/models/toolreferencechunk.py +2 -1
- mistralai/models/tooltypes.py +1 -1
- mistralai/models/transcriptionsegmentchunk.py +42 -3
- mistralai/models/transcriptionstreamdone.py +1 -1
- mistralai/models/transcriptionstreamlanguage.py +1 -1
- mistralai/models/transcriptionstreamsegmentdelta.py +39 -3
- mistralai/models/transcriptionstreamtextdelta.py +1 -1
- mistralai/models/unarchiveftmodelout.py +1 -1
- mistralai/models/uploadfileout.py +3 -5
- mistralai/models/usermessage.py +1 -1
- mistralai/models/wandbintegration.py +1 -1
- mistralai/models/wandbintegrationout.py +1 -1
- mistralai/models/websearchpremiumtool.py +1 -1
- mistralai/models/websearchtool.py +1 -1
- mistralai/models_.py +24 -12
- mistralai/ocr.py +38 -10
- mistralai/sdk.py +2 -2
- mistralai/transcriptions.py +52 -12
- mistralai/types/basemodel.py +41 -3
- mistralai/utils/__init__.py +0 -3
- mistralai/utils/annotations.py +32 -8
- mistralai/utils/enums.py +60 -0
- mistralai/utils/forms.py +21 -10
- mistralai/utils/queryparams.py +14 -2
- mistralai/utils/requestbodies.py +3 -3
- mistralai/utils/retries.py +69 -5
- mistralai/utils/serializers.py +0 -20
- mistralai/utils/unmarshal_json_response.py +15 -1
- {mistralai-1.10.1.dist-info → mistralai-1.12.0.dist-info}/METADATA +28 -31
- {mistralai-1.10.1.dist-info → mistralai-1.12.0.dist-info}/RECORD +251 -237
- mistralai_azure/_version.py +3 -3
- mistralai_azure/basesdk.py +6 -0
- mistralai_azure/chat.py +27 -15
- mistralai_azure/httpclient.py +0 -1
- mistralai_azure/models/__init__.py +16 -1
- mistralai_azure/models/assistantmessage.py +1 -1
- mistralai_azure/models/chatcompletionchoice.py +10 -7
- mistralai_azure/models/chatcompletionrequest.py +8 -6
- mistralai_azure/models/chatcompletionstreamrequest.py +8 -6
- mistralai_azure/models/completionresponsestreamchoice.py +11 -7
- mistralai_azure/models/documenturlchunk.py +1 -1
- mistralai_azure/models/httpvalidationerror.py +4 -2
- mistralai_azure/models/imageurlchunk.py +1 -1
- mistralai_azure/models/mistralazureerror.py +11 -7
- mistralai_azure/models/mistralpromptmode.py +1 -1
- mistralai_azure/models/no_response_error.py +5 -1
- mistralai_azure/models/ocrpageobject.py +32 -5
- mistralai_azure/models/ocrrequest.py +20 -1
- mistralai_azure/models/ocrtableobject.py +34 -0
- mistralai_azure/models/referencechunk.py +1 -1
- mistralai_azure/models/responseformats.py +5 -1
- mistralai_azure/models/responsevalidationerror.py +2 -0
- mistralai_azure/models/sdkerror.py +2 -0
- mistralai_azure/models/systemmessage.py +1 -1
- mistralai_azure/models/textchunk.py +1 -1
- mistralai_azure/models/thinkchunk.py +1 -1
- mistralai_azure/models/tool.py +2 -6
- mistralai_azure/models/toolcall.py +2 -6
- mistralai_azure/models/toolchoice.py +2 -6
- mistralai_azure/models/toolchoiceenum.py +6 -1
- mistralai_azure/models/toolmessage.py +1 -1
- mistralai_azure/models/tooltypes.py +1 -1
- mistralai_azure/models/usermessage.py +1 -1
- mistralai_azure/ocr.py +26 -6
- mistralai_azure/types/basemodel.py +41 -3
- mistralai_azure/utils/__init__.py +0 -3
- mistralai_azure/utils/annotations.py +32 -8
- mistralai_azure/utils/enums.py +60 -0
- mistralai_azure/utils/forms.py +21 -10
- mistralai_azure/utils/queryparams.py +14 -2
- mistralai_azure/utils/requestbodies.py +3 -3
- mistralai_azure/utils/retries.py +69 -5
- mistralai_azure/utils/serializers.py +0 -20
- mistralai_azure/utils/unmarshal_json_response.py +15 -1
- mistralai_gcp/_version.py +3 -3
- mistralai_gcp/basesdk.py +6 -0
- mistralai_gcp/chat.py +27 -15
- mistralai_gcp/fim.py +27 -15
- mistralai_gcp/httpclient.py +0 -1
- mistralai_gcp/models/assistantmessage.py +1 -1
- mistralai_gcp/models/chatcompletionchoice.py +10 -7
- mistralai_gcp/models/chatcompletionrequest.py +8 -6
- mistralai_gcp/models/chatcompletionstreamrequest.py +8 -6
- mistralai_gcp/models/completionresponsestreamchoice.py +11 -7
- mistralai_gcp/models/fimcompletionrequest.py +6 -1
- mistralai_gcp/models/fimcompletionstreamrequest.py +6 -1
- mistralai_gcp/models/httpvalidationerror.py +4 -2
- mistralai_gcp/models/imageurlchunk.py +1 -1
- mistralai_gcp/models/mistralgcperror.py +11 -7
- mistralai_gcp/models/mistralpromptmode.py +1 -1
- mistralai_gcp/models/no_response_error.py +5 -1
- mistralai_gcp/models/referencechunk.py +1 -1
- mistralai_gcp/models/responseformats.py +5 -1
- mistralai_gcp/models/responsevalidationerror.py +2 -0
- mistralai_gcp/models/sdkerror.py +2 -0
- mistralai_gcp/models/systemmessage.py +1 -1
- mistralai_gcp/models/textchunk.py +1 -1
- mistralai_gcp/models/thinkchunk.py +1 -1
- mistralai_gcp/models/tool.py +2 -6
- mistralai_gcp/models/toolcall.py +2 -6
- mistralai_gcp/models/toolchoice.py +2 -6
- mistralai_gcp/models/toolchoiceenum.py +6 -1
- mistralai_gcp/models/toolmessage.py +1 -1
- mistralai_gcp/models/tooltypes.py +1 -1
- mistralai_gcp/models/usermessage.py +1 -1
- mistralai_gcp/types/basemodel.py +41 -3
- mistralai_gcp/utils/__init__.py +0 -3
- mistralai_gcp/utils/annotations.py +32 -8
- mistralai_gcp/utils/enums.py +60 -0
- mistralai_gcp/utils/forms.py +21 -10
- mistralai_gcp/utils/queryparams.py +14 -2
- mistralai_gcp/utils/requestbodies.py +3 -3
- mistralai_gcp/utils/retries.py +69 -5
- mistralai_gcp/utils/serializers.py +0 -20
- mistralai_gcp/utils/unmarshal_json_response.py +15 -1
- {mistralai-1.10.1.dist-info → mistralai-1.12.0.dist-info}/WHEEL +0 -0
- {mistralai-1.10.1.dist-info → mistralai-1.12.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import base64
|
|
4
|
+
import json
|
|
5
|
+
from asyncio import CancelledError
|
|
6
|
+
from collections import deque
|
|
7
|
+
from typing import Any, AsyncIterator, Deque, Optional, Union
|
|
8
|
+
|
|
9
|
+
from pydantic import ValidationError, BaseModel
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
from websockets.asyncio.client import ClientConnection # websockets >= 13.0
|
|
13
|
+
except ImportError as exc:
|
|
14
|
+
raise ImportError(
|
|
15
|
+
"The `websockets` package (>=13.0) is required for real-time transcription. "
|
|
16
|
+
"Install with: pip install 'mistralai[realtime]'"
|
|
17
|
+
) from exc
|
|
18
|
+
|
|
19
|
+
from mistralai.models import (
|
|
20
|
+
AudioFormat,
|
|
21
|
+
RealtimeTranscriptionError,
|
|
22
|
+
RealtimeTranscriptionSession,
|
|
23
|
+
RealtimeTranscriptionSessionCreated,
|
|
24
|
+
RealtimeTranscriptionSessionUpdated,
|
|
25
|
+
TranscriptionStreamDone,
|
|
26
|
+
TranscriptionStreamLanguage,
|
|
27
|
+
TranscriptionStreamSegmentDelta,
|
|
28
|
+
TranscriptionStreamTextDelta,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class UnknownRealtimeEvent(BaseModel):
|
|
33
|
+
"""
|
|
34
|
+
Forward-compat fallback event:
|
|
35
|
+
- unknown message type
|
|
36
|
+
- invalid JSON payload
|
|
37
|
+
- schema validation failure
|
|
38
|
+
"""
|
|
39
|
+
type: Optional[str]
|
|
40
|
+
content: Any
|
|
41
|
+
error: Optional[str] = None
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
RealtimeEvent = Union[
|
|
45
|
+
# session lifecycle
|
|
46
|
+
RealtimeTranscriptionSessionCreated,
|
|
47
|
+
RealtimeTranscriptionSessionUpdated,
|
|
48
|
+
# server errors
|
|
49
|
+
RealtimeTranscriptionError,
|
|
50
|
+
# transcription events
|
|
51
|
+
TranscriptionStreamLanguage,
|
|
52
|
+
TranscriptionStreamSegmentDelta,
|
|
53
|
+
TranscriptionStreamTextDelta,
|
|
54
|
+
TranscriptionStreamDone,
|
|
55
|
+
# forward-compat fallback
|
|
56
|
+
UnknownRealtimeEvent,
|
|
57
|
+
]
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
_MESSAGE_MODELS: dict[str, Any] = {
|
|
61
|
+
"session.created": RealtimeTranscriptionSessionCreated,
|
|
62
|
+
"session.updated": RealtimeTranscriptionSessionUpdated,
|
|
63
|
+
"error": RealtimeTranscriptionError,
|
|
64
|
+
"transcription.language": TranscriptionStreamLanguage,
|
|
65
|
+
"transcription.segment": TranscriptionStreamSegmentDelta,
|
|
66
|
+
"transcription.text.delta": TranscriptionStreamTextDelta,
|
|
67
|
+
"transcription.done": TranscriptionStreamDone,
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def parse_realtime_event(payload: Any) -> RealtimeEvent:
|
|
72
|
+
"""
|
|
73
|
+
Tolerant parser:
|
|
74
|
+
- unknown event type -> UnknownRealtimeEvent
|
|
75
|
+
- validation failures -> UnknownRealtimeEvent (includes error string)
|
|
76
|
+
- invalid payload -> UnknownRealtimeEvent
|
|
77
|
+
"""
|
|
78
|
+
if not isinstance(payload, dict):
|
|
79
|
+
return UnknownRealtimeEvent(
|
|
80
|
+
type=None, content=payload, error="expected JSON object"
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
msg_type = payload.get("type")
|
|
84
|
+
if not isinstance(msg_type, str):
|
|
85
|
+
return UnknownRealtimeEvent(
|
|
86
|
+
type=None, content=payload, error="missing/invalid 'type'"
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
model_cls = _MESSAGE_MODELS.get(msg_type)
|
|
90
|
+
if model_cls is None:
|
|
91
|
+
return UnknownRealtimeEvent(
|
|
92
|
+
type=msg_type, content=payload, error="unknown event type"
|
|
93
|
+
)
|
|
94
|
+
try:
|
|
95
|
+
parsed = model_cls.model_validate(payload)
|
|
96
|
+
return parsed
|
|
97
|
+
except ValidationError as exc:
|
|
98
|
+
return UnknownRealtimeEvent(type=msg_type, content=payload, error=str(exc))
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class RealtimeConnection:
|
|
102
|
+
def __init__(
|
|
103
|
+
self,
|
|
104
|
+
websocket: ClientConnection,
|
|
105
|
+
session: RealtimeTranscriptionSession,
|
|
106
|
+
*,
|
|
107
|
+
initial_events: Optional[list[RealtimeEvent]] = None,
|
|
108
|
+
) -> None:
|
|
109
|
+
self._websocket = websocket
|
|
110
|
+
self._session = session
|
|
111
|
+
self._audio_format = session.audio_format
|
|
112
|
+
self._closed = False
|
|
113
|
+
self._initial_events: Deque[RealtimeEvent] = deque(initial_events or [])
|
|
114
|
+
|
|
115
|
+
@property
|
|
116
|
+
def request_id(self) -> str:
|
|
117
|
+
return self._session.request_id
|
|
118
|
+
|
|
119
|
+
@property
|
|
120
|
+
def session(self) -> RealtimeTranscriptionSession:
|
|
121
|
+
return self._session
|
|
122
|
+
|
|
123
|
+
@property
|
|
124
|
+
def audio_format(self) -> AudioFormat:
|
|
125
|
+
return self._audio_format
|
|
126
|
+
|
|
127
|
+
@property
|
|
128
|
+
def is_closed(self) -> bool:
|
|
129
|
+
return self._closed
|
|
130
|
+
|
|
131
|
+
async def send_audio(
|
|
132
|
+
self, audio_bytes: Union[bytes, bytearray, memoryview]
|
|
133
|
+
) -> None:
|
|
134
|
+
if self._closed:
|
|
135
|
+
raise RuntimeError("Connection is closed")
|
|
136
|
+
|
|
137
|
+
message = {
|
|
138
|
+
"type": "input_audio.append",
|
|
139
|
+
"audio": base64.b64encode(bytes(audio_bytes)).decode("ascii"),
|
|
140
|
+
}
|
|
141
|
+
await self._websocket.send(json.dumps(message))
|
|
142
|
+
|
|
143
|
+
async def update_session(self, audio_format: AudioFormat) -> None:
|
|
144
|
+
if self._closed:
|
|
145
|
+
raise RuntimeError("Connection is closed")
|
|
146
|
+
|
|
147
|
+
self._audio_format = audio_format
|
|
148
|
+
message = {
|
|
149
|
+
"type": "session.update",
|
|
150
|
+
"session": {"audio_format": audio_format.model_dump(mode="json")},
|
|
151
|
+
}
|
|
152
|
+
await self._websocket.send(json.dumps(message))
|
|
153
|
+
|
|
154
|
+
async def end_audio(self) -> None:
|
|
155
|
+
if self._closed:
|
|
156
|
+
return
|
|
157
|
+
await self._websocket.send(json.dumps({"type": "input_audio.end"}))
|
|
158
|
+
|
|
159
|
+
async def close(self, *, code: int = 1000, reason: str = "") -> None:
|
|
160
|
+
if self._closed:
|
|
161
|
+
return
|
|
162
|
+
self._closed = True
|
|
163
|
+
await self._websocket.close(code=code, reason=reason)
|
|
164
|
+
|
|
165
|
+
async def __aenter__(self) -> "RealtimeConnection":
|
|
166
|
+
return self
|
|
167
|
+
|
|
168
|
+
async def __aexit__(self, exc_type, exc, tb) -> None:
|
|
169
|
+
await self.close()
|
|
170
|
+
|
|
171
|
+
def __aiter__(self) -> AsyncIterator[RealtimeEvent]:
|
|
172
|
+
return self.events()
|
|
173
|
+
|
|
174
|
+
async def events(self) -> AsyncIterator[RealtimeEvent]:
|
|
175
|
+
# replay any handshake/prelude events (including session.created)
|
|
176
|
+
while self._initial_events:
|
|
177
|
+
ev = self._initial_events.popleft()
|
|
178
|
+
self._apply_session_updates(ev)
|
|
179
|
+
yield ev
|
|
180
|
+
|
|
181
|
+
try:
|
|
182
|
+
async for msg in self._websocket:
|
|
183
|
+
text = (
|
|
184
|
+
msg.decode("utf-8", errors="replace")
|
|
185
|
+
if isinstance(msg, (bytes, bytearray))
|
|
186
|
+
else msg
|
|
187
|
+
)
|
|
188
|
+
try:
|
|
189
|
+
data = json.loads(text)
|
|
190
|
+
except Exception as exc:
|
|
191
|
+
yield UnknownRealtimeEvent(
|
|
192
|
+
type=None, content=text, error=f"invalid JSON: {exc}"
|
|
193
|
+
)
|
|
194
|
+
continue
|
|
195
|
+
|
|
196
|
+
ev = parse_realtime_event(data)
|
|
197
|
+
self._apply_session_updates(ev)
|
|
198
|
+
yield ev
|
|
199
|
+
except CancelledError:
|
|
200
|
+
pass
|
|
201
|
+
finally:
|
|
202
|
+
await self.close()
|
|
203
|
+
|
|
204
|
+
def _apply_session_updates(self, ev: RealtimeEvent) -> None:
|
|
205
|
+
if isinstance(ev, RealtimeTranscriptionSessionCreated) or isinstance(ev, RealtimeTranscriptionSessionUpdated):
|
|
206
|
+
self._session = ev.session
|
|
207
|
+
self._audio_format = ev.session.audio_format
|
|
@@ -0,0 +1,271 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import json
|
|
5
|
+
import time
|
|
6
|
+
from typing import AsyncIterator, Mapping, Optional
|
|
7
|
+
from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse
|
|
8
|
+
|
|
9
|
+
try:
|
|
10
|
+
from websockets.asyncio.client import (
|
|
11
|
+
ClientConnection,
|
|
12
|
+
connect,
|
|
13
|
+
) # websockets >= 13.0
|
|
14
|
+
except ImportError as exc:
|
|
15
|
+
raise ImportError(
|
|
16
|
+
"The `websockets` package (>=13.0) is required for real-time transcription. "
|
|
17
|
+
"Install with: pip install 'mistralai[realtime]'"
|
|
18
|
+
) from exc
|
|
19
|
+
|
|
20
|
+
from mistralai import models, utils
|
|
21
|
+
from mistralai.models import (
|
|
22
|
+
AudioFormat,
|
|
23
|
+
RealtimeTranscriptionError,
|
|
24
|
+
RealtimeTranscriptionSession,
|
|
25
|
+
RealtimeTranscriptionSessionCreated,
|
|
26
|
+
)
|
|
27
|
+
from mistralai.sdkconfiguration import SDKConfiguration
|
|
28
|
+
from mistralai.utils import generate_url, get_security, get_security_from_env
|
|
29
|
+
|
|
30
|
+
from ..exceptions import RealtimeTranscriptionException, RealtimeTranscriptionWSError
|
|
31
|
+
from .connection import (
|
|
32
|
+
RealtimeConnection,
|
|
33
|
+
RealtimeEvent,
|
|
34
|
+
UnknownRealtimeEvent,
|
|
35
|
+
parse_realtime_event,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class RealtimeTranscription:
|
|
40
|
+
"""Client for realtime transcription over WebSocket (websockets >= 13.0)."""
|
|
41
|
+
|
|
42
|
+
def __init__(self, sdk_config: SDKConfiguration) -> None:
|
|
43
|
+
self._sdk_config = sdk_config
|
|
44
|
+
|
|
45
|
+
def _build_url(
|
|
46
|
+
self,
|
|
47
|
+
model: str,
|
|
48
|
+
*,
|
|
49
|
+
server_url: Optional[str],
|
|
50
|
+
query_params: Mapping[str, str],
|
|
51
|
+
) -> str:
|
|
52
|
+
if server_url is not None:
|
|
53
|
+
base_url = utils.remove_suffix(server_url, "/")
|
|
54
|
+
else:
|
|
55
|
+
base_url, _ = self._sdk_config.get_server_details()
|
|
56
|
+
|
|
57
|
+
url = generate_url(base_url, "/v1/audio/transcriptions/realtime", None)
|
|
58
|
+
|
|
59
|
+
parsed = urlparse(url)
|
|
60
|
+
merged = dict(parse_qsl(parsed.query, keep_blank_values=True))
|
|
61
|
+
merged["model"] = model
|
|
62
|
+
merged.update(dict(query_params))
|
|
63
|
+
|
|
64
|
+
return urlunparse(parsed._replace(query=urlencode(merged)))
|
|
65
|
+
|
|
66
|
+
async def connect(
|
|
67
|
+
self,
|
|
68
|
+
model: str,
|
|
69
|
+
audio_format: Optional[AudioFormat] = None,
|
|
70
|
+
server_url: Optional[str] = None,
|
|
71
|
+
timeout_ms: Optional[int] = None,
|
|
72
|
+
http_headers: Optional[Mapping[str, str]] = None,
|
|
73
|
+
) -> RealtimeConnection:
|
|
74
|
+
if timeout_ms is None:
|
|
75
|
+
timeout_ms = self._sdk_config.timeout_ms
|
|
76
|
+
|
|
77
|
+
security = self._sdk_config.security
|
|
78
|
+
if security is not None and callable(security):
|
|
79
|
+
security = security()
|
|
80
|
+
|
|
81
|
+
resolved_security = get_security_from_env(security, models.Security)
|
|
82
|
+
|
|
83
|
+
headers: dict[str, str] = {}
|
|
84
|
+
query_params: dict[str, str] = {}
|
|
85
|
+
|
|
86
|
+
if resolved_security is not None:
|
|
87
|
+
security_headers, security_query = get_security(resolved_security)
|
|
88
|
+
headers |= security_headers
|
|
89
|
+
for key, values in security_query.items():
|
|
90
|
+
if values:
|
|
91
|
+
query_params[key] = values[-1]
|
|
92
|
+
|
|
93
|
+
if http_headers is not None:
|
|
94
|
+
headers |= dict(http_headers)
|
|
95
|
+
|
|
96
|
+
url = self._build_url(model, server_url=server_url, query_params=query_params)
|
|
97
|
+
|
|
98
|
+
parsed = urlparse(url)
|
|
99
|
+
if parsed.scheme == "https":
|
|
100
|
+
parsed = parsed._replace(scheme="wss")
|
|
101
|
+
elif parsed.scheme == "http":
|
|
102
|
+
parsed = parsed._replace(scheme="ws")
|
|
103
|
+
ws_url = urlunparse(parsed)
|
|
104
|
+
open_timeout = None if timeout_ms is None else timeout_ms / 1000.0
|
|
105
|
+
user_agent = self._sdk_config.user_agent
|
|
106
|
+
|
|
107
|
+
websocket: Optional[ClientConnection] = None
|
|
108
|
+
try:
|
|
109
|
+
websocket = await connect(
|
|
110
|
+
ws_url,
|
|
111
|
+
additional_headers=dict(headers),
|
|
112
|
+
open_timeout=open_timeout,
|
|
113
|
+
user_agent_header=user_agent,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
session, initial_events = await _recv_handshake(
|
|
117
|
+
websocket, timeout_ms=timeout_ms
|
|
118
|
+
)
|
|
119
|
+
connection = RealtimeConnection(
|
|
120
|
+
websocket=websocket,
|
|
121
|
+
session=session,
|
|
122
|
+
initial_events=initial_events,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
if audio_format is not None:
|
|
126
|
+
await connection.update_session(audio_format)
|
|
127
|
+
|
|
128
|
+
return connection
|
|
129
|
+
|
|
130
|
+
except RealtimeTranscriptionException:
|
|
131
|
+
if websocket is not None:
|
|
132
|
+
await websocket.close()
|
|
133
|
+
raise
|
|
134
|
+
except Exception as exc:
|
|
135
|
+
if websocket is not None:
|
|
136
|
+
await websocket.close()
|
|
137
|
+
raise RealtimeTranscriptionException(f"Failed to connect: {exc}") from exc
|
|
138
|
+
|
|
139
|
+
async def transcribe_stream(
|
|
140
|
+
self,
|
|
141
|
+
audio_stream: AsyncIterator[bytes],
|
|
142
|
+
model: str,
|
|
143
|
+
audio_format: Optional[AudioFormat] = None,
|
|
144
|
+
server_url: Optional[str] = None,
|
|
145
|
+
timeout_ms: Optional[int] = None,
|
|
146
|
+
http_headers: Optional[Mapping[str, str]] = None,
|
|
147
|
+
) -> AsyncIterator[RealtimeEvent]:
|
|
148
|
+
"""
|
|
149
|
+
Flow
|
|
150
|
+
- opens connection
|
|
151
|
+
- streams audio in background
|
|
152
|
+
- yields events from the connection
|
|
153
|
+
"""
|
|
154
|
+
async with await self.connect(
|
|
155
|
+
model=model,
|
|
156
|
+
audio_format=audio_format,
|
|
157
|
+
server_url=server_url,
|
|
158
|
+
timeout_ms=timeout_ms,
|
|
159
|
+
http_headers=http_headers,
|
|
160
|
+
) as connection:
|
|
161
|
+
|
|
162
|
+
async def _send() -> None:
|
|
163
|
+
async for chunk in audio_stream:
|
|
164
|
+
if connection.is_closed:
|
|
165
|
+
break
|
|
166
|
+
await connection.send_audio(chunk)
|
|
167
|
+
await connection.end_audio()
|
|
168
|
+
|
|
169
|
+
send_task = asyncio.create_task(_send())
|
|
170
|
+
|
|
171
|
+
try:
|
|
172
|
+
async for event in connection:
|
|
173
|
+
yield event
|
|
174
|
+
|
|
175
|
+
# stop early (caller still sees the terminating event)
|
|
176
|
+
if isinstance(event, RealtimeTranscriptionError):
|
|
177
|
+
break
|
|
178
|
+
if getattr(event, "type", None) == "transcription.done":
|
|
179
|
+
break
|
|
180
|
+
finally:
|
|
181
|
+
send_task.cancel()
|
|
182
|
+
try:
|
|
183
|
+
await send_task
|
|
184
|
+
except asyncio.CancelledError:
|
|
185
|
+
pass
|
|
186
|
+
await connection.close()
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def _extract_error_message(payload: dict) -> str:
|
|
190
|
+
err = payload.get("error")
|
|
191
|
+
if isinstance(err, dict):
|
|
192
|
+
msg = err.get("message")
|
|
193
|
+
if isinstance(msg, str):
|
|
194
|
+
return msg
|
|
195
|
+
if isinstance(msg, dict):
|
|
196
|
+
detail = msg.get("detail")
|
|
197
|
+
if isinstance(detail, str):
|
|
198
|
+
return detail
|
|
199
|
+
return "Realtime transcription error"
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
async def _recv_handshake(
|
|
203
|
+
websocket: ClientConnection,
|
|
204
|
+
*,
|
|
205
|
+
timeout_ms: Optional[int],
|
|
206
|
+
) -> tuple[RealtimeTranscriptionSession, list[RealtimeEvent]]:
|
|
207
|
+
"""
|
|
208
|
+
Read messages until session.created or error.
|
|
209
|
+
Replay all messages read during handshake as initial events (lossless).
|
|
210
|
+
"""
|
|
211
|
+
timeout_s = None if timeout_ms is None else timeout_ms / 1000.0
|
|
212
|
+
deadline = None if timeout_s is None else (time.monotonic() + timeout_s)
|
|
213
|
+
|
|
214
|
+
initial_events: list[RealtimeEvent] = []
|
|
215
|
+
|
|
216
|
+
def remaining() -> Optional[float]:
|
|
217
|
+
if deadline is None:
|
|
218
|
+
return None
|
|
219
|
+
return max(0.0, deadline - time.monotonic())
|
|
220
|
+
|
|
221
|
+
try:
|
|
222
|
+
while True:
|
|
223
|
+
raw = await asyncio.wait_for(websocket.recv(), timeout=remaining())
|
|
224
|
+
text = (
|
|
225
|
+
raw.decode("utf-8", errors="replace")
|
|
226
|
+
if isinstance(raw, (bytes, bytearray))
|
|
227
|
+
else raw
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
try:
|
|
231
|
+
payload = json.loads(text)
|
|
232
|
+
except Exception as exc:
|
|
233
|
+
initial_events.append(
|
|
234
|
+
UnknownRealtimeEvent(
|
|
235
|
+
type=None, content=text, error=f"invalid JSON: {exc}"
|
|
236
|
+
)
|
|
237
|
+
)
|
|
238
|
+
continue
|
|
239
|
+
|
|
240
|
+
msg_type = payload.get("type") if isinstance(payload, dict) else None
|
|
241
|
+
if msg_type == "error" and isinstance(payload, dict):
|
|
242
|
+
parsed = parse_realtime_event(payload)
|
|
243
|
+
initial_events.append(parsed)
|
|
244
|
+
if isinstance(parsed, RealtimeTranscriptionError):
|
|
245
|
+
raise RealtimeTranscriptionWSError(
|
|
246
|
+
_extract_error_message(payload),
|
|
247
|
+
payload=parsed,
|
|
248
|
+
raw=payload,
|
|
249
|
+
)
|
|
250
|
+
raise RealtimeTranscriptionWSError(
|
|
251
|
+
_extract_error_message(payload),
|
|
252
|
+
payload=None,
|
|
253
|
+
raw=payload,
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
event = parse_realtime_event(payload)
|
|
257
|
+
initial_events.append(event)
|
|
258
|
+
|
|
259
|
+
if isinstance(event, RealtimeTranscriptionSessionCreated):
|
|
260
|
+
return event.session, initial_events
|
|
261
|
+
|
|
262
|
+
except asyncio.TimeoutError as exc:
|
|
263
|
+
raise RealtimeTranscriptionException(
|
|
264
|
+
"Timeout waiting for session creation."
|
|
265
|
+
) from exc
|
|
266
|
+
except RealtimeTranscriptionException:
|
|
267
|
+
raise
|
|
268
|
+
except Exception as exc:
|
|
269
|
+
raise RealtimeTranscriptionException(
|
|
270
|
+
f"Unexpected websocket handshake failure: {exc}"
|
|
271
|
+
) from exc
|