mistralai 1.0.0rc1__py3-none-any.whl → 1.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. mistralai/agents.py +434 -0
  2. mistralai/basesdk.py +43 -6
  3. mistralai/chat.py +29 -34
  4. mistralai/embeddings.py +4 -4
  5. mistralai/files.py +10 -10
  6. mistralai/fim.py +17 -18
  7. mistralai/fine_tuning.py +10 -849
  8. mistralai/jobs.py +854 -0
  9. mistralai/models/__init__.py +4 -2
  10. mistralai/models/agentscompletionrequest.py +96 -0
  11. mistralai/models/agentscompletionstreamrequest.py +92 -0
  12. mistralai/models/assistantmessage.py +4 -9
  13. mistralai/models/chatcompletionchoice.py +4 -15
  14. mistralai/models/chatcompletionrequest.py +11 -16
  15. mistralai/models/chatcompletionstreamrequest.py +11 -16
  16. mistralai/models/completionresponsestreamchoice.py +4 -9
  17. mistralai/models/deltamessage.py +4 -9
  18. mistralai/models/detailedjobout.py +4 -9
  19. mistralai/models/embeddingrequest.py +4 -9
  20. mistralai/models/eventout.py +4 -9
  21. mistralai/models/fileschema.py +4 -9
  22. mistralai/models/fimcompletionrequest.py +11 -16
  23. mistralai/models/fimcompletionstreamrequest.py +11 -16
  24. mistralai/models/ftmodelout.py +4 -9
  25. mistralai/models/githubrepositoryin.py +4 -9
  26. mistralai/models/githubrepositoryout.py +4 -9
  27. mistralai/models/httpvalidationerror.py +1 -1
  28. mistralai/models/jobin.py +4 -9
  29. mistralai/models/jobmetadataout.py +4 -9
  30. mistralai/models/jobout.py +4 -9
  31. mistralai/models/jobs_api_routes_fine_tuning_create_fine_tuning_jobop.py +4 -9
  32. mistralai/models/jobs_api_routes_fine_tuning_get_fine_tuning_jobsop.py +4 -9
  33. mistralai/models/legacyjobmetadataout.py +4 -9
  34. mistralai/models/metricout.py +4 -9
  35. mistralai/models/modelcard.py +4 -9
  36. mistralai/models/retrievefileout.py +4 -9
  37. mistralai/models/security.py +4 -4
  38. mistralai/models/toolmessage.py +4 -9
  39. mistralai/models/trainingparameters.py +4 -9
  40. mistralai/models/trainingparametersin.py +4 -9
  41. mistralai/models/updateftmodelin.py +4 -9
  42. mistralai/models/uploadfileout.py +4 -9
  43. mistralai/models/wandbintegration.py +4 -9
  44. mistralai/models/wandbintegrationout.py +4 -9
  45. mistralai/models_.py +14 -14
  46. mistralai/sdk.py +14 -6
  47. mistralai/sdkconfiguration.py +5 -4
  48. mistralai/types/basemodel.py +10 -6
  49. mistralai/utils/__init__.py +4 -0
  50. mistralai/utils/eventstreaming.py +8 -9
  51. mistralai/utils/logger.py +16 -0
  52. mistralai/utils/retries.py +2 -2
  53. mistralai/utils/security.py +5 -2
  54. {mistralai-1.0.0rc1.dist-info → mistralai-1.0.0rc2.dist-info}/METADATA +121 -56
  55. {mistralai-1.0.0rc1.dist-info → mistralai-1.0.0rc2.dist-info}/RECORD +96 -89
  56. mistralai_azure/basesdk.py +42 -4
  57. mistralai_azure/chat.py +15 -20
  58. mistralai_azure/models/__init__.py +2 -2
  59. mistralai_azure/models/assistantmessage.py +4 -9
  60. mistralai_azure/models/chatcompletionchoice.py +4 -15
  61. mistralai_azure/models/chatcompletionrequest.py +7 -12
  62. mistralai_azure/models/chatcompletionstreamrequest.py +7 -12
  63. mistralai_azure/models/completionresponsestreamchoice.py +4 -9
  64. mistralai_azure/models/deltamessage.py +4 -9
  65. mistralai_azure/models/httpvalidationerror.py +1 -1
  66. mistralai_azure/models/toolmessage.py +4 -9
  67. mistralai_azure/sdk.py +7 -2
  68. mistralai_azure/sdkconfiguration.py +5 -4
  69. mistralai_azure/types/basemodel.py +10 -6
  70. mistralai_azure/utils/__init__.py +4 -0
  71. mistralai_azure/utils/eventstreaming.py +8 -9
  72. mistralai_azure/utils/logger.py +16 -0
  73. mistralai_azure/utils/retries.py +2 -2
  74. mistralai_gcp/basesdk.py +42 -4
  75. mistralai_gcp/chat.py +12 -17
  76. mistralai_gcp/fim.py +12 -13
  77. mistralai_gcp/models/__init__.py +2 -2
  78. mistralai_gcp/models/assistantmessage.py +4 -9
  79. mistralai_gcp/models/chatcompletionchoice.py +4 -15
  80. mistralai_gcp/models/chatcompletionrequest.py +9 -14
  81. mistralai_gcp/models/chatcompletionstreamrequest.py +9 -14
  82. mistralai_gcp/models/completionresponsestreamchoice.py +4 -9
  83. mistralai_gcp/models/deltamessage.py +4 -9
  84. mistralai_gcp/models/fimcompletionrequest.py +11 -16
  85. mistralai_gcp/models/fimcompletionstreamrequest.py +11 -16
  86. mistralai_gcp/models/httpvalidationerror.py +1 -1
  87. mistralai_gcp/models/toolmessage.py +4 -9
  88. mistralai_gcp/sdk.py +9 -0
  89. mistralai_gcp/sdkconfiguration.py +5 -4
  90. mistralai_gcp/types/basemodel.py +10 -6
  91. mistralai_gcp/utils/__init__.py +4 -0
  92. mistralai_gcp/utils/eventstreaming.py +8 -9
  93. mistralai_gcp/utils/logger.py +16 -0
  94. mistralai_gcp/utils/retries.py +2 -2
  95. {mistralai-1.0.0rc1.dist-info → mistralai-1.0.0rc2.dist-info}/LICENSE +0 -0
  96. {mistralai-1.0.0rc1.dist-info → mistralai-1.0.0rc2.dist-info}/WHEEL +0 -0
mistralai/models_.py CHANGED
@@ -1,10 +1,10 @@
1
1
  """Code generated by Speakeasy (https://speakeasy.com). DO NOT EDIT."""
2
2
 
3
3
  from .basesdk import BaseSDK
4
- from mistralai import models
4
+ from mistralai import models, utils
5
5
  from mistralai._hooks import HookContext
6
6
  from mistralai.types import OptionalNullable, UNSET
7
- import mistralai.utils as utils
7
+ from mistralai.utils import get_security_from_env
8
8
  from typing import Any, Optional
9
9
 
10
10
  class Models(BaseSDK):
@@ -62,7 +62,7 @@ class Models(BaseSDK):
62
62
  ])
63
63
 
64
64
  http_res = self.do_request(
65
- hook_ctx=HookContext(operation_id="list_models_v1_models_get", oauth2_scopes=[], security_source=self.sdk_configuration.security),
65
+ hook_ctx=HookContext(operation_id="list_models_v1_models_get", oauth2_scopes=[], security_source=get_security_from_env(self.sdk_configuration.security, models.Security)),
66
66
  request=req,
67
67
  error_status_codes=["422","4XX","5XX"],
68
68
  retry_config=retry_config
@@ -133,7 +133,7 @@ class Models(BaseSDK):
133
133
  ])
134
134
 
135
135
  http_res = await self.do_request_async(
136
- hook_ctx=HookContext(operation_id="list_models_v1_models_get", oauth2_scopes=[], security_source=self.sdk_configuration.security),
136
+ hook_ctx=HookContext(operation_id="list_models_v1_models_get", oauth2_scopes=[], security_source=get_security_from_env(self.sdk_configuration.security, models.Security)),
137
137
  request=req,
138
138
  error_status_codes=["422","4XX","5XX"],
139
139
  retry_config=retry_config
@@ -211,7 +211,7 @@ class Models(BaseSDK):
211
211
  ])
212
212
 
213
213
  http_res = self.do_request(
214
- hook_ctx=HookContext(operation_id="retrieve_model_v1_models__model_id__get", oauth2_scopes=[], security_source=self.sdk_configuration.security),
214
+ hook_ctx=HookContext(operation_id="retrieve_model_v1_models__model_id__get", oauth2_scopes=[], security_source=get_security_from_env(self.sdk_configuration.security, models.Security)),
215
215
  request=req,
216
216
  error_status_codes=["422","4XX","5XX"],
217
217
  retry_config=retry_config
@@ -289,7 +289,7 @@ class Models(BaseSDK):
289
289
  ])
290
290
 
291
291
  http_res = await self.do_request_async(
292
- hook_ctx=HookContext(operation_id="retrieve_model_v1_models__model_id__get", oauth2_scopes=[], security_source=self.sdk_configuration.security),
292
+ hook_ctx=HookContext(operation_id="retrieve_model_v1_models__model_id__get", oauth2_scopes=[], security_source=get_security_from_env(self.sdk_configuration.security, models.Security)),
293
293
  request=req,
294
294
  error_status_codes=["422","4XX","5XX"],
295
295
  retry_config=retry_config
@@ -367,7 +367,7 @@ class Models(BaseSDK):
367
367
  ])
368
368
 
369
369
  http_res = self.do_request(
370
- hook_ctx=HookContext(operation_id="delete_model_v1_models__model_id__delete", oauth2_scopes=[], security_source=self.sdk_configuration.security),
370
+ hook_ctx=HookContext(operation_id="delete_model_v1_models__model_id__delete", oauth2_scopes=[], security_source=get_security_from_env(self.sdk_configuration.security, models.Security)),
371
371
  request=req,
372
372
  error_status_codes=["422","4XX","5XX"],
373
373
  retry_config=retry_config
@@ -445,7 +445,7 @@ class Models(BaseSDK):
445
445
  ])
446
446
 
447
447
  http_res = await self.do_request_async(
448
- hook_ctx=HookContext(operation_id="delete_model_v1_models__model_id__delete", oauth2_scopes=[], security_source=self.sdk_configuration.security),
448
+ hook_ctx=HookContext(operation_id="delete_model_v1_models__model_id__delete", oauth2_scopes=[], security_source=get_security_from_env(self.sdk_configuration.security, models.Security)),
449
449
  request=req,
450
450
  error_status_codes=["422","4XX","5XX"],
451
451
  retry_config=retry_config
@@ -532,7 +532,7 @@ class Models(BaseSDK):
532
532
  ])
533
533
 
534
534
  http_res = self.do_request(
535
- hook_ctx=HookContext(operation_id="jobs_api_routes_fine_tuning_update_fine_tuned_model", oauth2_scopes=[], security_source=self.sdk_configuration.security),
535
+ hook_ctx=HookContext(operation_id="jobs_api_routes_fine_tuning_update_fine_tuned_model", oauth2_scopes=[], security_source=get_security_from_env(self.sdk_configuration.security, models.Security)),
536
536
  request=req,
537
537
  error_status_codes=["4XX","5XX"],
538
538
  retry_config=retry_config
@@ -615,7 +615,7 @@ class Models(BaseSDK):
615
615
  ])
616
616
 
617
617
  http_res = await self.do_request_async(
618
- hook_ctx=HookContext(operation_id="jobs_api_routes_fine_tuning_update_fine_tuned_model", oauth2_scopes=[], security_source=self.sdk_configuration.security),
618
+ hook_ctx=HookContext(operation_id="jobs_api_routes_fine_tuning_update_fine_tuned_model", oauth2_scopes=[], security_source=get_security_from_env(self.sdk_configuration.security, models.Security)),
619
619
  request=req,
620
620
  error_status_codes=["4XX","5XX"],
621
621
  retry_config=retry_config
@@ -689,7 +689,7 @@ class Models(BaseSDK):
689
689
  ])
690
690
 
691
691
  http_res = self.do_request(
692
- hook_ctx=HookContext(operation_id="jobs_api_routes_fine_tuning_archive_fine_tuned_model", oauth2_scopes=[], security_source=self.sdk_configuration.security),
692
+ hook_ctx=HookContext(operation_id="jobs_api_routes_fine_tuning_archive_fine_tuned_model", oauth2_scopes=[], security_source=get_security_from_env(self.sdk_configuration.security, models.Security)),
693
693
  request=req,
694
694
  error_status_codes=["4XX","5XX"],
695
695
  retry_config=retry_config
@@ -763,7 +763,7 @@ class Models(BaseSDK):
763
763
  ])
764
764
 
765
765
  http_res = await self.do_request_async(
766
- hook_ctx=HookContext(operation_id="jobs_api_routes_fine_tuning_archive_fine_tuned_model", oauth2_scopes=[], security_source=self.sdk_configuration.security),
766
+ hook_ctx=HookContext(operation_id="jobs_api_routes_fine_tuning_archive_fine_tuned_model", oauth2_scopes=[], security_source=get_security_from_env(self.sdk_configuration.security, models.Security)),
767
767
  request=req,
768
768
  error_status_codes=["4XX","5XX"],
769
769
  retry_config=retry_config
@@ -837,7 +837,7 @@ class Models(BaseSDK):
837
837
  ])
838
838
 
839
839
  http_res = self.do_request(
840
- hook_ctx=HookContext(operation_id="jobs_api_routes_fine_tuning_unarchive_fine_tuned_model", oauth2_scopes=[], security_source=self.sdk_configuration.security),
840
+ hook_ctx=HookContext(operation_id="jobs_api_routes_fine_tuning_unarchive_fine_tuned_model", oauth2_scopes=[], security_source=get_security_from_env(self.sdk_configuration.security, models.Security)),
841
841
  request=req,
842
842
  error_status_codes=["4XX","5XX"],
843
843
  retry_config=retry_config
@@ -911,7 +911,7 @@ class Models(BaseSDK):
911
911
  ])
912
912
 
913
913
  http_res = await self.do_request_async(
914
- hook_ctx=HookContext(operation_id="jobs_api_routes_fine_tuning_unarchive_fine_tuned_model", oauth2_scopes=[], security_source=self.sdk_configuration.security),
914
+ hook_ctx=HookContext(operation_id="jobs_api_routes_fine_tuning_unarchive_fine_tuned_model", oauth2_scopes=[], security_source=get_security_from_env(self.sdk_configuration.security, models.Security)),
915
915
  request=req,
916
916
  error_status_codes=["4XX","5XX"],
917
917
  retry_config=retry_config
mistralai/sdk.py CHANGED
@@ -3,10 +3,12 @@
3
3
  from .basesdk import BaseSDK
4
4
  from .httpclient import AsyncHttpClient, HttpClient
5
5
  from .sdkconfiguration import SDKConfiguration
6
+ from .utils.logger import Logger, NoOpLogger
6
7
  from .utils.retries import RetryConfig
7
8
  import httpx
8
- from mistralai import models
9
+ from mistralai import models, utils
9
10
  from mistralai._hooks import SDKHooks
11
+ from mistralai.agents import Agents
10
12
  from mistralai.chat import Chat
11
13
  from mistralai.embeddings import Embeddings
12
14
  from mistralai.files import Files
@@ -14,7 +16,6 @@ from mistralai.fim import Fim
14
16
  from mistralai.fine_tuning import FineTuning
15
17
  from mistralai.models_ import Models
16
18
  from mistralai.types import OptionalNullable, UNSET
17
- import mistralai.utils as utils
18
19
  from typing import Any, Callable, Dict, Optional, Union
19
20
 
20
21
  class Mistral(BaseSDK):
@@ -24,23 +25,25 @@ class Mistral(BaseSDK):
24
25
  files: Files
25
26
  r"""Files API"""
26
27
  fine_tuning: FineTuning
27
- r"""Fine-tuning API"""
28
28
  chat: Chat
29
29
  r"""Chat Completion API."""
30
30
  fim: Fim
31
31
  r"""Fill-in-the-middle API."""
32
+ agents: Agents
33
+ r"""Agents API."""
32
34
  embeddings: Embeddings
33
35
  r"""Embeddings API."""
34
36
  def __init__(
35
37
  self,
36
- api_key: Union[str, Callable[[], str]],
38
+ api_key: Optional[Union[Optional[str], Callable[[], Optional[str]]]] = None,
37
39
  server: Optional[str] = None,
38
40
  server_url: Optional[str] = None,
39
41
  url_params: Optional[Dict[str, str]] = None,
40
42
  client: Optional[HttpClient] = None,
41
43
  async_client: Optional[AsyncHttpClient] = None,
42
44
  retry_config: OptionalNullable[RetryConfig] = UNSET,
43
- timeout_ms: Optional[int] = None
45
+ timeout_ms: Optional[int] = None,
46
+ debug_logger: Optional[Logger] = None
44
47
  ) -> None:
45
48
  r"""Instantiates the SDK configuring it with the provided parameters.
46
49
 
@@ -63,6 +66,9 @@ class Mistral(BaseSDK):
63
66
  if async_client is None:
64
67
  async_client = httpx.AsyncClient()
65
68
 
69
+ if debug_logger is None:
70
+ debug_logger = NoOpLogger()
71
+
66
72
  assert issubclass(
67
73
  type(async_client), AsyncHttpClient
68
74
  ), "The provided async_client must implement the AsyncHttpClient protocol."
@@ -85,7 +91,8 @@ class Mistral(BaseSDK):
85
91
  server_url=server_url,
86
92
  server=server,
87
93
  retry_config=retry_config,
88
- timeout_ms=timeout_ms
94
+ timeout_ms=timeout_ms,
95
+ debug_logger=debug_logger
89
96
  ))
90
97
 
91
98
  hooks = SDKHooks()
@@ -107,5 +114,6 @@ class Mistral(BaseSDK):
107
114
  self.fine_tuning = FineTuning(self.sdk_configuration)
108
115
  self.chat = Chat(self.sdk_configuration)
109
116
  self.fim = Fim(self.sdk_configuration)
117
+ self.agents = Agents(self.sdk_configuration)
110
118
  self.embeddings = Embeddings(self.sdk_configuration)
111
119
 
@@ -3,7 +3,7 @@
3
3
 
4
4
  from ._hooks import SDKHooks
5
5
  from .httpclient import AsyncHttpClient, HttpClient
6
- from .utils import RetryConfig, remove_suffix
6
+ from .utils import Logger, RetryConfig, remove_suffix
7
7
  from dataclasses import dataclass
8
8
  from mistralai import models
9
9
  from mistralai.types import OptionalNullable, UNSET
@@ -23,14 +23,15 @@ SERVERS = {
23
23
  class SDKConfiguration:
24
24
  client: HttpClient
25
25
  async_client: AsyncHttpClient
26
+ debug_logger: Logger
26
27
  security: Optional[Union[models.Security,Callable[[], models.Security]]] = None
27
28
  server_url: Optional[str] = ""
28
29
  server: Optional[str] = ""
29
30
  language: str = "python"
30
31
  openapi_doc_version: str = "0.0.2"
31
- sdk_version: str = "1.0.0rc1"
32
- gen_version: str = "2.382.2"
33
- user_agent: str = "speakeasy-sdk/python 1.0.0rc1 2.382.2 0.0.2 mistralai"
32
+ sdk_version: str = "1.0.0-rc.2"
33
+ gen_version: str = "2.386.0"
34
+ user_agent: str = "speakeasy-sdk/python 1.0.0-rc.2 2.386.0 0.0.2 mistralai"
34
35
  retry_config: OptionalNullable[RetryConfig] = Field(default_factory=lambda: UNSET)
35
36
  timeout_ms: Optional[int] = None
36
37
 
@@ -2,8 +2,8 @@
2
2
 
3
3
  from pydantic import ConfigDict, model_serializer
4
4
  from pydantic import BaseModel as PydanticBaseModel
5
- from typing import Literal, Optional, TypeVar, Union, NewType
6
- from typing_extensions import TypeAliasType
5
+ from typing import TYPE_CHECKING, Literal, Optional, TypeVar, Union, NewType
6
+ from typing_extensions import TypeAliasType, TypeAlias
7
7
 
8
8
 
9
9
  class BaseModel(PydanticBaseModel):
@@ -26,10 +26,14 @@ UNSET_SENTINEL = "~?~unset~?~sentinel~?~"
26
26
 
27
27
 
28
28
  T = TypeVar("T")
29
- Nullable = TypeAliasType("Nullable", Union[T, None], type_params=(T,))
30
- OptionalNullable = TypeAliasType(
31
- "OptionalNullable", Union[Optional[Nullable[T]], Unset], type_params=(T,)
32
- )
29
+ if TYPE_CHECKING:
30
+ Nullable: TypeAlias = Union[T, None]
31
+ OptionalNullable: TypeAlias = Union[Optional[Nullable[T]], Unset]
32
+ else:
33
+ Nullable = TypeAliasType("Nullable", Union[T, None], type_params=(T,))
34
+ OptionalNullable = TypeAliasType(
35
+ "OptionalNullable", Union[Optional[Nullable[T]], Unset], type_params=(T,)
36
+ )
33
37
 
34
38
  UnrecognizedInt = NewType("UnrecognizedInt", int)
35
39
  UnrecognizedStr = NewType("UnrecognizedStr", str)
@@ -35,6 +35,7 @@ from .serializers import (
35
35
  )
36
36
  from .url import generate_url, template_url, remove_suffix
37
37
  from .values import get_global_from_env, match_content_type, match_status_codes, match_response
38
+ from .logger import Logger, get_body_content, NoOpLogger
38
39
 
39
40
  __all__ = [
40
41
  "BackoffStrategy",
@@ -42,6 +43,7 @@ __all__ = [
42
43
  "find_metadata",
43
44
  "FormMetadata",
44
45
  "generate_url",
46
+ "get_body_content",
45
47
  "get_discriminator",
46
48
  "get_global_from_env",
47
49
  "get_headers",
@@ -51,11 +53,13 @@ __all__ = [
51
53
  "get_security",
52
54
  "get_security_from_env",
53
55
  "HeaderMetadata",
56
+ "Logger",
54
57
  "marshal_json",
55
58
  "match_content_type",
56
59
  "match_status_codes",
57
60
  "match_response",
58
61
  "MultipartFormMetadata",
62
+ "NoOpLogger",
59
63
  "OpenEnumMeta",
60
64
  "PathParamMetadata",
61
65
  "QueryParamMetadata",
@@ -147,15 +147,14 @@ def _parse_event(
147
147
  data = data[:-1]
148
148
  event.data = data
149
149
 
150
- if (
151
- data.isnumeric()
152
- or data == "true"
153
- or data == "false"
154
- or data == "null"
155
- or data.startswith("{")
156
- or data.startswith("[")
157
- or data.startswith('"')
158
- ):
150
+ data_is_primitive = (
151
+ data.isnumeric() or data == "true" or data == "false" or data == "null"
152
+ )
153
+ data_is_json = (
154
+ data.startswith("{") or data.startswith("[") or data.startswith('"')
155
+ )
156
+
157
+ if data_is_primitive or data_is_json:
159
158
  try:
160
159
  event.data = json.loads(data)
161
160
  except Exception:
@@ -0,0 +1,16 @@
1
+ """Code generated by Speakeasy (https://speakeasy.com). DO NOT EDIT."""
2
+
3
+ import httpx
4
+ from typing import Any, Protocol
5
+
6
+ class Logger(Protocol):
7
+ def debug(self, msg: str, *args: Any, **kwargs: Any) -> None:
8
+ pass
9
+
10
+ class NoOpLogger:
11
+ def debug(self, msg: str, *args: Any, **kwargs: Any) -> None:
12
+ pass
13
+
14
+ def get_body_content(req: httpx.Request) -> str:
15
+ return "<streaming body>" if not hasattr(req, "_content") else str(req.content)
16
+
@@ -76,7 +76,7 @@ def retry(func, retries: Retries):
76
76
 
77
77
  status_major = res.status_code / 100
78
78
 
79
- if status_major >= code_range and status_major < code_range + 1:
79
+ if code_range <= status_major < code_range + 1:
80
80
  raise TemporaryError(res)
81
81
  else:
82
82
  parsed_code = int(code)
@@ -125,7 +125,7 @@ async def retry_async(func, retries: Retries):
125
125
 
126
126
  status_major = res.status_code / 100
127
127
 
128
- if status_major >= code_range and status_major < code_range + 1:
128
+ if code_range <= status_major < code_range + 1:
129
129
  raise TemporaryError(res)
130
130
  else:
131
131
  parsed_code = int(code)
@@ -53,11 +53,14 @@ def get_security(security: Any) -> Tuple[Dict[str, str], Dict[str, List[str]]]:
53
53
  return headers, query_params
54
54
 
55
55
 
56
- def get_security_from_env(security_class: Any) -> Optional[BaseModel]:
56
+ def get_security_from_env(security: Any, security_class: Any) -> Optional[BaseModel]:
57
+ if security is not None:
58
+ return security
59
+
57
60
  if not issubclass(security_class, BaseModel):
58
61
  raise TypeError("security_class must be a pydantic model class")
59
62
 
60
- security_dict = {}
63
+ security_dict: Any = {}
61
64
 
62
65
  if os.getenv("MISTRAL_API_KEY"):
63
66
  security_dict["api_key"] = os.getenv("MISTRAL_API_KEY")