mirascope 1.20.1__py3-none-any.whl → 1.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mirascope/core/anthropic/call_response.py +6 -3
- mirascope/core/azure/call_response.py +2 -0
- mirascope/core/base/_create.py +2 -2
- mirascope/core/base/_utils/__init__.py +3 -0
- mirascope/core/base/_utils/_base_message_param_converter.py +1 -1
- mirascope/core/base/call_response.py +11 -1
- mirascope/core/bedrock/call_response.py +4 -0
- mirascope/core/cohere/call_response.py +3 -0
- mirascope/core/gemini/call_response.py +3 -0
- mirascope/core/google/call_response.py +3 -0
- mirascope/core/google/tool.py +35 -4
- mirascope/core/groq/call_response.py +3 -0
- mirascope/core/mistral/call_response.py +5 -0
- mirascope/core/openai/call_response.py +2 -0
- mirascope/core/vertex/call_response.py +3 -0
- mirascope/llm/__init__.py +6 -2
- mirascope/llm/{llm_call.py → _call.py} +94 -21
- mirascope/llm/_context.py +381 -0
- mirascope/llm/_override.py +3639 -0
- mirascope/llm/_protocols.py +3 -8
- mirascope/llm/call_response.py +8 -6
- mirascope/llm/call_response_chunk.py +4 -7
- mirascope/llm/stream.py +36 -54
- {mirascope-1.20.1.dist-info → mirascope-1.21.0.dist-info}/METADATA +1 -1
- {mirascope-1.20.1.dist-info → mirascope-1.21.0.dist-info}/RECORD +27 -26
- mirascope/llm/llm_override.py +0 -233
- {mirascope-1.20.1.dist-info → mirascope-1.21.0.dist-info}/WHEEL +0 -0
- {mirascope-1.20.1.dist-info → mirascope-1.21.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -20,9 +20,7 @@ from ..base.types import CostMetadata
|
|
|
20
20
|
from ._utils._convert_finish_reason_to_common_finish_reasons import (
|
|
21
21
|
_convert_finish_reasons_to_common_finish_reasons,
|
|
22
22
|
)
|
|
23
|
-
from ._utils._message_param_converter import
|
|
24
|
-
AnthropicMessageParamConverter,
|
|
25
|
-
)
|
|
23
|
+
from ._utils._message_param_converter import AnthropicMessageParamConverter
|
|
26
24
|
from .call_params import AnthropicCallParams
|
|
27
25
|
from .dynamic_config import AnthropicDynamicConfig, AsyncAnthropicDynamicConfig
|
|
28
26
|
from .tool import AnthropicTool
|
|
@@ -37,6 +35,7 @@ class AnthropicCallResponse(
|
|
|
37
35
|
MessageParam,
|
|
38
36
|
AnthropicCallParams,
|
|
39
37
|
MessageParam,
|
|
38
|
+
AnthropicMessageParamConverter,
|
|
40
39
|
]
|
|
41
40
|
):
|
|
42
41
|
"""A convenience wrapper around the Anthropic `Message` response.
|
|
@@ -62,6 +61,10 @@ class AnthropicCallResponse(
|
|
|
62
61
|
```
|
|
63
62
|
"""
|
|
64
63
|
|
|
64
|
+
_message_converter: type[AnthropicMessageParamConverter] = (
|
|
65
|
+
AnthropicMessageParamConverter
|
|
66
|
+
)
|
|
67
|
+
|
|
65
68
|
_provider = "anthropic"
|
|
66
69
|
|
|
67
70
|
@computed_field
|
|
@@ -38,6 +38,7 @@ class AzureCallResponse(
|
|
|
38
38
|
ChatRequestMessage,
|
|
39
39
|
AzureCallParams,
|
|
40
40
|
UserMessage,
|
|
41
|
+
AzureMessageParamConverter,
|
|
41
42
|
]
|
|
42
43
|
):
|
|
43
44
|
"""A convenience wrapper around the Azure `ChatCompletion` response.
|
|
@@ -64,6 +65,7 @@ class AzureCallResponse(
|
|
|
64
65
|
"""
|
|
65
66
|
|
|
66
67
|
response: SkipValidation[ChatCompletions]
|
|
68
|
+
_message_converter: type[AzureMessageParamConverter] = AzureMessageParamConverter
|
|
67
69
|
|
|
68
70
|
_provider = "azure"
|
|
69
71
|
|
mirascope/core/base/_create.py
CHANGED
|
@@ -187,7 +187,7 @@ def create_factory( # noqa: ANN202
|
|
|
187
187
|
start_time = datetime.datetime.now().timestamp() * 1000
|
|
188
188
|
response = await create(stream=False, **call_kwargs)
|
|
189
189
|
end_time = datetime.datetime.now().timestamp() * 1000
|
|
190
|
-
output = TCallResponse(
|
|
190
|
+
output = TCallResponse( # pyright: ignore [reportCallIssue]
|
|
191
191
|
metadata=get_metadata(fn, dynamic_config),
|
|
192
192
|
response=response,
|
|
193
193
|
tool_types=tool_types, # pyright: ignore [reportArgumentType]
|
|
@@ -231,7 +231,7 @@ def create_factory( # noqa: ANN202
|
|
|
231
231
|
start_time = datetime.datetime.now().timestamp() * 1000
|
|
232
232
|
response = create(stream=False, **call_kwargs)
|
|
233
233
|
end_time = datetime.datetime.now().timestamp() * 1000
|
|
234
|
-
output = TCallResponse(
|
|
234
|
+
output = TCallResponse( # pyright: ignore [reportCallIssue]
|
|
235
235
|
metadata=get_metadata(fn, dynamic_config),
|
|
236
236
|
response=response,
|
|
237
237
|
tool_types=tool_types, # pyright: ignore [reportArgumentType]
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Internal Utilities."""
|
|
2
2
|
|
|
3
|
+
from ._base_message_param_converter import BaseMessageParamConverter
|
|
3
4
|
from ._base_type import BaseType, is_base_type
|
|
4
5
|
from ._convert_base_model_to_base_tool import convert_base_model_to_base_tool
|
|
5
6
|
from ._convert_base_type_to_base_tool import convert_base_type_to_base_tool
|
|
@@ -46,8 +47,10 @@ from ._setup_extract_tool import setup_extract_tool
|
|
|
46
47
|
__all__ = [
|
|
47
48
|
"DEFAULT_TOOL_DOCSTRING",
|
|
48
49
|
"AsyncCreateFn",
|
|
50
|
+
"BaseMessageParamConverter",
|
|
49
51
|
"BaseType",
|
|
50
52
|
"CalculateCost",
|
|
53
|
+
"CallDecorator",
|
|
51
54
|
"CreateFn",
|
|
52
55
|
"GetJsonOutput",
|
|
53
56
|
"HandleStream",
|
|
@@ -19,7 +19,7 @@ from pydantic import (
|
|
|
19
19
|
)
|
|
20
20
|
|
|
21
21
|
from ..costs import calculate_cost
|
|
22
|
-
from ._utils import BaseType, get_common_usage
|
|
22
|
+
from ._utils import BaseMessageParamConverter, BaseType, get_common_usage
|
|
23
23
|
from .call_kwargs import BaseCallKwargs
|
|
24
24
|
from .call_params import BaseCallParams
|
|
25
25
|
from .dynamic_config import BaseDynamicConfig
|
|
@@ -40,6 +40,9 @@ _ToolMessageParamT = TypeVar("_ToolMessageParamT", bound=Any)
|
|
|
40
40
|
_CallParamsT = TypeVar("_CallParamsT", bound=BaseCallParams)
|
|
41
41
|
_UserMessageParamT = TypeVar("_UserMessageParamT", bound=Any)
|
|
42
42
|
_BaseCallResponseT = TypeVar("_BaseCallResponseT", bound="BaseCallResponse")
|
|
43
|
+
_BaseMessageParamConverterT = TypeVar(
|
|
44
|
+
"_BaseMessageParamConverterT", bound=BaseMessageParamConverter
|
|
45
|
+
)
|
|
43
46
|
|
|
44
47
|
|
|
45
48
|
def transform_tool_outputs(
|
|
@@ -97,6 +100,7 @@ class BaseCallResponse(
|
|
|
97
100
|
_MessageParamT,
|
|
98
101
|
_CallParamsT,
|
|
99
102
|
_UserMessageParamT,
|
|
103
|
+
_BaseMessageParamConverterT,
|
|
100
104
|
],
|
|
101
105
|
ABC,
|
|
102
106
|
):
|
|
@@ -131,6 +135,7 @@ class BaseCallResponse(
|
|
|
131
135
|
start_time: float
|
|
132
136
|
end_time: float
|
|
133
137
|
|
|
138
|
+
_message_converter: type[_BaseMessageParamConverterT]
|
|
134
139
|
_provider: ClassVar[str] = "NO PROVIDER"
|
|
135
140
|
_model: str = "NO MODEL"
|
|
136
141
|
|
|
@@ -313,6 +318,11 @@ class BaseCallResponse(
|
|
|
313
318
|
"""Provider-agnostic user message param."""
|
|
314
319
|
...
|
|
315
320
|
|
|
321
|
+
@property
|
|
322
|
+
def common_messages(self) -> list[BaseMessageParam]:
|
|
323
|
+
"""Provider-agnostic list of messages."""
|
|
324
|
+
return self._message_converter.from_provider(self.messages)
|
|
325
|
+
|
|
316
326
|
@property
|
|
317
327
|
def common_tools(self) -> list[Tool] | None:
|
|
318
328
|
"""Provider-agnostic tools."""
|
|
@@ -51,6 +51,7 @@ class BedrockCallResponse(
|
|
|
51
51
|
InternalBedrockMessageParam,
|
|
52
52
|
BedrockCallParams,
|
|
53
53
|
UserMessageTypeDef,
|
|
54
|
+
BedrockMessageParamConverter,
|
|
54
55
|
]
|
|
55
56
|
):
|
|
56
57
|
"""A convenience wrapper around the Bedrock `ChatCompletion` response.
|
|
@@ -78,6 +79,9 @@ class BedrockCallResponse(
|
|
|
78
79
|
"""
|
|
79
80
|
|
|
80
81
|
response: SkipValidation[SyncConverseResponseTypeDef | AsyncConverseResponseTypeDef]
|
|
82
|
+
_message_converter: type[BedrockMessageParamConverter] = (
|
|
83
|
+
BedrockMessageParamConverter
|
|
84
|
+
)
|
|
81
85
|
|
|
82
86
|
_provider = "bedrock"
|
|
83
87
|
|
|
@@ -35,6 +35,7 @@ class CohereCallResponse(
|
|
|
35
35
|
SkipValidation[ChatMessage],
|
|
36
36
|
CohereCallParams,
|
|
37
37
|
SkipValidation[ChatMessage],
|
|
38
|
+
CohereMessageParamConverter,
|
|
38
39
|
]
|
|
39
40
|
):
|
|
40
41
|
"""A convenience wrapper around the Cohere `ChatCompletion` response.
|
|
@@ -60,6 +61,8 @@ class CohereCallResponse(
|
|
|
60
61
|
```
|
|
61
62
|
"""
|
|
62
63
|
|
|
64
|
+
_message_converter: type[CohereMessageParamConverter] = CohereMessageParamConverter
|
|
65
|
+
|
|
63
66
|
_provider = "cohere"
|
|
64
67
|
|
|
65
68
|
@computed_field
|
|
@@ -36,6 +36,7 @@ class GeminiCallResponse(
|
|
|
36
36
|
ContentsType,
|
|
37
37
|
GeminiCallParams,
|
|
38
38
|
ContentDict,
|
|
39
|
+
GeminiMessageParamConverter,
|
|
39
40
|
]
|
|
40
41
|
):
|
|
41
42
|
"""A convenience wrapper around the Gemini API response.
|
|
@@ -62,6 +63,8 @@ class GeminiCallResponse(
|
|
|
62
63
|
```
|
|
63
64
|
"""
|
|
64
65
|
|
|
66
|
+
_message_converter: type[GeminiMessageParamConverter] = GeminiMessageParamConverter
|
|
67
|
+
|
|
65
68
|
_provider = "gemini"
|
|
66
69
|
|
|
67
70
|
@computed_field
|
|
@@ -40,6 +40,7 @@ class GoogleCallResponse(
|
|
|
40
40
|
ContentListUnion | ContentListUnionDict,
|
|
41
41
|
GoogleCallParams,
|
|
42
42
|
ContentDict,
|
|
43
|
+
GoogleMessageParamConverter,
|
|
43
44
|
]
|
|
44
45
|
):
|
|
45
46
|
"""A convenience wrapper around the Google API response.
|
|
@@ -66,6 +67,8 @@ class GoogleCallResponse(
|
|
|
66
67
|
```
|
|
67
68
|
"""
|
|
68
69
|
|
|
70
|
+
_message_converter: type[GoogleMessageParamConverter] = GoogleMessageParamConverter
|
|
71
|
+
|
|
69
72
|
_provider = "google"
|
|
70
73
|
|
|
71
74
|
@computed_field
|
mirascope/core/google/tool.py
CHANGED
|
@@ -70,17 +70,48 @@ class GoogleTool(BaseTool):
|
|
|
70
70
|
fn["parameters"] = model_schema
|
|
71
71
|
|
|
72
72
|
if "parameters" in fn:
|
|
73
|
+
# Resolve $defs and $ref
|
|
73
74
|
if "$defs" in fn["parameters"]:
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
75
|
+
defs = fn["parameters"].pop("$defs")
|
|
76
|
+
|
|
77
|
+
def resolve_refs(schema: dict[str, Any]) -> dict[str, Any]:
|
|
78
|
+
"""Recursively resolve $ref references using the $defs dictionary."""
|
|
79
|
+
# If this is a reference, resolve it
|
|
80
|
+
if "$ref" in schema:
|
|
81
|
+
ref = schema["$ref"]
|
|
82
|
+
if ref.startswith("#/$defs/"):
|
|
83
|
+
ref_key = ref.replace("#/$defs/", "")
|
|
84
|
+
if ref_key in defs:
|
|
85
|
+
# Merge the definition with the current schema (excluding $ref)
|
|
86
|
+
resolved = {
|
|
87
|
+
**{k: v for k, v in schema.items() if k != "$ref"},
|
|
88
|
+
**resolve_refs(defs[ref_key]),
|
|
89
|
+
}
|
|
90
|
+
return resolved
|
|
91
|
+
|
|
92
|
+
# Process all other keys recursively
|
|
93
|
+
result = {}
|
|
94
|
+
for key, value in schema.items():
|
|
95
|
+
if isinstance(value, dict):
|
|
96
|
+
result[key] = resolve_refs(value)
|
|
97
|
+
elif isinstance(value, list):
|
|
98
|
+
result[key] = [
|
|
99
|
+
resolve_refs(item) if isinstance(item, dict) else item
|
|
100
|
+
for item in value
|
|
101
|
+
]
|
|
102
|
+
else:
|
|
103
|
+
result[key] = value
|
|
104
|
+
return result
|
|
105
|
+
|
|
106
|
+
# Resolve all references in the parameters
|
|
107
|
+
fn["parameters"] = resolve_refs(fn["parameters"])
|
|
78
108
|
|
|
79
109
|
def handle_enum_schema(prop_schema: dict[str, Any]) -> dict[str, Any]:
|
|
80
110
|
if "enum" in prop_schema:
|
|
81
111
|
prop_schema["format"] = "enum"
|
|
82
112
|
return prop_schema
|
|
83
113
|
|
|
114
|
+
# Process properties after resolving references
|
|
84
115
|
fn["parameters"]["properties"] = {
|
|
85
116
|
prop: {
|
|
86
117
|
key: value
|
|
@@ -35,6 +35,7 @@ class GroqCallResponse(
|
|
|
35
35
|
ChatCompletionMessageParam,
|
|
36
36
|
GroqCallParams,
|
|
37
37
|
ChatCompletionUserMessageParam,
|
|
38
|
+
GroqMessageParamConverter,
|
|
38
39
|
]
|
|
39
40
|
):
|
|
40
41
|
"""A convenience wrapper around the Groq `ChatCompletion` response.
|
|
@@ -59,6 +60,8 @@ class GroqCallResponse(
|
|
|
59
60
|
```
|
|
60
61
|
"""
|
|
61
62
|
|
|
63
|
+
_message_converter: type[GroqMessageParamConverter] = GroqMessageParamConverter
|
|
64
|
+
|
|
62
65
|
_provider = "groq"
|
|
63
66
|
|
|
64
67
|
@computed_field
|
|
@@ -38,6 +38,7 @@ class MistralCallResponse(
|
|
|
38
38
|
AssistantMessage | SystemMessage | ToolMessage | UserMessage,
|
|
39
39
|
MistralCallParams,
|
|
40
40
|
UserMessage,
|
|
41
|
+
MistralMessageParamConverter,
|
|
41
42
|
]
|
|
42
43
|
):
|
|
43
44
|
"""A convenience wrapper around the Mistral `ChatCompletion` response.
|
|
@@ -62,6 +63,10 @@ class MistralCallResponse(
|
|
|
62
63
|
```
|
|
63
64
|
"""
|
|
64
65
|
|
|
66
|
+
_message_converter: type[MistralMessageParamConverter] = (
|
|
67
|
+
MistralMessageParamConverter
|
|
68
|
+
)
|
|
69
|
+
|
|
65
70
|
_provider = "mistral"
|
|
66
71
|
|
|
67
72
|
@property
|
|
@@ -58,6 +58,7 @@ class OpenAICallResponse(
|
|
|
58
58
|
ChatCompletionMessageParam,
|
|
59
59
|
OpenAICallParams,
|
|
60
60
|
ChatCompletionUserMessageParam,
|
|
61
|
+
OpenAIMessageParamConverter,
|
|
61
62
|
]
|
|
62
63
|
):
|
|
63
64
|
"""A convenience wrapper around the OpenAI `ChatCompletion` response.
|
|
@@ -84,6 +85,7 @@ class OpenAICallResponse(
|
|
|
84
85
|
"""
|
|
85
86
|
|
|
86
87
|
response: SkipValidation[ChatCompletion]
|
|
88
|
+
_message_converter: type[OpenAIMessageParamConverter] = OpenAIMessageParamConverter
|
|
87
89
|
|
|
88
90
|
_provider = "openai"
|
|
89
91
|
|
|
@@ -30,6 +30,7 @@ class VertexCallResponse(
|
|
|
30
30
|
Content,
|
|
31
31
|
VertexCallParams,
|
|
32
32
|
Content,
|
|
33
|
+
VertexMessageParamConverter,
|
|
33
34
|
]
|
|
34
35
|
):
|
|
35
36
|
"""A convenience wrapper around the Vertex AI `GenerateContentResponse`.
|
|
@@ -55,6 +56,8 @@ class VertexCallResponse(
|
|
|
55
56
|
```
|
|
56
57
|
"""
|
|
57
58
|
|
|
59
|
+
_message_converter: type[VertexMessageParamConverter] = VertexMessageParamConverter
|
|
60
|
+
|
|
58
61
|
_provider = "vertex"
|
|
59
62
|
|
|
60
63
|
@computed_field
|
mirascope/llm/__init__.py
CHANGED
|
@@ -1,14 +1,18 @@
|
|
|
1
1
|
from ..core import CostMetadata, LocalProvider, Provider, calculate_cost
|
|
2
|
+
from ._call import call
|
|
3
|
+
from ._context import context
|
|
4
|
+
from ._override import override
|
|
2
5
|
from .call_response import CallResponse
|
|
3
|
-
from .
|
|
4
|
-
from .llm_override import override
|
|
6
|
+
from .stream import Stream
|
|
5
7
|
|
|
6
8
|
__all__ = [
|
|
7
9
|
"CallResponse",
|
|
8
10
|
"CostMetadata",
|
|
9
11
|
"LocalProvider",
|
|
10
12
|
"Provider",
|
|
13
|
+
"Stream",
|
|
11
14
|
"calculate_cost",
|
|
12
15
|
"call",
|
|
16
|
+
"context",
|
|
13
17
|
"override",
|
|
14
18
|
]
|
|
@@ -20,6 +20,7 @@ from ..core.base import (
|
|
|
20
20
|
from ..core.base._utils import fn_is_async
|
|
21
21
|
from ..core.base.stream_config import StreamConfig
|
|
22
22
|
from ..core.base.types import LocalProvider, Provider
|
|
23
|
+
from ._context import CallArgs, apply_context_overrides_to_call_args
|
|
23
24
|
from ._protocols import (
|
|
24
25
|
AsyncLLMFunctionDecorator,
|
|
25
26
|
CallDecorator,
|
|
@@ -193,13 +194,9 @@ def _call(
|
|
|
193
194
|
]
|
|
194
195
|
):
|
|
195
196
|
"""Decorator for defining a function that calls a language model."""
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
)
|
|
200
|
-
else:
|
|
201
|
-
provider_call = _get_provider_call(cast(Provider, provider))
|
|
202
|
-
_original_args = {
|
|
197
|
+
# Store original call args that will be used for each function call
|
|
198
|
+
original_call_args: CallArgs = {
|
|
199
|
+
"provider": provider,
|
|
203
200
|
"model": model,
|
|
204
201
|
"stream": stream,
|
|
205
202
|
"tools": tools,
|
|
@@ -214,36 +211,112 @@ def _call(
|
|
|
214
211
|
fn: Callable[_P, _R | Awaitable[_R]],
|
|
215
212
|
) -> Callable[
|
|
216
213
|
_P,
|
|
217
|
-
CallResponse
|
|
214
|
+
CallResponse
|
|
215
|
+
| Stream
|
|
216
|
+
| _ResponseModelT
|
|
217
|
+
| _ParsedOutputT
|
|
218
|
+
| (_ResponseModelT | CallResponse)
|
|
219
|
+
| Awaitable[CallResponse]
|
|
220
|
+
| Awaitable[Stream]
|
|
221
|
+
| Awaitable[_ResponseModelT]
|
|
222
|
+
| Awaitable[_ParsedOutputT]
|
|
223
|
+
| Awaitable[(_ResponseModelT | CallResponse)],
|
|
218
224
|
]:
|
|
219
|
-
|
|
225
|
+
if fn_is_async(fn):
|
|
220
226
|
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
@wraps(decorated)
|
|
227
|
+
@wraps(fn)
|
|
224
228
|
async def inner_async(
|
|
225
229
|
*args: _P.args, **kwargs: _P.kwargs
|
|
226
|
-
) ->
|
|
230
|
+
) -> (
|
|
231
|
+
CallResponse
|
|
232
|
+
| Stream
|
|
233
|
+
| _ResponseModelT
|
|
234
|
+
| _ParsedOutputT
|
|
235
|
+
| (_ResponseModelT | CallResponse)
|
|
236
|
+
):
|
|
237
|
+
# Apply any context overrides to the original call args
|
|
238
|
+
effective_call_args = apply_context_overrides_to_call_args(
|
|
239
|
+
original_call_args
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
# Get the appropriate provider call function with the possibly overridden provider
|
|
243
|
+
effective_provider = effective_call_args["provider"]
|
|
244
|
+
effective_client = effective_call_args["client"]
|
|
245
|
+
|
|
246
|
+
if effective_provider in get_args(LocalProvider):
|
|
247
|
+
provider_call, effective_client = _get_local_provider_call(
|
|
248
|
+
cast(LocalProvider, effective_provider), effective_client
|
|
249
|
+
)
|
|
250
|
+
effective_call_args["client"] = effective_client
|
|
251
|
+
else:
|
|
252
|
+
provider_call = _get_provider_call(
|
|
253
|
+
cast(Provider, effective_provider)
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
# Use the provider-specific call function with overridden args
|
|
257
|
+
call_kwargs = dict(effective_call_args)
|
|
258
|
+
del call_kwargs[
|
|
259
|
+
"provider"
|
|
260
|
+
] # Remove provider as it's not a parameter to provider_call
|
|
261
|
+
|
|
262
|
+
# Get decorated function using provider_call
|
|
263
|
+
decorated = provider_call(**call_kwargs)(fn)
|
|
264
|
+
|
|
265
|
+
# Call the decorated function and wrap the result
|
|
227
266
|
result = await decorated(*args, **kwargs)
|
|
228
267
|
return _wrap_result(result)
|
|
229
268
|
|
|
230
|
-
inner_async.
|
|
231
|
-
inner_async._original_provider_call = provider_call # pyright: ignore [reportAttributeAccessIssue]
|
|
269
|
+
inner_async._original_call_args = original_call_args # pyright: ignore [reportAttributeAccessIssue]
|
|
232
270
|
inner_async._original_fn = fn # pyright: ignore [reportAttributeAccessIssue]
|
|
233
|
-
inner_async._original_provider = provider # pyright: ignore [reportAttributeAccessIssue]
|
|
234
271
|
|
|
235
272
|
return inner_async
|
|
236
273
|
else:
|
|
237
274
|
|
|
238
|
-
@wraps(
|
|
239
|
-
def inner(
|
|
275
|
+
@wraps(fn)
|
|
276
|
+
def inner(
|
|
277
|
+
*args: _P.args, **kwargs: _P.kwargs
|
|
278
|
+
) -> (
|
|
279
|
+
CallResponse
|
|
280
|
+
| Stream
|
|
281
|
+
| _ResponseModelT
|
|
282
|
+
| _ParsedOutputT
|
|
283
|
+
| (_ResponseModelT | CallResponse)
|
|
284
|
+
):
|
|
285
|
+
# Apply any context overrides to the original call args
|
|
286
|
+
effective_call_args = apply_context_overrides_to_call_args(
|
|
287
|
+
original_call_args
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
# Get the appropriate provider call function with the possibly overridden provider
|
|
291
|
+
effective_provider = effective_call_args["provider"]
|
|
292
|
+
effective_client = effective_call_args["client"]
|
|
293
|
+
|
|
294
|
+
if effective_provider in get_args(LocalProvider):
|
|
295
|
+
provider_call, effective_client = _get_local_provider_call(
|
|
296
|
+
cast(LocalProvider, effective_provider), effective_client
|
|
297
|
+
)
|
|
298
|
+
effective_call_args["client"] = effective_client
|
|
299
|
+
else:
|
|
300
|
+
provider_call = _get_provider_call(
|
|
301
|
+
cast(Provider, effective_provider)
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
# Use the provider-specific call function with overridden args
|
|
305
|
+
call_kwargs = dict(effective_call_args)
|
|
306
|
+
del call_kwargs[
|
|
307
|
+
"provider"
|
|
308
|
+
] # Remove provider as it's not a parameter to provider_call
|
|
309
|
+
|
|
310
|
+
# Get decorated function using provider_call
|
|
311
|
+
decorated = provider_call(**call_kwargs)(fn)
|
|
312
|
+
|
|
313
|
+
# Call the decorated function and wrap the result
|
|
240
314
|
result = decorated(*args, **kwargs)
|
|
241
315
|
return _wrap_result(result)
|
|
242
316
|
|
|
243
|
-
inner.
|
|
244
|
-
inner._original_provider_call = provider_call # pyright: ignore [reportAttributeAccessIssue]
|
|
317
|
+
inner._original_call_args = original_call_args # pyright: ignore [reportAttributeAccessIssue]
|
|
245
318
|
inner._original_fn = fn # pyright: ignore [reportAttributeAccessIssue]
|
|
246
|
-
|
|
319
|
+
|
|
247
320
|
return inner
|
|
248
321
|
|
|
249
322
|
return wrapper # pyright: ignore [reportReturnType]
|