mistralai 1.10.1__py3-none-any.whl → 1.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (251) hide show
  1. mistralai/_version.py +3 -3
  2. mistralai/accesses.py +22 -12
  3. mistralai/agents.py +88 -44
  4. mistralai/audio.py +20 -0
  5. mistralai/basesdk.py +6 -0
  6. mistralai/chat.py +96 -40
  7. mistralai/classifiers.py +35 -22
  8. mistralai/conversations.py +234 -72
  9. mistralai/documents.py +72 -26
  10. mistralai/embeddings.py +17 -8
  11. mistralai/extra/__init__.py +48 -0
  12. mistralai/extra/exceptions.py +49 -4
  13. mistralai/extra/realtime/__init__.py +25 -0
  14. mistralai/extra/realtime/connection.py +207 -0
  15. mistralai/extra/realtime/transcription.py +271 -0
  16. mistralai/files.py +64 -24
  17. mistralai/fim.py +20 -12
  18. mistralai/httpclient.py +0 -1
  19. mistralai/jobs.py +65 -26
  20. mistralai/libraries.py +20 -10
  21. mistralai/mistral_agents.py +825 -34
  22. mistralai/mistral_jobs.py +33 -14
  23. mistralai/models/__init__.py +119 -0
  24. mistralai/models/agent.py +1 -1
  25. mistralai/models/agentaliasresponse.py +23 -0
  26. mistralai/models/agentconversation.py +15 -5
  27. mistralai/models/agenthandoffdoneevent.py +1 -1
  28. mistralai/models/agenthandoffentry.py +3 -2
  29. mistralai/models/agenthandoffstartedevent.py +1 -1
  30. mistralai/models/agents_api_v1_agents_create_or_update_aliasop.py +26 -0
  31. mistralai/models/agents_api_v1_agents_get_versionop.py +21 -0
  32. mistralai/models/agents_api_v1_agents_getop.py +12 -3
  33. mistralai/models/agents_api_v1_agents_list_version_aliasesop.py +16 -0
  34. mistralai/models/agents_api_v1_agents_list_versionsop.py +33 -0
  35. mistralai/models/agents_api_v1_agents_listop.py +4 -0
  36. mistralai/models/agentscompletionrequest.py +2 -5
  37. mistralai/models/agentscompletionstreamrequest.py +2 -5
  38. mistralai/models/archiveftmodelout.py +1 -1
  39. mistralai/models/assistantmessage.py +1 -1
  40. mistralai/models/audiochunk.py +1 -1
  41. mistralai/models/audioencoding.py +6 -1
  42. mistralai/models/audioformat.py +2 -4
  43. mistralai/models/audiotranscriptionrequest.py +8 -0
  44. mistralai/models/audiotranscriptionrequeststream.py +8 -0
  45. mistralai/models/basemodelcard.py +1 -1
  46. mistralai/models/batchjobin.py +2 -4
  47. mistralai/models/batchjobout.py +1 -1
  48. mistralai/models/batchjobsout.py +1 -1
  49. mistralai/models/chatcompletionchoice.py +10 -5
  50. mistralai/models/chatcompletionrequest.py +2 -5
  51. mistralai/models/chatcompletionstreamrequest.py +2 -5
  52. mistralai/models/classifierdetailedjobout.py +4 -2
  53. mistralai/models/classifierftmodelout.py +3 -2
  54. mistralai/models/classifierjobout.py +4 -2
  55. mistralai/models/codeinterpretertool.py +1 -1
  56. mistralai/models/completiondetailedjobout.py +5 -2
  57. mistralai/models/completionftmodelout.py +3 -2
  58. mistralai/models/completionjobout.py +5 -2
  59. mistralai/models/completionresponsestreamchoice.py +9 -8
  60. mistralai/models/conversationappendrequest.py +4 -1
  61. mistralai/models/conversationappendstreamrequest.py +4 -1
  62. mistralai/models/conversationhistory.py +2 -1
  63. mistralai/models/conversationmessages.py +1 -1
  64. mistralai/models/conversationrequest.py +13 -3
  65. mistralai/models/conversationresponse.py +2 -1
  66. mistralai/models/conversationrestartrequest.py +22 -5
  67. mistralai/models/conversationrestartstreamrequest.py +24 -5
  68. mistralai/models/conversationstreamrequest.py +17 -3
  69. mistralai/models/documentlibrarytool.py +1 -1
  70. mistralai/models/documenturlchunk.py +1 -1
  71. mistralai/models/embeddingdtype.py +7 -1
  72. mistralai/models/encodingformat.py +4 -1
  73. mistralai/models/entitytype.py +8 -1
  74. mistralai/models/filepurpose.py +8 -1
  75. mistralai/models/files_api_routes_list_filesop.py +12 -12
  76. mistralai/models/files_api_routes_upload_fileop.py +2 -6
  77. mistralai/models/fileschema.py +3 -5
  78. mistralai/models/finetuneablemodeltype.py +4 -1
  79. mistralai/models/ftclassifierlossfunction.py +4 -1
  80. mistralai/models/ftmodelcard.py +1 -1
  81. mistralai/models/functioncallentry.py +3 -2
  82. mistralai/models/functioncallevent.py +1 -1
  83. mistralai/models/functionresultentry.py +3 -2
  84. mistralai/models/functiontool.py +1 -1
  85. mistralai/models/githubrepositoryin.py +1 -1
  86. mistralai/models/githubrepositoryout.py +1 -1
  87. mistralai/models/httpvalidationerror.py +4 -2
  88. mistralai/models/imagegenerationtool.py +1 -1
  89. mistralai/models/imageurlchunk.py +1 -1
  90. mistralai/models/jobsout.py +1 -1
  91. mistralai/models/legacyjobmetadataout.py +1 -1
  92. mistralai/models/messageinputentry.py +9 -3
  93. mistralai/models/messageoutputentry.py +6 -3
  94. mistralai/models/messageoutputevent.py +4 -2
  95. mistralai/models/mistralerror.py +11 -7
  96. mistralai/models/mistralpromptmode.py +5 -1
  97. mistralai/models/modelcapabilities.py +3 -0
  98. mistralai/models/modelconversation.py +1 -1
  99. mistralai/models/no_response_error.py +5 -1
  100. mistralai/models/ocrrequest.py +11 -1
  101. mistralai/models/ocrtableobject.py +4 -1
  102. mistralai/models/realtimetranscriptionerror.py +27 -0
  103. mistralai/models/realtimetranscriptionerrordetail.py +29 -0
  104. mistralai/models/realtimetranscriptionsession.py +20 -0
  105. mistralai/models/realtimetranscriptionsessioncreated.py +30 -0
  106. mistralai/models/realtimetranscriptionsessionupdated.py +30 -0
  107. mistralai/models/referencechunk.py +1 -1
  108. mistralai/models/requestsource.py +5 -1
  109. mistralai/models/responsedoneevent.py +1 -1
  110. mistralai/models/responseerrorevent.py +1 -1
  111. mistralai/models/responseformats.py +5 -1
  112. mistralai/models/responsestartedevent.py +1 -1
  113. mistralai/models/responsevalidationerror.py +2 -0
  114. mistralai/models/retrievefileout.py +3 -5
  115. mistralai/models/sampletype.py +7 -1
  116. mistralai/models/sdkerror.py +2 -0
  117. mistralai/models/shareenum.py +7 -1
  118. mistralai/models/sharingdelete.py +2 -4
  119. mistralai/models/sharingin.py +3 -5
  120. mistralai/models/source.py +8 -1
  121. mistralai/models/systemmessage.py +1 -1
  122. mistralai/models/textchunk.py +1 -1
  123. mistralai/models/thinkchunk.py +1 -1
  124. mistralai/models/timestampgranularity.py +4 -1
  125. mistralai/models/tool.py +2 -6
  126. mistralai/models/toolcall.py +2 -6
  127. mistralai/models/toolchoice.py +2 -6
  128. mistralai/models/toolchoiceenum.py +6 -1
  129. mistralai/models/toolexecutiondeltaevent.py +2 -1
  130. mistralai/models/toolexecutiondoneevent.py +2 -1
  131. mistralai/models/toolexecutionentry.py +4 -2
  132. mistralai/models/toolexecutionstartedevent.py +2 -1
  133. mistralai/models/toolfilechunk.py +2 -1
  134. mistralai/models/toolmessage.py +1 -1
  135. mistralai/models/toolreferencechunk.py +2 -1
  136. mistralai/models/tooltypes.py +1 -1
  137. mistralai/models/transcriptionsegmentchunk.py +42 -3
  138. mistralai/models/transcriptionstreamdone.py +1 -1
  139. mistralai/models/transcriptionstreamlanguage.py +1 -1
  140. mistralai/models/transcriptionstreamsegmentdelta.py +39 -3
  141. mistralai/models/transcriptionstreamtextdelta.py +1 -1
  142. mistralai/models/unarchiveftmodelout.py +1 -1
  143. mistralai/models/uploadfileout.py +3 -5
  144. mistralai/models/usermessage.py +1 -1
  145. mistralai/models/wandbintegration.py +1 -1
  146. mistralai/models/wandbintegrationout.py +1 -1
  147. mistralai/models/websearchpremiumtool.py +1 -1
  148. mistralai/models/websearchtool.py +1 -1
  149. mistralai/models_.py +24 -12
  150. mistralai/ocr.py +38 -10
  151. mistralai/sdk.py +2 -2
  152. mistralai/transcriptions.py +52 -12
  153. mistralai/types/basemodel.py +41 -3
  154. mistralai/utils/__init__.py +0 -3
  155. mistralai/utils/annotations.py +32 -8
  156. mistralai/utils/enums.py +60 -0
  157. mistralai/utils/forms.py +21 -10
  158. mistralai/utils/queryparams.py +14 -2
  159. mistralai/utils/requestbodies.py +3 -3
  160. mistralai/utils/retries.py +69 -5
  161. mistralai/utils/serializers.py +0 -20
  162. mistralai/utils/unmarshal_json_response.py +15 -1
  163. {mistralai-1.10.1.dist-info → mistralai-1.12.0.dist-info}/METADATA +28 -31
  164. {mistralai-1.10.1.dist-info → mistralai-1.12.0.dist-info}/RECORD +251 -237
  165. mistralai_azure/_version.py +3 -3
  166. mistralai_azure/basesdk.py +6 -0
  167. mistralai_azure/chat.py +27 -15
  168. mistralai_azure/httpclient.py +0 -1
  169. mistralai_azure/models/__init__.py +16 -1
  170. mistralai_azure/models/assistantmessage.py +1 -1
  171. mistralai_azure/models/chatcompletionchoice.py +10 -7
  172. mistralai_azure/models/chatcompletionrequest.py +8 -6
  173. mistralai_azure/models/chatcompletionstreamrequest.py +8 -6
  174. mistralai_azure/models/completionresponsestreamchoice.py +11 -7
  175. mistralai_azure/models/documenturlchunk.py +1 -1
  176. mistralai_azure/models/httpvalidationerror.py +4 -2
  177. mistralai_azure/models/imageurlchunk.py +1 -1
  178. mistralai_azure/models/mistralazureerror.py +11 -7
  179. mistralai_azure/models/mistralpromptmode.py +1 -1
  180. mistralai_azure/models/no_response_error.py +5 -1
  181. mistralai_azure/models/ocrpageobject.py +32 -5
  182. mistralai_azure/models/ocrrequest.py +20 -1
  183. mistralai_azure/models/ocrtableobject.py +34 -0
  184. mistralai_azure/models/referencechunk.py +1 -1
  185. mistralai_azure/models/responseformats.py +5 -1
  186. mistralai_azure/models/responsevalidationerror.py +2 -0
  187. mistralai_azure/models/sdkerror.py +2 -0
  188. mistralai_azure/models/systemmessage.py +1 -1
  189. mistralai_azure/models/textchunk.py +1 -1
  190. mistralai_azure/models/thinkchunk.py +1 -1
  191. mistralai_azure/models/tool.py +2 -6
  192. mistralai_azure/models/toolcall.py +2 -6
  193. mistralai_azure/models/toolchoice.py +2 -6
  194. mistralai_azure/models/toolchoiceenum.py +6 -1
  195. mistralai_azure/models/toolmessage.py +1 -1
  196. mistralai_azure/models/tooltypes.py +1 -1
  197. mistralai_azure/models/usermessage.py +1 -1
  198. mistralai_azure/ocr.py +26 -6
  199. mistralai_azure/types/basemodel.py +41 -3
  200. mistralai_azure/utils/__init__.py +0 -3
  201. mistralai_azure/utils/annotations.py +32 -8
  202. mistralai_azure/utils/enums.py +60 -0
  203. mistralai_azure/utils/forms.py +21 -10
  204. mistralai_azure/utils/queryparams.py +14 -2
  205. mistralai_azure/utils/requestbodies.py +3 -3
  206. mistralai_azure/utils/retries.py +69 -5
  207. mistralai_azure/utils/serializers.py +0 -20
  208. mistralai_azure/utils/unmarshal_json_response.py +15 -1
  209. mistralai_gcp/_version.py +3 -3
  210. mistralai_gcp/basesdk.py +6 -0
  211. mistralai_gcp/chat.py +27 -15
  212. mistralai_gcp/fim.py +27 -15
  213. mistralai_gcp/httpclient.py +0 -1
  214. mistralai_gcp/models/assistantmessage.py +1 -1
  215. mistralai_gcp/models/chatcompletionchoice.py +10 -7
  216. mistralai_gcp/models/chatcompletionrequest.py +8 -6
  217. mistralai_gcp/models/chatcompletionstreamrequest.py +8 -6
  218. mistralai_gcp/models/completionresponsestreamchoice.py +11 -7
  219. mistralai_gcp/models/fimcompletionrequest.py +6 -1
  220. mistralai_gcp/models/fimcompletionstreamrequest.py +6 -1
  221. mistralai_gcp/models/httpvalidationerror.py +4 -2
  222. mistralai_gcp/models/imageurlchunk.py +1 -1
  223. mistralai_gcp/models/mistralgcperror.py +11 -7
  224. mistralai_gcp/models/mistralpromptmode.py +1 -1
  225. mistralai_gcp/models/no_response_error.py +5 -1
  226. mistralai_gcp/models/referencechunk.py +1 -1
  227. mistralai_gcp/models/responseformats.py +5 -1
  228. mistralai_gcp/models/responsevalidationerror.py +2 -0
  229. mistralai_gcp/models/sdkerror.py +2 -0
  230. mistralai_gcp/models/systemmessage.py +1 -1
  231. mistralai_gcp/models/textchunk.py +1 -1
  232. mistralai_gcp/models/thinkchunk.py +1 -1
  233. mistralai_gcp/models/tool.py +2 -6
  234. mistralai_gcp/models/toolcall.py +2 -6
  235. mistralai_gcp/models/toolchoice.py +2 -6
  236. mistralai_gcp/models/toolchoiceenum.py +6 -1
  237. mistralai_gcp/models/toolmessage.py +1 -1
  238. mistralai_gcp/models/tooltypes.py +1 -1
  239. mistralai_gcp/models/usermessage.py +1 -1
  240. mistralai_gcp/types/basemodel.py +41 -3
  241. mistralai_gcp/utils/__init__.py +0 -3
  242. mistralai_gcp/utils/annotations.py +32 -8
  243. mistralai_gcp/utils/enums.py +60 -0
  244. mistralai_gcp/utils/forms.py +21 -10
  245. mistralai_gcp/utils/queryparams.py +14 -2
  246. mistralai_gcp/utils/requestbodies.py +3 -3
  247. mistralai_gcp/utils/retries.py +69 -5
  248. mistralai_gcp/utils/serializers.py +0 -20
  249. mistralai_gcp/utils/unmarshal_json_response.py +15 -1
  250. {mistralai-1.10.1.dist-info → mistralai-1.12.0.dist-info}/WHEEL +0 -0
  251. {mistralai-1.10.1.dist-info → mistralai-1.12.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,207 @@
1
+ from __future__ import annotations
2
+
3
+ import base64
4
+ import json
5
+ from asyncio import CancelledError
6
+ from collections import deque
7
+ from typing import Any, AsyncIterator, Deque, Optional, Union
8
+
9
+ from pydantic import ValidationError, BaseModel
10
+
11
+ try:
12
+ from websockets.asyncio.client import ClientConnection # websockets >= 13.0
13
+ except ImportError as exc:
14
+ raise ImportError(
15
+ "The `websockets` package (>=13.0) is required for real-time transcription. "
16
+ "Install with: pip install 'mistralai[realtime]'"
17
+ ) from exc
18
+
19
+ from mistralai.models import (
20
+ AudioFormat,
21
+ RealtimeTranscriptionError,
22
+ RealtimeTranscriptionSession,
23
+ RealtimeTranscriptionSessionCreated,
24
+ RealtimeTranscriptionSessionUpdated,
25
+ TranscriptionStreamDone,
26
+ TranscriptionStreamLanguage,
27
+ TranscriptionStreamSegmentDelta,
28
+ TranscriptionStreamTextDelta,
29
+ )
30
+
31
+
32
+ class UnknownRealtimeEvent(BaseModel):
33
+ """
34
+ Forward-compat fallback event:
35
+ - unknown message type
36
+ - invalid JSON payload
37
+ - schema validation failure
38
+ """
39
+ type: Optional[str]
40
+ content: Any
41
+ error: Optional[str] = None
42
+
43
+
44
+ RealtimeEvent = Union[
45
+ # session lifecycle
46
+ RealtimeTranscriptionSessionCreated,
47
+ RealtimeTranscriptionSessionUpdated,
48
+ # server errors
49
+ RealtimeTranscriptionError,
50
+ # transcription events
51
+ TranscriptionStreamLanguage,
52
+ TranscriptionStreamSegmentDelta,
53
+ TranscriptionStreamTextDelta,
54
+ TranscriptionStreamDone,
55
+ # forward-compat fallback
56
+ UnknownRealtimeEvent,
57
+ ]
58
+
59
+
60
+ _MESSAGE_MODELS: dict[str, Any] = {
61
+ "session.created": RealtimeTranscriptionSessionCreated,
62
+ "session.updated": RealtimeTranscriptionSessionUpdated,
63
+ "error": RealtimeTranscriptionError,
64
+ "transcription.language": TranscriptionStreamLanguage,
65
+ "transcription.segment": TranscriptionStreamSegmentDelta,
66
+ "transcription.text.delta": TranscriptionStreamTextDelta,
67
+ "transcription.done": TranscriptionStreamDone,
68
+ }
69
+
70
+
71
+ def parse_realtime_event(payload: Any) -> RealtimeEvent:
72
+ """
73
+ Tolerant parser:
74
+ - unknown event type -> UnknownRealtimeEvent
75
+ - validation failures -> UnknownRealtimeEvent (includes error string)
76
+ - invalid payload -> UnknownRealtimeEvent
77
+ """
78
+ if not isinstance(payload, dict):
79
+ return UnknownRealtimeEvent(
80
+ type=None, content=payload, error="expected JSON object"
81
+ )
82
+
83
+ msg_type = payload.get("type")
84
+ if not isinstance(msg_type, str):
85
+ return UnknownRealtimeEvent(
86
+ type=None, content=payload, error="missing/invalid 'type'"
87
+ )
88
+
89
+ model_cls = _MESSAGE_MODELS.get(msg_type)
90
+ if model_cls is None:
91
+ return UnknownRealtimeEvent(
92
+ type=msg_type, content=payload, error="unknown event type"
93
+ )
94
+ try:
95
+ parsed = model_cls.model_validate(payload)
96
+ return parsed
97
+ except ValidationError as exc:
98
+ return UnknownRealtimeEvent(type=msg_type, content=payload, error=str(exc))
99
+
100
+
101
+ class RealtimeConnection:
102
+ def __init__(
103
+ self,
104
+ websocket: ClientConnection,
105
+ session: RealtimeTranscriptionSession,
106
+ *,
107
+ initial_events: Optional[list[RealtimeEvent]] = None,
108
+ ) -> None:
109
+ self._websocket = websocket
110
+ self._session = session
111
+ self._audio_format = session.audio_format
112
+ self._closed = False
113
+ self._initial_events: Deque[RealtimeEvent] = deque(initial_events or [])
114
+
115
+ @property
116
+ def request_id(self) -> str:
117
+ return self._session.request_id
118
+
119
+ @property
120
+ def session(self) -> RealtimeTranscriptionSession:
121
+ return self._session
122
+
123
+ @property
124
+ def audio_format(self) -> AudioFormat:
125
+ return self._audio_format
126
+
127
+ @property
128
+ def is_closed(self) -> bool:
129
+ return self._closed
130
+
131
+ async def send_audio(
132
+ self, audio_bytes: Union[bytes, bytearray, memoryview]
133
+ ) -> None:
134
+ if self._closed:
135
+ raise RuntimeError("Connection is closed")
136
+
137
+ message = {
138
+ "type": "input_audio.append",
139
+ "audio": base64.b64encode(bytes(audio_bytes)).decode("ascii"),
140
+ }
141
+ await self._websocket.send(json.dumps(message))
142
+
143
+ async def update_session(self, audio_format: AudioFormat) -> None:
144
+ if self._closed:
145
+ raise RuntimeError("Connection is closed")
146
+
147
+ self._audio_format = audio_format
148
+ message = {
149
+ "type": "session.update",
150
+ "session": {"audio_format": audio_format.model_dump(mode="json")},
151
+ }
152
+ await self._websocket.send(json.dumps(message))
153
+
154
+ async def end_audio(self) -> None:
155
+ if self._closed:
156
+ return
157
+ await self._websocket.send(json.dumps({"type": "input_audio.end"}))
158
+
159
+ async def close(self, *, code: int = 1000, reason: str = "") -> None:
160
+ if self._closed:
161
+ return
162
+ self._closed = True
163
+ await self._websocket.close(code=code, reason=reason)
164
+
165
+ async def __aenter__(self) -> "RealtimeConnection":
166
+ return self
167
+
168
+ async def __aexit__(self, exc_type, exc, tb) -> None:
169
+ await self.close()
170
+
171
+ def __aiter__(self) -> AsyncIterator[RealtimeEvent]:
172
+ return self.events()
173
+
174
+ async def events(self) -> AsyncIterator[RealtimeEvent]:
175
+ # replay any handshake/prelude events (including session.created)
176
+ while self._initial_events:
177
+ ev = self._initial_events.popleft()
178
+ self._apply_session_updates(ev)
179
+ yield ev
180
+
181
+ try:
182
+ async for msg in self._websocket:
183
+ text = (
184
+ msg.decode("utf-8", errors="replace")
185
+ if isinstance(msg, (bytes, bytearray))
186
+ else msg
187
+ )
188
+ try:
189
+ data = json.loads(text)
190
+ except Exception as exc:
191
+ yield UnknownRealtimeEvent(
192
+ type=None, content=text, error=f"invalid JSON: {exc}"
193
+ )
194
+ continue
195
+
196
+ ev = parse_realtime_event(data)
197
+ self._apply_session_updates(ev)
198
+ yield ev
199
+ except CancelledError:
200
+ pass
201
+ finally:
202
+ await self.close()
203
+
204
+ def _apply_session_updates(self, ev: RealtimeEvent) -> None:
205
+ if isinstance(ev, RealtimeTranscriptionSessionCreated) or isinstance(ev, RealtimeTranscriptionSessionUpdated):
206
+ self._session = ev.session
207
+ self._audio_format = ev.session.audio_format
@@ -0,0 +1,271 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import json
5
+ import time
6
+ from typing import AsyncIterator, Mapping, Optional
7
+ from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse
8
+
9
+ try:
10
+ from websockets.asyncio.client import (
11
+ ClientConnection,
12
+ connect,
13
+ ) # websockets >= 13.0
14
+ except ImportError as exc:
15
+ raise ImportError(
16
+ "The `websockets` package (>=13.0) is required for real-time transcription. "
17
+ "Install with: pip install 'mistralai[realtime]'"
18
+ ) from exc
19
+
20
+ from mistralai import models, utils
21
+ from mistralai.models import (
22
+ AudioFormat,
23
+ RealtimeTranscriptionError,
24
+ RealtimeTranscriptionSession,
25
+ RealtimeTranscriptionSessionCreated,
26
+ )
27
+ from mistralai.sdkconfiguration import SDKConfiguration
28
+ from mistralai.utils import generate_url, get_security, get_security_from_env
29
+
30
+ from ..exceptions import RealtimeTranscriptionException, RealtimeTranscriptionWSError
31
+ from .connection import (
32
+ RealtimeConnection,
33
+ RealtimeEvent,
34
+ UnknownRealtimeEvent,
35
+ parse_realtime_event,
36
+ )
37
+
38
+
39
+ class RealtimeTranscription:
40
+ """Client for realtime transcription over WebSocket (websockets >= 13.0)."""
41
+
42
+ def __init__(self, sdk_config: SDKConfiguration) -> None:
43
+ self._sdk_config = sdk_config
44
+
45
+ def _build_url(
46
+ self,
47
+ model: str,
48
+ *,
49
+ server_url: Optional[str],
50
+ query_params: Mapping[str, str],
51
+ ) -> str:
52
+ if server_url is not None:
53
+ base_url = utils.remove_suffix(server_url, "/")
54
+ else:
55
+ base_url, _ = self._sdk_config.get_server_details()
56
+
57
+ url = generate_url(base_url, "/v1/audio/transcriptions/realtime", None)
58
+
59
+ parsed = urlparse(url)
60
+ merged = dict(parse_qsl(parsed.query, keep_blank_values=True))
61
+ merged["model"] = model
62
+ merged.update(dict(query_params))
63
+
64
+ return urlunparse(parsed._replace(query=urlencode(merged)))
65
+
66
+ async def connect(
67
+ self,
68
+ model: str,
69
+ audio_format: Optional[AudioFormat] = None,
70
+ server_url: Optional[str] = None,
71
+ timeout_ms: Optional[int] = None,
72
+ http_headers: Optional[Mapping[str, str]] = None,
73
+ ) -> RealtimeConnection:
74
+ if timeout_ms is None:
75
+ timeout_ms = self._sdk_config.timeout_ms
76
+
77
+ security = self._sdk_config.security
78
+ if security is not None and callable(security):
79
+ security = security()
80
+
81
+ resolved_security = get_security_from_env(security, models.Security)
82
+
83
+ headers: dict[str, str] = {}
84
+ query_params: dict[str, str] = {}
85
+
86
+ if resolved_security is not None:
87
+ security_headers, security_query = get_security(resolved_security)
88
+ headers |= security_headers
89
+ for key, values in security_query.items():
90
+ if values:
91
+ query_params[key] = values[-1]
92
+
93
+ if http_headers is not None:
94
+ headers |= dict(http_headers)
95
+
96
+ url = self._build_url(model, server_url=server_url, query_params=query_params)
97
+
98
+ parsed = urlparse(url)
99
+ if parsed.scheme == "https":
100
+ parsed = parsed._replace(scheme="wss")
101
+ elif parsed.scheme == "http":
102
+ parsed = parsed._replace(scheme="ws")
103
+ ws_url = urlunparse(parsed)
104
+ open_timeout = None if timeout_ms is None else timeout_ms / 1000.0
105
+ user_agent = self._sdk_config.user_agent
106
+
107
+ websocket: Optional[ClientConnection] = None
108
+ try:
109
+ websocket = await connect(
110
+ ws_url,
111
+ additional_headers=dict(headers),
112
+ open_timeout=open_timeout,
113
+ user_agent_header=user_agent,
114
+ )
115
+
116
+ session, initial_events = await _recv_handshake(
117
+ websocket, timeout_ms=timeout_ms
118
+ )
119
+ connection = RealtimeConnection(
120
+ websocket=websocket,
121
+ session=session,
122
+ initial_events=initial_events,
123
+ )
124
+
125
+ if audio_format is not None:
126
+ await connection.update_session(audio_format)
127
+
128
+ return connection
129
+
130
+ except RealtimeTranscriptionException:
131
+ if websocket is not None:
132
+ await websocket.close()
133
+ raise
134
+ except Exception as exc:
135
+ if websocket is not None:
136
+ await websocket.close()
137
+ raise RealtimeTranscriptionException(f"Failed to connect: {exc}") from exc
138
+
139
+ async def transcribe_stream(
140
+ self,
141
+ audio_stream: AsyncIterator[bytes],
142
+ model: str,
143
+ audio_format: Optional[AudioFormat] = None,
144
+ server_url: Optional[str] = None,
145
+ timeout_ms: Optional[int] = None,
146
+ http_headers: Optional[Mapping[str, str]] = None,
147
+ ) -> AsyncIterator[RealtimeEvent]:
148
+ """
149
+ Flow
150
+ - opens connection
151
+ - streams audio in background
152
+ - yields events from the connection
153
+ """
154
+ async with await self.connect(
155
+ model=model,
156
+ audio_format=audio_format,
157
+ server_url=server_url,
158
+ timeout_ms=timeout_ms,
159
+ http_headers=http_headers,
160
+ ) as connection:
161
+
162
+ async def _send() -> None:
163
+ async for chunk in audio_stream:
164
+ if connection.is_closed:
165
+ break
166
+ await connection.send_audio(chunk)
167
+ await connection.end_audio()
168
+
169
+ send_task = asyncio.create_task(_send())
170
+
171
+ try:
172
+ async for event in connection:
173
+ yield event
174
+
175
+ # stop early (caller still sees the terminating event)
176
+ if isinstance(event, RealtimeTranscriptionError):
177
+ break
178
+ if getattr(event, "type", None) == "transcription.done":
179
+ break
180
+ finally:
181
+ send_task.cancel()
182
+ try:
183
+ await send_task
184
+ except asyncio.CancelledError:
185
+ pass
186
+ await connection.close()
187
+
188
+
189
+ def _extract_error_message(payload: dict) -> str:
190
+ err = payload.get("error")
191
+ if isinstance(err, dict):
192
+ msg = err.get("message")
193
+ if isinstance(msg, str):
194
+ return msg
195
+ if isinstance(msg, dict):
196
+ detail = msg.get("detail")
197
+ if isinstance(detail, str):
198
+ return detail
199
+ return "Realtime transcription error"
200
+
201
+
202
+ async def _recv_handshake(
203
+ websocket: ClientConnection,
204
+ *,
205
+ timeout_ms: Optional[int],
206
+ ) -> tuple[RealtimeTranscriptionSession, list[RealtimeEvent]]:
207
+ """
208
+ Read messages until session.created or error.
209
+ Replay all messages read during handshake as initial events (lossless).
210
+ """
211
+ timeout_s = None if timeout_ms is None else timeout_ms / 1000.0
212
+ deadline = None if timeout_s is None else (time.monotonic() + timeout_s)
213
+
214
+ initial_events: list[RealtimeEvent] = []
215
+
216
+ def remaining() -> Optional[float]:
217
+ if deadline is None:
218
+ return None
219
+ return max(0.0, deadline - time.monotonic())
220
+
221
+ try:
222
+ while True:
223
+ raw = await asyncio.wait_for(websocket.recv(), timeout=remaining())
224
+ text = (
225
+ raw.decode("utf-8", errors="replace")
226
+ if isinstance(raw, (bytes, bytearray))
227
+ else raw
228
+ )
229
+
230
+ try:
231
+ payload = json.loads(text)
232
+ except Exception as exc:
233
+ initial_events.append(
234
+ UnknownRealtimeEvent(
235
+ type=None, content=text, error=f"invalid JSON: {exc}"
236
+ )
237
+ )
238
+ continue
239
+
240
+ msg_type = payload.get("type") if isinstance(payload, dict) else None
241
+ if msg_type == "error" and isinstance(payload, dict):
242
+ parsed = parse_realtime_event(payload)
243
+ initial_events.append(parsed)
244
+ if isinstance(parsed, RealtimeTranscriptionError):
245
+ raise RealtimeTranscriptionWSError(
246
+ _extract_error_message(payload),
247
+ payload=parsed,
248
+ raw=payload,
249
+ )
250
+ raise RealtimeTranscriptionWSError(
251
+ _extract_error_message(payload),
252
+ payload=None,
253
+ raw=payload,
254
+ )
255
+
256
+ event = parse_realtime_event(payload)
257
+ initial_events.append(event)
258
+
259
+ if isinstance(event, RealtimeTranscriptionSessionCreated):
260
+ return event.session, initial_events
261
+
262
+ except asyncio.TimeoutError as exc:
263
+ raise RealtimeTranscriptionException(
264
+ "Timeout waiting for session creation."
265
+ ) from exc
266
+ except RealtimeTranscriptionException:
267
+ raise
268
+ except Exception as exc:
269
+ raise RealtimeTranscriptionException(
270
+ f"Unexpected websocket handshake failure: {exc}"
271
+ ) from exc