mistralai 1.9.11__py3-none-any.whl → 1.10.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mistralai/_hooks/registration.py +5 -0
- mistralai/_hooks/tracing.py +75 -0
- mistralai/_version.py +2 -2
- mistralai/accesses.py +8 -8
- mistralai/agents.py +29 -17
- mistralai/chat.py +41 -29
- mistralai/classifiers.py +13 -1
- mistralai/conversations.py +294 -62
- mistralai/documents.py +19 -3
- mistralai/embeddings.py +13 -7
- mistralai/extra/README.md +1 -1
- mistralai/extra/mcp/auth.py +10 -11
- mistralai/extra/mcp/base.py +17 -16
- mistralai/extra/mcp/sse.py +13 -15
- mistralai/extra/mcp/stdio.py +5 -6
- mistralai/extra/observability/__init__.py +15 -0
- mistralai/extra/observability/otel.py +372 -0
- mistralai/extra/run/context.py +33 -43
- mistralai/extra/run/result.py +29 -30
- mistralai/extra/run/tools.py +34 -23
- mistralai/extra/struct_chat.py +15 -8
- mistralai/extra/utils/response_format.py +5 -3
- mistralai/files.py +6 -0
- mistralai/fim.py +17 -5
- mistralai/mistral_agents.py +229 -1
- mistralai/mistral_jobs.py +39 -13
- mistralai/models/__init__.py +99 -3
- mistralai/models/agent.py +15 -2
- mistralai/models/agentconversation.py +11 -3
- mistralai/models/agentcreationrequest.py +6 -2
- mistralai/models/agents_api_v1_agents_deleteop.py +16 -0
- mistralai/models/agents_api_v1_agents_getop.py +40 -3
- mistralai/models/agents_api_v1_agents_listop.py +72 -2
- mistralai/models/agents_api_v1_conversations_deleteop.py +18 -0
- mistralai/models/agents_api_v1_conversations_listop.py +39 -2
- mistralai/models/agentscompletionrequest.py +21 -6
- mistralai/models/agentscompletionstreamrequest.py +21 -6
- mistralai/models/agentupdaterequest.py +18 -2
- mistralai/models/audioencoding.py +13 -0
- mistralai/models/audioformat.py +19 -0
- mistralai/models/audiotranscriptionrequest.py +2 -0
- mistralai/models/batchjobin.py +26 -5
- mistralai/models/batchjobout.py +5 -0
- mistralai/models/batchrequest.py +48 -0
- mistralai/models/chatcompletionrequest.py +22 -5
- mistralai/models/chatcompletionstreamrequest.py +22 -5
- mistralai/models/classificationrequest.py +37 -3
- mistralai/models/conversationrequest.py +15 -4
- mistralai/models/conversationrestartrequest.py +50 -2
- mistralai/models/conversationrestartstreamrequest.py +50 -2
- mistralai/models/conversationstreamrequest.py +15 -4
- mistralai/models/documentout.py +26 -10
- mistralai/models/documentupdatein.py +24 -3
- mistralai/models/embeddingrequest.py +19 -11
- mistralai/models/files_api_routes_list_filesop.py +7 -0
- mistralai/models/fimcompletionrequest.py +8 -9
- mistralai/models/fimcompletionstreamrequest.py +8 -9
- mistralai/models/jobs_api_routes_batch_get_batch_jobop.py +40 -3
- mistralai/models/libraries_documents_list_v1op.py +15 -2
- mistralai/models/libraryout.py +10 -7
- mistralai/models/listfilesout.py +35 -4
- mistralai/models/modelcapabilities.py +13 -4
- mistralai/models/modelconversation.py +8 -2
- mistralai/models/ocrpageobject.py +26 -5
- mistralai/models/ocrrequest.py +17 -1
- mistralai/models/ocrtableobject.py +31 -0
- mistralai/models/prediction.py +4 -0
- mistralai/models/requestsource.py +7 -0
- mistralai/models/responseformat.py +4 -2
- mistralai/models/responseformats.py +0 -1
- mistralai/models/sharingdelete.py +36 -5
- mistralai/models/sharingin.py +36 -5
- mistralai/models/sharingout.py +3 -3
- mistralai/models/toolexecutiondeltaevent.py +13 -4
- mistralai/models/toolexecutiondoneevent.py +13 -4
- mistralai/models/toolexecutionentry.py +9 -4
- mistralai/models/toolexecutionstartedevent.py +13 -4
- mistralai/models/toolfilechunk.py +11 -4
- mistralai/models/toolreferencechunk.py +13 -4
- mistralai/models_.py +2 -14
- mistralai/ocr.py +18 -0
- mistralai/transcriptions.py +4 -4
- {mistralai-1.9.11.dist-info → mistralai-1.10.1.dist-info}/METADATA +162 -152
- {mistralai-1.9.11.dist-info → mistralai-1.10.1.dist-info}/RECORD +168 -144
- {mistralai-1.9.11.dist-info → mistralai-1.10.1.dist-info}/WHEEL +1 -1
- mistralai_azure/_version.py +3 -3
- mistralai_azure/basesdk.py +15 -5
- mistralai_azure/chat.py +59 -98
- mistralai_azure/models/__init__.py +50 -3
- mistralai_azure/models/chatcompletionrequest.py +16 -4
- mistralai_azure/models/chatcompletionstreamrequest.py +16 -4
- mistralai_azure/models/httpvalidationerror.py +11 -6
- mistralai_azure/models/mistralazureerror.py +26 -0
- mistralai_azure/models/no_response_error.py +13 -0
- mistralai_azure/models/prediction.py +4 -0
- mistralai_azure/models/responseformat.py +4 -2
- mistralai_azure/models/responseformats.py +0 -1
- mistralai_azure/models/responsevalidationerror.py +25 -0
- mistralai_azure/models/sdkerror.py +30 -14
- mistralai_azure/models/systemmessage.py +7 -3
- mistralai_azure/models/systemmessagecontentchunks.py +21 -0
- mistralai_azure/models/thinkchunk.py +35 -0
- mistralai_azure/ocr.py +15 -36
- mistralai_azure/utils/__init__.py +18 -5
- mistralai_azure/utils/eventstreaming.py +10 -0
- mistralai_azure/utils/serializers.py +3 -2
- mistralai_azure/utils/unmarshal_json_response.py +24 -0
- mistralai_gcp/_hooks/types.py +7 -0
- mistralai_gcp/_version.py +4 -4
- mistralai_gcp/basesdk.py +27 -25
- mistralai_gcp/chat.py +75 -98
- mistralai_gcp/fim.py +39 -74
- mistralai_gcp/httpclient.py +6 -16
- mistralai_gcp/models/__init__.py +321 -116
- mistralai_gcp/models/assistantmessage.py +1 -1
- mistralai_gcp/models/chatcompletionrequest.py +36 -7
- mistralai_gcp/models/chatcompletionresponse.py +6 -6
- mistralai_gcp/models/chatcompletionstreamrequest.py +36 -7
- mistralai_gcp/models/completionresponsestreamchoice.py +1 -1
- mistralai_gcp/models/deltamessage.py +1 -1
- mistralai_gcp/models/fimcompletionrequest.py +3 -9
- mistralai_gcp/models/fimcompletionresponse.py +6 -6
- mistralai_gcp/models/fimcompletionstreamrequest.py +3 -9
- mistralai_gcp/models/httpvalidationerror.py +11 -6
- mistralai_gcp/models/imageurl.py +1 -1
- mistralai_gcp/models/jsonschema.py +1 -1
- mistralai_gcp/models/mistralgcperror.py +26 -0
- mistralai_gcp/models/mistralpromptmode.py +8 -0
- mistralai_gcp/models/no_response_error.py +13 -0
- mistralai_gcp/models/prediction.py +4 -0
- mistralai_gcp/models/responseformat.py +5 -3
- mistralai_gcp/models/responseformats.py +0 -1
- mistralai_gcp/models/responsevalidationerror.py +25 -0
- mistralai_gcp/models/sdkerror.py +30 -14
- mistralai_gcp/models/systemmessage.py +7 -3
- mistralai_gcp/models/systemmessagecontentchunks.py +21 -0
- mistralai_gcp/models/thinkchunk.py +35 -0
- mistralai_gcp/models/toolmessage.py +1 -1
- mistralai_gcp/models/usageinfo.py +71 -8
- mistralai_gcp/models/usermessage.py +1 -1
- mistralai_gcp/sdk.py +12 -10
- mistralai_gcp/sdkconfiguration.py +0 -7
- mistralai_gcp/types/basemodel.py +3 -3
- mistralai_gcp/utils/__init__.py +143 -45
- mistralai_gcp/utils/datetimes.py +23 -0
- mistralai_gcp/utils/enums.py +67 -27
- mistralai_gcp/utils/eventstreaming.py +10 -0
- mistralai_gcp/utils/forms.py +49 -28
- mistralai_gcp/utils/serializers.py +33 -3
- mistralai_gcp/utils/unmarshal_json_response.py +24 -0
- {mistralai-1.9.11.dist-info → mistralai-1.10.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,372 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
import traceback
|
|
6
|
+
from datetime import datetime, timezone
|
|
7
|
+
from enum import Enum
|
|
8
|
+
|
|
9
|
+
import httpx
|
|
10
|
+
import opentelemetry.semconv._incubating.attributes.gen_ai_attributes as gen_ai_attributes
|
|
11
|
+
import opentelemetry.semconv._incubating.attributes.http_attributes as http_attributes
|
|
12
|
+
import opentelemetry.semconv.attributes.server_attributes as server_attributes
|
|
13
|
+
from opentelemetry import propagate, trace
|
|
14
|
+
from opentelemetry.sdk.trace import SpanProcessor
|
|
15
|
+
from opentelemetry.trace import Span, Status, StatusCode, Tracer, set_span_in_context
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
OTEL_SERVICE_NAME: str = "mistralai_sdk"
|
|
21
|
+
MISTRAL_SDK_OTEL_TRACER_NAME: str = OTEL_SERVICE_NAME + "_tracer"
|
|
22
|
+
|
|
23
|
+
MISTRAL_SDK_DEBUG_TRACING: bool = os.getenv("MISTRAL_SDK_DEBUG_TRACING", "false").lower() == "true"
|
|
24
|
+
DEBUG_HINT: str = "To see detailed tracing logs, set MISTRAL_SDK_DEBUG_TRACING=true."
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class MistralAIAttributes:
|
|
28
|
+
MISTRAL_AI_TOTAL_TOKENS = "mistral_ai.request.total_tokens"
|
|
29
|
+
MISTRAL_AI_TOOL_CALL_ARGUMENTS = "mistral_ai.tool.call.arguments"
|
|
30
|
+
MISTRAL_AI_MESSAGE_ID = "mistral_ai.message.id"
|
|
31
|
+
MISTRAL_AI_OPERATION_NAME= "mistral_ai.operation.name"
|
|
32
|
+
MISTRAL_AI_OCR_USAGE_PAGES_PROCESSED = "mistral_ai.ocr.usage.pages_processed"
|
|
33
|
+
MISTRAL_AI_OCR_USAGE_DOC_SIZE_BYTES = "mistral_ai.ocr.usage.doc_size_bytes"
|
|
34
|
+
MISTRAL_AI_OPERATION_ID = "mistral_ai.operation.id"
|
|
35
|
+
MISTRAL_AI_ERROR_TYPE = "mistral_ai.error.type"
|
|
36
|
+
MISTRAL_AI_ERROR_MESSAGE = "mistral_ai.error.message"
|
|
37
|
+
MISTRAL_AI_ERROR_CODE = "mistral_ai.error.code"
|
|
38
|
+
MISTRAL_AI_FUNCTION_CALL_ARGUMENTS = "mistral_ai.function.call.arguments"
|
|
39
|
+
|
|
40
|
+
class MistralAINameValues(Enum):
|
|
41
|
+
OCR = "ocr"
|
|
42
|
+
|
|
43
|
+
class TracingErrors(Exception, Enum):
|
|
44
|
+
FAILED_TO_CREATE_SPAN_FOR_REQUEST = "Failed to create span for request."
|
|
45
|
+
FAILED_TO_ENRICH_SPAN_WITH_RESPONSE = "Failed to enrich span with response."
|
|
46
|
+
FAILED_TO_HANDLE_ERROR_IN_SPAN = "Failed to handle error in span."
|
|
47
|
+
FAILED_TO_END_SPAN = "Failed to end span."
|
|
48
|
+
|
|
49
|
+
def __str__(self):
|
|
50
|
+
return str(self.value)
|
|
51
|
+
|
|
52
|
+
class GenAISpanEnum(str, Enum):
|
|
53
|
+
CONVERSATION = "conversation"
|
|
54
|
+
CONV_REQUEST = "POST /v1/conversations"
|
|
55
|
+
EXECUTE_TOOL = "execute_tool"
|
|
56
|
+
VALIDATE_RUN = "validate_run"
|
|
57
|
+
|
|
58
|
+
@staticmethod
|
|
59
|
+
def function_call(func_name: str):
|
|
60
|
+
return f"function_call[{func_name}]"
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def parse_time_to_nanos(ts: str) -> int:
|
|
64
|
+
dt = datetime.fromisoformat(ts.replace("Z", "+00:00")).astimezone(timezone.utc)
|
|
65
|
+
return int(dt.timestamp() * 1e9)
|
|
66
|
+
|
|
67
|
+
def set_available_attributes(span: Span, attributes: dict) -> None:
|
|
68
|
+
for attribute, value in attributes.items():
|
|
69
|
+
if value:
|
|
70
|
+
span.set_attribute(attribute, value)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def enrich_span_from_request(span: Span, request: httpx.Request) -> Span:
|
|
74
|
+
if not request.url.port:
|
|
75
|
+
# From httpx doc:
|
|
76
|
+
# Note that the URL class performs port normalization as per the WHATWG spec.
|
|
77
|
+
# Default ports for "http", "https", "ws", "wss", and "ftp" schemes are always treated as None.
|
|
78
|
+
# Handling default ports since most of the time we are using https
|
|
79
|
+
if request.url.scheme == "https":
|
|
80
|
+
port = 443
|
|
81
|
+
elif request.url.scheme == "http":
|
|
82
|
+
port = 80
|
|
83
|
+
else:
|
|
84
|
+
port = -1
|
|
85
|
+
else:
|
|
86
|
+
port = request.url.port
|
|
87
|
+
|
|
88
|
+
span.set_attributes({
|
|
89
|
+
http_attributes.HTTP_REQUEST_METHOD: request.method,
|
|
90
|
+
http_attributes.HTTP_URL: str(request.url),
|
|
91
|
+
server_attributes.SERVER_ADDRESS: request.headers.get("host", ""),
|
|
92
|
+
server_attributes.SERVER_PORT: port
|
|
93
|
+
})
|
|
94
|
+
if request._content:
|
|
95
|
+
request_body = json.loads(request._content)
|
|
96
|
+
|
|
97
|
+
attributes = {
|
|
98
|
+
gen_ai_attributes.GEN_AI_REQUEST_CHOICE_COUNT: request_body.get("n", None),
|
|
99
|
+
gen_ai_attributes.GEN_AI_REQUEST_ENCODING_FORMATS: request_body.get("encoding_formats", None),
|
|
100
|
+
gen_ai_attributes.GEN_AI_REQUEST_FREQUENCY_PENALTY: request_body.get("frequency_penalty", None),
|
|
101
|
+
gen_ai_attributes.GEN_AI_REQUEST_MAX_TOKENS: request_body.get("max_tokens", None),
|
|
102
|
+
gen_ai_attributes.GEN_AI_REQUEST_MODEL: request_body.get("model", None),
|
|
103
|
+
gen_ai_attributes.GEN_AI_REQUEST_PRESENCE_PENALTY: request_body.get("presence_penalty", None),
|
|
104
|
+
gen_ai_attributes.GEN_AI_REQUEST_SEED: request_body.get("random_seed", None),
|
|
105
|
+
gen_ai_attributes.GEN_AI_REQUEST_STOP_SEQUENCES: request_body.get("stop", None),
|
|
106
|
+
gen_ai_attributes.GEN_AI_REQUEST_TEMPERATURE: request_body.get("temperature", None),
|
|
107
|
+
gen_ai_attributes.GEN_AI_REQUEST_TOP_P: request_body.get("top_p", None),
|
|
108
|
+
gen_ai_attributes.GEN_AI_REQUEST_TOP_K: request_body.get("top_k", None),
|
|
109
|
+
# Input messages are likely to be large, containing user/PII data and other sensitive information.
|
|
110
|
+
# Also structured attributes are not yet supported on spans in Python.
|
|
111
|
+
# For those reasons, we will not record the input messages for now.
|
|
112
|
+
gen_ai_attributes.GEN_AI_INPUT_MESSAGES: None,
|
|
113
|
+
}
|
|
114
|
+
# Set attributes only if they are not None.
|
|
115
|
+
# From OpenTelemetry documentation: None is not a valid attribute value per spec / is not a permitted value type for an attribute.
|
|
116
|
+
set_available_attributes(span, attributes)
|
|
117
|
+
return span
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def enrich_span_from_response(tracer: trace.Tracer, span: Span, operation_id: str, response: httpx.Response) -> None:
|
|
121
|
+
span.set_status(Status(StatusCode.OK))
|
|
122
|
+
response_data = json.loads(response.content)
|
|
123
|
+
|
|
124
|
+
# Base attributes
|
|
125
|
+
attributes: dict[str, str | int] = {
|
|
126
|
+
http_attributes.HTTP_RESPONSE_STATUS_CODE: response.status_code,
|
|
127
|
+
MistralAIAttributes.MISTRAL_AI_OPERATION_ID: operation_id,
|
|
128
|
+
gen_ai_attributes.GEN_AI_PROVIDER_NAME: gen_ai_attributes.GenAiProviderNameValues.MISTRAL_AI.value
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
# Add usage attributes if available
|
|
132
|
+
usage = response_data.get("usage", {})
|
|
133
|
+
if usage:
|
|
134
|
+
attributes.update({
|
|
135
|
+
gen_ai_attributes.GEN_AI_USAGE_PROMPT_TOKENS: usage.get("prompt_tokens", 0),
|
|
136
|
+
gen_ai_attributes.GEN_AI_USAGE_OUTPUT_TOKENS: usage.get("completion_tokens", 0),
|
|
137
|
+
MistralAIAttributes.MISTRAL_AI_TOTAL_TOKENS: usage.get("total_tokens", 0)
|
|
138
|
+
})
|
|
139
|
+
|
|
140
|
+
span.set_attributes(attributes)
|
|
141
|
+
if operation_id == "agents_api_v1_agents_create":
|
|
142
|
+
# Semantics from https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-agent-spans/#create-agent-span
|
|
143
|
+
agent_attributes = {
|
|
144
|
+
gen_ai_attributes.GEN_AI_OPERATION_NAME: gen_ai_attributes.GenAiOperationNameValues.CREATE_AGENT.value,
|
|
145
|
+
gen_ai_attributes.GEN_AI_AGENT_DESCRIPTION: response_data.get("description", ""),
|
|
146
|
+
gen_ai_attributes.GEN_AI_AGENT_ID: response_data.get("id", ""),
|
|
147
|
+
gen_ai_attributes.GEN_AI_AGENT_NAME: response_data.get("name", ""),
|
|
148
|
+
gen_ai_attributes.GEN_AI_REQUEST_MODEL: response_data.get("model", ""),
|
|
149
|
+
gen_ai_attributes.GEN_AI_SYSTEM_INSTRUCTIONS: response_data.get("instructions", "")
|
|
150
|
+
}
|
|
151
|
+
span.set_attributes(agent_attributes)
|
|
152
|
+
if operation_id in ["agents_api_v1_conversations_start", "agents_api_v1_conversations_append"]:
|
|
153
|
+
outputs = response_data.get("outputs", [])
|
|
154
|
+
conversation_attributes = {
|
|
155
|
+
gen_ai_attributes.GEN_AI_OPERATION_NAME: gen_ai_attributes.GenAiOperationNameValues.INVOKE_AGENT.value,
|
|
156
|
+
gen_ai_attributes.GEN_AI_CONVERSATION_ID: response_data.get("conversation_id", "")
|
|
157
|
+
}
|
|
158
|
+
span.set_attributes(conversation_attributes)
|
|
159
|
+
parent_context = set_span_in_context(span)
|
|
160
|
+
|
|
161
|
+
for output in outputs:
|
|
162
|
+
# TODO: Only enrich the spans if it's a single turn conversation.
|
|
163
|
+
# Multi turn conversations are handled in the extra.run.tools.create_function_result function
|
|
164
|
+
if output["type"] == "function.call":
|
|
165
|
+
pass
|
|
166
|
+
if output["type"] == "tool.execution":
|
|
167
|
+
start_ns = parse_time_to_nanos(output["created_at"])
|
|
168
|
+
end_ns = parse_time_to_nanos(output["completed_at"])
|
|
169
|
+
child_span = tracer.start_span("Tool Execution", start_time=start_ns, context=parent_context)
|
|
170
|
+
child_span.set_attributes({"agent.trace.public": ""})
|
|
171
|
+
tool_attributes = {
|
|
172
|
+
gen_ai_attributes.GEN_AI_OPERATION_NAME: gen_ai_attributes.GenAiOperationNameValues.EXECUTE_TOOL.value,
|
|
173
|
+
gen_ai_attributes.GEN_AI_TOOL_CALL_ID: output.get("id", ""),
|
|
174
|
+
MistralAIAttributes.MISTRAL_AI_TOOL_CALL_ARGUMENTS: output.get("arguments", ""),
|
|
175
|
+
gen_ai_attributes.GEN_AI_TOOL_NAME: output.get("name", "")
|
|
176
|
+
}
|
|
177
|
+
child_span.set_attributes(tool_attributes)
|
|
178
|
+
child_span.end(end_time=end_ns)
|
|
179
|
+
if output["type"] == "message.output":
|
|
180
|
+
start_ns = parse_time_to_nanos(output["created_at"])
|
|
181
|
+
end_ns = parse_time_to_nanos(output["completed_at"])
|
|
182
|
+
child_span = tracer.start_span("Message Output", start_time=start_ns, context=parent_context)
|
|
183
|
+
child_span.set_attributes({"agent.trace.public": ""})
|
|
184
|
+
message_attributes = {
|
|
185
|
+
gen_ai_attributes.GEN_AI_OPERATION_NAME: gen_ai_attributes.GenAiOperationNameValues.CHAT.value,
|
|
186
|
+
gen_ai_attributes.GEN_AI_PROVIDER_NAME: gen_ai_attributes.GenAiProviderNameValues.MISTRAL_AI.value,
|
|
187
|
+
MistralAIAttributes.MISTRAL_AI_MESSAGE_ID: output.get("id", ""),
|
|
188
|
+
gen_ai_attributes.GEN_AI_AGENT_ID: output.get("agent_id", ""),
|
|
189
|
+
gen_ai_attributes.GEN_AI_REQUEST_MODEL: output.get("model", "")
|
|
190
|
+
}
|
|
191
|
+
child_span.set_attributes(message_attributes)
|
|
192
|
+
child_span.end(end_time=end_ns)
|
|
193
|
+
if operation_id == "ocr_v1_ocr_post":
|
|
194
|
+
usage_info = response_data.get("usage_info", "")
|
|
195
|
+
ocr_attributes = {
|
|
196
|
+
MistralAIAttributes.MISTRAL_AI_OPERATION_NAME: MistralAINameValues.OCR.value,
|
|
197
|
+
MistralAIAttributes.MISTRAL_AI_OCR_USAGE_PAGES_PROCESSED: usage_info.get("pages_processed", "") if usage_info else "",
|
|
198
|
+
MistralAIAttributes.MISTRAL_AI_OCR_USAGE_DOC_SIZE_BYTES: usage_info.get("doc_size_bytes", "") if usage_info else "",
|
|
199
|
+
gen_ai_attributes.GEN_AI_REQUEST_MODEL: response_data.get("model", "")
|
|
200
|
+
}
|
|
201
|
+
span.set_attributes(ocr_attributes)
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
class GenAISpanProcessor(SpanProcessor):
|
|
205
|
+
def on_start(self, span, parent_context = None):
|
|
206
|
+
span.set_attributes({"agent.trace.public": ""})
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def get_or_create_otel_tracer() -> tuple[bool, Tracer]:
|
|
210
|
+
"""
|
|
211
|
+
Get a tracer from the current TracerProvider.
|
|
212
|
+
|
|
213
|
+
The SDK does not set up its own TracerProvider - it relies on the application
|
|
214
|
+
to configure OpenTelemetry. This follows OTEL best practices where:
|
|
215
|
+
- Libraries/SDKs get tracers from the global provider
|
|
216
|
+
- Applications configure the TracerProvider
|
|
217
|
+
|
|
218
|
+
If no TracerProvider is configured, the ProxyTracerProvider (default) will
|
|
219
|
+
return a NoOp tracer, effectively disabling tracing. Once the application
|
|
220
|
+
sets up a real TracerProvider, subsequent spans will be recorded.
|
|
221
|
+
|
|
222
|
+
Returns:
|
|
223
|
+
Tuple[bool, Tracer]: (tracing_enabled, tracer)
|
|
224
|
+
- tracing_enabled is True if a real TracerProvider is configured
|
|
225
|
+
- tracer is always valid (may be NoOp if no provider configured)
|
|
226
|
+
"""
|
|
227
|
+
tracer_provider = trace.get_tracer_provider()
|
|
228
|
+
tracer = tracer_provider.get_tracer(MISTRAL_SDK_OTEL_TRACER_NAME)
|
|
229
|
+
|
|
230
|
+
# Tracing is considered enabled if we have a real TracerProvider (not the default proxy)
|
|
231
|
+
tracing_enabled = not isinstance(tracer_provider, trace.ProxyTracerProvider)
|
|
232
|
+
|
|
233
|
+
return tracing_enabled, tracer
|
|
234
|
+
|
|
235
|
+
def get_traced_request_and_span(
|
|
236
|
+
tracing_enabled: bool,
|
|
237
|
+
tracer: Tracer,
|
|
238
|
+
span: Span | None,
|
|
239
|
+
operation_id: str,
|
|
240
|
+
request: httpx.Request,
|
|
241
|
+
) -> tuple[httpx.Request, Span | None]:
|
|
242
|
+
if not tracing_enabled:
|
|
243
|
+
return request, span
|
|
244
|
+
|
|
245
|
+
try:
|
|
246
|
+
span = tracer.start_span(name=operation_id)
|
|
247
|
+
span.set_attributes({"agent.trace.public": ""})
|
|
248
|
+
# Inject the span context into the request headers to be used by the backend service to continue the trace
|
|
249
|
+
propagate.inject(request.headers, context=set_span_in_context(span))
|
|
250
|
+
span = enrich_span_from_request(span, request)
|
|
251
|
+
except Exception:
|
|
252
|
+
logger.warning(
|
|
253
|
+
"%s %s",
|
|
254
|
+
TracingErrors.FAILED_TO_CREATE_SPAN_FOR_REQUEST,
|
|
255
|
+
traceback.format_exc() if MISTRAL_SDK_DEBUG_TRACING else DEBUG_HINT,
|
|
256
|
+
)
|
|
257
|
+
if span:
|
|
258
|
+
end_span(span=span)
|
|
259
|
+
span = None
|
|
260
|
+
|
|
261
|
+
return request, span
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def get_traced_response(
|
|
265
|
+
tracing_enabled: bool,
|
|
266
|
+
tracer: Tracer,
|
|
267
|
+
span: Span | None,
|
|
268
|
+
operation_id: str,
|
|
269
|
+
response: httpx.Response,
|
|
270
|
+
) -> httpx.Response:
|
|
271
|
+
if not tracing_enabled or not span:
|
|
272
|
+
return response
|
|
273
|
+
try:
|
|
274
|
+
is_stream_response = not response.is_closed and not response.is_stream_consumed
|
|
275
|
+
if is_stream_response:
|
|
276
|
+
return TracedResponse.from_response(resp=response, span=span)
|
|
277
|
+
enrich_span_from_response(
|
|
278
|
+
tracer, span, operation_id, response
|
|
279
|
+
)
|
|
280
|
+
except Exception:
|
|
281
|
+
logger.warning(
|
|
282
|
+
"%s %s",
|
|
283
|
+
TracingErrors.FAILED_TO_ENRICH_SPAN_WITH_RESPONSE,
|
|
284
|
+
traceback.format_exc() if MISTRAL_SDK_DEBUG_TRACING else DEBUG_HINT,
|
|
285
|
+
)
|
|
286
|
+
if span:
|
|
287
|
+
end_span(span=span)
|
|
288
|
+
return response
|
|
289
|
+
|
|
290
|
+
def get_response_and_error(
|
|
291
|
+
tracing_enabled: bool,
|
|
292
|
+
tracer: Tracer,
|
|
293
|
+
span: Span | None,
|
|
294
|
+
operation_id: str,
|
|
295
|
+
response: httpx.Response,
|
|
296
|
+
error: Exception | None,
|
|
297
|
+
) -> tuple[httpx.Response, Exception | None]:
|
|
298
|
+
if not tracing_enabled or not span:
|
|
299
|
+
return response, error
|
|
300
|
+
try:
|
|
301
|
+
if error:
|
|
302
|
+
span.record_exception(error)
|
|
303
|
+
span.set_status(Status(StatusCode.ERROR, str(error)))
|
|
304
|
+
if hasattr(response, "_content") and response._content:
|
|
305
|
+
response_body = json.loads(response._content)
|
|
306
|
+
if response_body.get("object", "") == "error":
|
|
307
|
+
if error_msg := response_body.get("message", ""):
|
|
308
|
+
attributes = {
|
|
309
|
+
http_attributes.HTTP_RESPONSE_STATUS_CODE: response.status_code,
|
|
310
|
+
MistralAIAttributes.MISTRAL_AI_ERROR_TYPE: response_body.get("type", ""),
|
|
311
|
+
MistralAIAttributes.MISTRAL_AI_ERROR_MESSAGE: error_msg,
|
|
312
|
+
MistralAIAttributes.MISTRAL_AI_ERROR_CODE: response_body.get("code", ""),
|
|
313
|
+
}
|
|
314
|
+
for attribute, value in attributes.items():
|
|
315
|
+
if value:
|
|
316
|
+
span.set_attribute(attribute, value)
|
|
317
|
+
span.end()
|
|
318
|
+
span = None
|
|
319
|
+
except Exception:
|
|
320
|
+
logger.warning(
|
|
321
|
+
"%s %s",
|
|
322
|
+
TracingErrors.FAILED_TO_HANDLE_ERROR_IN_SPAN,
|
|
323
|
+
traceback.format_exc() if MISTRAL_SDK_DEBUG_TRACING else DEBUG_HINT,
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
if span:
|
|
327
|
+
span.end()
|
|
328
|
+
span = None
|
|
329
|
+
return response, error
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
def end_span(span: Span) -> None:
|
|
333
|
+
try:
|
|
334
|
+
span.end()
|
|
335
|
+
except Exception:
|
|
336
|
+
logger.warning(
|
|
337
|
+
"%s %s",
|
|
338
|
+
TracingErrors.FAILED_TO_END_SPAN,
|
|
339
|
+
traceback.format_exc() if MISTRAL_SDK_DEBUG_TRACING else DEBUG_HINT,
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
class TracedResponse(httpx.Response):
|
|
343
|
+
"""
|
|
344
|
+
TracedResponse is a subclass of httpx.Response that ends the span when the response is closed.
|
|
345
|
+
|
|
346
|
+
This hack allows ending the span only once the stream is fully consumed.
|
|
347
|
+
"""
|
|
348
|
+
def __init__(self, *args, span: Span | None, **kwargs) -> None:
|
|
349
|
+
super().__init__(*args, **kwargs)
|
|
350
|
+
self.span = span
|
|
351
|
+
|
|
352
|
+
def close(self) -> None:
|
|
353
|
+
if self.span:
|
|
354
|
+
end_span(span=self.span)
|
|
355
|
+
super().close()
|
|
356
|
+
|
|
357
|
+
async def aclose(self) -> None:
|
|
358
|
+
if self.span:
|
|
359
|
+
end_span(span=self.span)
|
|
360
|
+
await super().aclose()
|
|
361
|
+
|
|
362
|
+
@classmethod
|
|
363
|
+
def from_response(cls, resp: httpx.Response, span: Span | None) -> "TracedResponse":
|
|
364
|
+
traced_resp = cls.__new__(cls)
|
|
365
|
+
traced_resp.__dict__ = copy.copy(resp.__dict__)
|
|
366
|
+
traced_resp.span = span
|
|
367
|
+
|
|
368
|
+
# Warning: this syntax bypasses the __init__ method.
|
|
369
|
+
# If you add init logic in the TracedResponse.__init__ method, you will need to add the following line for it to execute:
|
|
370
|
+
# traced_resp.__init__(your_arguments)
|
|
371
|
+
|
|
372
|
+
return traced_resp
|
mistralai/extra/run/context.py
CHANGED
|
@@ -1,47 +1,41 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import inspect
|
|
3
3
|
import typing
|
|
4
|
-
from contextlib import AsyncExitStack
|
|
5
|
-
from functools import wraps
|
|
6
4
|
from collections.abc import Callable
|
|
7
|
-
|
|
5
|
+
from contextlib import AsyncExitStack
|
|
8
6
|
from dataclasses import dataclass, field
|
|
9
|
-
from
|
|
7
|
+
from functools import wraps
|
|
8
|
+
from logging import getLogger
|
|
10
9
|
|
|
11
10
|
import pydantic
|
|
12
11
|
|
|
13
|
-
from mistralai.extra import
|
|
14
|
-
response_format_from_pydantic_model,
|
|
15
|
-
)
|
|
12
|
+
from mistralai.extra import response_format_from_pydantic_model
|
|
16
13
|
from mistralai.extra.exceptions import RunException
|
|
17
14
|
from mistralai.extra.mcp.base import MCPClientProtocol
|
|
18
15
|
from mistralai.extra.run.result import RunResult
|
|
19
|
-
from mistralai.
|
|
16
|
+
from mistralai.extra.run.tools import (
|
|
17
|
+
RunCoroutine,
|
|
18
|
+
RunFunction,
|
|
19
|
+
RunMCPTool,
|
|
20
|
+
RunTool,
|
|
21
|
+
create_function_result,
|
|
22
|
+
create_tool_call,
|
|
23
|
+
)
|
|
20
24
|
from mistralai.models import (
|
|
21
|
-
ResponseFormat,
|
|
22
|
-
FunctionCallEntry,
|
|
23
|
-
Tools,
|
|
24
|
-
ToolsTypedDict,
|
|
25
25
|
CompletionArgs,
|
|
26
26
|
CompletionArgsTypedDict,
|
|
27
|
-
FunctionResultEntry,
|
|
28
27
|
ConversationInputs,
|
|
29
28
|
ConversationInputsTypedDict,
|
|
29
|
+
FunctionCallEntry,
|
|
30
|
+
FunctionResultEntry,
|
|
30
31
|
FunctionTool,
|
|
31
|
-
MessageInputEntry,
|
|
32
32
|
InputEntries,
|
|
33
|
+
MessageInputEntry,
|
|
34
|
+
ResponseFormat,
|
|
35
|
+
Tools,
|
|
36
|
+
ToolsTypedDict,
|
|
33
37
|
)
|
|
34
|
-
|
|
35
|
-
from logging import getLogger
|
|
36
|
-
|
|
37
|
-
from mistralai.extra.run.tools import (
|
|
38
|
-
create_function_result,
|
|
39
|
-
RunFunction,
|
|
40
|
-
create_tool_call,
|
|
41
|
-
RunTool,
|
|
42
|
-
RunMCPTool,
|
|
43
|
-
RunCoroutine,
|
|
44
|
-
)
|
|
38
|
+
from mistralai.types.basemodel import BaseModel, OptionalNullable, UNSET
|
|
45
39
|
|
|
46
40
|
if typing.TYPE_CHECKING:
|
|
47
41
|
from mistralai import Beta, OptionalNullable
|
|
@@ -56,8 +50,8 @@ class AgentRequestKwargs(typing.TypedDict):
|
|
|
56
50
|
class ModelRequestKwargs(typing.TypedDict):
|
|
57
51
|
model: str
|
|
58
52
|
instructions: OptionalNullable[str]
|
|
59
|
-
tools: OptionalNullable[
|
|
60
|
-
completion_args: OptionalNullable[
|
|
53
|
+
tools: OptionalNullable[list[Tools] | list[ToolsTypedDict]]
|
|
54
|
+
completion_args: OptionalNullable[CompletionArgs | CompletionArgsTypedDict]
|
|
61
55
|
|
|
62
56
|
|
|
63
57
|
@dataclass
|
|
@@ -72,7 +66,7 @@ class RunContext:
|
|
|
72
66
|
passed if the user wants to continue an existing conversation.
|
|
73
67
|
model (Options[str]): The model name to be used for the conversation. Can't be used along with 'agent_id'.
|
|
74
68
|
agent_id (Options[str]): The agent id to be used for the conversation. Can't be used along with 'model'.
|
|
75
|
-
output_format (
|
|
69
|
+
output_format (type[BaseModel] | None): The output format expected from the conversation. It represents
|
|
76
70
|
the `response_format` which is part of the `CompletionArgs`.
|
|
77
71
|
request_count (int): The number of requests made in the current `RunContext`.
|
|
78
72
|
continue_on_fn_error (bool): Flag to determine if the conversation should continue when function execution
|
|
@@ -83,10 +77,10 @@ class RunContext:
|
|
|
83
77
|
_callable_tools: dict[str, RunTool] = field(init=False, default_factory=dict)
|
|
84
78
|
_mcp_clients: list[MCPClientProtocol] = field(init=False, default_factory=list)
|
|
85
79
|
|
|
86
|
-
conversation_id:
|
|
87
|
-
model:
|
|
88
|
-
agent_id:
|
|
89
|
-
output_format:
|
|
80
|
+
conversation_id: str | None = field(default=None)
|
|
81
|
+
model: str | None = field(default=None)
|
|
82
|
+
agent_id: str | None = field(default=None)
|
|
83
|
+
output_format: type[BaseModel] | None = field(default=None)
|
|
90
84
|
request_count: int = field(default=0)
|
|
91
85
|
continue_on_fn_error: bool = field(default=False)
|
|
92
86
|
|
|
@@ -215,10 +209,8 @@ class RunContext:
|
|
|
215
209
|
|
|
216
210
|
async def prepare_model_request(
|
|
217
211
|
self,
|
|
218
|
-
tools: OptionalNullable[
|
|
219
|
-
completion_args: OptionalNullable[
|
|
220
|
-
Union[CompletionArgs, CompletionArgsTypedDict]
|
|
221
|
-
] = UNSET,
|
|
212
|
+
tools: OptionalNullable[list[Tools] | list[ToolsTypedDict]] = UNSET,
|
|
213
|
+
completion_args: OptionalNullable[CompletionArgs | CompletionArgsTypedDict] = UNSET,
|
|
222
214
|
instructions: OptionalNullable[str] = None,
|
|
223
215
|
) -> ModelRequestKwargs:
|
|
224
216
|
if self.model is None:
|
|
@@ -254,14 +246,12 @@ async def _validate_run(
|
|
|
254
246
|
*,
|
|
255
247
|
beta_client: "Beta",
|
|
256
248
|
run_ctx: RunContext,
|
|
257
|
-
inputs:
|
|
249
|
+
inputs: ConversationInputs | ConversationInputsTypedDict,
|
|
258
250
|
instructions: OptionalNullable[str] = UNSET,
|
|
259
|
-
tools: OptionalNullable[
|
|
260
|
-
completion_args: OptionalNullable[
|
|
261
|
-
Union[CompletionArgs, CompletionArgsTypedDict]
|
|
262
|
-
] = UNSET,
|
|
251
|
+
tools: OptionalNullable[list[Tools] | list[ToolsTypedDict]] = UNSET,
|
|
252
|
+
completion_args: OptionalNullable[CompletionArgs | CompletionArgsTypedDict] = UNSET,
|
|
263
253
|
) -> tuple[
|
|
264
|
-
|
|
254
|
+
AgentRequestKwargs | ModelRequestKwargs, RunResult, list[InputEntries]
|
|
265
255
|
]:
|
|
266
256
|
input_entries: list[InputEntries] = []
|
|
267
257
|
if isinstance(inputs, str):
|
|
@@ -277,7 +267,7 @@ async def _validate_run(
|
|
|
277
267
|
output_model=run_ctx.output_format,
|
|
278
268
|
conversation_id=run_ctx.conversation_id,
|
|
279
269
|
)
|
|
280
|
-
req:
|
|
270
|
+
req: AgentRequestKwargs | ModelRequestKwargs
|
|
281
271
|
if run_ctx.agent_id:
|
|
282
272
|
if tools or completion_args:
|
|
283
273
|
raise RunException("Can't set tools or completion_args when using an agent")
|
mistralai/extra/run/result.py
CHANGED
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
import datetime
|
|
2
2
|
import json
|
|
3
3
|
import typing
|
|
4
|
-
from typing import Union, Annotated, Optional, Literal
|
|
5
4
|
from dataclasses import dataclass, field
|
|
6
|
-
from
|
|
5
|
+
from typing import Annotated, Literal
|
|
6
|
+
|
|
7
|
+
from pydantic import BaseModel, Discriminator, Tag
|
|
7
8
|
|
|
8
9
|
from mistralai.extra.utils.response_format import pydantic_model_from_json
|
|
9
10
|
from mistralai.models import (
|
|
@@ -35,15 +36,15 @@ from mistralai.models import (
|
|
|
35
36
|
)
|
|
36
37
|
from mistralai.utils import get_discriminator
|
|
37
38
|
|
|
38
|
-
RunOutputEntries =
|
|
39
|
-
MessageOutputEntry
|
|
40
|
-
FunctionCallEntry
|
|
41
|
-
FunctionResultEntry
|
|
42
|
-
AgentHandoffEntry
|
|
43
|
-
ToolExecutionEntry
|
|
44
|
-
|
|
39
|
+
RunOutputEntries = (
|
|
40
|
+
MessageOutputEntry
|
|
41
|
+
| FunctionCallEntry
|
|
42
|
+
| FunctionResultEntry
|
|
43
|
+
| AgentHandoffEntry
|
|
44
|
+
| ToolExecutionEntry
|
|
45
|
+
)
|
|
45
46
|
|
|
46
|
-
RunEntries =
|
|
47
|
+
RunEntries = RunOutputEntries | MessageInputEntry
|
|
47
48
|
|
|
48
49
|
|
|
49
50
|
def as_text(entry: RunOutputEntries) -> str:
|
|
@@ -140,12 +141,12 @@ class RunFiles:
|
|
|
140
141
|
@dataclass
|
|
141
142
|
class RunResult:
|
|
142
143
|
input_entries: list[InputEntries]
|
|
143
|
-
conversation_id:
|
|
144
|
+
conversation_id: str | None = field(default=None)
|
|
144
145
|
output_entries: list[RunOutputEntries] = field(default_factory=list)
|
|
145
146
|
files: dict[str, RunFiles] = field(default_factory=dict)
|
|
146
|
-
output_model:
|
|
147
|
+
output_model: type[BaseModel] | None = field(default=None)
|
|
147
148
|
|
|
148
|
-
def get_file(self, file_id: str) ->
|
|
149
|
+
def get_file(self, file_id: str) -> RunFiles | None:
|
|
149
150
|
return self.files.get(file_id)
|
|
150
151
|
|
|
151
152
|
@property
|
|
@@ -172,36 +173,34 @@ class RunResult:
|
|
|
172
173
|
|
|
173
174
|
|
|
174
175
|
class FunctionResultEvent(BaseModel):
|
|
175
|
-
id:
|
|
176
|
+
id: str | None = None
|
|
176
177
|
|
|
177
|
-
type:
|
|
178
|
+
type: Literal["function.result"] | None = "function.result"
|
|
178
179
|
|
|
179
180
|
result: str
|
|
180
181
|
|
|
181
182
|
tool_call_id: str
|
|
182
183
|
|
|
183
|
-
created_at:
|
|
184
|
+
created_at: datetime.datetime | None = datetime.datetime.now(
|
|
184
185
|
tz=datetime.timezone.utc
|
|
185
186
|
)
|
|
186
187
|
|
|
187
|
-
output_index:
|
|
188
|
+
output_index: int | None = 0
|
|
188
189
|
|
|
189
190
|
|
|
190
|
-
RunResultEventsType =
|
|
191
|
+
RunResultEventsType = SSETypes | Literal["function.result"]
|
|
191
192
|
|
|
192
193
|
RunResultEventsData = typing.Annotated[
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
Annotated[FunctionResultEvent, Tag("function.result")],
|
|
204
|
-
],
|
|
194
|
+
Annotated[AgentHandoffDoneEvent, Tag("agent.handoff.done")]
|
|
195
|
+
| Annotated[AgentHandoffStartedEvent, Tag("agent.handoff.started")]
|
|
196
|
+
| Annotated[ResponseDoneEvent, Tag("conversation.response.done")]
|
|
197
|
+
| Annotated[ResponseErrorEvent, Tag("conversation.response.error")]
|
|
198
|
+
| Annotated[ResponseStartedEvent, Tag("conversation.response.started")]
|
|
199
|
+
| Annotated[FunctionCallEvent, Tag("function.call.delta")]
|
|
200
|
+
| Annotated[MessageOutputEvent, Tag("message.output.delta")]
|
|
201
|
+
| Annotated[ToolExecutionDoneEvent, Tag("tool.execution.done")]
|
|
202
|
+
| Annotated[ToolExecutionStartedEvent, Tag("tool.execution.started")]
|
|
203
|
+
| Annotated[FunctionResultEvent, Tag("function.result")],
|
|
205
204
|
Discriminator(lambda m: get_discriminator(m, "type", "type")),
|
|
206
205
|
]
|
|
207
206
|
|