mirascope 1.20.1__py3-none-any.whl → 1.21.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mirascope/core/anthropic/call_response.py +6 -3
- mirascope/core/azure/call_response.py +2 -0
- mirascope/core/base/_create.py +2 -2
- mirascope/core/base/_utils/__init__.py +3 -0
- mirascope/core/base/_utils/_base_message_param_converter.py +1 -1
- mirascope/core/base/call_response.py +11 -1
- mirascope/core/bedrock/call_response.py +4 -0
- mirascope/core/cohere/call_response.py +3 -0
- mirascope/core/gemini/call_response.py +3 -0
- mirascope/core/google/call_response.py +3 -0
- mirascope/core/google/tool.py +35 -4
- mirascope/core/groq/call_response.py +3 -0
- mirascope/core/mistral/call_response.py +5 -0
- mirascope/core/openai/call_response.py +2 -0
- mirascope/core/vertex/call_response.py +3 -0
- mirascope/llm/__init__.py +6 -2
- mirascope/llm/{llm_call.py → _call.py} +114 -25
- mirascope/llm/_context.py +381 -0
- mirascope/llm/_override.py +3639 -0
- mirascope/llm/_protocols.py +3 -8
- mirascope/llm/call_response.py +8 -6
- mirascope/llm/call_response_chunk.py +4 -7
- mirascope/llm/stream.py +36 -54
- {mirascope-1.20.1.dist-info → mirascope-1.21.1.dist-info}/METADATA +1 -1
- {mirascope-1.20.1.dist-info → mirascope-1.21.1.dist-info}/RECORD +27 -26
- mirascope/llm/llm_override.py +0 -233
- {mirascope-1.20.1.dist-info → mirascope-1.21.1.dist-info}/WHEEL +0 -0
- {mirascope-1.20.1.dist-info → mirascope-1.21.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -20,9 +20,7 @@ from ..base.types import CostMetadata
|
|
|
20
20
|
from ._utils._convert_finish_reason_to_common_finish_reasons import (
|
|
21
21
|
_convert_finish_reasons_to_common_finish_reasons,
|
|
22
22
|
)
|
|
23
|
-
from ._utils._message_param_converter import
|
|
24
|
-
AnthropicMessageParamConverter,
|
|
25
|
-
)
|
|
23
|
+
from ._utils._message_param_converter import AnthropicMessageParamConverter
|
|
26
24
|
from .call_params import AnthropicCallParams
|
|
27
25
|
from .dynamic_config import AnthropicDynamicConfig, AsyncAnthropicDynamicConfig
|
|
28
26
|
from .tool import AnthropicTool
|
|
@@ -37,6 +35,7 @@ class AnthropicCallResponse(
|
|
|
37
35
|
MessageParam,
|
|
38
36
|
AnthropicCallParams,
|
|
39
37
|
MessageParam,
|
|
38
|
+
AnthropicMessageParamConverter,
|
|
40
39
|
]
|
|
41
40
|
):
|
|
42
41
|
"""A convenience wrapper around the Anthropic `Message` response.
|
|
@@ -62,6 +61,10 @@ class AnthropicCallResponse(
|
|
|
62
61
|
```
|
|
63
62
|
"""
|
|
64
63
|
|
|
64
|
+
_message_converter: type[AnthropicMessageParamConverter] = (
|
|
65
|
+
AnthropicMessageParamConverter
|
|
66
|
+
)
|
|
67
|
+
|
|
65
68
|
_provider = "anthropic"
|
|
66
69
|
|
|
67
70
|
@computed_field
|
|
@@ -38,6 +38,7 @@ class AzureCallResponse(
|
|
|
38
38
|
ChatRequestMessage,
|
|
39
39
|
AzureCallParams,
|
|
40
40
|
UserMessage,
|
|
41
|
+
AzureMessageParamConverter,
|
|
41
42
|
]
|
|
42
43
|
):
|
|
43
44
|
"""A convenience wrapper around the Azure `ChatCompletion` response.
|
|
@@ -64,6 +65,7 @@ class AzureCallResponse(
|
|
|
64
65
|
"""
|
|
65
66
|
|
|
66
67
|
response: SkipValidation[ChatCompletions]
|
|
68
|
+
_message_converter: type[AzureMessageParamConverter] = AzureMessageParamConverter
|
|
67
69
|
|
|
68
70
|
_provider = "azure"
|
|
69
71
|
|
mirascope/core/base/_create.py
CHANGED
|
@@ -187,7 +187,7 @@ def create_factory( # noqa: ANN202
|
|
|
187
187
|
start_time = datetime.datetime.now().timestamp() * 1000
|
|
188
188
|
response = await create(stream=False, **call_kwargs)
|
|
189
189
|
end_time = datetime.datetime.now().timestamp() * 1000
|
|
190
|
-
output = TCallResponse(
|
|
190
|
+
output = TCallResponse( # pyright: ignore [reportCallIssue]
|
|
191
191
|
metadata=get_metadata(fn, dynamic_config),
|
|
192
192
|
response=response,
|
|
193
193
|
tool_types=tool_types, # pyright: ignore [reportArgumentType]
|
|
@@ -231,7 +231,7 @@ def create_factory( # noqa: ANN202
|
|
|
231
231
|
start_time = datetime.datetime.now().timestamp() * 1000
|
|
232
232
|
response = create(stream=False, **call_kwargs)
|
|
233
233
|
end_time = datetime.datetime.now().timestamp() * 1000
|
|
234
|
-
output = TCallResponse(
|
|
234
|
+
output = TCallResponse( # pyright: ignore [reportCallIssue]
|
|
235
235
|
metadata=get_metadata(fn, dynamic_config),
|
|
236
236
|
response=response,
|
|
237
237
|
tool_types=tool_types, # pyright: ignore [reportArgumentType]
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Internal Utilities."""
|
|
2
2
|
|
|
3
|
+
from ._base_message_param_converter import BaseMessageParamConverter
|
|
3
4
|
from ._base_type import BaseType, is_base_type
|
|
4
5
|
from ._convert_base_model_to_base_tool import convert_base_model_to_base_tool
|
|
5
6
|
from ._convert_base_type_to_base_tool import convert_base_type_to_base_tool
|
|
@@ -46,8 +47,10 @@ from ._setup_extract_tool import setup_extract_tool
|
|
|
46
47
|
__all__ = [
|
|
47
48
|
"DEFAULT_TOOL_DOCSTRING",
|
|
48
49
|
"AsyncCreateFn",
|
|
50
|
+
"BaseMessageParamConverter",
|
|
49
51
|
"BaseType",
|
|
50
52
|
"CalculateCost",
|
|
53
|
+
"CallDecorator",
|
|
51
54
|
"CreateFn",
|
|
52
55
|
"GetJsonOutput",
|
|
53
56
|
"HandleStream",
|
|
@@ -19,7 +19,7 @@ from pydantic import (
|
|
|
19
19
|
)
|
|
20
20
|
|
|
21
21
|
from ..costs import calculate_cost
|
|
22
|
-
from ._utils import BaseType, get_common_usage
|
|
22
|
+
from ._utils import BaseMessageParamConverter, BaseType, get_common_usage
|
|
23
23
|
from .call_kwargs import BaseCallKwargs
|
|
24
24
|
from .call_params import BaseCallParams
|
|
25
25
|
from .dynamic_config import BaseDynamicConfig
|
|
@@ -40,6 +40,9 @@ _ToolMessageParamT = TypeVar("_ToolMessageParamT", bound=Any)
|
|
|
40
40
|
_CallParamsT = TypeVar("_CallParamsT", bound=BaseCallParams)
|
|
41
41
|
_UserMessageParamT = TypeVar("_UserMessageParamT", bound=Any)
|
|
42
42
|
_BaseCallResponseT = TypeVar("_BaseCallResponseT", bound="BaseCallResponse")
|
|
43
|
+
_BaseMessageParamConverterT = TypeVar(
|
|
44
|
+
"_BaseMessageParamConverterT", bound=BaseMessageParamConverter
|
|
45
|
+
)
|
|
43
46
|
|
|
44
47
|
|
|
45
48
|
def transform_tool_outputs(
|
|
@@ -97,6 +100,7 @@ class BaseCallResponse(
|
|
|
97
100
|
_MessageParamT,
|
|
98
101
|
_CallParamsT,
|
|
99
102
|
_UserMessageParamT,
|
|
103
|
+
_BaseMessageParamConverterT,
|
|
100
104
|
],
|
|
101
105
|
ABC,
|
|
102
106
|
):
|
|
@@ -131,6 +135,7 @@ class BaseCallResponse(
|
|
|
131
135
|
start_time: float
|
|
132
136
|
end_time: float
|
|
133
137
|
|
|
138
|
+
_message_converter: type[_BaseMessageParamConverterT]
|
|
134
139
|
_provider: ClassVar[str] = "NO PROVIDER"
|
|
135
140
|
_model: str = "NO MODEL"
|
|
136
141
|
|
|
@@ -313,6 +318,11 @@ class BaseCallResponse(
|
|
|
313
318
|
"""Provider-agnostic user message param."""
|
|
314
319
|
...
|
|
315
320
|
|
|
321
|
+
@property
|
|
322
|
+
def common_messages(self) -> list[BaseMessageParam]:
|
|
323
|
+
"""Provider-agnostic list of messages."""
|
|
324
|
+
return self._message_converter.from_provider(self.messages)
|
|
325
|
+
|
|
316
326
|
@property
|
|
317
327
|
def common_tools(self) -> list[Tool] | None:
|
|
318
328
|
"""Provider-agnostic tools."""
|
|
@@ -51,6 +51,7 @@ class BedrockCallResponse(
|
|
|
51
51
|
InternalBedrockMessageParam,
|
|
52
52
|
BedrockCallParams,
|
|
53
53
|
UserMessageTypeDef,
|
|
54
|
+
BedrockMessageParamConverter,
|
|
54
55
|
]
|
|
55
56
|
):
|
|
56
57
|
"""A convenience wrapper around the Bedrock `ChatCompletion` response.
|
|
@@ -78,6 +79,9 @@ class BedrockCallResponse(
|
|
|
78
79
|
"""
|
|
79
80
|
|
|
80
81
|
response: SkipValidation[SyncConverseResponseTypeDef | AsyncConverseResponseTypeDef]
|
|
82
|
+
_message_converter: type[BedrockMessageParamConverter] = (
|
|
83
|
+
BedrockMessageParamConverter
|
|
84
|
+
)
|
|
81
85
|
|
|
82
86
|
_provider = "bedrock"
|
|
83
87
|
|
|
@@ -35,6 +35,7 @@ class CohereCallResponse(
|
|
|
35
35
|
SkipValidation[ChatMessage],
|
|
36
36
|
CohereCallParams,
|
|
37
37
|
SkipValidation[ChatMessage],
|
|
38
|
+
CohereMessageParamConverter,
|
|
38
39
|
]
|
|
39
40
|
):
|
|
40
41
|
"""A convenience wrapper around the Cohere `ChatCompletion` response.
|
|
@@ -60,6 +61,8 @@ class CohereCallResponse(
|
|
|
60
61
|
```
|
|
61
62
|
"""
|
|
62
63
|
|
|
64
|
+
_message_converter: type[CohereMessageParamConverter] = CohereMessageParamConverter
|
|
65
|
+
|
|
63
66
|
_provider = "cohere"
|
|
64
67
|
|
|
65
68
|
@computed_field
|
|
@@ -36,6 +36,7 @@ class GeminiCallResponse(
|
|
|
36
36
|
ContentsType,
|
|
37
37
|
GeminiCallParams,
|
|
38
38
|
ContentDict,
|
|
39
|
+
GeminiMessageParamConverter,
|
|
39
40
|
]
|
|
40
41
|
):
|
|
41
42
|
"""A convenience wrapper around the Gemini API response.
|
|
@@ -62,6 +63,8 @@ class GeminiCallResponse(
|
|
|
62
63
|
```
|
|
63
64
|
"""
|
|
64
65
|
|
|
66
|
+
_message_converter: type[GeminiMessageParamConverter] = GeminiMessageParamConverter
|
|
67
|
+
|
|
65
68
|
_provider = "gemini"
|
|
66
69
|
|
|
67
70
|
@computed_field
|
|
@@ -40,6 +40,7 @@ class GoogleCallResponse(
|
|
|
40
40
|
ContentListUnion | ContentListUnionDict,
|
|
41
41
|
GoogleCallParams,
|
|
42
42
|
ContentDict,
|
|
43
|
+
GoogleMessageParamConverter,
|
|
43
44
|
]
|
|
44
45
|
):
|
|
45
46
|
"""A convenience wrapper around the Google API response.
|
|
@@ -66,6 +67,8 @@ class GoogleCallResponse(
|
|
|
66
67
|
```
|
|
67
68
|
"""
|
|
68
69
|
|
|
70
|
+
_message_converter: type[GoogleMessageParamConverter] = GoogleMessageParamConverter
|
|
71
|
+
|
|
69
72
|
_provider = "google"
|
|
70
73
|
|
|
71
74
|
@computed_field
|
mirascope/core/google/tool.py
CHANGED
|
@@ -70,17 +70,48 @@ class GoogleTool(BaseTool):
|
|
|
70
70
|
fn["parameters"] = model_schema
|
|
71
71
|
|
|
72
72
|
if "parameters" in fn:
|
|
73
|
+
# Resolve $defs and $ref
|
|
73
74
|
if "$defs" in fn["parameters"]:
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
75
|
+
defs = fn["parameters"].pop("$defs")
|
|
76
|
+
|
|
77
|
+
def resolve_refs(schema: dict[str, Any]) -> dict[str, Any]:
|
|
78
|
+
"""Recursively resolve $ref references using the $defs dictionary."""
|
|
79
|
+
# If this is a reference, resolve it
|
|
80
|
+
if "$ref" in schema:
|
|
81
|
+
ref = schema["$ref"]
|
|
82
|
+
if ref.startswith("#/$defs/"):
|
|
83
|
+
ref_key = ref.replace("#/$defs/", "")
|
|
84
|
+
if ref_key in defs:
|
|
85
|
+
# Merge the definition with the current schema (excluding $ref)
|
|
86
|
+
resolved = {
|
|
87
|
+
**{k: v for k, v in schema.items() if k != "$ref"},
|
|
88
|
+
**resolve_refs(defs[ref_key]),
|
|
89
|
+
}
|
|
90
|
+
return resolved
|
|
91
|
+
|
|
92
|
+
# Process all other keys recursively
|
|
93
|
+
result = {}
|
|
94
|
+
for key, value in schema.items():
|
|
95
|
+
if isinstance(value, dict):
|
|
96
|
+
result[key] = resolve_refs(value)
|
|
97
|
+
elif isinstance(value, list):
|
|
98
|
+
result[key] = [
|
|
99
|
+
resolve_refs(item) if isinstance(item, dict) else item
|
|
100
|
+
for item in value
|
|
101
|
+
]
|
|
102
|
+
else:
|
|
103
|
+
result[key] = value
|
|
104
|
+
return result
|
|
105
|
+
|
|
106
|
+
# Resolve all references in the parameters
|
|
107
|
+
fn["parameters"] = resolve_refs(fn["parameters"])
|
|
78
108
|
|
|
79
109
|
def handle_enum_schema(prop_schema: dict[str, Any]) -> dict[str, Any]:
|
|
80
110
|
if "enum" in prop_schema:
|
|
81
111
|
prop_schema["format"] = "enum"
|
|
82
112
|
return prop_schema
|
|
83
113
|
|
|
114
|
+
# Process properties after resolving references
|
|
84
115
|
fn["parameters"]["properties"] = {
|
|
85
116
|
prop: {
|
|
86
117
|
key: value
|
|
@@ -35,6 +35,7 @@ class GroqCallResponse(
|
|
|
35
35
|
ChatCompletionMessageParam,
|
|
36
36
|
GroqCallParams,
|
|
37
37
|
ChatCompletionUserMessageParam,
|
|
38
|
+
GroqMessageParamConverter,
|
|
38
39
|
]
|
|
39
40
|
):
|
|
40
41
|
"""A convenience wrapper around the Groq `ChatCompletion` response.
|
|
@@ -59,6 +60,8 @@ class GroqCallResponse(
|
|
|
59
60
|
```
|
|
60
61
|
"""
|
|
61
62
|
|
|
63
|
+
_message_converter: type[GroqMessageParamConverter] = GroqMessageParamConverter
|
|
64
|
+
|
|
62
65
|
_provider = "groq"
|
|
63
66
|
|
|
64
67
|
@computed_field
|
|
@@ -38,6 +38,7 @@ class MistralCallResponse(
|
|
|
38
38
|
AssistantMessage | SystemMessage | ToolMessage | UserMessage,
|
|
39
39
|
MistralCallParams,
|
|
40
40
|
UserMessage,
|
|
41
|
+
MistralMessageParamConverter,
|
|
41
42
|
]
|
|
42
43
|
):
|
|
43
44
|
"""A convenience wrapper around the Mistral `ChatCompletion` response.
|
|
@@ -62,6 +63,10 @@ class MistralCallResponse(
|
|
|
62
63
|
```
|
|
63
64
|
"""
|
|
64
65
|
|
|
66
|
+
_message_converter: type[MistralMessageParamConverter] = (
|
|
67
|
+
MistralMessageParamConverter
|
|
68
|
+
)
|
|
69
|
+
|
|
65
70
|
_provider = "mistral"
|
|
66
71
|
|
|
67
72
|
@property
|
|
@@ -58,6 +58,7 @@ class OpenAICallResponse(
|
|
|
58
58
|
ChatCompletionMessageParam,
|
|
59
59
|
OpenAICallParams,
|
|
60
60
|
ChatCompletionUserMessageParam,
|
|
61
|
+
OpenAIMessageParamConverter,
|
|
61
62
|
]
|
|
62
63
|
):
|
|
63
64
|
"""A convenience wrapper around the OpenAI `ChatCompletion` response.
|
|
@@ -84,6 +85,7 @@ class OpenAICallResponse(
|
|
|
84
85
|
"""
|
|
85
86
|
|
|
86
87
|
response: SkipValidation[ChatCompletion]
|
|
88
|
+
_message_converter: type[OpenAIMessageParamConverter] = OpenAIMessageParamConverter
|
|
87
89
|
|
|
88
90
|
_provider = "openai"
|
|
89
91
|
|
|
@@ -30,6 +30,7 @@ class VertexCallResponse(
|
|
|
30
30
|
Content,
|
|
31
31
|
VertexCallParams,
|
|
32
32
|
Content,
|
|
33
|
+
VertexMessageParamConverter,
|
|
33
34
|
]
|
|
34
35
|
):
|
|
35
36
|
"""A convenience wrapper around the Vertex AI `GenerateContentResponse`.
|
|
@@ -55,6 +56,8 @@ class VertexCallResponse(
|
|
|
55
56
|
```
|
|
56
57
|
"""
|
|
57
58
|
|
|
59
|
+
_message_converter: type[VertexMessageParamConverter] = VertexMessageParamConverter
|
|
60
|
+
|
|
58
61
|
_provider = "vertex"
|
|
59
62
|
|
|
60
63
|
@computed_field
|
mirascope/llm/__init__.py
CHANGED
|
@@ -1,14 +1,18 @@
|
|
|
1
1
|
from ..core import CostMetadata, LocalProvider, Provider, calculate_cost
|
|
2
|
+
from ._call import call
|
|
3
|
+
from ._context import context
|
|
4
|
+
from ._override import override
|
|
2
5
|
from .call_response import CallResponse
|
|
3
|
-
from .
|
|
4
|
-
from .llm_override import override
|
|
6
|
+
from .stream import Stream
|
|
5
7
|
|
|
6
8
|
__all__ = [
|
|
7
9
|
"CallResponse",
|
|
8
10
|
"CostMetadata",
|
|
9
11
|
"LocalProvider",
|
|
10
12
|
"Provider",
|
|
13
|
+
"Stream",
|
|
11
14
|
"calculate_cost",
|
|
12
15
|
"call",
|
|
16
|
+
"context",
|
|
13
17
|
"override",
|
|
14
18
|
]
|
|
@@ -20,6 +20,7 @@ from ..core.base import (
|
|
|
20
20
|
from ..core.base._utils import fn_is_async
|
|
21
21
|
from ..core.base.stream_config import StreamConfig
|
|
22
22
|
from ..core.base.types import LocalProvider, Provider
|
|
23
|
+
from ._context import CallArgs, apply_context_overrides_to_call_args
|
|
23
24
|
from ._protocols import (
|
|
24
25
|
AsyncLLMFunctionDecorator,
|
|
25
26
|
CallDecorator,
|
|
@@ -49,24 +50,36 @@ _ResultT = TypeVar("_ResultT")
|
|
|
49
50
|
def _get_local_provider_call(
|
|
50
51
|
provider: LocalProvider,
|
|
51
52
|
client: Any | None, # noqa: ANN401
|
|
53
|
+
is_async: bool,
|
|
52
54
|
) -> tuple[Callable, Any | None]:
|
|
53
55
|
if provider == "ollama":
|
|
54
56
|
from ..core.openai import openai_call
|
|
55
57
|
|
|
56
58
|
if client:
|
|
57
59
|
return openai_call, client
|
|
58
|
-
|
|
60
|
+
if is_async:
|
|
61
|
+
from openai import AsyncOpenAI
|
|
59
62
|
|
|
60
|
-
|
|
63
|
+
client = AsyncOpenAI(api_key="ollama", base_url="http://localhost:11434/v1")
|
|
64
|
+
else:
|
|
65
|
+
from openai import OpenAI
|
|
66
|
+
|
|
67
|
+
client = OpenAI(api_key="ollama", base_url="http://localhost:11434/v1")
|
|
61
68
|
return openai_call, client
|
|
62
69
|
else: # provider == "vllm"
|
|
63
70
|
from ..core.openai import openai_call
|
|
64
71
|
|
|
65
72
|
if client:
|
|
66
73
|
return openai_call, client
|
|
67
|
-
from openai import OpenAI
|
|
68
74
|
|
|
69
|
-
|
|
75
|
+
if is_async:
|
|
76
|
+
from openai import AsyncOpenAI
|
|
77
|
+
|
|
78
|
+
client = AsyncOpenAI(api_key="ollama", base_url="http://localhost:8000/v1")
|
|
79
|
+
else:
|
|
80
|
+
from openai import OpenAI
|
|
81
|
+
|
|
82
|
+
client = OpenAI(api_key="ollama", base_url="http://localhost:8000/v1")
|
|
70
83
|
return openai_call, client
|
|
71
84
|
|
|
72
85
|
|
|
@@ -193,13 +206,9 @@ def _call(
|
|
|
193
206
|
]
|
|
194
207
|
):
|
|
195
208
|
"""Decorator for defining a function that calls a language model."""
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
)
|
|
200
|
-
else:
|
|
201
|
-
provider_call = _get_provider_call(cast(Provider, provider))
|
|
202
|
-
_original_args = {
|
|
209
|
+
# Store original call args that will be used for each function call
|
|
210
|
+
original_call_args: CallArgs = {
|
|
211
|
+
"provider": provider,
|
|
203
212
|
"model": model,
|
|
204
213
|
"stream": stream,
|
|
205
214
|
"tools": tools,
|
|
@@ -214,36 +223,116 @@ def _call(
|
|
|
214
223
|
fn: Callable[_P, _R | Awaitable[_R]],
|
|
215
224
|
) -> Callable[
|
|
216
225
|
_P,
|
|
217
|
-
CallResponse
|
|
226
|
+
CallResponse
|
|
227
|
+
| Stream
|
|
228
|
+
| _ResponseModelT
|
|
229
|
+
| _ParsedOutputT
|
|
230
|
+
| (_ResponseModelT | CallResponse)
|
|
231
|
+
| Awaitable[CallResponse]
|
|
232
|
+
| Awaitable[Stream]
|
|
233
|
+
| Awaitable[_ResponseModelT]
|
|
234
|
+
| Awaitable[_ParsedOutputT]
|
|
235
|
+
| Awaitable[(_ResponseModelT | CallResponse)],
|
|
218
236
|
]:
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
if fn_is_async(decorated):
|
|
237
|
+
if fn_is_async(fn):
|
|
222
238
|
|
|
223
|
-
@wraps(
|
|
239
|
+
@wraps(fn)
|
|
224
240
|
async def inner_async(
|
|
225
241
|
*args: _P.args, **kwargs: _P.kwargs
|
|
226
|
-
) ->
|
|
242
|
+
) -> (
|
|
243
|
+
CallResponse
|
|
244
|
+
| Stream
|
|
245
|
+
| _ResponseModelT
|
|
246
|
+
| _ParsedOutputT
|
|
247
|
+
| (_ResponseModelT | CallResponse)
|
|
248
|
+
):
|
|
249
|
+
# Apply any context overrides to the original call args
|
|
250
|
+
effective_call_args = apply_context_overrides_to_call_args(
|
|
251
|
+
original_call_args
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
# Get the appropriate provider call function with the possibly overridden provider
|
|
255
|
+
effective_provider = effective_call_args["provider"]
|
|
256
|
+
effective_client = effective_call_args["client"]
|
|
257
|
+
|
|
258
|
+
if effective_provider in get_args(LocalProvider):
|
|
259
|
+
provider_call, effective_client = _get_local_provider_call(
|
|
260
|
+
cast(LocalProvider, effective_provider),
|
|
261
|
+
effective_client,
|
|
262
|
+
True,
|
|
263
|
+
)
|
|
264
|
+
effective_call_args["client"] = effective_client
|
|
265
|
+
else:
|
|
266
|
+
provider_call = _get_provider_call(
|
|
267
|
+
cast(Provider, effective_provider)
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
# Use the provider-specific call function with overridden args
|
|
271
|
+
call_kwargs = dict(effective_call_args)
|
|
272
|
+
del call_kwargs[
|
|
273
|
+
"provider"
|
|
274
|
+
] # Remove provider as it's not a parameter to provider_call
|
|
275
|
+
|
|
276
|
+
# Get decorated function using provider_call
|
|
277
|
+
decorated = provider_call(**call_kwargs)(fn)
|
|
278
|
+
|
|
279
|
+
# Call the decorated function and wrap the result
|
|
227
280
|
result = await decorated(*args, **kwargs)
|
|
228
281
|
return _wrap_result(result)
|
|
229
282
|
|
|
230
|
-
inner_async.
|
|
231
|
-
inner_async._original_provider_call = provider_call # pyright: ignore [reportAttributeAccessIssue]
|
|
283
|
+
inner_async._original_call_args = original_call_args # pyright: ignore [reportAttributeAccessIssue]
|
|
232
284
|
inner_async._original_fn = fn # pyright: ignore [reportAttributeAccessIssue]
|
|
233
|
-
inner_async._original_provider = provider # pyright: ignore [reportAttributeAccessIssue]
|
|
234
285
|
|
|
235
286
|
return inner_async
|
|
236
287
|
else:
|
|
237
288
|
|
|
238
|
-
@wraps(
|
|
239
|
-
def inner(
|
|
289
|
+
@wraps(fn)
|
|
290
|
+
def inner(
|
|
291
|
+
*args: _P.args, **kwargs: _P.kwargs
|
|
292
|
+
) -> (
|
|
293
|
+
CallResponse
|
|
294
|
+
| Stream
|
|
295
|
+
| _ResponseModelT
|
|
296
|
+
| _ParsedOutputT
|
|
297
|
+
| (_ResponseModelT | CallResponse)
|
|
298
|
+
):
|
|
299
|
+
# Apply any context overrides to the original call args
|
|
300
|
+
effective_call_args = apply_context_overrides_to_call_args(
|
|
301
|
+
original_call_args
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
# Get the appropriate provider call function with the possibly overridden provider
|
|
305
|
+
effective_provider = effective_call_args["provider"]
|
|
306
|
+
effective_client = effective_call_args["client"]
|
|
307
|
+
|
|
308
|
+
if effective_provider in get_args(LocalProvider):
|
|
309
|
+
provider_call, effective_client = _get_local_provider_call(
|
|
310
|
+
cast(LocalProvider, effective_provider),
|
|
311
|
+
effective_client,
|
|
312
|
+
False,
|
|
313
|
+
)
|
|
314
|
+
effective_call_args["client"] = effective_client
|
|
315
|
+
else:
|
|
316
|
+
provider_call = _get_provider_call(
|
|
317
|
+
cast(Provider, effective_provider)
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
# Use the provider-specific call function with overridden args
|
|
321
|
+
call_kwargs = dict(effective_call_args)
|
|
322
|
+
del call_kwargs[
|
|
323
|
+
"provider"
|
|
324
|
+
] # Remove provider as it's not a parameter to provider_call
|
|
325
|
+
|
|
326
|
+
# Get decorated function using provider_call
|
|
327
|
+
decorated = provider_call(**call_kwargs)(fn)
|
|
328
|
+
|
|
329
|
+
# Call the decorated function and wrap the result
|
|
240
330
|
result = decorated(*args, **kwargs)
|
|
241
331
|
return _wrap_result(result)
|
|
242
332
|
|
|
243
|
-
inner.
|
|
244
|
-
inner._original_provider_call = provider_call # pyright: ignore [reportAttributeAccessIssue]
|
|
333
|
+
inner._original_call_args = original_call_args # pyright: ignore [reportAttributeAccessIssue]
|
|
245
334
|
inner._original_fn = fn # pyright: ignore [reportAttributeAccessIssue]
|
|
246
|
-
|
|
335
|
+
|
|
247
336
|
return inner
|
|
248
337
|
|
|
249
338
|
return wrapper # pyright: ignore [reportReturnType]
|