mistralai 1.11.1__py3-none-any.whl → 1.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. mistralai/_version.py +2 -2
  2. mistralai/audio.py +20 -0
  3. mistralai/conversations.py +48 -8
  4. mistralai/extra/__init__.py +48 -0
  5. mistralai/extra/exceptions.py +49 -4
  6. mistralai/extra/realtime/__init__.py +25 -0
  7. mistralai/extra/realtime/connection.py +207 -0
  8. mistralai/extra/realtime/transcription.py +271 -0
  9. mistralai/files.py +6 -0
  10. mistralai/mistral_agents.py +391 -8
  11. mistralai/models/__init__.py +103 -0
  12. mistralai/models/agentaliasresponse.py +23 -0
  13. mistralai/models/agentconversation.py +14 -4
  14. mistralai/models/agents_api_v1_agents_create_or_update_aliasop.py +26 -0
  15. mistralai/models/agents_api_v1_agents_get_versionop.py +2 -2
  16. mistralai/models/agents_api_v1_agents_getop.py +12 -3
  17. mistralai/models/agents_api_v1_agents_list_version_aliasesop.py +16 -0
  18. mistralai/models/audiotranscriptionrequest.py +8 -0
  19. mistralai/models/audiotranscriptionrequeststream.py +8 -0
  20. mistralai/models/conversationrequest.py +8 -2
  21. mistralai/models/conversationrestartrequest.py +18 -4
  22. mistralai/models/conversationrestartstreamrequest.py +20 -4
  23. mistralai/models/conversationstreamrequest.py +12 -2
  24. mistralai/models/files_api_routes_list_filesop.py +8 -1
  25. mistralai/models/mistralpromptmode.py +4 -0
  26. mistralai/models/modelcapabilities.py +3 -0
  27. mistralai/models/realtimetranscriptionerror.py +27 -0
  28. mistralai/models/realtimetranscriptionerrordetail.py +29 -0
  29. mistralai/models/realtimetranscriptionsession.py +20 -0
  30. mistralai/models/realtimetranscriptionsessioncreated.py +30 -0
  31. mistralai/models/realtimetranscriptionsessionupdated.py +30 -0
  32. mistralai/models/timestampgranularity.py +4 -1
  33. mistralai/models/transcriptionsegmentchunk.py +41 -2
  34. mistralai/models/transcriptionstreamsegmentdelta.py +38 -2
  35. mistralai/transcriptions.py +24 -0
  36. {mistralai-1.11.1.dist-info → mistralai-1.12.0.dist-info}/METADATA +6 -2
  37. {mistralai-1.11.1.dist-info → mistralai-1.12.0.dist-info}/RECORD +39 -28
  38. {mistralai-1.11.1.dist-info → mistralai-1.12.0.dist-info}/WHEEL +0 -0
  39. {mistralai-1.11.1.dist-info → mistralai-1.12.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,271 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import json
5
+ import time
6
+ from typing import AsyncIterator, Mapping, Optional
7
+ from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse
8
+
9
+ try:
10
+ from websockets.asyncio.client import (
11
+ ClientConnection,
12
+ connect,
13
+ ) # websockets >= 13.0
14
+ except ImportError as exc:
15
+ raise ImportError(
16
+ "The `websockets` package (>=13.0) is required for real-time transcription. "
17
+ "Install with: pip install 'mistralai[realtime]'"
18
+ ) from exc
19
+
20
+ from mistralai import models, utils
21
+ from mistralai.models import (
22
+ AudioFormat,
23
+ RealtimeTranscriptionError,
24
+ RealtimeTranscriptionSession,
25
+ RealtimeTranscriptionSessionCreated,
26
+ )
27
+ from mistralai.sdkconfiguration import SDKConfiguration
28
+ from mistralai.utils import generate_url, get_security, get_security_from_env
29
+
30
+ from ..exceptions import RealtimeTranscriptionException, RealtimeTranscriptionWSError
31
+ from .connection import (
32
+ RealtimeConnection,
33
+ RealtimeEvent,
34
+ UnknownRealtimeEvent,
35
+ parse_realtime_event,
36
+ )
37
+
38
+
39
+ class RealtimeTranscription:
40
+ """Client for realtime transcription over WebSocket (websockets >= 13.0)."""
41
+
42
+ def __init__(self, sdk_config: SDKConfiguration) -> None:
43
+ self._sdk_config = sdk_config
44
+
45
+ def _build_url(
46
+ self,
47
+ model: str,
48
+ *,
49
+ server_url: Optional[str],
50
+ query_params: Mapping[str, str],
51
+ ) -> str:
52
+ if server_url is not None:
53
+ base_url = utils.remove_suffix(server_url, "/")
54
+ else:
55
+ base_url, _ = self._sdk_config.get_server_details()
56
+
57
+ url = generate_url(base_url, "/v1/audio/transcriptions/realtime", None)
58
+
59
+ parsed = urlparse(url)
60
+ merged = dict(parse_qsl(parsed.query, keep_blank_values=True))
61
+ merged["model"] = model
62
+ merged.update(dict(query_params))
63
+
64
+ return urlunparse(parsed._replace(query=urlencode(merged)))
65
+
66
+ async def connect(
67
+ self,
68
+ model: str,
69
+ audio_format: Optional[AudioFormat] = None,
70
+ server_url: Optional[str] = None,
71
+ timeout_ms: Optional[int] = None,
72
+ http_headers: Optional[Mapping[str, str]] = None,
73
+ ) -> RealtimeConnection:
74
+ if timeout_ms is None:
75
+ timeout_ms = self._sdk_config.timeout_ms
76
+
77
+ security = self._sdk_config.security
78
+ if security is not None and callable(security):
79
+ security = security()
80
+
81
+ resolved_security = get_security_from_env(security, models.Security)
82
+
83
+ headers: dict[str, str] = {}
84
+ query_params: dict[str, str] = {}
85
+
86
+ if resolved_security is not None:
87
+ security_headers, security_query = get_security(resolved_security)
88
+ headers |= security_headers
89
+ for key, values in security_query.items():
90
+ if values:
91
+ query_params[key] = values[-1]
92
+
93
+ if http_headers is not None:
94
+ headers |= dict(http_headers)
95
+
96
+ url = self._build_url(model, server_url=server_url, query_params=query_params)
97
+
98
+ parsed = urlparse(url)
99
+ if parsed.scheme == "https":
100
+ parsed = parsed._replace(scheme="wss")
101
+ elif parsed.scheme == "http":
102
+ parsed = parsed._replace(scheme="ws")
103
+ ws_url = urlunparse(parsed)
104
+ open_timeout = None if timeout_ms is None else timeout_ms / 1000.0
105
+ user_agent = self._sdk_config.user_agent
106
+
107
+ websocket: Optional[ClientConnection] = None
108
+ try:
109
+ websocket = await connect(
110
+ ws_url,
111
+ additional_headers=dict(headers),
112
+ open_timeout=open_timeout,
113
+ user_agent_header=user_agent,
114
+ )
115
+
116
+ session, initial_events = await _recv_handshake(
117
+ websocket, timeout_ms=timeout_ms
118
+ )
119
+ connection = RealtimeConnection(
120
+ websocket=websocket,
121
+ session=session,
122
+ initial_events=initial_events,
123
+ )
124
+
125
+ if audio_format is not None:
126
+ await connection.update_session(audio_format)
127
+
128
+ return connection
129
+
130
+ except RealtimeTranscriptionException:
131
+ if websocket is not None:
132
+ await websocket.close()
133
+ raise
134
+ except Exception as exc:
135
+ if websocket is not None:
136
+ await websocket.close()
137
+ raise RealtimeTranscriptionException(f"Failed to connect: {exc}") from exc
138
+
139
+ async def transcribe_stream(
140
+ self,
141
+ audio_stream: AsyncIterator[bytes],
142
+ model: str,
143
+ audio_format: Optional[AudioFormat] = None,
144
+ server_url: Optional[str] = None,
145
+ timeout_ms: Optional[int] = None,
146
+ http_headers: Optional[Mapping[str, str]] = None,
147
+ ) -> AsyncIterator[RealtimeEvent]:
148
+ """
149
+ Flow
150
+ - opens connection
151
+ - streams audio in background
152
+ - yields events from the connection
153
+ """
154
+ async with await self.connect(
155
+ model=model,
156
+ audio_format=audio_format,
157
+ server_url=server_url,
158
+ timeout_ms=timeout_ms,
159
+ http_headers=http_headers,
160
+ ) as connection:
161
+
162
+ async def _send() -> None:
163
+ async for chunk in audio_stream:
164
+ if connection.is_closed:
165
+ break
166
+ await connection.send_audio(chunk)
167
+ await connection.end_audio()
168
+
169
+ send_task = asyncio.create_task(_send())
170
+
171
+ try:
172
+ async for event in connection:
173
+ yield event
174
+
175
+ # stop early (caller still sees the terminating event)
176
+ if isinstance(event, RealtimeTranscriptionError):
177
+ break
178
+ if getattr(event, "type", None) == "transcription.done":
179
+ break
180
+ finally:
181
+ send_task.cancel()
182
+ try:
183
+ await send_task
184
+ except asyncio.CancelledError:
185
+ pass
186
+ await connection.close()
187
+
188
+
189
+ def _extract_error_message(payload: dict) -> str:
190
+ err = payload.get("error")
191
+ if isinstance(err, dict):
192
+ msg = err.get("message")
193
+ if isinstance(msg, str):
194
+ return msg
195
+ if isinstance(msg, dict):
196
+ detail = msg.get("detail")
197
+ if isinstance(detail, str):
198
+ return detail
199
+ return "Realtime transcription error"
200
+
201
+
202
+ async def _recv_handshake(
203
+ websocket: ClientConnection,
204
+ *,
205
+ timeout_ms: Optional[int],
206
+ ) -> tuple[RealtimeTranscriptionSession, list[RealtimeEvent]]:
207
+ """
208
+ Read messages until session.created or error.
209
+ Replay all messages read during handshake as initial events (lossless).
210
+ """
211
+ timeout_s = None if timeout_ms is None else timeout_ms / 1000.0
212
+ deadline = None if timeout_s is None else (time.monotonic() + timeout_s)
213
+
214
+ initial_events: list[RealtimeEvent] = []
215
+
216
+ def remaining() -> Optional[float]:
217
+ if deadline is None:
218
+ return None
219
+ return max(0.0, deadline - time.monotonic())
220
+
221
+ try:
222
+ while True:
223
+ raw = await asyncio.wait_for(websocket.recv(), timeout=remaining())
224
+ text = (
225
+ raw.decode("utf-8", errors="replace")
226
+ if isinstance(raw, (bytes, bytearray))
227
+ else raw
228
+ )
229
+
230
+ try:
231
+ payload = json.loads(text)
232
+ except Exception as exc:
233
+ initial_events.append(
234
+ UnknownRealtimeEvent(
235
+ type=None, content=text, error=f"invalid JSON: {exc}"
236
+ )
237
+ )
238
+ continue
239
+
240
+ msg_type = payload.get("type") if isinstance(payload, dict) else None
241
+ if msg_type == "error" and isinstance(payload, dict):
242
+ parsed = parse_realtime_event(payload)
243
+ initial_events.append(parsed)
244
+ if isinstance(parsed, RealtimeTranscriptionError):
245
+ raise RealtimeTranscriptionWSError(
246
+ _extract_error_message(payload),
247
+ payload=parsed,
248
+ raw=payload,
249
+ )
250
+ raise RealtimeTranscriptionWSError(
251
+ _extract_error_message(payload),
252
+ payload=None,
253
+ raw=payload,
254
+ )
255
+
256
+ event = parse_realtime_event(payload)
257
+ initial_events.append(event)
258
+
259
+ if isinstance(event, RealtimeTranscriptionSessionCreated):
260
+ return event.session, initial_events
261
+
262
+ except asyncio.TimeoutError as exc:
263
+ raise RealtimeTranscriptionException(
264
+ "Timeout waiting for session creation."
265
+ ) from exc
266
+ except RealtimeTranscriptionException:
267
+ raise
268
+ except Exception as exc:
269
+ raise RealtimeTranscriptionException(
270
+ f"Unexpected websocket handshake failure: {exc}"
271
+ ) from exc
mistralai/files.py CHANGED
@@ -241,6 +241,7 @@ class Files(BaseSDK):
241
241
  source: OptionalNullable[List[models_source.Source]] = UNSET,
242
242
  search: OptionalNullable[str] = UNSET,
243
243
  purpose: OptionalNullable[models_filepurpose.FilePurpose] = UNSET,
244
+ mimetypes: OptionalNullable[List[str]] = UNSET,
244
245
  retries: OptionalNullable[utils.RetryConfig] = UNSET,
245
246
  server_url: Optional[str] = None,
246
247
  timeout_ms: Optional[int] = None,
@@ -257,6 +258,7 @@ class Files(BaseSDK):
257
258
  :param source:
258
259
  :param search:
259
260
  :param purpose:
261
+ :param mimetypes:
260
262
  :param retries: Override the default retry configuration for this method
261
263
  :param server_url: Override the default server URL for this method
262
264
  :param timeout_ms: Override the default request timeout configuration for this method in milliseconds
@@ -280,6 +282,7 @@ class Files(BaseSDK):
280
282
  source=source,
281
283
  search=search,
282
284
  purpose=purpose,
285
+ mimetypes=mimetypes,
283
286
  )
284
287
 
285
288
  req = self._build_request(
@@ -343,6 +346,7 @@ class Files(BaseSDK):
343
346
  source: OptionalNullable[List[models_source.Source]] = UNSET,
344
347
  search: OptionalNullable[str] = UNSET,
345
348
  purpose: OptionalNullable[models_filepurpose.FilePurpose] = UNSET,
349
+ mimetypes: OptionalNullable[List[str]] = UNSET,
346
350
  retries: OptionalNullable[utils.RetryConfig] = UNSET,
347
351
  server_url: Optional[str] = None,
348
352
  timeout_ms: Optional[int] = None,
@@ -359,6 +363,7 @@ class Files(BaseSDK):
359
363
  :param source:
360
364
  :param search:
361
365
  :param purpose:
366
+ :param mimetypes:
362
367
  :param retries: Override the default retry configuration for this method
363
368
  :param server_url: Override the default server URL for this method
364
369
  :param timeout_ms: Override the default request timeout configuration for this method in milliseconds
@@ -382,6 +387,7 @@ class Files(BaseSDK):
382
387
  source=source,
383
388
  search=search,
384
389
  purpose=purpose,
390
+ mimetypes=mimetypes,
385
391
  )
386
392
 
387
393
  req = self._build_request_async(