mirascope 2.0.0a4__py3-none-any.whl → 2.0.0a5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mirascope/api/_generated/__init__.py +17 -1
- mirascope/api/_generated/api_keys/__init__.py +7 -0
- mirascope/api/_generated/api_keys/client.py +453 -0
- mirascope/api/_generated/api_keys/raw_client.py +853 -0
- mirascope/api/_generated/api_keys/types/__init__.py +9 -0
- mirascope/api/_generated/api_keys/types/api_keys_create_response.py +36 -0
- mirascope/api/_generated/api_keys/types/api_keys_get_response.py +35 -0
- mirascope/api/_generated/api_keys/types/api_keys_list_response_item.py +35 -0
- mirascope/api/_generated/client.py +6 -0
- mirascope/api/_generated/environments/__init__.py +17 -0
- mirascope/api/_generated/environments/client.py +532 -0
- mirascope/api/_generated/environments/raw_client.py +1088 -0
- mirascope/api/_generated/environments/types/__init__.py +15 -0
- mirascope/api/_generated/environments/types/environments_create_response.py +26 -0
- mirascope/api/_generated/environments/types/environments_get_response.py +26 -0
- mirascope/api/_generated/environments/types/environments_list_response_item.py +26 -0
- mirascope/api/_generated/environments/types/environments_update_response.py +26 -0
- mirascope/api/_generated/organizations/client.py +36 -12
- mirascope/api/_generated/organizations/raw_client.py +32 -6
- mirascope/api/_generated/organizations/types/organizations_create_response.py +1 -0
- mirascope/api/_generated/organizations/types/organizations_get_response.py +1 -0
- mirascope/api/_generated/organizations/types/organizations_list_response_item.py +1 -0
- mirascope/api/_generated/organizations/types/organizations_update_response.py +1 -0
- mirascope/api/_generated/projects/client.py +34 -10
- mirascope/api/_generated/projects/raw_client.py +46 -4
- mirascope/api/_generated/projects/types/projects_create_response.py +1 -0
- mirascope/api/_generated/projects/types/projects_get_response.py +1 -0
- mirascope/api/_generated/projects/types/projects_list_response_item.py +1 -0
- mirascope/api/_generated/projects/types/projects_update_response.py +1 -0
- mirascope/api/_generated/reference.md +729 -4
- mirascope/llm/__init__.py +2 -2
- mirascope/llm/exceptions.py +28 -0
- mirascope/llm/providers/__init__.py +6 -4
- mirascope/llm/providers/anthropic/_utils/__init__.py +2 -0
- mirascope/llm/providers/anthropic/_utils/errors.py +46 -0
- mirascope/llm/providers/anthropic/beta_provider.py +6 -0
- mirascope/llm/providers/anthropic/provider.py +5 -0
- mirascope/llm/providers/base/__init__.py +2 -1
- mirascope/llm/providers/base/base_provider.py +173 -58
- mirascope/llm/providers/google/_utils/__init__.py +2 -0
- mirascope/llm/providers/google/_utils/errors.py +49 -0
- mirascope/llm/providers/google/provider.py +5 -4
- mirascope/llm/providers/mlx/_utils.py +8 -1
- mirascope/llm/providers/mlx/provider.py +8 -0
- mirascope/llm/providers/openai/__init__.py +10 -1
- mirascope/llm/providers/openai/_utils/__init__.py +5 -0
- mirascope/llm/providers/openai/_utils/errors.py +46 -0
- mirascope/llm/providers/openai/completions/base_provider.py +6 -6
- mirascope/llm/providers/openai/provider.py +14 -1
- mirascope/llm/providers/openai/responses/provider.py +13 -7
- mirascope/llm/providers/provider_registry.py +56 -3
- mirascope/ops/_internal/closure.py +62 -11
- {mirascope-2.0.0a4.dist-info → mirascope-2.0.0a5.dist-info}/METADATA +1 -1
- {mirascope-2.0.0a4.dist-info → mirascope-2.0.0a5.dist-info}/RECORD +56 -38
- mirascope/llm/providers/load_provider.py +0 -54
- {mirascope-2.0.0a4.dist-info → mirascope-2.0.0a5.dist-info}/WHEEL +0 -0
- {mirascope-2.0.0a4.dist-info → mirascope-2.0.0a5.dist-info}/licenses/LICENSE +0 -0
|
@@ -39,6 +39,7 @@ class GoogleProvider(BaseProvider[Client]):
|
|
|
39
39
|
|
|
40
40
|
id = "google"
|
|
41
41
|
default_scope = "google/"
|
|
42
|
+
error_map = _utils.GOOGLE_ERROR_MAP
|
|
42
43
|
|
|
43
44
|
def __init__(
|
|
44
45
|
self, *, api_key: str | None = None, base_url: str | None = None
|
|
@@ -50,6 +51,10 @@ class GoogleProvider(BaseProvider[Client]):
|
|
|
50
51
|
|
|
51
52
|
self.client = Client(api_key=api_key, http_options=http_options)
|
|
52
53
|
|
|
54
|
+
def get_error_status(self, e: Exception) -> int | None:
|
|
55
|
+
"""Extract HTTP status code from Google exception."""
|
|
56
|
+
return getattr(e, "code", None)
|
|
57
|
+
|
|
53
58
|
def _call(
|
|
54
59
|
self,
|
|
55
60
|
*,
|
|
@@ -78,7 +83,6 @@ class GoogleProvider(BaseProvider[Client]):
|
|
|
78
83
|
format=format,
|
|
79
84
|
params=params,
|
|
80
85
|
)
|
|
81
|
-
|
|
82
86
|
google_response = self.client.models.generate_content(**kwargs)
|
|
83
87
|
|
|
84
88
|
assistant_message, finish_reason, usage = _utils.decode_response(
|
|
@@ -131,7 +135,6 @@ class GoogleProvider(BaseProvider[Client]):
|
|
|
131
135
|
format=format,
|
|
132
136
|
params=params,
|
|
133
137
|
)
|
|
134
|
-
|
|
135
138
|
google_response = self.client.models.generate_content(**kwargs)
|
|
136
139
|
|
|
137
140
|
assistant_message, finish_reason, usage = _utils.decode_response(
|
|
@@ -180,7 +183,6 @@ class GoogleProvider(BaseProvider[Client]):
|
|
|
180
183
|
format=format,
|
|
181
184
|
params=params,
|
|
182
185
|
)
|
|
183
|
-
|
|
184
186
|
google_response = await self.client.aio.models.generate_content(**kwargs)
|
|
185
187
|
|
|
186
188
|
assistant_message, finish_reason, usage = _utils.decode_response(
|
|
@@ -233,7 +235,6 @@ class GoogleProvider(BaseProvider[Client]):
|
|
|
233
235
|
format=format,
|
|
234
236
|
params=params,
|
|
235
237
|
)
|
|
236
|
-
|
|
237
238
|
google_response = await self.client.aio.models.generate_content(**kwargs)
|
|
238
239
|
|
|
239
240
|
assistant_message, finish_reason, usage = _utils.decode_response(
|
|
@@ -2,14 +2,21 @@ from collections.abc import Callable
|
|
|
2
2
|
from typing import TypeAlias, TypedDict
|
|
3
3
|
|
|
4
4
|
import mlx.core as mx
|
|
5
|
+
from huggingface_hub.errors import LocalEntryNotFoundError
|
|
5
6
|
from mlx_lm.generate import GenerationResponse
|
|
6
7
|
from mlx_lm.sample_utils import make_sampler
|
|
7
8
|
|
|
9
|
+
from ...exceptions import NotFoundError
|
|
8
10
|
from ...responses import FinishReason, Usage
|
|
9
|
-
from ..base import Params, _utils as _base_utils
|
|
11
|
+
from ..base import Params, ProviderErrorMap, _utils as _base_utils
|
|
10
12
|
|
|
11
13
|
Sampler: TypeAlias = Callable[[mx.array], mx.array]
|
|
12
14
|
|
|
15
|
+
# Error mapping for MLX provider
|
|
16
|
+
MLX_ERROR_MAP: ProviderErrorMap = {
|
|
17
|
+
LocalEntryNotFoundError: NotFoundError,
|
|
18
|
+
}
|
|
19
|
+
|
|
13
20
|
|
|
14
21
|
class MakeSamplerKwargs(TypedDict, total=False):
|
|
15
22
|
"""Keyword arguments to be used for `mlx_lm`-s `make_sampler` function.
|
|
@@ -70,6 +70,14 @@ class MLXProvider(BaseProvider[None]):
|
|
|
70
70
|
|
|
71
71
|
id = "mlx"
|
|
72
72
|
default_scope = "mlx-community/"
|
|
73
|
+
error_map = _utils.MLX_ERROR_MAP
|
|
74
|
+
|
|
75
|
+
def get_error_status(self, e: Exception) -> int | None:
|
|
76
|
+
"""Extract HTTP status code from MLX exception.
|
|
77
|
+
|
|
78
|
+
MLX/HuggingFace Hub exceptions don't have status codes.
|
|
79
|
+
"""
|
|
80
|
+
return None
|
|
73
81
|
|
|
74
82
|
def _call(
|
|
75
83
|
self,
|
|
@@ -1,6 +1,15 @@
|
|
|
1
1
|
"""OpenAI client implementation."""
|
|
2
2
|
|
|
3
|
+
from .completions.base_provider import BaseOpenAICompletionsProvider
|
|
4
|
+
from .completions.provider import OpenAICompletionsProvider
|
|
3
5
|
from .model_id import OpenAIModelId
|
|
4
6
|
from .provider import OpenAIProvider
|
|
7
|
+
from .responses.provider import OpenAIResponsesProvider
|
|
5
8
|
|
|
6
|
-
__all__ = [
|
|
9
|
+
__all__ = [
|
|
10
|
+
"BaseOpenAICompletionsProvider",
|
|
11
|
+
"OpenAICompletionsProvider",
|
|
12
|
+
"OpenAIModelId",
|
|
13
|
+
"OpenAIProvider",
|
|
14
|
+
"OpenAIResponsesProvider",
|
|
15
|
+
]
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
"""OpenAI error handling utilities."""
|
|
2
|
+
|
|
3
|
+
from openai import (
|
|
4
|
+
APIConnectionError as OpenAIAPIConnectionError,
|
|
5
|
+
APIResponseValidationError as OpenAIAPIResponseValidationError,
|
|
6
|
+
APITimeoutError as OpenAIAPITimeoutError,
|
|
7
|
+
AuthenticationError as OpenAIAuthenticationError,
|
|
8
|
+
BadRequestError as OpenAIBadRequestError,
|
|
9
|
+
ConflictError as OpenAIConflictError,
|
|
10
|
+
InternalServerError as OpenAIInternalServerError,
|
|
11
|
+
NotFoundError as OpenAINotFoundError,
|
|
12
|
+
OpenAIError,
|
|
13
|
+
PermissionDeniedError as OpenAIPermissionDeniedError,
|
|
14
|
+
RateLimitError as OpenAIRateLimitError,
|
|
15
|
+
UnprocessableEntityError as OpenAIUnprocessableEntityError,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
from ....exceptions import (
|
|
19
|
+
APIError,
|
|
20
|
+
AuthenticationError,
|
|
21
|
+
BadRequestError,
|
|
22
|
+
ConnectionError,
|
|
23
|
+
NotFoundError,
|
|
24
|
+
PermissionError,
|
|
25
|
+
RateLimitError,
|
|
26
|
+
ResponseValidationError,
|
|
27
|
+
ServerError,
|
|
28
|
+
TimeoutError,
|
|
29
|
+
)
|
|
30
|
+
from ...base import ProviderErrorMap
|
|
31
|
+
|
|
32
|
+
# Shared error mapping used by OpenAI Responses and Completions providers
|
|
33
|
+
OPENAI_ERROR_MAP: ProviderErrorMap = {
|
|
34
|
+
OpenAIAuthenticationError: AuthenticationError,
|
|
35
|
+
OpenAIPermissionDeniedError: PermissionError,
|
|
36
|
+
OpenAINotFoundError: NotFoundError,
|
|
37
|
+
OpenAIBadRequestError: BadRequestError,
|
|
38
|
+
OpenAIUnprocessableEntityError: BadRequestError,
|
|
39
|
+
OpenAIConflictError: BadRequestError,
|
|
40
|
+
OpenAIRateLimitError: RateLimitError,
|
|
41
|
+
OpenAIInternalServerError: ServerError,
|
|
42
|
+
OpenAIAPITimeoutError: TimeoutError,
|
|
43
|
+
OpenAIAPIConnectionError: ConnectionError,
|
|
44
|
+
OpenAIAPIResponseValidationError: ResponseValidationError,
|
|
45
|
+
OpenAIError: APIError, # Catch-all for unknown OpenAI errors
|
|
46
|
+
}
|
|
@@ -31,6 +31,7 @@ from ....tools import (
|
|
|
31
31
|
Toolkit,
|
|
32
32
|
)
|
|
33
33
|
from ...base import BaseProvider, Params
|
|
34
|
+
from .. import _utils as _shared_utils
|
|
34
35
|
from ..model_id import model_name as openai_model_name
|
|
35
36
|
from . import _utils
|
|
36
37
|
|
|
@@ -44,6 +45,7 @@ class BaseOpenAICompletionsProvider(BaseProvider[OpenAI]):
|
|
|
44
45
|
api_key_env_var: ClassVar[str]
|
|
45
46
|
api_key_required: ClassVar[bool] = True
|
|
46
47
|
provider_name: ClassVar[str | None] = None
|
|
48
|
+
error_map = _shared_utils.OPENAI_ERROR_MAP
|
|
47
49
|
|
|
48
50
|
def __init__(
|
|
49
51
|
self,
|
|
@@ -77,6 +79,10 @@ class BaseOpenAICompletionsProvider(BaseProvider[OpenAI]):
|
|
|
77
79
|
base_url=resolved_base_url,
|
|
78
80
|
)
|
|
79
81
|
|
|
82
|
+
def get_error_status(self, e: Exception) -> int | None:
|
|
83
|
+
"""Extract HTTP status code from OpenAI exception."""
|
|
84
|
+
return getattr(e, "status_code", None)
|
|
85
|
+
|
|
80
86
|
def _model_name(self, model_id: str) -> str:
|
|
81
87
|
"""Extract the model name to send to the API."""
|
|
82
88
|
return openai_model_name(model_id, None)
|
|
@@ -114,7 +120,6 @@ class BaseOpenAICompletionsProvider(BaseProvider[OpenAI]):
|
|
|
114
120
|
params=params,
|
|
115
121
|
)
|
|
116
122
|
kwargs["model"] = self._model_name(model_id)
|
|
117
|
-
|
|
118
123
|
openai_response = self.client.chat.completions.create(**kwargs)
|
|
119
124
|
|
|
120
125
|
assistant_message, finish_reason, usage = _utils.decode_response(
|
|
@@ -171,7 +176,6 @@ class BaseOpenAICompletionsProvider(BaseProvider[OpenAI]):
|
|
|
171
176
|
params=params,
|
|
172
177
|
)
|
|
173
178
|
kwargs["model"] = self._model_name(model_id)
|
|
174
|
-
|
|
175
179
|
openai_response = self.client.chat.completions.create(**kwargs)
|
|
176
180
|
|
|
177
181
|
assistant_message, finish_reason, usage = _utils.decode_response(
|
|
@@ -224,7 +228,6 @@ class BaseOpenAICompletionsProvider(BaseProvider[OpenAI]):
|
|
|
224
228
|
format=format,
|
|
225
229
|
)
|
|
226
230
|
kwargs["model"] = self._model_name(model_id)
|
|
227
|
-
|
|
228
231
|
openai_response = await self.async_client.chat.completions.create(**kwargs)
|
|
229
232
|
|
|
230
233
|
assistant_message, finish_reason, usage = _utils.decode_response(
|
|
@@ -281,7 +284,6 @@ class BaseOpenAICompletionsProvider(BaseProvider[OpenAI]):
|
|
|
281
284
|
format=format,
|
|
282
285
|
)
|
|
283
286
|
kwargs["model"] = self._model_name(model_id)
|
|
284
|
-
|
|
285
287
|
openai_response = await self.async_client.chat.completions.create(**kwargs)
|
|
286
288
|
|
|
287
289
|
assistant_message, finish_reason, usage = _utils.decode_response(
|
|
@@ -334,7 +336,6 @@ class BaseOpenAICompletionsProvider(BaseProvider[OpenAI]):
|
|
|
334
336
|
params=params,
|
|
335
337
|
)
|
|
336
338
|
kwargs["model"] = self._model_name(model_id)
|
|
337
|
-
|
|
338
339
|
openai_stream = self.client.chat.completions.create(
|
|
339
340
|
**kwargs,
|
|
340
341
|
stream=True,
|
|
@@ -436,7 +437,6 @@ class BaseOpenAICompletionsProvider(BaseProvider[OpenAI]):
|
|
|
436
437
|
params=params,
|
|
437
438
|
)
|
|
438
439
|
kwargs["model"] = self._model_name(model_id)
|
|
439
|
-
|
|
440
440
|
openai_stream = await self.async_client.chat.completions.create(
|
|
441
441
|
**kwargs,
|
|
442
442
|
stream=True,
|
|
@@ -3,9 +3,10 @@
|
|
|
3
3
|
from collections.abc import Sequence
|
|
4
4
|
from typing_extensions import Unpack
|
|
5
5
|
|
|
6
|
-
from openai import OpenAI
|
|
6
|
+
from openai import BadRequestError as OpenAIBadRequestError, OpenAI
|
|
7
7
|
|
|
8
8
|
from ...context import Context, DepsT
|
|
9
|
+
from ...exceptions import BadRequestError, NotFoundError
|
|
9
10
|
from ...formatting import Format, FormattableT
|
|
10
11
|
from ...messages import Message
|
|
11
12
|
from ...responses import (
|
|
@@ -29,6 +30,7 @@ from ...tools import (
|
|
|
29
30
|
Toolkit,
|
|
30
31
|
)
|
|
31
32
|
from ..base import BaseProvider, Params
|
|
33
|
+
from . import _utils
|
|
32
34
|
from .completions import OpenAICompletionsProvider
|
|
33
35
|
from .model_id import OPENAI_KNOWN_MODELS, OpenAIModelId
|
|
34
36
|
from .responses import OpenAIResponsesProvider
|
|
@@ -107,6 +109,13 @@ class OpenAIProvider(BaseProvider[OpenAI]):
|
|
|
107
109
|
|
|
108
110
|
id = "openai"
|
|
109
111
|
default_scope = "openai/"
|
|
112
|
+
# Include special handling for model_not_found from Responses API
|
|
113
|
+
error_map = {
|
|
114
|
+
**_utils.OPENAI_ERROR_MAP,
|
|
115
|
+
OpenAIBadRequestError: lambda e: NotFoundError
|
|
116
|
+
if hasattr(e, "code") and e.code == "model_not_found" # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue]
|
|
117
|
+
else BadRequestError,
|
|
118
|
+
}
|
|
110
119
|
|
|
111
120
|
def __init__(
|
|
112
121
|
self, *, api_key: str | None = None, base_url: str | None = None
|
|
@@ -121,6 +130,10 @@ class OpenAIProvider(BaseProvider[OpenAI]):
|
|
|
121
130
|
# Use completions client's underlying OpenAI client as the main one
|
|
122
131
|
self.client = self._completions_provider.client
|
|
123
132
|
|
|
133
|
+
def get_error_status(self, e: Exception) -> int | None:
|
|
134
|
+
"""Extract HTTP status code from OpenAI exception."""
|
|
135
|
+
return getattr(e, "status_code", None) # pragma: no cover
|
|
136
|
+
|
|
124
137
|
def _choose_subprovider(
|
|
125
138
|
self, model_id: OpenAIModelId, messages: Sequence[Message]
|
|
126
139
|
) -> OpenAICompletionsProvider | OpenAIResponsesProvider:
|
|
@@ -3,9 +3,10 @@
|
|
|
3
3
|
from collections.abc import Sequence
|
|
4
4
|
from typing_extensions import Unpack
|
|
5
5
|
|
|
6
|
-
from openai import AsyncOpenAI, OpenAI
|
|
6
|
+
from openai import AsyncOpenAI, BadRequestError as OpenAIBadRequestError, OpenAI
|
|
7
7
|
|
|
8
8
|
from ....context import Context, DepsT
|
|
9
|
+
from ....exceptions import BadRequestError, NotFoundError
|
|
9
10
|
from ....formatting import Format, FormattableT
|
|
10
11
|
from ....messages import Message
|
|
11
12
|
from ....responses import (
|
|
@@ -29,6 +30,7 @@ from ....tools import (
|
|
|
29
30
|
Toolkit,
|
|
30
31
|
)
|
|
31
32
|
from ...base import BaseProvider, Params
|
|
33
|
+
from .. import _utils as _shared_utils
|
|
32
34
|
from ..model_id import OpenAIModelId, model_name
|
|
33
35
|
from . import _utils
|
|
34
36
|
|
|
@@ -38,6 +40,12 @@ class OpenAIResponsesProvider(BaseProvider[OpenAI]):
|
|
|
38
40
|
|
|
39
41
|
id = "openai:responses"
|
|
40
42
|
default_scope = "openai/"
|
|
43
|
+
error_map = {
|
|
44
|
+
**_shared_utils.OPENAI_ERROR_MAP,
|
|
45
|
+
OpenAIBadRequestError: lambda e: NotFoundError
|
|
46
|
+
if hasattr(e, "code") and e.code == "model_not_found" # pyright: ignore[reportAttributeAccessIssue,reportUnknownMemberType]
|
|
47
|
+
else BadRequestError,
|
|
48
|
+
}
|
|
41
49
|
|
|
42
50
|
def __init__(
|
|
43
51
|
self,
|
|
@@ -49,6 +57,10 @@ class OpenAIResponsesProvider(BaseProvider[OpenAI]):
|
|
|
49
57
|
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
|
50
58
|
self.async_client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
|
51
59
|
|
|
60
|
+
def get_error_status(self, e: Exception) -> int | None:
|
|
61
|
+
"""Extract HTTP status code from OpenAI exception."""
|
|
62
|
+
return getattr(e, "status_code", None)
|
|
63
|
+
|
|
52
64
|
def _call(
|
|
53
65
|
self,
|
|
54
66
|
*,
|
|
@@ -77,7 +89,6 @@ class OpenAIResponsesProvider(BaseProvider[OpenAI]):
|
|
|
77
89
|
format=format,
|
|
78
90
|
params=params,
|
|
79
91
|
)
|
|
80
|
-
|
|
81
92
|
openai_response = self.client.responses.create(**kwargs)
|
|
82
93
|
|
|
83
94
|
assistant_message, finish_reason, usage = _utils.decode_response(
|
|
@@ -127,7 +138,6 @@ class OpenAIResponsesProvider(BaseProvider[OpenAI]):
|
|
|
127
138
|
format=format,
|
|
128
139
|
params=params,
|
|
129
140
|
)
|
|
130
|
-
|
|
131
141
|
openai_response = await self.async_client.responses.create(**kwargs)
|
|
132
142
|
|
|
133
143
|
assistant_message, finish_reason, usage = _utils.decode_response(
|
|
@@ -177,7 +187,6 @@ class OpenAIResponsesProvider(BaseProvider[OpenAI]):
|
|
|
177
187
|
format=format,
|
|
178
188
|
params=params,
|
|
179
189
|
)
|
|
180
|
-
|
|
181
190
|
openai_stream = self.client.responses.create(
|
|
182
191
|
**kwargs,
|
|
183
192
|
stream=True,
|
|
@@ -227,7 +236,6 @@ class OpenAIResponsesProvider(BaseProvider[OpenAI]):
|
|
|
227
236
|
format=format,
|
|
228
237
|
params=params,
|
|
229
238
|
)
|
|
230
|
-
|
|
231
239
|
openai_stream = await self.async_client.responses.create(
|
|
232
240
|
**kwargs,
|
|
233
241
|
stream=True,
|
|
@@ -281,7 +289,6 @@ class OpenAIResponsesProvider(BaseProvider[OpenAI]):
|
|
|
281
289
|
format=format,
|
|
282
290
|
params=params,
|
|
283
291
|
)
|
|
284
|
-
|
|
285
292
|
openai_response = self.client.responses.create(**kwargs)
|
|
286
293
|
|
|
287
294
|
assistant_message, finish_reason, usage = _utils.decode_response(
|
|
@@ -335,7 +342,6 @@ class OpenAIResponsesProvider(BaseProvider[OpenAI]):
|
|
|
335
342
|
format=format,
|
|
336
343
|
params=params,
|
|
337
344
|
)
|
|
338
|
-
|
|
339
345
|
openai_response = await self.async_client.responses.create(**kwargs)
|
|
340
346
|
|
|
341
347
|
assistant_message, finish_reason, usage = _utils.decode_response(
|
|
@@ -1,16 +1,31 @@
|
|
|
1
1
|
"""Provider registry for managing provider instances and scopes."""
|
|
2
2
|
|
|
3
|
+
from functools import lru_cache
|
|
3
4
|
from typing import overload
|
|
4
5
|
|
|
5
6
|
from ..exceptions import NoRegisteredProviderError
|
|
7
|
+
from .anthropic import AnthropicProvider
|
|
6
8
|
from .base import Provider
|
|
7
|
-
from .
|
|
9
|
+
from .google import GoogleProvider
|
|
10
|
+
from .mlx import MLXProvider
|
|
11
|
+
from .ollama import OllamaProvider
|
|
12
|
+
from .openai import OpenAIProvider
|
|
13
|
+
from .openai.completions.provider import OpenAICompletionsProvider
|
|
14
|
+
from .openai.responses.provider import OpenAIResponsesProvider
|
|
8
15
|
from .provider_id import ProviderId
|
|
16
|
+
from .together import TogetherProvider
|
|
9
17
|
|
|
10
18
|
# Global registry mapping scopes to providers
|
|
11
19
|
# Scopes are matched by prefix (longest match wins)
|
|
12
20
|
PROVIDER_REGISTRY: dict[str, Provider] = {}
|
|
13
21
|
|
|
22
|
+
|
|
23
|
+
def reset_provider_registry() -> None:
|
|
24
|
+
"""Resets the provider registry, clearing all registered providers."""
|
|
25
|
+
PROVIDER_REGISTRY.clear()
|
|
26
|
+
provider_singleton.cache_clear()
|
|
27
|
+
|
|
28
|
+
|
|
14
29
|
# Default auto-registration mapping for built-in providers
|
|
15
30
|
# These providers will be automatically registered on first use
|
|
16
31
|
DEFAULT_AUTO_REGISTER_SCOPES: dict[str, ProviderId] = {
|
|
@@ -23,6 +38,44 @@ DEFAULT_AUTO_REGISTER_SCOPES: dict[str, ProviderId] = {
|
|
|
23
38
|
}
|
|
24
39
|
|
|
25
40
|
|
|
41
|
+
@lru_cache(maxsize=256)
|
|
42
|
+
def provider_singleton(
|
|
43
|
+
provider_id: ProviderId, *, api_key: str | None = None, base_url: str | None = None
|
|
44
|
+
) -> Provider:
|
|
45
|
+
"""Create a cached provider instance for the specified provider id.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
provider_id: The provider name ("openai", "anthropic", or "google").
|
|
49
|
+
api_key: API key for authentication. If None, uses provider-specific env var.
|
|
50
|
+
base_url: Base URL for the API. If None, uses provider-specific env var.
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
A cached provider instance for the specified provider with the given parameters.
|
|
54
|
+
|
|
55
|
+
Raises:
|
|
56
|
+
ValueError: If the provider_id is not supported.
|
|
57
|
+
"""
|
|
58
|
+
match provider_id:
|
|
59
|
+
case "anthropic":
|
|
60
|
+
return AnthropicProvider(api_key=api_key, base_url=base_url)
|
|
61
|
+
case "google":
|
|
62
|
+
return GoogleProvider(api_key=api_key, base_url=base_url)
|
|
63
|
+
case "mlx": # pragma: no cover (MLX is only available on macOS)
|
|
64
|
+
return MLXProvider()
|
|
65
|
+
case "ollama":
|
|
66
|
+
return OllamaProvider(api_key=api_key, base_url=base_url)
|
|
67
|
+
case "openai":
|
|
68
|
+
return OpenAIProvider(api_key=api_key, base_url=base_url)
|
|
69
|
+
case "openai:completions":
|
|
70
|
+
return OpenAICompletionsProvider(api_key=api_key, base_url=base_url)
|
|
71
|
+
case "openai:responses":
|
|
72
|
+
return OpenAIResponsesProvider(api_key=api_key, base_url=base_url)
|
|
73
|
+
case "together":
|
|
74
|
+
return TogetherProvider(api_key=api_key, base_url=base_url)
|
|
75
|
+
case _: # pragma: no cover
|
|
76
|
+
raise ValueError(f"Unknown provider: '{provider_id}'")
|
|
77
|
+
|
|
78
|
+
|
|
26
79
|
@overload
|
|
27
80
|
def register_provider(
|
|
28
81
|
provider: Provider,
|
|
@@ -100,7 +153,7 @@ def register_provider(
|
|
|
100
153
|
"""
|
|
101
154
|
|
|
102
155
|
if isinstance(provider, str):
|
|
103
|
-
provider =
|
|
156
|
+
provider = provider_singleton(provider, api_key=api_key, base_url=base_url)
|
|
104
157
|
|
|
105
158
|
if scope is None:
|
|
106
159
|
scope = provider.default_scope
|
|
@@ -160,7 +213,7 @@ def get_provider_for_model(model_id: str) -> Provider:
|
|
|
160
213
|
if matching_defaults:
|
|
161
214
|
best_scope = max(matching_defaults, key=len)
|
|
162
215
|
provider_id = DEFAULT_AUTO_REGISTER_SCOPES[best_scope]
|
|
163
|
-
provider =
|
|
216
|
+
provider = provider_singleton(provider_id)
|
|
164
217
|
# Auto-register for future calls
|
|
165
218
|
PROVIDER_REGISTRY[best_scope] = provider
|
|
166
219
|
return provider
|
|
@@ -50,6 +50,55 @@ def _is_third_party(module: ModuleType, site_packages: set[str]) -> bool:
|
|
|
50
50
|
)
|
|
51
51
|
|
|
52
52
|
|
|
53
|
+
class _RemoveVersionDecoratorTransformer(cst.CSTTransformer):
|
|
54
|
+
"""CST transformer to remove @ops.version and @version decorators."""
|
|
55
|
+
|
|
56
|
+
@classmethod
|
|
57
|
+
def _is_version_decorator(cls, decorator: cst.Decorator) -> bool:
|
|
58
|
+
"""Returns True if the decorator is @version or @ops.version."""
|
|
59
|
+
decorator_node = decorator.decorator
|
|
60
|
+
|
|
61
|
+
if isinstance(decorator_node, cst.Name) and decorator_node.value == "version":
|
|
62
|
+
return True
|
|
63
|
+
if (
|
|
64
|
+
isinstance(decorator_node, cst.Call)
|
|
65
|
+
and isinstance(decorator_node.func, cst.Name)
|
|
66
|
+
and decorator_node.func.value == "version"
|
|
67
|
+
):
|
|
68
|
+
return True
|
|
69
|
+
|
|
70
|
+
if (
|
|
71
|
+
isinstance(decorator_node, cst.Attribute)
|
|
72
|
+
and isinstance(decorator_node.value, cst.Name)
|
|
73
|
+
and decorator_node.value.value == "ops"
|
|
74
|
+
and decorator_node.attr.value == "version"
|
|
75
|
+
):
|
|
76
|
+
return True
|
|
77
|
+
if isinstance(decorator_node, cst.Call) and isinstance(
|
|
78
|
+
decorator_node.func, cst.Attribute
|
|
79
|
+
):
|
|
80
|
+
func_attribute = decorator_node.func
|
|
81
|
+
if (
|
|
82
|
+
isinstance(func_attribute.value, cst.Name)
|
|
83
|
+
and func_attribute.value.value == "ops"
|
|
84
|
+
and func_attribute.attr.value == "version"
|
|
85
|
+
):
|
|
86
|
+
return True
|
|
87
|
+
|
|
88
|
+
return False
|
|
89
|
+
|
|
90
|
+
def leave_FunctionDef(
|
|
91
|
+
self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
|
|
92
|
+
) -> cst.FunctionDef:
|
|
93
|
+
"""Returns function definition with @version/@ops.version decorators removed."""
|
|
94
|
+
new_decorators = [
|
|
95
|
+
decorator
|
|
96
|
+
for decorator in updated_node.decorators
|
|
97
|
+
if not self._is_version_decorator(decorator)
|
|
98
|
+
]
|
|
99
|
+
return updated_node.with_changes(decorators=new_decorators)
|
|
100
|
+
|
|
101
|
+
|
|
53
102
|
class _RemoveDocstringTransformer(cst.CSTTransformer):
|
|
54
103
|
"""CST transformer to remove docstrings from functions and classes."""
|
|
55
104
|
|
|
@@ -125,14 +174,9 @@ def _clean_source_code(
|
|
|
125
174
|
if docstr_flag in ("1", "true", "yes"):
|
|
126
175
|
return source.rstrip()
|
|
127
176
|
module = cst.parse_module(source)
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
code = new_module.code
|
|
133
|
-
code = code.rstrip()
|
|
134
|
-
|
|
135
|
-
return code
|
|
177
|
+
module = module.visit(_RemoveVersionDecoratorTransformer())
|
|
178
|
+
module = module.visit(_RemoveDocstringTransformer(exclude_fn_body=exclude_fn_body))
|
|
179
|
+
return module.code.rstrip()
|
|
136
180
|
|
|
137
181
|
|
|
138
182
|
@dataclass(frozen=True)
|
|
@@ -596,9 +640,9 @@ def _clean_source_from_string(source: str, exclude_fn_body: bool = False) -> str
|
|
|
596
640
|
"""Returns cleaned source code string with optional docstring removal."""
|
|
597
641
|
source = dedent(source)
|
|
598
642
|
module = cst.parse_module(source)
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
return
|
|
643
|
+
module = module.visit(_RemoveVersionDecoratorTransformer())
|
|
644
|
+
module = module.visit(_RemoveDocstringTransformer(exclude_fn_body=exclude_fn_body))
|
|
645
|
+
return module.code.rstrip()
|
|
602
646
|
|
|
603
647
|
|
|
604
648
|
def _get_class_source_from_method(method: Callable[..., Any]) -> str:
|
|
@@ -694,6 +738,13 @@ class _DependencyCollector:
|
|
|
694
738
|
# For Python 3.13+
|
|
695
739
|
return definition.func # pyright: ignore[reportFunctionMemberAccess] # pragma: no cover
|
|
696
740
|
|
|
741
|
+
if (
|
|
742
|
+
(wrapped_function := getattr(definition, "fn", None)) is not None
|
|
743
|
+
and not hasattr(definition, "__qualname__")
|
|
744
|
+
and callable(wrapped_function)
|
|
745
|
+
):
|
|
746
|
+
return wrapped_function
|
|
747
|
+
|
|
697
748
|
return definition
|
|
698
749
|
|
|
699
750
|
def _get_source_code(self, definition: Callable[..., Any] | type) -> str | None:
|