mistralai 1.10.0__py3-none-any.whl → 1.10.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mistralai/_hooks/tracing.py +28 -3
- mistralai/_version.py +2 -2
- mistralai/classifiers.py +13 -1
- mistralai/embeddings.py +7 -1
- mistralai/extra/README.md +1 -1
- mistralai/extra/mcp/auth.py +10 -11
- mistralai/extra/mcp/base.py +17 -16
- mistralai/extra/mcp/sse.py +13 -15
- mistralai/extra/mcp/stdio.py +5 -6
- mistralai/extra/observability/otel.py +47 -68
- mistralai/extra/run/context.py +33 -43
- mistralai/extra/run/result.py +29 -30
- mistralai/extra/run/tools.py +8 -9
- mistralai/extra/struct_chat.py +15 -8
- mistralai/extra/utils/response_format.py +5 -3
- mistralai/mistral_jobs.py +31 -5
- mistralai/models/__init__.py +30 -1
- mistralai/models/agents_api_v1_agents_listop.py +1 -1
- mistralai/models/agents_api_v1_conversations_listop.py +1 -1
- mistralai/models/audioencoding.py +13 -0
- mistralai/models/audioformat.py +19 -0
- mistralai/models/batchjobin.py +17 -6
- mistralai/models/batchjobout.py +5 -0
- mistralai/models/batchrequest.py +48 -0
- mistralai/models/classificationrequest.py +37 -3
- mistralai/models/embeddingrequest.py +11 -3
- mistralai/models/jobs_api_routes_batch_get_batch_jobop.py +40 -3
- mistralai/models/toolfilechunk.py +11 -4
- mistralai/models/toolreferencechunk.py +13 -4
- {mistralai-1.10.0.dist-info → mistralai-1.10.1.dist-info}/METADATA +142 -150
- {mistralai-1.10.0.dist-info → mistralai-1.10.1.dist-info}/RECORD +122 -105
- {mistralai-1.10.0.dist-info → mistralai-1.10.1.dist-info}/WHEEL +1 -1
- mistralai_azure/_version.py +3 -3
- mistralai_azure/basesdk.py +15 -5
- mistralai_azure/chat.py +59 -98
- mistralai_azure/models/__init__.py +50 -3
- mistralai_azure/models/chatcompletionrequest.py +16 -4
- mistralai_azure/models/chatcompletionstreamrequest.py +16 -4
- mistralai_azure/models/httpvalidationerror.py +11 -6
- mistralai_azure/models/mistralazureerror.py +26 -0
- mistralai_azure/models/no_response_error.py +13 -0
- mistralai_azure/models/prediction.py +4 -0
- mistralai_azure/models/responseformat.py +4 -2
- mistralai_azure/models/responseformats.py +0 -1
- mistralai_azure/models/responsevalidationerror.py +25 -0
- mistralai_azure/models/sdkerror.py +30 -14
- mistralai_azure/models/systemmessage.py +7 -3
- mistralai_azure/models/systemmessagecontentchunks.py +21 -0
- mistralai_azure/models/thinkchunk.py +35 -0
- mistralai_azure/ocr.py +15 -36
- mistralai_azure/utils/__init__.py +18 -5
- mistralai_azure/utils/eventstreaming.py +10 -0
- mistralai_azure/utils/serializers.py +3 -2
- mistralai_azure/utils/unmarshal_json_response.py +24 -0
- mistralai_gcp/_hooks/types.py +7 -0
- mistralai_gcp/_version.py +4 -4
- mistralai_gcp/basesdk.py +27 -25
- mistralai_gcp/chat.py +75 -98
- mistralai_gcp/fim.py +39 -74
- mistralai_gcp/httpclient.py +6 -16
- mistralai_gcp/models/__init__.py +321 -116
- mistralai_gcp/models/assistantmessage.py +1 -1
- mistralai_gcp/models/chatcompletionrequest.py +36 -7
- mistralai_gcp/models/chatcompletionresponse.py +6 -6
- mistralai_gcp/models/chatcompletionstreamrequest.py +36 -7
- mistralai_gcp/models/completionresponsestreamchoice.py +1 -1
- mistralai_gcp/models/deltamessage.py +1 -1
- mistralai_gcp/models/fimcompletionrequest.py +3 -9
- mistralai_gcp/models/fimcompletionresponse.py +6 -6
- mistralai_gcp/models/fimcompletionstreamrequest.py +3 -9
- mistralai_gcp/models/httpvalidationerror.py +11 -6
- mistralai_gcp/models/imageurl.py +1 -1
- mistralai_gcp/models/jsonschema.py +1 -1
- mistralai_gcp/models/mistralgcperror.py +26 -0
- mistralai_gcp/models/mistralpromptmode.py +8 -0
- mistralai_gcp/models/no_response_error.py +13 -0
- mistralai_gcp/models/prediction.py +4 -0
- mistralai_gcp/models/responseformat.py +5 -3
- mistralai_gcp/models/responseformats.py +0 -1
- mistralai_gcp/models/responsevalidationerror.py +25 -0
- mistralai_gcp/models/sdkerror.py +30 -14
- mistralai_gcp/models/systemmessage.py +7 -3
- mistralai_gcp/models/systemmessagecontentchunks.py +21 -0
- mistralai_gcp/models/thinkchunk.py +35 -0
- mistralai_gcp/models/toolmessage.py +1 -1
- mistralai_gcp/models/usageinfo.py +71 -8
- mistralai_gcp/models/usermessage.py +1 -1
- mistralai_gcp/sdk.py +12 -10
- mistralai_gcp/sdkconfiguration.py +0 -7
- mistralai_gcp/types/basemodel.py +3 -3
- mistralai_gcp/utils/__init__.py +143 -45
- mistralai_gcp/utils/datetimes.py +23 -0
- mistralai_gcp/utils/enums.py +67 -27
- mistralai_gcp/utils/eventstreaming.py +10 -0
- mistralai_gcp/utils/forms.py +49 -28
- mistralai_gcp/utils/serializers.py +33 -3
- mistralai_gcp/utils/unmarshal_json_response.py +24 -0
- {mistralai-1.10.0.dist-info → mistralai-1.10.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -5,34 +5,23 @@ import os
|
|
|
5
5
|
import traceback
|
|
6
6
|
from datetime import datetime, timezone
|
|
7
7
|
from enum import Enum
|
|
8
|
-
from typing import Optional, Tuple
|
|
9
8
|
|
|
10
9
|
import httpx
|
|
11
10
|
import opentelemetry.semconv._incubating.attributes.gen_ai_attributes as gen_ai_attributes
|
|
12
11
|
import opentelemetry.semconv._incubating.attributes.http_attributes as http_attributes
|
|
13
12
|
import opentelemetry.semconv.attributes.server_attributes as server_attributes
|
|
14
13
|
from opentelemetry import propagate, trace
|
|
15
|
-
from opentelemetry.
|
|
16
|
-
from opentelemetry.sdk.resources import SERVICE_NAME, Resource
|
|
17
|
-
from opentelemetry.sdk.trace import SpanProcessor, TracerProvider
|
|
18
|
-
from opentelemetry.sdk.trace.export import BatchSpanProcessor, SpanExportResult
|
|
14
|
+
from opentelemetry.sdk.trace import SpanProcessor
|
|
19
15
|
from opentelemetry.trace import Span, Status, StatusCode, Tracer, set_span_in_context
|
|
20
16
|
|
|
21
17
|
logger = logging.getLogger(__name__)
|
|
22
18
|
|
|
23
19
|
|
|
24
20
|
OTEL_SERVICE_NAME: str = "mistralai_sdk"
|
|
25
|
-
OTEL_EXPORTER_OTLP_ENDPOINT: str = os.getenv("OTEL_EXPORTER_OTLP_ENDPOINT", "")
|
|
26
|
-
OTEL_EXPORTER_OTLP_TIMEOUT: int = int(os.getenv("OTEL_EXPORTER_OTLP_TIMEOUT", "2"))
|
|
27
|
-
OTEL_EXPORTER_OTLP_MAX_EXPORT_BATCH_SIZE: int = int(os.getenv("OTEL_EXPORTER_OTLP_MAX_EXPORT_BATCH_SIZE", "512"))
|
|
28
|
-
OTEL_EXPORTER_OTLP_SCHEDULE_DELAY_MILLIS: int = int(os.getenv("OTEL_EXPORTER_OTLP_SCHEDULE_DELAY_MILLIS", "1000"))
|
|
29
|
-
OTEL_EXPORTER_OTLP_MAX_QUEUE_SIZE: int = int(os.getenv("OTEL_EXPORTER_OTLP_MAX_QUEUE_SIZE", "2048"))
|
|
30
|
-
OTEL_EXPORTER_OTLP_EXPORT_TIMEOUT_MILLIS: int = int(os.getenv("OTEL_EXPORTER_OTLP_EXPORT_TIMEOUT_MILLIS", "5000"))
|
|
31
|
-
|
|
32
21
|
MISTRAL_SDK_OTEL_TRACER_NAME: str = OTEL_SERVICE_NAME + "_tracer"
|
|
33
22
|
|
|
34
23
|
MISTRAL_SDK_DEBUG_TRACING: bool = os.getenv("MISTRAL_SDK_DEBUG_TRACING", "false").lower() == "true"
|
|
35
|
-
DEBUG_HINT: str = "To see detailed
|
|
24
|
+
DEBUG_HINT: str = "To see detailed tracing logs, set MISTRAL_SDK_DEBUG_TRACING=true."
|
|
36
25
|
|
|
37
26
|
|
|
38
27
|
class MistralAIAttributes:
|
|
@@ -52,13 +41,11 @@ class MistralAINameValues(Enum):
|
|
|
52
41
|
OCR = "ocr"
|
|
53
42
|
|
|
54
43
|
class TracingErrors(Exception, Enum):
|
|
55
|
-
FAILED_TO_EXPORT_OTEL_SPANS = "Failed to export OpenTelemetry (OTEL) spans."
|
|
56
|
-
FAILED_TO_INITIALIZE_OPENTELEMETRY_TRACING = "Failed to initialize OpenTelemetry tracing."
|
|
57
44
|
FAILED_TO_CREATE_SPAN_FOR_REQUEST = "Failed to create span for request."
|
|
58
45
|
FAILED_TO_ENRICH_SPAN_WITH_RESPONSE = "Failed to enrich span with response."
|
|
59
46
|
FAILED_TO_HANDLE_ERROR_IN_SPAN = "Failed to handle error in span."
|
|
60
47
|
FAILED_TO_END_SPAN = "Failed to end span."
|
|
61
|
-
|
|
48
|
+
|
|
62
49
|
def __str__(self):
|
|
63
50
|
return str(self.value)
|
|
64
51
|
|
|
@@ -180,6 +167,7 @@ def enrich_span_from_response(tracer: trace.Tracer, span: Span, operation_id: st
|
|
|
180
167
|
start_ns = parse_time_to_nanos(output["created_at"])
|
|
181
168
|
end_ns = parse_time_to_nanos(output["completed_at"])
|
|
182
169
|
child_span = tracer.start_span("Tool Execution", start_time=start_ns, context=parent_context)
|
|
170
|
+
child_span.set_attributes({"agent.trace.public": ""})
|
|
183
171
|
tool_attributes = {
|
|
184
172
|
gen_ai_attributes.GEN_AI_OPERATION_NAME: gen_ai_attributes.GenAiOperationNameValues.EXECUTE_TOOL.value,
|
|
185
173
|
gen_ai_attributes.GEN_AI_TOOL_CALL_ID: output.get("id", ""),
|
|
@@ -192,6 +180,7 @@ def enrich_span_from_response(tracer: trace.Tracer, span: Span, operation_id: st
|
|
|
192
180
|
start_ns = parse_time_to_nanos(output["created_at"])
|
|
193
181
|
end_ns = parse_time_to_nanos(output["completed_at"])
|
|
194
182
|
child_span = tracer.start_span("Message Output", start_time=start_ns, context=parent_context)
|
|
183
|
+
child_span.set_attributes({"agent.trace.public": ""})
|
|
195
184
|
message_attributes = {
|
|
196
185
|
gen_ai_attributes.GEN_AI_OPERATION_NAME: gen_ai_attributes.GenAiOperationNameValues.CHAT.value,
|
|
197
186
|
gen_ai_attributes.GEN_AI_PROVIDER_NAME: gen_ai_attributes.GenAiProviderNameValues.MISTRAL_AI.value,
|
|
@@ -217,70 +206,47 @@ class GenAISpanProcessor(SpanProcessor):
|
|
|
217
206
|
span.set_attributes({"agent.trace.public": ""})
|
|
218
207
|
|
|
219
208
|
|
|
220
|
-
|
|
221
|
-
def export(self, spans):
|
|
222
|
-
try:
|
|
223
|
-
return super().export(spans)
|
|
224
|
-
except Exception:
|
|
225
|
-
logger.warning(f"{TracingErrors.FAILED_TO_EXPORT_OTEL_SPANS} {(traceback.format_exc() if MISTRAL_SDK_DEBUG_TRACING else DEBUG_HINT)}")
|
|
226
|
-
return SpanExportResult.FAILURE
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
def get_or_create_otel_tracer() -> Tuple[bool, Tracer]:
|
|
209
|
+
def get_or_create_otel_tracer() -> tuple[bool, Tracer]:
|
|
230
210
|
"""
|
|
231
|
-
|
|
211
|
+
Get a tracer from the current TracerProvider.
|
|
232
212
|
|
|
233
|
-
|
|
213
|
+
The SDK does not set up its own TracerProvider - it relies on the application
|
|
214
|
+
to configure OpenTelemetry. This follows OTEL best practices where:
|
|
215
|
+
- Libraries/SDKs get tracers from the global provider
|
|
216
|
+
- Applications configure the TracerProvider
|
|
234
217
|
|
|
235
|
-
|
|
218
|
+
If no TracerProvider is configured, the ProxyTracerProvider (default) will
|
|
219
|
+
return a NoOp tracer, effectively disabling tracing. Once the application
|
|
220
|
+
sets up a real TracerProvider, subsequent spans will be recorded.
|
|
236
221
|
|
|
237
|
-
|
|
222
|
+
Returns:
|
|
223
|
+
Tuple[bool, Tracer]: (tracing_enabled, tracer)
|
|
224
|
+
- tracing_enabled is True if a real TracerProvider is configured
|
|
225
|
+
- tracer is always valid (may be NoOp if no provider configured)
|
|
238
226
|
"""
|
|
239
|
-
tracing_enabled = True
|
|
240
227
|
tracer_provider = trace.get_tracer_provider()
|
|
241
|
-
|
|
242
|
-
if isinstance(tracer_provider, trace.ProxyTracerProvider):
|
|
243
|
-
if OTEL_EXPORTER_OTLP_ENDPOINT:
|
|
244
|
-
# SDK standalone: No tracer provider but OTEL_EXPORTER_OTLP_ENDPOINT is set -> create a new tracer provider that exports to the OTEL_EXPORTER_OTLP_ENDPOINT
|
|
245
|
-
try:
|
|
246
|
-
exporter = QuietOTLPSpanExporter(
|
|
247
|
-
endpoint=OTEL_EXPORTER_OTLP_ENDPOINT,
|
|
248
|
-
timeout=OTEL_EXPORTER_OTLP_TIMEOUT
|
|
249
|
-
)
|
|
250
|
-
resource = Resource.create(attributes={SERVICE_NAME: OTEL_SERVICE_NAME})
|
|
251
|
-
tracer_provider = TracerProvider(resource=resource)
|
|
252
|
-
|
|
253
|
-
span_processor = BatchSpanProcessor(
|
|
254
|
-
exporter,
|
|
255
|
-
export_timeout_millis=OTEL_EXPORTER_OTLP_EXPORT_TIMEOUT_MILLIS,
|
|
256
|
-
max_export_batch_size=OTEL_EXPORTER_OTLP_MAX_EXPORT_BATCH_SIZE,
|
|
257
|
-
schedule_delay_millis=OTEL_EXPORTER_OTLP_SCHEDULE_DELAY_MILLIS,
|
|
258
|
-
max_queue_size=OTEL_EXPORTER_OTLP_MAX_QUEUE_SIZE
|
|
259
|
-
)
|
|
260
|
-
|
|
261
|
-
tracer_provider.add_span_processor(span_processor)
|
|
262
|
-
tracer_provider.add_span_processor(GenAISpanProcessor())
|
|
263
|
-
trace.set_tracer_provider(tracer_provider)
|
|
264
|
-
|
|
265
|
-
except Exception:
|
|
266
|
-
logger.warning(f"{TracingErrors.FAILED_TO_INITIALIZE_OPENTELEMETRY_TRACING} {(traceback.format_exc() if MISTRAL_SDK_DEBUG_TRACING else DEBUG_HINT)}")
|
|
267
|
-
tracing_enabled = False
|
|
268
|
-
else:
|
|
269
|
-
# No tracer provider nor OTEL_EXPORTER_OTLP_ENDPOINT set -> tracing is disabled
|
|
270
|
-
tracing_enabled = False
|
|
271
|
-
|
|
272
228
|
tracer = tracer_provider.get_tracer(MISTRAL_SDK_OTEL_TRACER_NAME)
|
|
273
229
|
|
|
230
|
+
# Tracing is considered enabled if we have a real TracerProvider (not the default proxy)
|
|
231
|
+
tracing_enabled = not isinstance(tracer_provider, trace.ProxyTracerProvider)
|
|
232
|
+
|
|
274
233
|
return tracing_enabled, tracer
|
|
275
234
|
|
|
276
|
-
def get_traced_request_and_span(
|
|
235
|
+
def get_traced_request_and_span(
|
|
236
|
+
tracing_enabled: bool,
|
|
237
|
+
tracer: Tracer,
|
|
238
|
+
span: Span | None,
|
|
239
|
+
operation_id: str,
|
|
240
|
+
request: httpx.Request,
|
|
241
|
+
) -> tuple[httpx.Request, Span | None]:
|
|
277
242
|
if not tracing_enabled:
|
|
278
243
|
return request, span
|
|
279
244
|
|
|
280
245
|
try:
|
|
281
246
|
span = tracer.start_span(name=operation_id)
|
|
247
|
+
span.set_attributes({"agent.trace.public": ""})
|
|
282
248
|
# Inject the span context into the request headers to be used by the backend service to continue the trace
|
|
283
|
-
propagate.inject(request.headers)
|
|
249
|
+
propagate.inject(request.headers, context=set_span_in_context(span))
|
|
284
250
|
span = enrich_span_from_request(span, request)
|
|
285
251
|
except Exception:
|
|
286
252
|
logger.warning(
|
|
@@ -295,7 +261,13 @@ def get_traced_request_and_span(tracing_enabled: bool, tracer: Tracer, span: Opt
|
|
|
295
261
|
return request, span
|
|
296
262
|
|
|
297
263
|
|
|
298
|
-
def get_traced_response(
|
|
264
|
+
def get_traced_response(
|
|
265
|
+
tracing_enabled: bool,
|
|
266
|
+
tracer: Tracer,
|
|
267
|
+
span: Span | None,
|
|
268
|
+
operation_id: str,
|
|
269
|
+
response: httpx.Response,
|
|
270
|
+
) -> httpx.Response:
|
|
299
271
|
if not tracing_enabled or not span:
|
|
300
272
|
return response
|
|
301
273
|
try:
|
|
@@ -315,7 +287,14 @@ def get_traced_response(tracing_enabled: bool, tracer: Tracer, span: Optional[Sp
|
|
|
315
287
|
end_span(span=span)
|
|
316
288
|
return response
|
|
317
289
|
|
|
318
|
-
def get_response_and_error(
|
|
290
|
+
def get_response_and_error(
|
|
291
|
+
tracing_enabled: bool,
|
|
292
|
+
tracer: Tracer,
|
|
293
|
+
span: Span | None,
|
|
294
|
+
operation_id: str,
|
|
295
|
+
response: httpx.Response,
|
|
296
|
+
error: Exception | None,
|
|
297
|
+
) -> tuple[httpx.Response, Exception | None]:
|
|
319
298
|
if not tracing_enabled or not span:
|
|
320
299
|
return response, error
|
|
321
300
|
try:
|
|
@@ -366,7 +345,7 @@ class TracedResponse(httpx.Response):
|
|
|
366
345
|
|
|
367
346
|
This hack allows ending the span only once the stream is fully consumed.
|
|
368
347
|
"""
|
|
369
|
-
def __init__(self, *args, span:
|
|
348
|
+
def __init__(self, *args, span: Span | None, **kwargs) -> None:
|
|
370
349
|
super().__init__(*args, **kwargs)
|
|
371
350
|
self.span = span
|
|
372
351
|
|
|
@@ -381,7 +360,7 @@ class TracedResponse(httpx.Response):
|
|
|
381
360
|
await super().aclose()
|
|
382
361
|
|
|
383
362
|
@classmethod
|
|
384
|
-
def from_response(cls, resp: httpx.Response, span:
|
|
363
|
+
def from_response(cls, resp: httpx.Response, span: Span | None) -> "TracedResponse":
|
|
385
364
|
traced_resp = cls.__new__(cls)
|
|
386
365
|
traced_resp.__dict__ = copy.copy(resp.__dict__)
|
|
387
366
|
traced_resp.span = span
|
mistralai/extra/run/context.py
CHANGED
|
@@ -1,47 +1,41 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import inspect
|
|
3
3
|
import typing
|
|
4
|
-
from contextlib import AsyncExitStack
|
|
5
|
-
from functools import wraps
|
|
6
4
|
from collections.abc import Callable
|
|
7
|
-
|
|
5
|
+
from contextlib import AsyncExitStack
|
|
8
6
|
from dataclasses import dataclass, field
|
|
9
|
-
from
|
|
7
|
+
from functools import wraps
|
|
8
|
+
from logging import getLogger
|
|
10
9
|
|
|
11
10
|
import pydantic
|
|
12
11
|
|
|
13
|
-
from mistralai.extra import
|
|
14
|
-
response_format_from_pydantic_model,
|
|
15
|
-
)
|
|
12
|
+
from mistralai.extra import response_format_from_pydantic_model
|
|
16
13
|
from mistralai.extra.exceptions import RunException
|
|
17
14
|
from mistralai.extra.mcp.base import MCPClientProtocol
|
|
18
15
|
from mistralai.extra.run.result import RunResult
|
|
19
|
-
from mistralai.
|
|
16
|
+
from mistralai.extra.run.tools import (
|
|
17
|
+
RunCoroutine,
|
|
18
|
+
RunFunction,
|
|
19
|
+
RunMCPTool,
|
|
20
|
+
RunTool,
|
|
21
|
+
create_function_result,
|
|
22
|
+
create_tool_call,
|
|
23
|
+
)
|
|
20
24
|
from mistralai.models import (
|
|
21
|
-
ResponseFormat,
|
|
22
|
-
FunctionCallEntry,
|
|
23
|
-
Tools,
|
|
24
|
-
ToolsTypedDict,
|
|
25
25
|
CompletionArgs,
|
|
26
26
|
CompletionArgsTypedDict,
|
|
27
|
-
FunctionResultEntry,
|
|
28
27
|
ConversationInputs,
|
|
29
28
|
ConversationInputsTypedDict,
|
|
29
|
+
FunctionCallEntry,
|
|
30
|
+
FunctionResultEntry,
|
|
30
31
|
FunctionTool,
|
|
31
|
-
MessageInputEntry,
|
|
32
32
|
InputEntries,
|
|
33
|
+
MessageInputEntry,
|
|
34
|
+
ResponseFormat,
|
|
35
|
+
Tools,
|
|
36
|
+
ToolsTypedDict,
|
|
33
37
|
)
|
|
34
|
-
|
|
35
|
-
from logging import getLogger
|
|
36
|
-
|
|
37
|
-
from mistralai.extra.run.tools import (
|
|
38
|
-
create_function_result,
|
|
39
|
-
RunFunction,
|
|
40
|
-
create_tool_call,
|
|
41
|
-
RunTool,
|
|
42
|
-
RunMCPTool,
|
|
43
|
-
RunCoroutine,
|
|
44
|
-
)
|
|
38
|
+
from mistralai.types.basemodel import BaseModel, OptionalNullable, UNSET
|
|
45
39
|
|
|
46
40
|
if typing.TYPE_CHECKING:
|
|
47
41
|
from mistralai import Beta, OptionalNullable
|
|
@@ -56,8 +50,8 @@ class AgentRequestKwargs(typing.TypedDict):
|
|
|
56
50
|
class ModelRequestKwargs(typing.TypedDict):
|
|
57
51
|
model: str
|
|
58
52
|
instructions: OptionalNullable[str]
|
|
59
|
-
tools: OptionalNullable[
|
|
60
|
-
completion_args: OptionalNullable[
|
|
53
|
+
tools: OptionalNullable[list[Tools] | list[ToolsTypedDict]]
|
|
54
|
+
completion_args: OptionalNullable[CompletionArgs | CompletionArgsTypedDict]
|
|
61
55
|
|
|
62
56
|
|
|
63
57
|
@dataclass
|
|
@@ -72,7 +66,7 @@ class RunContext:
|
|
|
72
66
|
passed if the user wants to continue an existing conversation.
|
|
73
67
|
model (Options[str]): The model name to be used for the conversation. Can't be used along with 'agent_id'.
|
|
74
68
|
agent_id (Options[str]): The agent id to be used for the conversation. Can't be used along with 'model'.
|
|
75
|
-
output_format (
|
|
69
|
+
output_format (type[BaseModel] | None): The output format expected from the conversation. It represents
|
|
76
70
|
the `response_format` which is part of the `CompletionArgs`.
|
|
77
71
|
request_count (int): The number of requests made in the current `RunContext`.
|
|
78
72
|
continue_on_fn_error (bool): Flag to determine if the conversation should continue when function execution
|
|
@@ -83,10 +77,10 @@ class RunContext:
|
|
|
83
77
|
_callable_tools: dict[str, RunTool] = field(init=False, default_factory=dict)
|
|
84
78
|
_mcp_clients: list[MCPClientProtocol] = field(init=False, default_factory=list)
|
|
85
79
|
|
|
86
|
-
conversation_id:
|
|
87
|
-
model:
|
|
88
|
-
agent_id:
|
|
89
|
-
output_format:
|
|
80
|
+
conversation_id: str | None = field(default=None)
|
|
81
|
+
model: str | None = field(default=None)
|
|
82
|
+
agent_id: str | None = field(default=None)
|
|
83
|
+
output_format: type[BaseModel] | None = field(default=None)
|
|
90
84
|
request_count: int = field(default=0)
|
|
91
85
|
continue_on_fn_error: bool = field(default=False)
|
|
92
86
|
|
|
@@ -215,10 +209,8 @@ class RunContext:
|
|
|
215
209
|
|
|
216
210
|
async def prepare_model_request(
|
|
217
211
|
self,
|
|
218
|
-
tools: OptionalNullable[
|
|
219
|
-
completion_args: OptionalNullable[
|
|
220
|
-
Union[CompletionArgs, CompletionArgsTypedDict]
|
|
221
|
-
] = UNSET,
|
|
212
|
+
tools: OptionalNullable[list[Tools] | list[ToolsTypedDict]] = UNSET,
|
|
213
|
+
completion_args: OptionalNullable[CompletionArgs | CompletionArgsTypedDict] = UNSET,
|
|
222
214
|
instructions: OptionalNullable[str] = None,
|
|
223
215
|
) -> ModelRequestKwargs:
|
|
224
216
|
if self.model is None:
|
|
@@ -254,14 +246,12 @@ async def _validate_run(
|
|
|
254
246
|
*,
|
|
255
247
|
beta_client: "Beta",
|
|
256
248
|
run_ctx: RunContext,
|
|
257
|
-
inputs:
|
|
249
|
+
inputs: ConversationInputs | ConversationInputsTypedDict,
|
|
258
250
|
instructions: OptionalNullable[str] = UNSET,
|
|
259
|
-
tools: OptionalNullable[
|
|
260
|
-
completion_args: OptionalNullable[
|
|
261
|
-
Union[CompletionArgs, CompletionArgsTypedDict]
|
|
262
|
-
] = UNSET,
|
|
251
|
+
tools: OptionalNullable[list[Tools] | list[ToolsTypedDict]] = UNSET,
|
|
252
|
+
completion_args: OptionalNullable[CompletionArgs | CompletionArgsTypedDict] = UNSET,
|
|
263
253
|
) -> tuple[
|
|
264
|
-
|
|
254
|
+
AgentRequestKwargs | ModelRequestKwargs, RunResult, list[InputEntries]
|
|
265
255
|
]:
|
|
266
256
|
input_entries: list[InputEntries] = []
|
|
267
257
|
if isinstance(inputs, str):
|
|
@@ -277,7 +267,7 @@ async def _validate_run(
|
|
|
277
267
|
output_model=run_ctx.output_format,
|
|
278
268
|
conversation_id=run_ctx.conversation_id,
|
|
279
269
|
)
|
|
280
|
-
req:
|
|
270
|
+
req: AgentRequestKwargs | ModelRequestKwargs
|
|
281
271
|
if run_ctx.agent_id:
|
|
282
272
|
if tools or completion_args:
|
|
283
273
|
raise RunException("Can't set tools or completion_args when using an agent")
|
mistralai/extra/run/result.py
CHANGED
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
import datetime
|
|
2
2
|
import json
|
|
3
3
|
import typing
|
|
4
|
-
from typing import Union, Annotated, Optional, Literal
|
|
5
4
|
from dataclasses import dataclass, field
|
|
6
|
-
from
|
|
5
|
+
from typing import Annotated, Literal
|
|
6
|
+
|
|
7
|
+
from pydantic import BaseModel, Discriminator, Tag
|
|
7
8
|
|
|
8
9
|
from mistralai.extra.utils.response_format import pydantic_model_from_json
|
|
9
10
|
from mistralai.models import (
|
|
@@ -35,15 +36,15 @@ from mistralai.models import (
|
|
|
35
36
|
)
|
|
36
37
|
from mistralai.utils import get_discriminator
|
|
37
38
|
|
|
38
|
-
RunOutputEntries =
|
|
39
|
-
MessageOutputEntry
|
|
40
|
-
FunctionCallEntry
|
|
41
|
-
FunctionResultEntry
|
|
42
|
-
AgentHandoffEntry
|
|
43
|
-
ToolExecutionEntry
|
|
44
|
-
|
|
39
|
+
RunOutputEntries = (
|
|
40
|
+
MessageOutputEntry
|
|
41
|
+
| FunctionCallEntry
|
|
42
|
+
| FunctionResultEntry
|
|
43
|
+
| AgentHandoffEntry
|
|
44
|
+
| ToolExecutionEntry
|
|
45
|
+
)
|
|
45
46
|
|
|
46
|
-
RunEntries =
|
|
47
|
+
RunEntries = RunOutputEntries | MessageInputEntry
|
|
47
48
|
|
|
48
49
|
|
|
49
50
|
def as_text(entry: RunOutputEntries) -> str:
|
|
@@ -140,12 +141,12 @@ class RunFiles:
|
|
|
140
141
|
@dataclass
|
|
141
142
|
class RunResult:
|
|
142
143
|
input_entries: list[InputEntries]
|
|
143
|
-
conversation_id:
|
|
144
|
+
conversation_id: str | None = field(default=None)
|
|
144
145
|
output_entries: list[RunOutputEntries] = field(default_factory=list)
|
|
145
146
|
files: dict[str, RunFiles] = field(default_factory=dict)
|
|
146
|
-
output_model:
|
|
147
|
+
output_model: type[BaseModel] | None = field(default=None)
|
|
147
148
|
|
|
148
|
-
def get_file(self, file_id: str) ->
|
|
149
|
+
def get_file(self, file_id: str) -> RunFiles | None:
|
|
149
150
|
return self.files.get(file_id)
|
|
150
151
|
|
|
151
152
|
@property
|
|
@@ -172,36 +173,34 @@ class RunResult:
|
|
|
172
173
|
|
|
173
174
|
|
|
174
175
|
class FunctionResultEvent(BaseModel):
|
|
175
|
-
id:
|
|
176
|
+
id: str | None = None
|
|
176
177
|
|
|
177
|
-
type:
|
|
178
|
+
type: Literal["function.result"] | None = "function.result"
|
|
178
179
|
|
|
179
180
|
result: str
|
|
180
181
|
|
|
181
182
|
tool_call_id: str
|
|
182
183
|
|
|
183
|
-
created_at:
|
|
184
|
+
created_at: datetime.datetime | None = datetime.datetime.now(
|
|
184
185
|
tz=datetime.timezone.utc
|
|
185
186
|
)
|
|
186
187
|
|
|
187
|
-
output_index:
|
|
188
|
+
output_index: int | None = 0
|
|
188
189
|
|
|
189
190
|
|
|
190
|
-
RunResultEventsType =
|
|
191
|
+
RunResultEventsType = SSETypes | Literal["function.result"]
|
|
191
192
|
|
|
192
193
|
RunResultEventsData = typing.Annotated[
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
Annotated[FunctionResultEvent, Tag("function.result")],
|
|
204
|
-
],
|
|
194
|
+
Annotated[AgentHandoffDoneEvent, Tag("agent.handoff.done")]
|
|
195
|
+
| Annotated[AgentHandoffStartedEvent, Tag("agent.handoff.started")]
|
|
196
|
+
| Annotated[ResponseDoneEvent, Tag("conversation.response.done")]
|
|
197
|
+
| Annotated[ResponseErrorEvent, Tag("conversation.response.error")]
|
|
198
|
+
| Annotated[ResponseStartedEvent, Tag("conversation.response.started")]
|
|
199
|
+
| Annotated[FunctionCallEvent, Tag("function.call.delta")]
|
|
200
|
+
| Annotated[MessageOutputEvent, Tag("message.output.delta")]
|
|
201
|
+
| Annotated[ToolExecutionDoneEvent, Tag("tool.execution.done")]
|
|
202
|
+
| Annotated[ToolExecutionStartedEvent, Tag("tool.execution.started")]
|
|
203
|
+
| Annotated[FunctionResultEvent, Tag("function.result")],
|
|
205
204
|
Discriminator(lambda m: get_discriminator(m, "type", "type")),
|
|
206
205
|
]
|
|
207
206
|
|
mistralai/extra/run/tools.py
CHANGED
|
@@ -1,14 +1,11 @@
|
|
|
1
|
+
import inspect
|
|
1
2
|
import itertools
|
|
3
|
+
import json
|
|
2
4
|
import logging
|
|
3
5
|
from dataclasses import dataclass
|
|
4
|
-
import
|
|
5
|
-
|
|
6
|
-
from pydantic import Field, create_model
|
|
7
|
-
from pydantic.fields import FieldInfo
|
|
8
|
-
import json
|
|
9
|
-
from typing import cast, Callable, Sequence, Any, ForwardRef, get_type_hints, Union
|
|
6
|
+
from typing import Any, Callable, ForwardRef, Sequence, cast, get_type_hints
|
|
10
7
|
|
|
11
|
-
|
|
8
|
+
import opentelemetry.semconv._incubating.attributes.gen_ai_attributes as gen_ai_attributes
|
|
12
9
|
from griffe import (
|
|
13
10
|
Docstring,
|
|
14
11
|
DocstringSectionKind,
|
|
@@ -16,7 +13,9 @@ from griffe import (
|
|
|
16
13
|
DocstringParameter,
|
|
17
14
|
DocstringSection,
|
|
18
15
|
)
|
|
19
|
-
|
|
16
|
+
from opentelemetry import trace
|
|
17
|
+
from pydantic import Field, create_model
|
|
18
|
+
from pydantic.fields import FieldInfo
|
|
20
19
|
|
|
21
20
|
from mistralai.extra.exceptions import RunException
|
|
22
21
|
from mistralai.extra.mcp.base import MCPClientProtocol
|
|
@@ -54,7 +53,7 @@ class RunMCPTool:
|
|
|
54
53
|
mcp_client: MCPClientProtocol
|
|
55
54
|
|
|
56
55
|
|
|
57
|
-
RunTool =
|
|
56
|
+
RunTool = RunFunction | RunCoroutine | RunMCPTool
|
|
58
57
|
|
|
59
58
|
|
|
60
59
|
def _get_function_description(docstring_sections: list[DocstringSection]) -> str:
|
mistralai/extra/struct_chat.py
CHANGED
|
@@ -1,19 +1,26 @@
|
|
|
1
|
-
from ..models import ChatCompletionResponse, ChatCompletionChoice, AssistantMessage
|
|
2
|
-
from .utils.response_format import CustomPydanticModel, pydantic_model_from_json
|
|
3
|
-
from typing import List, Optional, Type, Generic
|
|
4
|
-
from pydantic import BaseModel
|
|
5
1
|
import json
|
|
2
|
+
from typing import Generic
|
|
3
|
+
|
|
4
|
+
from ..models import AssistantMessage, ChatCompletionChoice, ChatCompletionResponse
|
|
5
|
+
from .utils.response_format import CustomPydanticModel, pydantic_model_from_json
|
|
6
|
+
|
|
6
7
|
|
|
7
8
|
class ParsedAssistantMessage(AssistantMessage, Generic[CustomPydanticModel]):
|
|
8
|
-
parsed:
|
|
9
|
+
parsed: CustomPydanticModel | None
|
|
10
|
+
|
|
9
11
|
|
|
10
12
|
class ParsedChatCompletionChoice(ChatCompletionChoice, Generic[CustomPydanticModel]):
|
|
11
|
-
message:
|
|
13
|
+
message: ParsedAssistantMessage[CustomPydanticModel] | None # type: ignore
|
|
14
|
+
|
|
12
15
|
|
|
13
16
|
class ParsedChatCompletionResponse(ChatCompletionResponse, Generic[CustomPydanticModel]):
|
|
14
|
-
choices:
|
|
17
|
+
choices: list[ParsedChatCompletionChoice[CustomPydanticModel]] | None # type: ignore
|
|
18
|
+
|
|
15
19
|
|
|
16
|
-
def convert_to_parsed_chat_completion_response(
|
|
20
|
+
def convert_to_parsed_chat_completion_response(
|
|
21
|
+
response: ChatCompletionResponse,
|
|
22
|
+
response_format: type[CustomPydanticModel],
|
|
23
|
+
) -> ParsedChatCompletionResponse[CustomPydanticModel]:
|
|
17
24
|
parsed_choices = []
|
|
18
25
|
|
|
19
26
|
if response.choices:
|
|
@@ -1,5 +1,6 @@
|
|
|
1
|
+
from typing import Any, TypeVar
|
|
2
|
+
|
|
1
3
|
from pydantic import BaseModel
|
|
2
|
-
from typing import TypeVar, Any, Type, Dict
|
|
3
4
|
from ...models import JSONSchema, ResponseFormat
|
|
4
5
|
from ._pydantic_helper import rec_strict_json_schema
|
|
5
6
|
|
|
@@ -7,7 +8,7 @@ CustomPydanticModel = TypeVar("CustomPydanticModel", bound=BaseModel)
|
|
|
7
8
|
|
|
8
9
|
|
|
9
10
|
def response_format_from_pydantic_model(
|
|
10
|
-
model:
|
|
11
|
+
model: type[CustomPydanticModel],
|
|
11
12
|
) -> ResponseFormat:
|
|
12
13
|
"""Generate a strict JSON schema from a pydantic model."""
|
|
13
14
|
model_schema = rec_strict_json_schema(model.model_json_schema())
|
|
@@ -18,7 +19,8 @@ def response_format_from_pydantic_model(
|
|
|
18
19
|
|
|
19
20
|
|
|
20
21
|
def pydantic_model_from_json(
|
|
21
|
-
json_data:
|
|
22
|
+
json_data: dict[str, Any],
|
|
23
|
+
pydantic_model: type[CustomPydanticModel],
|
|
22
24
|
) -> CustomPydanticModel:
|
|
23
25
|
"""Parse a JSON schema into a pydantic model."""
|
|
24
26
|
return pydantic_model.model_validate(json_data)
|