mistralai 1.5.1__py3-none-any.whl → 1.5.2rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. mistralai/_hooks/types.py +15 -3
  2. mistralai/_version.py +3 -3
  3. mistralai/agents.py +32 -12
  4. mistralai/basesdk.py +8 -0
  5. mistralai/chat.py +32 -12
  6. mistralai/classifiers.py +32 -12
  7. mistralai/embeddings.py +20 -10
  8. mistralai/extra/utils/response_format.py +3 -3
  9. mistralai/files.py +36 -0
  10. mistralai/fim.py +32 -12
  11. mistralai/httpclient.py +4 -2
  12. mistralai/jobs.py +30 -0
  13. mistralai/mistral_jobs.py +24 -0
  14. mistralai/models/__init__.py +6 -1
  15. mistralai/models/documenturlchunk.py +8 -14
  16. mistralai/models/embeddingrequest.py +7 -7
  17. mistralai/models/filepurpose.py +1 -1
  18. mistralai/models_.py +66 -18
  19. mistralai/ocr.py +16 -6
  20. mistralai/sdk.py +19 -3
  21. mistralai/sdkconfiguration.py +4 -2
  22. mistralai/utils/__init__.py +2 -0
  23. mistralai/utils/serializers.py +10 -6
  24. mistralai/utils/values.py +4 -1
  25. {mistralai-1.5.1.dist-info → mistralai-1.5.2rc1.dist-info}/METADATA +66 -19
  26. {mistralai-1.5.1.dist-info → mistralai-1.5.2rc1.dist-info}/RECORD +73 -69
  27. mistralai_azure/__init__.py +10 -1
  28. mistralai_azure/_hooks/types.py +15 -3
  29. mistralai_azure/_version.py +3 -0
  30. mistralai_azure/basesdk.py +8 -0
  31. mistralai_azure/chat.py +88 -20
  32. mistralai_azure/httpclient.py +52 -0
  33. mistralai_azure/models/__init__.py +7 -0
  34. mistralai_azure/models/assistantmessage.py +2 -0
  35. mistralai_azure/models/chatcompletionrequest.py +8 -10
  36. mistralai_azure/models/chatcompletionstreamrequest.py +8 -10
  37. mistralai_azure/models/function.py +3 -0
  38. mistralai_azure/models/jsonschema.py +61 -0
  39. mistralai_azure/models/prediction.py +25 -0
  40. mistralai_azure/models/responseformat.py +42 -1
  41. mistralai_azure/models/responseformats.py +1 -1
  42. mistralai_azure/models/toolcall.py +3 -0
  43. mistralai_azure/sdk.py +56 -14
  44. mistralai_azure/sdkconfiguration.py +14 -6
  45. mistralai_azure/utils/__init__.py +2 -0
  46. mistralai_azure/utils/serializers.py +10 -6
  47. mistralai_azure/utils/values.py +4 -1
  48. mistralai_gcp/__init__.py +10 -1
  49. mistralai_gcp/_hooks/types.py +15 -3
  50. mistralai_gcp/_version.py +3 -0
  51. mistralai_gcp/basesdk.py +8 -0
  52. mistralai_gcp/chat.py +89 -21
  53. mistralai_gcp/fim.py +61 -21
  54. mistralai_gcp/httpclient.py +52 -0
  55. mistralai_gcp/models/__init__.py +7 -0
  56. mistralai_gcp/models/assistantmessage.py +2 -0
  57. mistralai_gcp/models/chatcompletionrequest.py +8 -10
  58. mistralai_gcp/models/chatcompletionstreamrequest.py +8 -10
  59. mistralai_gcp/models/fimcompletionrequest.py +2 -3
  60. mistralai_gcp/models/fimcompletionstreamrequest.py +2 -3
  61. mistralai_gcp/models/function.py +3 -0
  62. mistralai_gcp/models/jsonschema.py +61 -0
  63. mistralai_gcp/models/prediction.py +25 -0
  64. mistralai_gcp/models/responseformat.py +42 -1
  65. mistralai_gcp/models/responseformats.py +1 -1
  66. mistralai_gcp/models/toolcall.py +3 -0
  67. mistralai_gcp/sdk.py +63 -19
  68. mistralai_gcp/sdkconfiguration.py +14 -6
  69. mistralai_gcp/utils/__init__.py +2 -0
  70. mistralai_gcp/utils/serializers.py +10 -6
  71. mistralai_gcp/utils/values.py +4 -1
  72. {mistralai-1.5.1.dist-info → mistralai-1.5.2rc1.dist-info}/LICENSE +0 -0
  73. {mistralai-1.5.1.dist-info → mistralai-1.5.2rc1.dist-info}/WHEEL +0 -0
@@ -2,6 +2,7 @@
2
2
 
3
3
  from __future__ import annotations
4
4
  from .assistantmessage import AssistantMessage, AssistantMessageTypedDict
5
+ from .prediction import Prediction, PredictionTypedDict
5
6
  from .responseformat import ResponseFormat, ResponseFormatTypedDict
6
7
  from .systemmessage import SystemMessage, SystemMessageTypedDict
7
8
  from .tool import Tool, ToolTypedDict
@@ -68,7 +69,7 @@ ChatCompletionRequestToolChoice = TypeAliasType(
68
69
 
69
70
 
70
71
  class ChatCompletionRequestTypedDict(TypedDict):
71
- model: Nullable[str]
72
+ model: str
72
73
  r"""ID of the model to use. You can use the [List Available Models](/api/#tag/models/operation/list_models_v1_models_get) API to see all of your available models, or see our [Model overview](/models) for model descriptions."""
73
74
  messages: List[ChatCompletionRequestMessagesTypedDict]
74
75
  r"""The prompt(s) to generate completions for, encoded as a list of dict with role and content."""
@@ -93,10 +94,11 @@ class ChatCompletionRequestTypedDict(TypedDict):
93
94
  r"""frequency_penalty penalizes the repetition of words based on their frequency in the generated text. A higher frequency penalty discourages the model from repeating words that have already appeared frequently in the output, promoting diversity and reducing repetition."""
94
95
  n: NotRequired[Nullable[int]]
95
96
  r"""Number of completions to return for each request, input tokens are only billed once."""
97
+ prediction: NotRequired[PredictionTypedDict]
96
98
 
97
99
 
98
100
  class ChatCompletionRequest(BaseModel):
99
- model: Nullable[str]
101
+ model: str
100
102
  r"""ID of the model to use. You can use the [List Available Models](/api/#tag/models/operation/list_models_v1_models_get) API to see all of your available models, or see our [Model overview](/models) for model descriptions."""
101
103
 
102
104
  messages: List[ChatCompletionRequestMessages]
@@ -135,6 +137,8 @@ class ChatCompletionRequest(BaseModel):
135
137
  n: OptionalNullable[int] = UNSET
136
138
  r"""Number of completions to return for each request, input tokens are only billed once."""
137
139
 
140
+ prediction: Optional[Prediction] = None
141
+
138
142
  @model_serializer(mode="wrap")
139
143
  def serialize_model(self, handler):
140
144
  optional_fields = [
@@ -150,15 +154,9 @@ class ChatCompletionRequest(BaseModel):
150
154
  "presence_penalty",
151
155
  "frequency_penalty",
152
156
  "n",
157
+ "prediction",
153
158
  ]
154
- nullable_fields = [
155
- "model",
156
- "temperature",
157
- "max_tokens",
158
- "random_seed",
159
- "tools",
160
- "n",
161
- ]
159
+ nullable_fields = ["temperature", "max_tokens", "random_seed", "tools", "n"]
162
160
  null_default_fields = []
163
161
 
164
162
  serialized = handler(self)
@@ -2,6 +2,7 @@
2
2
 
3
3
  from __future__ import annotations
4
4
  from .assistantmessage import AssistantMessage, AssistantMessageTypedDict
5
+ from .prediction import Prediction, PredictionTypedDict
5
6
  from .responseformat import ResponseFormat, ResponseFormatTypedDict
6
7
  from .systemmessage import SystemMessage, SystemMessageTypedDict
7
8
  from .tool import Tool, ToolTypedDict
@@ -64,7 +65,7 @@ ChatCompletionStreamRequestToolChoice = TypeAliasType(
64
65
 
65
66
 
66
67
  class ChatCompletionStreamRequestTypedDict(TypedDict):
67
- model: Nullable[str]
68
+ model: str
68
69
  r"""ID of the model to use. You can use the [List Available Models](/api/#tag/models/operation/list_models_v1_models_get) API to see all of your available models, or see our [Model overview](/models) for model descriptions."""
69
70
  messages: List[MessagesTypedDict]
70
71
  r"""The prompt(s) to generate completions for, encoded as a list of dict with role and content."""
@@ -88,10 +89,11 @@ class ChatCompletionStreamRequestTypedDict(TypedDict):
88
89
  r"""frequency_penalty penalizes the repetition of words based on their frequency in the generated text. A higher frequency penalty discourages the model from repeating words that have already appeared frequently in the output, promoting diversity and reducing repetition."""
89
90
  n: NotRequired[Nullable[int]]
90
91
  r"""Number of completions to return for each request, input tokens are only billed once."""
92
+ prediction: NotRequired[PredictionTypedDict]
91
93
 
92
94
 
93
95
  class ChatCompletionStreamRequest(BaseModel):
94
- model: Nullable[str]
96
+ model: str
95
97
  r"""ID of the model to use. You can use the [List Available Models](/api/#tag/models/operation/list_models_v1_models_get) API to see all of your available models, or see our [Model overview](/models) for model descriptions."""
96
98
 
97
99
  messages: List[Messages]
@@ -129,6 +131,8 @@ class ChatCompletionStreamRequest(BaseModel):
129
131
  n: OptionalNullable[int] = UNSET
130
132
  r"""Number of completions to return for each request, input tokens are only billed once."""
131
133
 
134
+ prediction: Optional[Prediction] = None
135
+
132
136
  @model_serializer(mode="wrap")
133
137
  def serialize_model(self, handler):
134
138
  optional_fields = [
@@ -144,15 +148,9 @@ class ChatCompletionStreamRequest(BaseModel):
144
148
  "presence_penalty",
145
149
  "frequency_penalty",
146
150
  "n",
151
+ "prediction",
147
152
  ]
148
- nullable_fields = [
149
- "model",
150
- "temperature",
151
- "max_tokens",
152
- "random_seed",
153
- "tools",
154
- "n",
155
- ]
153
+ nullable_fields = ["temperature", "max_tokens", "random_seed", "tools", "n"]
156
154
  null_default_fields = []
157
155
 
158
156
  serialized = handler(self)
@@ -26,7 +26,7 @@ r"""Stop generation if this token is detected. Or if one of these tokens is dete
26
26
 
27
27
 
28
28
  class FIMCompletionRequestTypedDict(TypedDict):
29
- model: Nullable[str]
29
+ model: str
30
30
  r"""ID of the model to use. Only compatible for now with:
31
31
  - `codestral-2405`
32
32
  - `codestral-latest`
@@ -52,7 +52,7 @@ class FIMCompletionRequestTypedDict(TypedDict):
52
52
 
53
53
 
54
54
  class FIMCompletionRequest(BaseModel):
55
- model: Nullable[str]
55
+ model: str
56
56
  r"""ID of the model to use. Only compatible for now with:
57
57
  - `codestral-2405`
58
58
  - `codestral-latest`
@@ -98,7 +98,6 @@ class FIMCompletionRequest(BaseModel):
98
98
  "min_tokens",
99
99
  ]
100
100
  nullable_fields = [
101
- "model",
102
101
  "temperature",
103
102
  "max_tokens",
104
103
  "random_seed",
@@ -26,7 +26,7 @@ r"""Stop generation if this token is detected. Or if one of these tokens is dete
26
26
 
27
27
 
28
28
  class FIMCompletionStreamRequestTypedDict(TypedDict):
29
- model: Nullable[str]
29
+ model: str
30
30
  r"""ID of the model to use. Only compatible for now with:
31
31
  - `codestral-2405`
32
32
  - `codestral-latest`
@@ -51,7 +51,7 @@ class FIMCompletionStreamRequestTypedDict(TypedDict):
51
51
 
52
52
 
53
53
  class FIMCompletionStreamRequest(BaseModel):
54
- model: Nullable[str]
54
+ model: str
55
55
  r"""ID of the model to use. Only compatible for now with:
56
56
  - `codestral-2405`
57
57
  - `codestral-latest`
@@ -96,7 +96,6 @@ class FIMCompletionStreamRequest(BaseModel):
96
96
  "min_tokens",
97
97
  ]
98
98
  nullable_fields = [
99
- "model",
100
99
  "temperature",
101
100
  "max_tokens",
102
101
  "random_seed",
@@ -10,6 +10,7 @@ class FunctionTypedDict(TypedDict):
10
10
  name: str
11
11
  parameters: Dict[str, Any]
12
12
  description: NotRequired[str]
13
+ strict: NotRequired[bool]
13
14
 
14
15
 
15
16
  class Function(BaseModel):
@@ -18,3 +19,5 @@ class Function(BaseModel):
18
19
  parameters: Dict[str, Any]
19
20
 
20
21
  description: Optional[str] = ""
22
+
23
+ strict: Optional[bool] = False
@@ -0,0 +1,61 @@
1
+ """Code generated by Speakeasy (https://speakeasy.com). DO NOT EDIT."""
2
+
3
+ from __future__ import annotations
4
+ from mistralai_gcp.types import (
5
+ BaseModel,
6
+ Nullable,
7
+ OptionalNullable,
8
+ UNSET,
9
+ UNSET_SENTINEL,
10
+ )
11
+ import pydantic
12
+ from pydantic import model_serializer
13
+ from typing import Any, Dict, Optional
14
+ from typing_extensions import Annotated, NotRequired, TypedDict
15
+
16
+
17
+ class JSONSchemaTypedDict(TypedDict):
18
+ name: str
19
+ schema_definition: Dict[str, Any]
20
+ description: NotRequired[Nullable[str]]
21
+ strict: NotRequired[bool]
22
+
23
+
24
+ class JSONSchema(BaseModel):
25
+ name: str
26
+
27
+ schema_definition: Annotated[Dict[str, Any], pydantic.Field(alias="schema")]
28
+
29
+ description: OptionalNullable[str] = UNSET
30
+
31
+ strict: Optional[bool] = False
32
+
33
+ @model_serializer(mode="wrap")
34
+ def serialize_model(self, handler):
35
+ optional_fields = ["description", "strict"]
36
+ nullable_fields = ["description"]
37
+ null_default_fields = []
38
+
39
+ serialized = handler(self)
40
+
41
+ m = {}
42
+
43
+ for n, f in self.model_fields.items():
44
+ k = f.alias or n
45
+ val = serialized.get(k)
46
+ serialized.pop(k, None)
47
+
48
+ optional_nullable = k in optional_fields and k in nullable_fields
49
+ is_set = (
50
+ self.__pydantic_fields_set__.intersection({n})
51
+ or k in null_default_fields
52
+ ) # pylint: disable=no-member
53
+
54
+ if val is not None and val != UNSET_SENTINEL:
55
+ m[k] = val
56
+ elif val != UNSET_SENTINEL and (
57
+ not k in optional_fields or (optional_nullable and is_set)
58
+ ):
59
+ m[k] = val
60
+
61
+ return m
@@ -0,0 +1,25 @@
1
+ """Code generated by Speakeasy (https://speakeasy.com). DO NOT EDIT."""
2
+
3
+ from __future__ import annotations
4
+ from mistralai_gcp.types import BaseModel
5
+ from mistralai_gcp.utils import validate_const
6
+ import pydantic
7
+ from pydantic.functional_validators import AfterValidator
8
+ from typing import Literal, Optional
9
+ from typing_extensions import Annotated, NotRequired, TypedDict
10
+
11
+
12
+ class PredictionTypedDict(TypedDict):
13
+ type: Literal["content"]
14
+ content: NotRequired[str]
15
+
16
+
17
+ class Prediction(BaseModel):
18
+ TYPE: Annotated[
19
+ Annotated[
20
+ Optional[Literal["content"]], AfterValidator(validate_const("content"))
21
+ ],
22
+ pydantic.Field(alias="type"),
23
+ ] = "content"
24
+
25
+ content: Optional[str] = ""
@@ -1,8 +1,16 @@
1
1
  """Code generated by Speakeasy (https://speakeasy.com). DO NOT EDIT."""
2
2
 
3
3
  from __future__ import annotations
4
+ from .jsonschema import JSONSchema, JSONSchemaTypedDict
4
5
  from .responseformats import ResponseFormats
5
- from mistralai_gcp.types import BaseModel
6
+ from mistralai_gcp.types import (
7
+ BaseModel,
8
+ Nullable,
9
+ OptionalNullable,
10
+ UNSET,
11
+ UNSET_SENTINEL,
12
+ )
13
+ from pydantic import model_serializer
6
14
  from typing import Optional
7
15
  from typing_extensions import NotRequired, TypedDict
8
16
 
@@ -10,8 +18,41 @@ from typing_extensions import NotRequired, TypedDict
10
18
  class ResponseFormatTypedDict(TypedDict):
11
19
  type: NotRequired[ResponseFormats]
12
20
  r"""An object specifying the format that the model must output. Setting to `{ \"type\": \"json_object\" }` enables JSON mode, which guarantees the message the model generates is in JSON. When using JSON mode you MUST also instruct the model to produce JSON yourself with a system or a user message."""
21
+ json_schema: NotRequired[Nullable[JSONSchemaTypedDict]]
13
22
 
14
23
 
15
24
  class ResponseFormat(BaseModel):
16
25
  type: Optional[ResponseFormats] = None
17
26
  r"""An object specifying the format that the model must output. Setting to `{ \"type\": \"json_object\" }` enables JSON mode, which guarantees the message the model generates is in JSON. When using JSON mode you MUST also instruct the model to produce JSON yourself with a system or a user message."""
27
+
28
+ json_schema: OptionalNullable[JSONSchema] = UNSET
29
+
30
+ @model_serializer(mode="wrap")
31
+ def serialize_model(self, handler):
32
+ optional_fields = ["type", "json_schema"]
33
+ nullable_fields = ["json_schema"]
34
+ null_default_fields = []
35
+
36
+ serialized = handler(self)
37
+
38
+ m = {}
39
+
40
+ for n, f in self.model_fields.items():
41
+ k = f.alias or n
42
+ val = serialized.get(k)
43
+ serialized.pop(k, None)
44
+
45
+ optional_nullable = k in optional_fields and k in nullable_fields
46
+ is_set = (
47
+ self.__pydantic_fields_set__.intersection({n})
48
+ or k in null_default_fields
49
+ ) # pylint: disable=no-member
50
+
51
+ if val is not None and val != UNSET_SENTINEL:
52
+ m[k] = val
53
+ elif val != UNSET_SENTINEL and (
54
+ not k in optional_fields or (optional_nullable and is_set)
55
+ ):
56
+ m[k] = val
57
+
58
+ return m
@@ -4,5 +4,5 @@ from __future__ import annotations
4
4
  from typing import Literal
5
5
 
6
6
 
7
- ResponseFormats = Literal["text", "json_object"]
7
+ ResponseFormats = Literal["text", "json_object", "json_schema"]
8
8
  r"""An object specifying the format that the model must output. Setting to `{ \"type\": \"json_object\" }` enables JSON mode, which guarantees the message the model generates is in JSON. When using JSON mode you MUST also instruct the model to produce JSON yourself with a system or a user message."""
@@ -14,6 +14,7 @@ class ToolCallTypedDict(TypedDict):
14
14
  function: FunctionCallTypedDict
15
15
  id: NotRequired[str]
16
16
  type: NotRequired[ToolTypes]
17
+ index: NotRequired[int]
17
18
 
18
19
 
19
20
  class ToolCall(BaseModel):
@@ -24,3 +25,5 @@ class ToolCall(BaseModel):
24
25
  type: Annotated[Optional[ToolTypes], PlainValidator(validate_open_enum(False))] = (
25
26
  None
26
27
  )
28
+
29
+ index: Optional[int] = 0
mistralai_gcp/sdk.py CHANGED
@@ -1,23 +1,25 @@
1
- """Code generated by Speakeasy (https://speakeasyapi.dev). DO NOT EDIT."""
1
+ """Code generated by Speakeasy (https://speakeasy.com). DO NOT EDIT."""
2
2
 
3
3
  import json
4
- from typing import Optional, Tuple, Union
4
+ import weakref
5
+ from typing import Any, Optional, cast
5
6
 
6
7
  import google.auth
7
8
  import google.auth.credentials
8
9
  import google.auth.transport
9
10
  import google.auth.transport.requests
10
11
  import httpx
12
+
11
13
  from mistralai_gcp import models
12
14
  from mistralai_gcp._hooks import BeforeRequestHook, SDKHooks
13
15
  from mistralai_gcp.chat import Chat
14
16
  from mistralai_gcp.fim import Fim
15
- from mistralai_gcp.types import Nullable
17
+ from mistralai_gcp.types import UNSET, OptionalNullable
16
18
 
17
19
  from .basesdk import BaseSDK
18
- from .httpclient import AsyncHttpClient, HttpClient
20
+ from .httpclient import AsyncHttpClient, ClientOwner, HttpClient, close_clients
19
21
  from .sdkconfiguration import SDKConfiguration
20
- from .utils.logger import Logger, NoOpLogger
22
+ from .utils.logger import Logger, get_default_logger
21
23
  from .utils.retries import RetryConfig
22
24
 
23
25
  LEGACY_MODEL_ID_FORMAT = {
@@ -26,20 +28,21 @@ LEGACY_MODEL_ID_FORMAT = {
26
28
  "mistral-nemo-2407": "mistral-nemo@2407",
27
29
  }
28
30
 
29
- def get_model_info(model: str) -> Tuple[str, str]:
31
+
32
+ def get_model_info(model: str) -> tuple[str, str]:
30
33
  # if the model requiers the legacy fomat, use it, else do nothing.
31
34
  if model in LEGACY_MODEL_ID_FORMAT:
32
35
  return "-".join(model.split("-")[:-1]), LEGACY_MODEL_ID_FORMAT[model]
33
36
  return model, model
34
37
 
35
38
 
36
-
37
39
  class MistralGoogleCloud(BaseSDK):
38
40
  r"""Mistral AI API: Our Chat Completion and Embeddings APIs specification. Create your account on [La Plateforme](https://console.mistral.ai) to get access and read the [docs](https://docs.mistral.ai) to learn how to use it."""
39
41
 
40
42
  chat: Chat
43
+ r"""Chat Completion API."""
41
44
  fim: Fim
42
- r"""Chat Completion API"""
45
+ r"""Fill-in-the-middle API."""
43
46
 
44
47
  def __init__(
45
48
  self,
@@ -48,16 +51,20 @@ class MistralGoogleCloud(BaseSDK):
48
51
  access_token: Optional[str] = None,
49
52
  client: Optional[HttpClient] = None,
50
53
  async_client: Optional[AsyncHttpClient] = None,
51
- retry_config: Optional[Nullable[RetryConfig]] = None,
54
+ retry_config: OptionalNullable[RetryConfig] = UNSET,
55
+ timeout_ms: Optional[int] = None,
52
56
  debug_logger: Optional[Logger] = None,
53
57
  ) -> None:
54
58
  r"""Instantiates the SDK configuring it with the provided parameters.
55
59
 
56
- :param region: The Google Cloud region to use for all methods
57
- :param project_id: The project ID to use for all methods
60
+ :param api_key: The api_key required for authentication
61
+ :param server: The server by name to use for all methods
62
+ :param server_url: The server URL to use for all methods
63
+ :param url_params: Parameters to optionally template the server URL with
58
64
  :param client: The HTTP client to use for all synchronous methods
59
65
  :param async_client: The Async HTTP client to use for all asynchronous methods
60
66
  :param retry_config: The retry configuration to use for all supported methods
67
+ :param timeout_ms: Optional request timeout applied to each operation in milliseconds
61
68
  """
62
69
 
63
70
  if not access_token:
@@ -72,36 +79,42 @@ class MistralGoogleCloud(BaseSDK):
72
79
  )
73
80
 
74
81
  project_id = project_id or loaded_project_id
82
+
75
83
  if project_id is None:
76
84
  raise models.SDKError("project_id must be provided")
77
85
 
78
86
  def auth_token() -> str:
79
87
  if access_token:
80
88
  return access_token
89
+
81
90
  credentials.refresh(google.auth.transport.requests.Request())
82
91
  token = credentials.token
83
92
  if not token:
84
93
  raise models.SDKError("Failed to get token from credentials")
85
94
  return token
86
95
 
96
+ client_supplied = True
87
97
  if client is None:
88
98
  client = httpx.Client()
99
+ client_supplied = False
89
100
 
90
101
  assert issubclass(
91
102
  type(client), HttpClient
92
103
  ), "The provided client must implement the HttpClient protocol."
93
104
 
105
+ async_client_supplied = True
94
106
  if async_client is None:
95
107
  async_client = httpx.AsyncClient()
108
+ async_client_supplied = False
96
109
 
97
110
  if debug_logger is None:
98
- debug_logger = NoOpLogger()
111
+ debug_logger = get_default_logger()
99
112
 
100
113
  assert issubclass(
101
114
  type(async_client), AsyncHttpClient
102
115
  ), "The provided async_client must implement the AsyncHttpClient protocol."
103
116
 
104
- security = None
117
+ security: Any = None
105
118
  if callable(auth_token):
106
119
  security = lambda: models.Security( # pylint: disable=unnecessary-lambda-assignment
107
120
  api_key=auth_token()
@@ -113,23 +126,24 @@ class MistralGoogleCloud(BaseSDK):
113
126
  self,
114
127
  SDKConfiguration(
115
128
  client=client,
129
+ client_supplied=client_supplied,
116
130
  async_client=async_client,
131
+ async_client_supplied=async_client_supplied,
117
132
  security=security,
118
133
  server_url=f"https://{region}-aiplatform.googleapis.com",
119
134
  server=None,
120
135
  retry_config=retry_config,
136
+ timeout_ms=timeout_ms,
121
137
  debug_logger=debug_logger,
122
138
  ),
123
139
  )
124
140
 
125
141
  hooks = SDKHooks()
126
-
127
142
  hook = GoogleCloudBeforeRequestHook(region, project_id)
128
143
  hooks.register_before_request_hook(hook)
129
-
130
144
  current_server_url, *_ = self.sdk_configuration.get_server_details()
131
145
  server_url, self.sdk_configuration.client = hooks.sdk_init(
132
- current_server_url, self.sdk_configuration.client
146
+ current_server_url, client
133
147
  )
134
148
  if current_server_url != server_url:
135
149
  self.sdk_configuration.server_url = server_url
@@ -137,22 +151,53 @@ class MistralGoogleCloud(BaseSDK):
137
151
  # pylint: disable=protected-access
138
152
  self.sdk_configuration.__dict__["_hooks"] = hooks
139
153
 
154
+ weakref.finalize(
155
+ self,
156
+ close_clients,
157
+ cast(ClientOwner, self.sdk_configuration),
158
+ self.sdk_configuration.client,
159
+ self.sdk_configuration.client_supplied,
160
+ self.sdk_configuration.async_client,
161
+ self.sdk_configuration.async_client_supplied,
162
+ )
163
+
140
164
  self._init_sdks()
141
165
 
142
166
  def _init_sdks(self):
143
167
  self.chat = Chat(self.sdk_configuration)
144
168
  self.fim = Fim(self.sdk_configuration)
145
169
 
170
+ def __enter__(self):
171
+ return self
146
172
 
147
- class GoogleCloudBeforeRequestHook(BeforeRequestHook):
173
+ async def __aenter__(self):
174
+ return self
175
+
176
+ def __exit__(self, exc_type, exc_val, exc_tb):
177
+ if (
178
+ self.sdk_configuration.client is not None
179
+ and not self.sdk_configuration.client_supplied
180
+ ):
181
+ self.sdk_configuration.client.close()
182
+ self.sdk_configuration.client = None
148
183
 
184
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
185
+ if (
186
+ self.sdk_configuration.async_client is not None
187
+ and not self.sdk_configuration.async_client_supplied
188
+ ):
189
+ await self.sdk_configuration.async_client.aclose()
190
+ self.sdk_configuration.async_client = None
191
+
192
+
193
+ class GoogleCloudBeforeRequestHook(BeforeRequestHook):
149
194
  def __init__(self, region: str, project_id: str):
150
195
  self.region = region
151
196
  self.project_id = project_id
152
197
 
153
198
  def before_request(
154
199
  self, hook_ctx, request: httpx.Request
155
- ) -> Union[httpx.Request, Exception]:
200
+ ) -> httpx.Request | Exception:
156
201
  # The goal of this function is to template in the region, project and model into the URL path
157
202
  # We do this here so that the API remains more user-friendly
158
203
  model_id = None
@@ -167,7 +212,6 @@ class GoogleCloudBeforeRequestHook(BeforeRequestHook):
167
212
  if model_id == "":
168
213
  raise models.SDKError("model must be provided")
169
214
 
170
-
171
215
  stream = "streamRawPredict" in request.url.path
172
216
  specifier = "streamRawPredict" if stream else "rawPredict"
173
217
  url = f"/v1/projects/{self.project_id}/locations/{self.region}/publishers/mistralai/models/{model_id}:{specifier}"
@@ -1,6 +1,12 @@
1
1
  """Code generated by Speakeasy (https://speakeasy.com). DO NOT EDIT."""
2
2
 
3
3
  from ._hooks import SDKHooks
4
+ from ._version import (
5
+ __gen_version__,
6
+ __openapi_doc_version__,
7
+ __user_agent__,
8
+ __version__,
9
+ )
4
10
  from .httpclient import AsyncHttpClient, HttpClient
5
11
  from .utils import Logger, RetryConfig, remove_suffix
6
12
  from dataclasses import dataclass
@@ -20,17 +26,19 @@ SERVERS = {
20
26
 
21
27
  @dataclass
22
28
  class SDKConfiguration:
23
- client: HttpClient
24
- async_client: AsyncHttpClient
29
+ client: Union[HttpClient, None]
30
+ client_supplied: bool
31
+ async_client: Union[AsyncHttpClient, None]
32
+ async_client_supplied: bool
25
33
  debug_logger: Logger
26
34
  security: Optional[Union[models.Security, Callable[[], models.Security]]] = None
27
35
  server_url: Optional[str] = ""
28
36
  server: Optional[str] = ""
29
37
  language: str = "python"
30
- openapi_doc_version: str = "0.0.2"
31
- sdk_version: str = "1.2.6"
32
- gen_version: str = "2.486.1"
33
- user_agent: str = "speakeasy-sdk/python 1.2.6 2.486.1 0.0.2 mistralai-gcp"
38
+ openapi_doc_version: str = __openapi_doc_version__
39
+ sdk_version: str = __version__
40
+ gen_version: str = __gen_version__
41
+ user_agent: str = __user_agent__
34
42
  retry_config: OptionalNullable[RetryConfig] = Field(default_factory=lambda: UNSET)
35
43
  timeout_ms: Optional[int] = None
36
44
 
@@ -42,6 +42,7 @@ from .values import (
42
42
  match_content_type,
43
43
  match_status_codes,
44
44
  match_response,
45
+ cast_partial,
45
46
  )
46
47
  from .logger import Logger, get_body_content, get_default_logger
47
48
 
@@ -94,4 +95,5 @@ __all__ = [
94
95
  "validate_float",
95
96
  "validate_int",
96
97
  "validate_open_enum",
98
+ "cast_partial",
97
99
  ]
@@ -7,14 +7,15 @@ import httpx
7
7
  from typing_extensions import get_origin
8
8
  from pydantic import ConfigDict, create_model
9
9
  from pydantic_core import from_json
10
- from typing_inspect import is_optional_type
10
+ from typing_inspection.typing_objects import is_union
11
11
 
12
12
  from ..types.basemodel import BaseModel, Nullable, OptionalNullable, Unset
13
13
 
14
14
 
15
15
  def serialize_decimal(as_str: bool):
16
16
  def serialize(d):
17
- if is_optional_type(type(d)) and d is None:
17
+ # Optional[T] is a Union[T, None]
18
+ if is_union(type(d)) and type(None) in get_args(type(d)) and d is None:
18
19
  return None
19
20
  if isinstance(d, Unset):
20
21
  return d
@@ -42,7 +43,8 @@ def validate_decimal(d):
42
43
 
43
44
  def serialize_float(as_str: bool):
44
45
  def serialize(f):
45
- if is_optional_type(type(f)) and f is None:
46
+ # Optional[T] is a Union[T, None]
47
+ if is_union(type(f)) and type(None) in get_args(type(f)) and f is None:
46
48
  return None
47
49
  if isinstance(f, Unset):
48
50
  return f
@@ -70,7 +72,8 @@ def validate_float(f):
70
72
 
71
73
  def serialize_int(as_str: bool):
72
74
  def serialize(i):
73
- if is_optional_type(type(i)) and i is None:
75
+ # Optional[T] is a Union[T, None]
76
+ if is_union(type(i)) and type(None) in get_args(type(i)) and i is None:
74
77
  return None
75
78
  if isinstance(i, Unset):
76
79
  return i
@@ -118,7 +121,8 @@ def validate_open_enum(is_int: bool):
118
121
 
119
122
  def validate_const(v):
120
123
  def validate(c):
121
- if is_optional_type(type(c)) and c is None:
124
+ # Optional[T] is a Union[T, None]
125
+ if is_union(type(c)) and type(None) in get_args(type(c)) and c is None:
122
126
  return None
123
127
 
124
128
  if v != c:
@@ -163,7 +167,7 @@ def marshal_json(val, typ):
163
167
  if len(d) == 0:
164
168
  return ""
165
169
 
166
- return json.dumps(d[next(iter(d))], separators=(",", ":"), sort_keys=True)
170
+ return json.dumps(d[next(iter(d))], separators=(",", ":"))
167
171
 
168
172
 
169
173
  def is_nullable(field):