mistralai 1.9.10__py3-none-any.whl → 1.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mistralai/_hooks/registration.py +5 -0
- mistralai/_hooks/tracing.py +50 -0
- mistralai/_version.py +3 -3
- mistralai/accesses.py +51 -116
- mistralai/agents.py +58 -85
- mistralai/audio.py +8 -3
- mistralai/basesdk.py +15 -5
- mistralai/batch.py +6 -3
- mistralai/beta.py +10 -5
- mistralai/chat.py +70 -97
- mistralai/classifiers.py +57 -144
- mistralai/conversations.py +435 -412
- mistralai/documents.py +156 -359
- mistralai/embeddings.py +21 -42
- mistralai/extra/observability/__init__.py +15 -0
- mistralai/extra/observability/otel.py +393 -0
- mistralai/extra/run/tools.py +28 -16
- mistralai/files.py +53 -176
- mistralai/fim.py +46 -73
- mistralai/fine_tuning.py +6 -3
- mistralai/jobs.py +49 -158
- mistralai/libraries.py +71 -178
- mistralai/mistral_agents.py +298 -179
- mistralai/mistral_jobs.py +51 -138
- mistralai/models/__init__.py +94 -5
- mistralai/models/agent.py +15 -2
- mistralai/models/agentconversation.py +11 -3
- mistralai/models/agentcreationrequest.py +6 -2
- mistralai/models/agents_api_v1_agents_deleteop.py +16 -0
- mistralai/models/agents_api_v1_agents_getop.py +40 -3
- mistralai/models/agents_api_v1_agents_listop.py +72 -2
- mistralai/models/agents_api_v1_conversations_deleteop.py +18 -0
- mistralai/models/agents_api_v1_conversations_listop.py +39 -2
- mistralai/models/agentscompletionrequest.py +21 -6
- mistralai/models/agentscompletionstreamrequest.py +21 -6
- mistralai/models/agentupdaterequest.py +18 -2
- mistralai/models/audiotranscriptionrequest.py +2 -0
- mistralai/models/batchjobin.py +10 -0
- mistralai/models/chatcompletionrequest.py +22 -5
- mistralai/models/chatcompletionstreamrequest.py +22 -5
- mistralai/models/conversationrequest.py +15 -4
- mistralai/models/conversationrestartrequest.py +50 -2
- mistralai/models/conversationrestartstreamrequest.py +50 -2
- mistralai/models/conversationstreamrequest.py +15 -4
- mistralai/models/documentout.py +26 -10
- mistralai/models/documentupdatein.py +24 -3
- mistralai/models/embeddingrequest.py +8 -8
- mistralai/models/files_api_routes_list_filesop.py +7 -0
- mistralai/models/fimcompletionrequest.py +8 -9
- mistralai/models/fimcompletionstreamrequest.py +8 -9
- mistralai/models/httpvalidationerror.py +11 -6
- mistralai/models/libraries_documents_list_v1op.py +15 -2
- mistralai/models/libraryout.py +10 -7
- mistralai/models/listfilesout.py +35 -4
- mistralai/models/mistralerror.py +26 -0
- mistralai/models/modelcapabilities.py +13 -4
- mistralai/models/modelconversation.py +8 -2
- mistralai/models/no_response_error.py +13 -0
- mistralai/models/ocrpageobject.py +26 -5
- mistralai/models/ocrrequest.py +17 -1
- mistralai/models/ocrtableobject.py +31 -0
- mistralai/models/prediction.py +4 -0
- mistralai/models/requestsource.py +7 -0
- mistralai/models/responseformat.py +4 -2
- mistralai/models/responseformats.py +0 -1
- mistralai/models/responsevalidationerror.py +25 -0
- mistralai/models/sdkerror.py +30 -14
- mistralai/models/sharingdelete.py +36 -5
- mistralai/models/sharingin.py +36 -5
- mistralai/models/sharingout.py +3 -3
- mistralai/models/toolexecutiondeltaevent.py +13 -4
- mistralai/models/toolexecutiondoneevent.py +13 -4
- mistralai/models/toolexecutionentry.py +9 -4
- mistralai/models/toolexecutionstartedevent.py +13 -4
- mistralai/models_.py +67 -212
- mistralai/ocr.py +33 -36
- mistralai/sdk.py +15 -2
- mistralai/transcriptions.py +21 -60
- mistralai/utils/__init__.py +18 -5
- mistralai/utils/eventstreaming.py +10 -0
- mistralai/utils/serializers.py +3 -2
- mistralai/utils/unmarshal_json_response.py +24 -0
- {mistralai-1.9.10.dist-info → mistralai-1.10.0.dist-info}/METADATA +89 -40
- {mistralai-1.9.10.dist-info → mistralai-1.10.0.dist-info}/RECORD +86 -75
- {mistralai-1.9.10.dist-info → mistralai-1.10.0.dist-info}/WHEEL +1 -1
- {mistralai-1.9.10.dist-info → mistralai-1.10.0.dist-info/licenses}/LICENSE +0 -0
mistralai/embeddings.py
CHANGED
|
@@ -5,6 +5,7 @@ from mistralai import models, utils
|
|
|
5
5
|
from mistralai._hooks import HookContext
|
|
6
6
|
from mistralai.types import OptionalNullable, UNSET
|
|
7
7
|
from mistralai.utils import get_security_from_env
|
|
8
|
+
from mistralai.utils.unmarshal_json_response import unmarshal_json_response
|
|
8
9
|
from typing import Any, Mapping, Optional, Union
|
|
9
10
|
|
|
10
11
|
|
|
@@ -30,9 +31,9 @@ class Embeddings(BaseSDK):
|
|
|
30
31
|
|
|
31
32
|
Embeddings
|
|
32
33
|
|
|
33
|
-
:param model: ID of the model to
|
|
34
|
-
:param inputs:
|
|
35
|
-
:param output_dimension: The dimension of the output embeddings.
|
|
34
|
+
:param model: The ID of the model to be used for embedding.
|
|
35
|
+
:param inputs: The text content to be embedded, can be a string or an array of strings for fast processing in bulk.
|
|
36
|
+
:param output_dimension: The dimension of the output embeddings when feature available. If not provided, a default output dimension will be used.
|
|
36
37
|
:param output_dtype:
|
|
37
38
|
:param encoding_format:
|
|
38
39
|
:param retries: Override the default retry configuration for this method
|
|
@@ -102,31 +103,20 @@ class Embeddings(BaseSDK):
|
|
|
102
103
|
|
|
103
104
|
response_data: Any = None
|
|
104
105
|
if utils.match_response(http_res, "200", "application/json"):
|
|
105
|
-
return
|
|
106
|
+
return unmarshal_json_response(models.EmbeddingResponse, http_res)
|
|
106
107
|
if utils.match_response(http_res, "422", "application/json"):
|
|
107
|
-
response_data =
|
|
108
|
-
|
|
108
|
+
response_data = unmarshal_json_response(
|
|
109
|
+
models.HTTPValidationErrorData, http_res
|
|
109
110
|
)
|
|
110
|
-
raise models.HTTPValidationError(
|
|
111
|
+
raise models.HTTPValidationError(response_data, http_res)
|
|
111
112
|
if utils.match_response(http_res, "4XX", "*"):
|
|
112
113
|
http_res_text = utils.stream_to_text(http_res)
|
|
113
|
-
raise models.SDKError(
|
|
114
|
-
"API error occurred", http_res.status_code, http_res_text, http_res
|
|
115
|
-
)
|
|
114
|
+
raise models.SDKError("API error occurred", http_res, http_res_text)
|
|
116
115
|
if utils.match_response(http_res, "5XX", "*"):
|
|
117
116
|
http_res_text = utils.stream_to_text(http_res)
|
|
118
|
-
raise models.SDKError(
|
|
119
|
-
"API error occurred", http_res.status_code, http_res_text, http_res
|
|
120
|
-
)
|
|
117
|
+
raise models.SDKError("API error occurred", http_res, http_res_text)
|
|
121
118
|
|
|
122
|
-
|
|
123
|
-
http_res_text = utils.stream_to_text(http_res)
|
|
124
|
-
raise models.SDKError(
|
|
125
|
-
f"Unexpected response received (code: {http_res.status_code}, type: {content_type})",
|
|
126
|
-
http_res.status_code,
|
|
127
|
-
http_res_text,
|
|
128
|
-
http_res,
|
|
129
|
-
)
|
|
119
|
+
raise models.SDKError("Unexpected response received", http_res)
|
|
130
120
|
|
|
131
121
|
async def create_async(
|
|
132
122
|
self,
|
|
@@ -147,9 +137,9 @@ class Embeddings(BaseSDK):
|
|
|
147
137
|
|
|
148
138
|
Embeddings
|
|
149
139
|
|
|
150
|
-
:param model: ID of the model to
|
|
151
|
-
:param inputs:
|
|
152
|
-
:param output_dimension: The dimension of the output embeddings.
|
|
140
|
+
:param model: The ID of the model to be used for embedding.
|
|
141
|
+
:param inputs: The text content to be embedded, can be a string or an array of strings for fast processing in bulk.
|
|
142
|
+
:param output_dimension: The dimension of the output embeddings when feature available. If not provided, a default output dimension will be used.
|
|
153
143
|
:param output_dtype:
|
|
154
144
|
:param encoding_format:
|
|
155
145
|
:param retries: Override the default retry configuration for this method
|
|
@@ -219,28 +209,17 @@ class Embeddings(BaseSDK):
|
|
|
219
209
|
|
|
220
210
|
response_data: Any = None
|
|
221
211
|
if utils.match_response(http_res, "200", "application/json"):
|
|
222
|
-
return
|
|
212
|
+
return unmarshal_json_response(models.EmbeddingResponse, http_res)
|
|
223
213
|
if utils.match_response(http_res, "422", "application/json"):
|
|
224
|
-
response_data =
|
|
225
|
-
|
|
214
|
+
response_data = unmarshal_json_response(
|
|
215
|
+
models.HTTPValidationErrorData, http_res
|
|
226
216
|
)
|
|
227
|
-
raise models.HTTPValidationError(
|
|
217
|
+
raise models.HTTPValidationError(response_data, http_res)
|
|
228
218
|
if utils.match_response(http_res, "4XX", "*"):
|
|
229
219
|
http_res_text = await utils.stream_to_text_async(http_res)
|
|
230
|
-
raise models.SDKError(
|
|
231
|
-
"API error occurred", http_res.status_code, http_res_text, http_res
|
|
232
|
-
)
|
|
220
|
+
raise models.SDKError("API error occurred", http_res, http_res_text)
|
|
233
221
|
if utils.match_response(http_res, "5XX", "*"):
|
|
234
222
|
http_res_text = await utils.stream_to_text_async(http_res)
|
|
235
|
-
raise models.SDKError(
|
|
236
|
-
"API error occurred", http_res.status_code, http_res_text, http_res
|
|
237
|
-
)
|
|
223
|
+
raise models.SDKError("API error occurred", http_res, http_res_text)
|
|
238
224
|
|
|
239
|
-
|
|
240
|
-
http_res_text = await utils.stream_to_text_async(http_res)
|
|
241
|
-
raise models.SDKError(
|
|
242
|
-
f"Unexpected response received (code: {http_res.status_code}, type: {content_type})",
|
|
243
|
-
http_res.status_code,
|
|
244
|
-
http_res_text,
|
|
245
|
-
http_res,
|
|
246
|
-
)
|
|
225
|
+
raise models.SDKError("Unexpected response received", http_res)
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from contextlib import contextmanager
|
|
2
|
+
|
|
3
|
+
from opentelemetry import trace as otel_trace
|
|
4
|
+
|
|
5
|
+
from .otel import MISTRAL_SDK_OTEL_TRACER_NAME
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@contextmanager
|
|
9
|
+
def trace(name: str, **kwargs):
|
|
10
|
+
tracer = otel_trace.get_tracer(MISTRAL_SDK_OTEL_TRACER_NAME)
|
|
11
|
+
with tracer.start_as_current_span(name, **kwargs) as span:
|
|
12
|
+
yield span
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
__all__ = ["trace"]
|
|
@@ -0,0 +1,393 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
import traceback
|
|
6
|
+
from datetime import datetime, timezone
|
|
7
|
+
from enum import Enum
|
|
8
|
+
from typing import Optional, Tuple
|
|
9
|
+
|
|
10
|
+
import httpx
|
|
11
|
+
import opentelemetry.semconv._incubating.attributes.gen_ai_attributes as gen_ai_attributes
|
|
12
|
+
import opentelemetry.semconv._incubating.attributes.http_attributes as http_attributes
|
|
13
|
+
import opentelemetry.semconv.attributes.server_attributes as server_attributes
|
|
14
|
+
from opentelemetry import propagate, trace
|
|
15
|
+
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
|
16
|
+
from opentelemetry.sdk.resources import SERVICE_NAME, Resource
|
|
17
|
+
from opentelemetry.sdk.trace import SpanProcessor, TracerProvider
|
|
18
|
+
from opentelemetry.sdk.trace.export import BatchSpanProcessor, SpanExportResult
|
|
19
|
+
from opentelemetry.trace import Span, Status, StatusCode, Tracer, set_span_in_context
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
OTEL_SERVICE_NAME: str = "mistralai_sdk"
|
|
25
|
+
OTEL_EXPORTER_OTLP_ENDPOINT: str = os.getenv("OTEL_EXPORTER_OTLP_ENDPOINT", "")
|
|
26
|
+
OTEL_EXPORTER_OTLP_TIMEOUT: int = int(os.getenv("OTEL_EXPORTER_OTLP_TIMEOUT", "2"))
|
|
27
|
+
OTEL_EXPORTER_OTLP_MAX_EXPORT_BATCH_SIZE: int = int(os.getenv("OTEL_EXPORTER_OTLP_MAX_EXPORT_BATCH_SIZE", "512"))
|
|
28
|
+
OTEL_EXPORTER_OTLP_SCHEDULE_DELAY_MILLIS: int = int(os.getenv("OTEL_EXPORTER_OTLP_SCHEDULE_DELAY_MILLIS", "1000"))
|
|
29
|
+
OTEL_EXPORTER_OTLP_MAX_QUEUE_SIZE: int = int(os.getenv("OTEL_EXPORTER_OTLP_MAX_QUEUE_SIZE", "2048"))
|
|
30
|
+
OTEL_EXPORTER_OTLP_EXPORT_TIMEOUT_MILLIS: int = int(os.getenv("OTEL_EXPORTER_OTLP_EXPORT_TIMEOUT_MILLIS", "5000"))
|
|
31
|
+
|
|
32
|
+
MISTRAL_SDK_OTEL_TRACER_NAME: str = OTEL_SERVICE_NAME + "_tracer"
|
|
33
|
+
|
|
34
|
+
MISTRAL_SDK_DEBUG_TRACING: bool = os.getenv("MISTRAL_SDK_DEBUG_TRACING", "false").lower() == "true"
|
|
35
|
+
DEBUG_HINT: str = "To see detailed exporter logs, set MISTRAL_SDK_DEBUG_TRACING=true."
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class MistralAIAttributes:
|
|
39
|
+
MISTRAL_AI_TOTAL_TOKENS = "mistral_ai.request.total_tokens"
|
|
40
|
+
MISTRAL_AI_TOOL_CALL_ARGUMENTS = "mistral_ai.tool.call.arguments"
|
|
41
|
+
MISTRAL_AI_MESSAGE_ID = "mistral_ai.message.id"
|
|
42
|
+
MISTRAL_AI_OPERATION_NAME= "mistral_ai.operation.name"
|
|
43
|
+
MISTRAL_AI_OCR_USAGE_PAGES_PROCESSED = "mistral_ai.ocr.usage.pages_processed"
|
|
44
|
+
MISTRAL_AI_OCR_USAGE_DOC_SIZE_BYTES = "mistral_ai.ocr.usage.doc_size_bytes"
|
|
45
|
+
MISTRAL_AI_OPERATION_ID = "mistral_ai.operation.id"
|
|
46
|
+
MISTRAL_AI_ERROR_TYPE = "mistral_ai.error.type"
|
|
47
|
+
MISTRAL_AI_ERROR_MESSAGE = "mistral_ai.error.message"
|
|
48
|
+
MISTRAL_AI_ERROR_CODE = "mistral_ai.error.code"
|
|
49
|
+
MISTRAL_AI_FUNCTION_CALL_ARGUMENTS = "mistral_ai.function.call.arguments"
|
|
50
|
+
|
|
51
|
+
class MistralAINameValues(Enum):
|
|
52
|
+
OCR = "ocr"
|
|
53
|
+
|
|
54
|
+
class TracingErrors(Exception, Enum):
|
|
55
|
+
FAILED_TO_EXPORT_OTEL_SPANS = "Failed to export OpenTelemetry (OTEL) spans."
|
|
56
|
+
FAILED_TO_INITIALIZE_OPENTELEMETRY_TRACING = "Failed to initialize OpenTelemetry tracing."
|
|
57
|
+
FAILED_TO_CREATE_SPAN_FOR_REQUEST = "Failed to create span for request."
|
|
58
|
+
FAILED_TO_ENRICH_SPAN_WITH_RESPONSE = "Failed to enrich span with response."
|
|
59
|
+
FAILED_TO_HANDLE_ERROR_IN_SPAN = "Failed to handle error in span."
|
|
60
|
+
FAILED_TO_END_SPAN = "Failed to end span."
|
|
61
|
+
|
|
62
|
+
def __str__(self):
|
|
63
|
+
return str(self.value)
|
|
64
|
+
|
|
65
|
+
class GenAISpanEnum(str, Enum):
|
|
66
|
+
CONVERSATION = "conversation"
|
|
67
|
+
CONV_REQUEST = "POST /v1/conversations"
|
|
68
|
+
EXECUTE_TOOL = "execute_tool"
|
|
69
|
+
VALIDATE_RUN = "validate_run"
|
|
70
|
+
|
|
71
|
+
@staticmethod
|
|
72
|
+
def function_call(func_name: str):
|
|
73
|
+
return f"function_call[{func_name}]"
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def parse_time_to_nanos(ts: str) -> int:
|
|
77
|
+
dt = datetime.fromisoformat(ts.replace("Z", "+00:00")).astimezone(timezone.utc)
|
|
78
|
+
return int(dt.timestamp() * 1e9)
|
|
79
|
+
|
|
80
|
+
def set_available_attributes(span: Span, attributes: dict) -> None:
|
|
81
|
+
for attribute, value in attributes.items():
|
|
82
|
+
if value:
|
|
83
|
+
span.set_attribute(attribute, value)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def enrich_span_from_request(span: Span, request: httpx.Request) -> Span:
|
|
87
|
+
if not request.url.port:
|
|
88
|
+
# From httpx doc:
|
|
89
|
+
# Note that the URL class performs port normalization as per the WHATWG spec.
|
|
90
|
+
# Default ports for "http", "https", "ws", "wss", and "ftp" schemes are always treated as None.
|
|
91
|
+
# Handling default ports since most of the time we are using https
|
|
92
|
+
if request.url.scheme == "https":
|
|
93
|
+
port = 443
|
|
94
|
+
elif request.url.scheme == "http":
|
|
95
|
+
port = 80
|
|
96
|
+
else:
|
|
97
|
+
port = -1
|
|
98
|
+
else:
|
|
99
|
+
port = request.url.port
|
|
100
|
+
|
|
101
|
+
span.set_attributes({
|
|
102
|
+
http_attributes.HTTP_REQUEST_METHOD: request.method,
|
|
103
|
+
http_attributes.HTTP_URL: str(request.url),
|
|
104
|
+
server_attributes.SERVER_ADDRESS: request.headers.get("host", ""),
|
|
105
|
+
server_attributes.SERVER_PORT: port
|
|
106
|
+
})
|
|
107
|
+
if request._content:
|
|
108
|
+
request_body = json.loads(request._content)
|
|
109
|
+
|
|
110
|
+
attributes = {
|
|
111
|
+
gen_ai_attributes.GEN_AI_REQUEST_CHOICE_COUNT: request_body.get("n", None),
|
|
112
|
+
gen_ai_attributes.GEN_AI_REQUEST_ENCODING_FORMATS: request_body.get("encoding_formats", None),
|
|
113
|
+
gen_ai_attributes.GEN_AI_REQUEST_FREQUENCY_PENALTY: request_body.get("frequency_penalty", None),
|
|
114
|
+
gen_ai_attributes.GEN_AI_REQUEST_MAX_TOKENS: request_body.get("max_tokens", None),
|
|
115
|
+
gen_ai_attributes.GEN_AI_REQUEST_MODEL: request_body.get("model", None),
|
|
116
|
+
gen_ai_attributes.GEN_AI_REQUEST_PRESENCE_PENALTY: request_body.get("presence_penalty", None),
|
|
117
|
+
gen_ai_attributes.GEN_AI_REQUEST_SEED: request_body.get("random_seed", None),
|
|
118
|
+
gen_ai_attributes.GEN_AI_REQUEST_STOP_SEQUENCES: request_body.get("stop", None),
|
|
119
|
+
gen_ai_attributes.GEN_AI_REQUEST_TEMPERATURE: request_body.get("temperature", None),
|
|
120
|
+
gen_ai_attributes.GEN_AI_REQUEST_TOP_P: request_body.get("top_p", None),
|
|
121
|
+
gen_ai_attributes.GEN_AI_REQUEST_TOP_K: request_body.get("top_k", None),
|
|
122
|
+
# Input messages are likely to be large, containing user/PII data and other sensitive information.
|
|
123
|
+
# Also structured attributes are not yet supported on spans in Python.
|
|
124
|
+
# For those reasons, we will not record the input messages for now.
|
|
125
|
+
gen_ai_attributes.GEN_AI_INPUT_MESSAGES: None,
|
|
126
|
+
}
|
|
127
|
+
# Set attributes only if they are not None.
|
|
128
|
+
# From OpenTelemetry documentation: None is not a valid attribute value per spec / is not a permitted value type for an attribute.
|
|
129
|
+
set_available_attributes(span, attributes)
|
|
130
|
+
return span
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def enrich_span_from_response(tracer: trace.Tracer, span: Span, operation_id: str, response: httpx.Response) -> None:
|
|
134
|
+
span.set_status(Status(StatusCode.OK))
|
|
135
|
+
response_data = json.loads(response.content)
|
|
136
|
+
|
|
137
|
+
# Base attributes
|
|
138
|
+
attributes: dict[str, str | int] = {
|
|
139
|
+
http_attributes.HTTP_RESPONSE_STATUS_CODE: response.status_code,
|
|
140
|
+
MistralAIAttributes.MISTRAL_AI_OPERATION_ID: operation_id,
|
|
141
|
+
gen_ai_attributes.GEN_AI_PROVIDER_NAME: gen_ai_attributes.GenAiProviderNameValues.MISTRAL_AI.value
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
# Add usage attributes if available
|
|
145
|
+
usage = response_data.get("usage", {})
|
|
146
|
+
if usage:
|
|
147
|
+
attributes.update({
|
|
148
|
+
gen_ai_attributes.GEN_AI_USAGE_PROMPT_TOKENS: usage.get("prompt_tokens", 0),
|
|
149
|
+
gen_ai_attributes.GEN_AI_USAGE_OUTPUT_TOKENS: usage.get("completion_tokens", 0),
|
|
150
|
+
MistralAIAttributes.MISTRAL_AI_TOTAL_TOKENS: usage.get("total_tokens", 0)
|
|
151
|
+
})
|
|
152
|
+
|
|
153
|
+
span.set_attributes(attributes)
|
|
154
|
+
if operation_id == "agents_api_v1_agents_create":
|
|
155
|
+
# Semantics from https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-agent-spans/#create-agent-span
|
|
156
|
+
agent_attributes = {
|
|
157
|
+
gen_ai_attributes.GEN_AI_OPERATION_NAME: gen_ai_attributes.GenAiOperationNameValues.CREATE_AGENT.value,
|
|
158
|
+
gen_ai_attributes.GEN_AI_AGENT_DESCRIPTION: response_data.get("description", ""),
|
|
159
|
+
gen_ai_attributes.GEN_AI_AGENT_ID: response_data.get("id", ""),
|
|
160
|
+
gen_ai_attributes.GEN_AI_AGENT_NAME: response_data.get("name", ""),
|
|
161
|
+
gen_ai_attributes.GEN_AI_REQUEST_MODEL: response_data.get("model", ""),
|
|
162
|
+
gen_ai_attributes.GEN_AI_SYSTEM_INSTRUCTIONS: response_data.get("instructions", "")
|
|
163
|
+
}
|
|
164
|
+
span.set_attributes(agent_attributes)
|
|
165
|
+
if operation_id in ["agents_api_v1_conversations_start", "agents_api_v1_conversations_append"]:
|
|
166
|
+
outputs = response_data.get("outputs", [])
|
|
167
|
+
conversation_attributes = {
|
|
168
|
+
gen_ai_attributes.GEN_AI_OPERATION_NAME: gen_ai_attributes.GenAiOperationNameValues.INVOKE_AGENT.value,
|
|
169
|
+
gen_ai_attributes.GEN_AI_CONVERSATION_ID: response_data.get("conversation_id", "")
|
|
170
|
+
}
|
|
171
|
+
span.set_attributes(conversation_attributes)
|
|
172
|
+
parent_context = set_span_in_context(span)
|
|
173
|
+
|
|
174
|
+
for output in outputs:
|
|
175
|
+
# TODO: Only enrich the spans if it's a single turn conversation.
|
|
176
|
+
# Multi turn conversations are handled in the extra.run.tools.create_function_result function
|
|
177
|
+
if output["type"] == "function.call":
|
|
178
|
+
pass
|
|
179
|
+
if output["type"] == "tool.execution":
|
|
180
|
+
start_ns = parse_time_to_nanos(output["created_at"])
|
|
181
|
+
end_ns = parse_time_to_nanos(output["completed_at"])
|
|
182
|
+
child_span = tracer.start_span("Tool Execution", start_time=start_ns, context=parent_context)
|
|
183
|
+
tool_attributes = {
|
|
184
|
+
gen_ai_attributes.GEN_AI_OPERATION_NAME: gen_ai_attributes.GenAiOperationNameValues.EXECUTE_TOOL.value,
|
|
185
|
+
gen_ai_attributes.GEN_AI_TOOL_CALL_ID: output.get("id", ""),
|
|
186
|
+
MistralAIAttributes.MISTRAL_AI_TOOL_CALL_ARGUMENTS: output.get("arguments", ""),
|
|
187
|
+
gen_ai_attributes.GEN_AI_TOOL_NAME: output.get("name", "")
|
|
188
|
+
}
|
|
189
|
+
child_span.set_attributes(tool_attributes)
|
|
190
|
+
child_span.end(end_time=end_ns)
|
|
191
|
+
if output["type"] == "message.output":
|
|
192
|
+
start_ns = parse_time_to_nanos(output["created_at"])
|
|
193
|
+
end_ns = parse_time_to_nanos(output["completed_at"])
|
|
194
|
+
child_span = tracer.start_span("Message Output", start_time=start_ns, context=parent_context)
|
|
195
|
+
message_attributes = {
|
|
196
|
+
gen_ai_attributes.GEN_AI_OPERATION_NAME: gen_ai_attributes.GenAiOperationNameValues.CHAT.value,
|
|
197
|
+
gen_ai_attributes.GEN_AI_PROVIDER_NAME: gen_ai_attributes.GenAiProviderNameValues.MISTRAL_AI.value,
|
|
198
|
+
MistralAIAttributes.MISTRAL_AI_MESSAGE_ID: output.get("id", ""),
|
|
199
|
+
gen_ai_attributes.GEN_AI_AGENT_ID: output.get("agent_id", ""),
|
|
200
|
+
gen_ai_attributes.GEN_AI_REQUEST_MODEL: output.get("model", "")
|
|
201
|
+
}
|
|
202
|
+
child_span.set_attributes(message_attributes)
|
|
203
|
+
child_span.end(end_time=end_ns)
|
|
204
|
+
if operation_id == "ocr_v1_ocr_post":
|
|
205
|
+
usage_info = response_data.get("usage_info", "")
|
|
206
|
+
ocr_attributes = {
|
|
207
|
+
MistralAIAttributes.MISTRAL_AI_OPERATION_NAME: MistralAINameValues.OCR.value,
|
|
208
|
+
MistralAIAttributes.MISTRAL_AI_OCR_USAGE_PAGES_PROCESSED: usage_info.get("pages_processed", "") if usage_info else "",
|
|
209
|
+
MistralAIAttributes.MISTRAL_AI_OCR_USAGE_DOC_SIZE_BYTES: usage_info.get("doc_size_bytes", "") if usage_info else "",
|
|
210
|
+
gen_ai_attributes.GEN_AI_REQUEST_MODEL: response_data.get("model", "")
|
|
211
|
+
}
|
|
212
|
+
span.set_attributes(ocr_attributes)
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
class GenAISpanProcessor(SpanProcessor):
|
|
216
|
+
def on_start(self, span, parent_context = None):
|
|
217
|
+
span.set_attributes({"agent.trace.public": ""})
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
class QuietOTLPSpanExporter(OTLPSpanExporter):
|
|
221
|
+
def export(self, spans):
|
|
222
|
+
try:
|
|
223
|
+
return super().export(spans)
|
|
224
|
+
except Exception:
|
|
225
|
+
logger.warning(f"{TracingErrors.FAILED_TO_EXPORT_OTEL_SPANS} {(traceback.format_exc() if MISTRAL_SDK_DEBUG_TRACING else DEBUG_HINT)}")
|
|
226
|
+
return SpanExportResult.FAILURE
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def get_or_create_otel_tracer() -> Tuple[bool, Tracer]:
|
|
230
|
+
"""
|
|
231
|
+
3 possible cases:
|
|
232
|
+
|
|
233
|
+
-> [SDK in a Workflow / App] If there is already a tracer provider set -> use that one
|
|
234
|
+
|
|
235
|
+
-> [SDK standalone] If no tracer provider is set but the OTEL_EXPORTER_OTLP_ENDPOINT is set -> create a new tracer provider that exports to the OTEL_EXPORTER_OTLP_ENDPOINT
|
|
236
|
+
|
|
237
|
+
-> Else tracing is disabled
|
|
238
|
+
"""
|
|
239
|
+
tracing_enabled = True
|
|
240
|
+
tracer_provider = trace.get_tracer_provider()
|
|
241
|
+
|
|
242
|
+
if isinstance(tracer_provider, trace.ProxyTracerProvider):
|
|
243
|
+
if OTEL_EXPORTER_OTLP_ENDPOINT:
|
|
244
|
+
# SDK standalone: No tracer provider but OTEL_EXPORTER_OTLP_ENDPOINT is set -> create a new tracer provider that exports to the OTEL_EXPORTER_OTLP_ENDPOINT
|
|
245
|
+
try:
|
|
246
|
+
exporter = QuietOTLPSpanExporter(
|
|
247
|
+
endpoint=OTEL_EXPORTER_OTLP_ENDPOINT,
|
|
248
|
+
timeout=OTEL_EXPORTER_OTLP_TIMEOUT
|
|
249
|
+
)
|
|
250
|
+
resource = Resource.create(attributes={SERVICE_NAME: OTEL_SERVICE_NAME})
|
|
251
|
+
tracer_provider = TracerProvider(resource=resource)
|
|
252
|
+
|
|
253
|
+
span_processor = BatchSpanProcessor(
|
|
254
|
+
exporter,
|
|
255
|
+
export_timeout_millis=OTEL_EXPORTER_OTLP_EXPORT_TIMEOUT_MILLIS,
|
|
256
|
+
max_export_batch_size=OTEL_EXPORTER_OTLP_MAX_EXPORT_BATCH_SIZE,
|
|
257
|
+
schedule_delay_millis=OTEL_EXPORTER_OTLP_SCHEDULE_DELAY_MILLIS,
|
|
258
|
+
max_queue_size=OTEL_EXPORTER_OTLP_MAX_QUEUE_SIZE
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
tracer_provider.add_span_processor(span_processor)
|
|
262
|
+
tracer_provider.add_span_processor(GenAISpanProcessor())
|
|
263
|
+
trace.set_tracer_provider(tracer_provider)
|
|
264
|
+
|
|
265
|
+
except Exception:
|
|
266
|
+
logger.warning(f"{TracingErrors.FAILED_TO_INITIALIZE_OPENTELEMETRY_TRACING} {(traceback.format_exc() if MISTRAL_SDK_DEBUG_TRACING else DEBUG_HINT)}")
|
|
267
|
+
tracing_enabled = False
|
|
268
|
+
else:
|
|
269
|
+
# No tracer provider nor OTEL_EXPORTER_OTLP_ENDPOINT set -> tracing is disabled
|
|
270
|
+
tracing_enabled = False
|
|
271
|
+
|
|
272
|
+
tracer = tracer_provider.get_tracer(MISTRAL_SDK_OTEL_TRACER_NAME)
|
|
273
|
+
|
|
274
|
+
return tracing_enabled, tracer
|
|
275
|
+
|
|
276
|
+
def get_traced_request_and_span(tracing_enabled: bool, tracer: Tracer, span: Optional[Span], operation_id: str, request: httpx.Request) -> Tuple[httpx.Request, Optional[Span]]:
|
|
277
|
+
if not tracing_enabled:
|
|
278
|
+
return request, span
|
|
279
|
+
|
|
280
|
+
try:
|
|
281
|
+
span = tracer.start_span(name=operation_id)
|
|
282
|
+
# Inject the span context into the request headers to be used by the backend service to continue the trace
|
|
283
|
+
propagate.inject(request.headers)
|
|
284
|
+
span = enrich_span_from_request(span, request)
|
|
285
|
+
except Exception:
|
|
286
|
+
logger.warning(
|
|
287
|
+
"%s %s",
|
|
288
|
+
TracingErrors.FAILED_TO_CREATE_SPAN_FOR_REQUEST,
|
|
289
|
+
traceback.format_exc() if MISTRAL_SDK_DEBUG_TRACING else DEBUG_HINT,
|
|
290
|
+
)
|
|
291
|
+
if span:
|
|
292
|
+
end_span(span=span)
|
|
293
|
+
span = None
|
|
294
|
+
|
|
295
|
+
return request, span
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
def get_traced_response(tracing_enabled: bool, tracer: Tracer, span: Optional[Span], operation_id: str, response: httpx.Response) -> httpx.Response:
|
|
299
|
+
if not tracing_enabled or not span:
|
|
300
|
+
return response
|
|
301
|
+
try:
|
|
302
|
+
is_stream_response = not response.is_closed and not response.is_stream_consumed
|
|
303
|
+
if is_stream_response:
|
|
304
|
+
return TracedResponse.from_response(resp=response, span=span)
|
|
305
|
+
enrich_span_from_response(
|
|
306
|
+
tracer, span, operation_id, response
|
|
307
|
+
)
|
|
308
|
+
except Exception:
|
|
309
|
+
logger.warning(
|
|
310
|
+
"%s %s",
|
|
311
|
+
TracingErrors.FAILED_TO_ENRICH_SPAN_WITH_RESPONSE,
|
|
312
|
+
traceback.format_exc() if MISTRAL_SDK_DEBUG_TRACING else DEBUG_HINT,
|
|
313
|
+
)
|
|
314
|
+
if span:
|
|
315
|
+
end_span(span=span)
|
|
316
|
+
return response
|
|
317
|
+
|
|
318
|
+
def get_response_and_error(tracing_enabled: bool, tracer: Tracer, span: Optional[Span], operation_id: str, response: httpx.Response, error: Optional[Exception]) -> Tuple[httpx.Response, Optional[Exception]]:
|
|
319
|
+
if not tracing_enabled or not span:
|
|
320
|
+
return response, error
|
|
321
|
+
try:
|
|
322
|
+
if error:
|
|
323
|
+
span.record_exception(error)
|
|
324
|
+
span.set_status(Status(StatusCode.ERROR, str(error)))
|
|
325
|
+
if hasattr(response, "_content") and response._content:
|
|
326
|
+
response_body = json.loads(response._content)
|
|
327
|
+
if response_body.get("object", "") == "error":
|
|
328
|
+
if error_msg := response_body.get("message", ""):
|
|
329
|
+
attributes = {
|
|
330
|
+
http_attributes.HTTP_RESPONSE_STATUS_CODE: response.status_code,
|
|
331
|
+
MistralAIAttributes.MISTRAL_AI_ERROR_TYPE: response_body.get("type", ""),
|
|
332
|
+
MistralAIAttributes.MISTRAL_AI_ERROR_MESSAGE: error_msg,
|
|
333
|
+
MistralAIAttributes.MISTRAL_AI_ERROR_CODE: response_body.get("code", ""),
|
|
334
|
+
}
|
|
335
|
+
for attribute, value in attributes.items():
|
|
336
|
+
if value:
|
|
337
|
+
span.set_attribute(attribute, value)
|
|
338
|
+
span.end()
|
|
339
|
+
span = None
|
|
340
|
+
except Exception:
|
|
341
|
+
logger.warning(
|
|
342
|
+
"%s %s",
|
|
343
|
+
TracingErrors.FAILED_TO_HANDLE_ERROR_IN_SPAN,
|
|
344
|
+
traceback.format_exc() if MISTRAL_SDK_DEBUG_TRACING else DEBUG_HINT,
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
if span:
|
|
348
|
+
span.end()
|
|
349
|
+
span = None
|
|
350
|
+
return response, error
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
def end_span(span: Span) -> None:
|
|
354
|
+
try:
|
|
355
|
+
span.end()
|
|
356
|
+
except Exception:
|
|
357
|
+
logger.warning(
|
|
358
|
+
"%s %s",
|
|
359
|
+
TracingErrors.FAILED_TO_END_SPAN,
|
|
360
|
+
traceback.format_exc() if MISTRAL_SDK_DEBUG_TRACING else DEBUG_HINT,
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
class TracedResponse(httpx.Response):
|
|
364
|
+
"""
|
|
365
|
+
TracedResponse is a subclass of httpx.Response that ends the span when the response is closed.
|
|
366
|
+
|
|
367
|
+
This hack allows ending the span only once the stream is fully consumed.
|
|
368
|
+
"""
|
|
369
|
+
def __init__(self, *args, span: Optional[Span], **kwargs) -> None:
|
|
370
|
+
super().__init__(*args, **kwargs)
|
|
371
|
+
self.span = span
|
|
372
|
+
|
|
373
|
+
def close(self) -> None:
|
|
374
|
+
if self.span:
|
|
375
|
+
end_span(span=self.span)
|
|
376
|
+
super().close()
|
|
377
|
+
|
|
378
|
+
async def aclose(self) -> None:
|
|
379
|
+
if self.span:
|
|
380
|
+
end_span(span=self.span)
|
|
381
|
+
await super().aclose()
|
|
382
|
+
|
|
383
|
+
@classmethod
|
|
384
|
+
def from_response(cls, resp: httpx.Response, span: Optional[Span]) -> "TracedResponse":
|
|
385
|
+
traced_resp = cls.__new__(cls)
|
|
386
|
+
traced_resp.__dict__ = copy.copy(resp.__dict__)
|
|
387
|
+
traced_resp.span = span
|
|
388
|
+
|
|
389
|
+
# Warning: this syntax bypasses the __init__ method.
|
|
390
|
+
# If you add init logic in the TracedResponse.__init__ method, you will need to add the following line for it to execute:
|
|
391
|
+
# traced_resp.__init__(your_arguments)
|
|
392
|
+
|
|
393
|
+
return traced_resp
|
mistralai/extra/run/tools.py
CHANGED
|
@@ -8,6 +8,7 @@ from pydantic.fields import FieldInfo
|
|
|
8
8
|
import json
|
|
9
9
|
from typing import cast, Callable, Sequence, Any, ForwardRef, get_type_hints, Union
|
|
10
10
|
|
|
11
|
+
from opentelemetry import trace
|
|
11
12
|
from griffe import (
|
|
12
13
|
Docstring,
|
|
13
14
|
DocstringSectionKind,
|
|
@@ -15,9 +16,11 @@ from griffe import (
|
|
|
15
16
|
DocstringParameter,
|
|
16
17
|
DocstringSection,
|
|
17
18
|
)
|
|
19
|
+
import opentelemetry.semconv._incubating.attributes.gen_ai_attributes as gen_ai_attributes
|
|
18
20
|
|
|
19
21
|
from mistralai.extra.exceptions import RunException
|
|
20
22
|
from mistralai.extra.mcp.base import MCPClientProtocol
|
|
23
|
+
from mistralai.extra.observability.otel import GenAISpanEnum, MistralAIAttributes, set_available_attributes
|
|
21
24
|
from mistralai.extra.run.result import RunOutputEntries
|
|
22
25
|
from mistralai.models import (
|
|
23
26
|
FunctionResultEntry,
|
|
@@ -191,22 +194,31 @@ async def create_function_result(
|
|
|
191
194
|
if isinstance(function_call.arguments, str)
|
|
192
195
|
else function_call.arguments
|
|
193
196
|
)
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
197
|
+
tracer = trace.get_tracer(__name__)
|
|
198
|
+
with tracer.start_as_current_span(GenAISpanEnum.function_call(function_call.name)) as span:
|
|
199
|
+
try:
|
|
200
|
+
if isinstance(run_tool, RunFunction):
|
|
201
|
+
res = run_tool.callable(**arguments)
|
|
202
|
+
elif isinstance(run_tool, RunCoroutine):
|
|
203
|
+
res = await run_tool.awaitable(**arguments)
|
|
204
|
+
elif isinstance(run_tool, RunMCPTool):
|
|
205
|
+
res = await run_tool.mcp_client.execute_tool(function_call.name, arguments)
|
|
206
|
+
function_call_attributes = {
|
|
207
|
+
gen_ai_attributes.GEN_AI_OPERATION_NAME: gen_ai_attributes.GenAiOperationNameValues.EXECUTE_TOOL.value,
|
|
208
|
+
gen_ai_attributes.GEN_AI_TOOL_CALL_ID: function_call.id,
|
|
209
|
+
MistralAIAttributes.MISTRAL_AI_TOOL_CALL_ARGUMENTS: str(function_call.arguments),
|
|
210
|
+
gen_ai_attributes.GEN_AI_TOOL_NAME: function_call.name
|
|
211
|
+
}
|
|
212
|
+
set_available_attributes(span, function_call_attributes)
|
|
213
|
+
except Exception as e:
|
|
214
|
+
if continue_on_fn_error is True:
|
|
215
|
+
return FunctionResultEntry(
|
|
216
|
+
tool_call_id=function_call.tool_call_id,
|
|
217
|
+
result=f"Error while executing {function_call.name}: {str(e)}",
|
|
218
|
+
)
|
|
219
|
+
raise RunException(
|
|
220
|
+
f"Failed to execute tool {function_call.name} with arguments '{function_call.arguments}'"
|
|
221
|
+
) from e
|
|
210
222
|
|
|
211
223
|
return FunctionResultEntry(
|
|
212
224
|
tool_call_id=function_call.tool_call_id,
|