mirascope 1.19.0__py3-none-any.whl → 1.20.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mirascope/__init__.py +4 -0
- mirascope/beta/openai/realtime/realtime.py +7 -8
- mirascope/beta/openai/realtime/tool.py +2 -2
- mirascope/core/__init__.py +8 -1
- mirascope/core/anthropic/_utils/__init__.py +0 -2
- mirascope/core/anthropic/_utils/_convert_message_params.py +1 -7
- mirascope/core/anthropic/_utils/_message_param_converter.py +48 -31
- mirascope/core/anthropic/call_response.py +7 -9
- mirascope/core/anthropic/call_response_chunk.py +10 -0
- mirascope/core/anthropic/stream.py +6 -8
- mirascope/core/azure/_utils/__init__.py +0 -2
- mirascope/core/azure/call_response.py +7 -10
- mirascope/core/azure/call_response_chunk.py +6 -1
- mirascope/core/azure/stream.py +6 -8
- mirascope/core/base/__init__.py +2 -1
- mirascope/core/base/_utils/__init__.py +2 -0
- mirascope/core/base/_utils/_get_image_dimensions.py +39 -0
- mirascope/core/base/call_response.py +36 -6
- mirascope/core/base/call_response_chunk.py +15 -1
- mirascope/core/base/stream.py +25 -3
- mirascope/core/base/types.py +276 -2
- mirascope/core/bedrock/_utils/__init__.py +0 -2
- mirascope/core/bedrock/call_response.py +7 -10
- mirascope/core/bedrock/call_response_chunk.py +6 -0
- mirascope/core/bedrock/stream.py +6 -10
- mirascope/core/cohere/_utils/__init__.py +0 -2
- mirascope/core/cohere/call_response.py +7 -10
- mirascope/core/cohere/call_response_chunk.py +6 -0
- mirascope/core/cohere/stream.py +5 -8
- mirascope/core/costs/__init__.py +5 -0
- mirascope/core/{anthropic/_utils/_calculate_cost.py → costs/_anthropic_calculate_cost.py} +45 -14
- mirascope/core/{azure/_utils/_calculate_cost.py → costs/_azure_calculate_cost.py} +3 -3
- mirascope/core/{bedrock/_utils/_calculate_cost.py → costs/_bedrock_calculate_cost.py} +3 -3
- mirascope/core/{cohere/_utils/_calculate_cost.py → costs/_cohere_calculate_cost.py} +12 -8
- mirascope/core/{gemini/_utils/_calculate_cost.py → costs/_gemini_calculate_cost.py} +7 -7
- mirascope/core/costs/_google_calculate_cost.py +427 -0
- mirascope/core/costs/_groq_calculate_cost.py +156 -0
- mirascope/core/costs/_litellm_calculate_cost.py +11 -0
- mirascope/core/costs/_mistral_calculate_cost.py +64 -0
- mirascope/core/costs/_openai_calculate_cost.py +416 -0
- mirascope/core/{vertex/_utils/_calculate_cost.py → costs/_vertex_calculate_cost.py} +8 -7
- mirascope/core/{xai/_utils/_calculate_cost.py → costs/_xai_calculate_cost.py} +9 -9
- mirascope/core/costs/calculate_cost.py +86 -0
- mirascope/core/gemini/_utils/__init__.py +0 -2
- mirascope/core/gemini/call_response.py +7 -10
- mirascope/core/gemini/call_response_chunk.py +6 -1
- mirascope/core/gemini/stream.py +5 -8
- mirascope/core/google/_utils/__init__.py +0 -2
- mirascope/core/google/_utils/_setup_call.py +21 -2
- mirascope/core/google/call_response.py +9 -10
- mirascope/core/google/call_response_chunk.py +6 -1
- mirascope/core/google/stream.py +5 -8
- mirascope/core/groq/_utils/__init__.py +0 -2
- mirascope/core/groq/call_response.py +22 -10
- mirascope/core/groq/call_response_chunk.py +6 -0
- mirascope/core/groq/stream.py +5 -8
- mirascope/core/litellm/call_response.py +3 -4
- mirascope/core/litellm/stream.py +30 -22
- mirascope/core/mistral/_utils/__init__.py +0 -2
- mirascope/core/mistral/call_response.py +7 -10
- mirascope/core/mistral/call_response_chunk.py +6 -0
- mirascope/core/mistral/stream.py +5 -8
- mirascope/core/openai/_utils/__init__.py +0 -2
- mirascope/core/openai/_utils/_convert_message_params.py +4 -4
- mirascope/core/openai/call_response.py +30 -10
- mirascope/core/openai/call_response_chunk.py +6 -0
- mirascope/core/openai/stream.py +5 -8
- mirascope/core/vertex/_utils/__init__.py +0 -2
- mirascope/core/vertex/call_response.py +5 -10
- mirascope/core/vertex/call_response_chunk.py +6 -0
- mirascope/core/vertex/stream.py +5 -8
- mirascope/core/xai/_utils/__init__.py +1 -2
- mirascope/core/xai/call_response.py +0 -11
- mirascope/llm/__init__.py +9 -2
- mirascope/llm/_protocols.py +8 -28
- mirascope/llm/call_response.py +6 -6
- mirascope/llm/call_response_chunk.py +12 -3
- mirascope/llm/llm_call.py +21 -23
- mirascope/llm/llm_override.py +56 -27
- mirascope/llm/stream.py +7 -7
- mirascope/llm/tool.py +1 -1
- mirascope/retries/fallback.py +1 -1
- {mirascope-1.19.0.dist-info → mirascope-1.20.0.dist-info}/METADATA +1 -1
- {mirascope-1.19.0.dist-info → mirascope-1.20.0.dist-info}/RECORD +86 -82
- mirascope/core/google/_utils/_calculate_cost.py +0 -215
- mirascope/core/groq/_utils/_calculate_cost.py +0 -69
- mirascope/core/mistral/_utils/_calculate_cost.py +0 -48
- mirascope/core/openai/_utils/_calculate_cost.py +0 -246
- {mirascope-1.19.0.dist-info → mirascope-1.20.0.dist-info}/WHEEL +0 -0
- {mirascope-1.19.0.dist-info → mirascope-1.20.0.dist-info}/licenses/LICENSE +0 -0
mirascope/core/openai/stream.py
CHANGED
|
@@ -21,7 +21,7 @@ from openai.types.chat.chat_completion_message_tool_call_param import Function
|
|
|
21
21
|
from openai.types.completion_usage import CompletionUsage
|
|
22
22
|
|
|
23
23
|
from ..base.stream import BaseStream
|
|
24
|
-
from .
|
|
24
|
+
from ..base.types import CostMetadata
|
|
25
25
|
from .call_params import OpenAICallParams
|
|
26
26
|
from .call_response import OpenAICallResponse
|
|
27
27
|
from .call_response_chunk import OpenAICallResponseChunk
|
|
@@ -101,13 +101,6 @@ class OpenAIStream(
|
|
|
101
101
|
|
|
102
102
|
return generator()
|
|
103
103
|
|
|
104
|
-
@property
|
|
105
|
-
def cost(self) -> float | None:
|
|
106
|
-
"""Returns the cost of the call."""
|
|
107
|
-
return calculate_cost(
|
|
108
|
-
self.input_tokens, self.cached_tokens, self.output_tokens, self.model
|
|
109
|
-
)
|
|
110
|
-
|
|
111
104
|
def _construct_message_param(
|
|
112
105
|
self,
|
|
113
106
|
tool_calls: list[ChatCompletionMessageToolCall] | None = None,
|
|
@@ -186,3 +179,7 @@ class OpenAIStream(
|
|
|
186
179
|
start_time=self.start_time,
|
|
187
180
|
end_time=self.end_time,
|
|
188
181
|
)
|
|
182
|
+
|
|
183
|
+
@property
|
|
184
|
+
def cost_metadata(self) -> CostMetadata:
|
|
185
|
+
return super().cost_metadata
|
|
@@ -1,13 +1,11 @@
|
|
|
1
1
|
"""Vertex utilities for decorator factories."""
|
|
2
2
|
|
|
3
|
-
from ._calculate_cost import calculate_cost
|
|
4
3
|
from ._convert_message_params import convert_message_params
|
|
5
4
|
from ._get_json_output import get_json_output
|
|
6
5
|
from ._handle_stream import handle_stream, handle_stream_async
|
|
7
6
|
from ._setup_call import setup_call
|
|
8
7
|
|
|
9
8
|
__all__ = [
|
|
10
|
-
"calculate_cost",
|
|
11
9
|
"convert_message_params",
|
|
12
10
|
"get_json_output",
|
|
13
11
|
"handle_stream",
|
|
@@ -11,8 +11,7 @@ from vertexai.generative_models import Content, GenerationResponse, Part, Tool
|
|
|
11
11
|
|
|
12
12
|
from .. import BaseMessageParam
|
|
13
13
|
from ..base import BaseCallResponse, transform_tool_outputs
|
|
14
|
-
from ..base.types import FinishReason
|
|
15
|
-
from ._utils import calculate_cost
|
|
14
|
+
from ..base.types import CostMetadata, FinishReason
|
|
16
15
|
from ._utils._convert_finish_reason_to_common_finish_reasons import (
|
|
17
16
|
_convert_finish_reasons_to_common_finish_reasons,
|
|
18
17
|
)
|
|
@@ -124,14 +123,6 @@ class VertexCallResponse(
|
|
|
124
123
|
"""Returns the number of output tokens."""
|
|
125
124
|
return self.usage.candidates_token_count
|
|
126
125
|
|
|
127
|
-
@computed_field
|
|
128
|
-
@property
|
|
129
|
-
def cost(self) -> float | None:
|
|
130
|
-
"""Returns the cost of the call."""
|
|
131
|
-
return calculate_cost(
|
|
132
|
-
self.input_tokens, self.cached_tokens, self.output_tokens, self.model
|
|
133
|
-
)
|
|
134
|
-
|
|
135
126
|
@computed_field
|
|
136
127
|
@cached_property
|
|
137
128
|
def message_param(self) -> Content:
|
|
@@ -205,3 +196,7 @@ class VertexCallResponse(
|
|
|
205
196
|
if not self.user_message_param:
|
|
206
197
|
return None
|
|
207
198
|
return VertexMessageParamConverter.from_provider([self.user_message_param])[0]
|
|
199
|
+
|
|
200
|
+
@property
|
|
201
|
+
def cost_metadata(self) -> CostMetadata:
|
|
202
|
+
return super().cost_metadata
|
|
@@ -6,6 +6,7 @@ usage docs: learn/streams.md#handling-streamed-responses
|
|
|
6
6
|
from vertexai.generative_models import FinishReason, GenerationResponse
|
|
7
7
|
|
|
8
8
|
from ..base import BaseCallResponseChunk, types
|
|
9
|
+
from ..base.types import CostMetadata
|
|
9
10
|
from ._utils._convert_finish_reason_to_common_finish_reasons import (
|
|
10
11
|
_convert_finish_reasons_to_common_finish_reasons,
|
|
11
12
|
)
|
|
@@ -86,6 +87,11 @@ class VertexCallResponseChunk(
|
|
|
86
87
|
"""Returns the number of output tokens."""
|
|
87
88
|
return None
|
|
88
89
|
|
|
90
|
+
@property
|
|
91
|
+
def cost_metadata(self) -> CostMetadata:
|
|
92
|
+
"""Returns the cost metadata."""
|
|
93
|
+
return super().cost_metadata
|
|
94
|
+
|
|
89
95
|
@property
|
|
90
96
|
def common_finish_reasons(self) -> list[types.FinishReason] | None:
|
|
91
97
|
return _convert_finish_reasons_to_common_finish_reasons(
|
mirascope/core/vertex/stream.py
CHANGED
|
@@ -16,7 +16,7 @@ from vertexai.generative_models import (
|
|
|
16
16
|
)
|
|
17
17
|
|
|
18
18
|
from ..base.stream import BaseStream
|
|
19
|
-
from .
|
|
19
|
+
from ..base.types import CostMetadata
|
|
20
20
|
from .call_params import VertexCallParams
|
|
21
21
|
from .call_response import VertexCallResponse
|
|
22
22
|
from .call_response_chunk import VertexCallResponseChunk
|
|
@@ -61,13 +61,6 @@ class VertexStream(
|
|
|
61
61
|
|
|
62
62
|
_provider = "vertex"
|
|
63
63
|
|
|
64
|
-
@property
|
|
65
|
-
def cost(self) -> float | None:
|
|
66
|
-
"""Returns the cost of the call."""
|
|
67
|
-
return calculate_cost(
|
|
68
|
-
self.input_tokens, self.cached_tokens, self.output_tokens, self.model
|
|
69
|
-
)
|
|
70
|
-
|
|
71
64
|
def _construct_message_param(
|
|
72
65
|
self,
|
|
73
66
|
tool_calls: list[FunctionCall] | None = None,
|
|
@@ -120,3 +113,7 @@ class VertexStream(
|
|
|
120
113
|
start_time=self.start_time,
|
|
121
114
|
end_time=self.end_time,
|
|
122
115
|
)
|
|
116
|
+
|
|
117
|
+
@property
|
|
118
|
+
def cost_metadata(self) -> CostMetadata:
|
|
119
|
+
return super().cost_metadata
|
|
@@ -3,10 +3,7 @@
|
|
|
3
3
|
usage docs: learn/calls.md#handling-responses
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
|
-
from pydantic import computed_field
|
|
7
|
-
|
|
8
6
|
from ..openai import OpenAICallResponse
|
|
9
|
-
from ._utils import calculate_cost
|
|
10
7
|
|
|
11
8
|
|
|
12
9
|
class XAICallResponse(OpenAICallResponse):
|
|
@@ -17,11 +14,3 @@ class XAICallResponse(OpenAICallResponse):
|
|
|
17
14
|
"""
|
|
18
15
|
|
|
19
16
|
_provider = "xai"
|
|
20
|
-
|
|
21
|
-
@computed_field
|
|
22
|
-
@property
|
|
23
|
-
def cost(self) -> float | None:
|
|
24
|
-
"""Returns the cost of the call."""
|
|
25
|
-
return calculate_cost(
|
|
26
|
-
self.input_tokens, self.cached_tokens, self.output_tokens, self.model
|
|
27
|
-
)
|
mirascope/llm/__init__.py
CHANGED
|
@@ -1,6 +1,13 @@
|
|
|
1
|
-
from
|
|
1
|
+
from ..core import LocalProvider, Provider, calculate_cost
|
|
2
2
|
from .call_response import CallResponse
|
|
3
3
|
from .llm_call import call
|
|
4
4
|
from .llm_override import override
|
|
5
5
|
|
|
6
|
-
__all__ = [
|
|
6
|
+
__all__ = [
|
|
7
|
+
"CallResponse",
|
|
8
|
+
"LocalProvider",
|
|
9
|
+
"Provider",
|
|
10
|
+
"calculate_cost",
|
|
11
|
+
"call",
|
|
12
|
+
"override",
|
|
13
|
+
]
|
mirascope/llm/_protocols.py
CHANGED
|
@@ -19,22 +19,23 @@ from typing import (
|
|
|
19
19
|
|
|
20
20
|
from pydantic import BaseModel
|
|
21
21
|
|
|
22
|
-
from
|
|
23
|
-
from
|
|
22
|
+
from ..core import BaseDynamicConfig, BaseTool
|
|
23
|
+
from ..core.base import (
|
|
24
24
|
BaseCallResponse,
|
|
25
25
|
BaseCallResponseChunk,
|
|
26
26
|
BaseType,
|
|
27
27
|
CommonCallParams,
|
|
28
28
|
)
|
|
29
|
-
from
|
|
29
|
+
from ..core.base._utils._protocols import (
|
|
30
30
|
AsyncLLMFunctionDecorator,
|
|
31
31
|
LLMFunctionDecorator,
|
|
32
32
|
SyncLLMFunctionDecorator,
|
|
33
33
|
)
|
|
34
|
-
from
|
|
35
|
-
from
|
|
36
|
-
from
|
|
37
|
-
from
|
|
34
|
+
from ..core.base.stream_config import StreamConfig
|
|
35
|
+
from ..core.base.types import LocalProvider, Provider
|
|
36
|
+
from .call_response import CallResponse
|
|
37
|
+
from .call_response_chunk import CallResponseChunk
|
|
38
|
+
from .stream import Stream
|
|
38
39
|
|
|
39
40
|
_BaseStreamT = TypeVar("_BaseStreamT", covariant=True)
|
|
40
41
|
_ResponseModelT = TypeVar("_ResponseModelT", bound=BaseModel | BaseType | Enum)
|
|
@@ -61,27 +62,6 @@ _BaseCallResponseChunkT = TypeVar(
|
|
|
61
62
|
)
|
|
62
63
|
|
|
63
64
|
|
|
64
|
-
Provider: TypeAlias = Literal[
|
|
65
|
-
"anthropic",
|
|
66
|
-
"azure",
|
|
67
|
-
"bedrock",
|
|
68
|
-
"cohere",
|
|
69
|
-
"gemini",
|
|
70
|
-
"google",
|
|
71
|
-
"groq",
|
|
72
|
-
"litellm",
|
|
73
|
-
"mistral",
|
|
74
|
-
"openai",
|
|
75
|
-
"vertex",
|
|
76
|
-
"xai",
|
|
77
|
-
]
|
|
78
|
-
|
|
79
|
-
LocalProvider: TypeAlias = Literal[
|
|
80
|
-
"ollama",
|
|
81
|
-
"vllm",
|
|
82
|
-
]
|
|
83
|
-
|
|
84
|
-
|
|
85
65
|
class _CallDecorator(
|
|
86
66
|
Protocol[
|
|
87
67
|
_BaseCallResponseT,
|
mirascope/llm/call_response.py
CHANGED
|
@@ -7,8 +7,8 @@ from typing import Any, TypeVar
|
|
|
7
7
|
|
|
8
8
|
from pydantic import computed_field
|
|
9
9
|
|
|
10
|
-
from
|
|
11
|
-
from
|
|
10
|
+
from ..core import BaseDynamicConfig
|
|
11
|
+
from ..core.base import (
|
|
12
12
|
BaseCallParams,
|
|
13
13
|
BaseCallResponse,
|
|
14
14
|
BaseMessageParam,
|
|
@@ -16,10 +16,10 @@ from mirascope.core.base import (
|
|
|
16
16
|
Usage,
|
|
17
17
|
transform_tool_outputs,
|
|
18
18
|
)
|
|
19
|
-
from
|
|
20
|
-
from
|
|
21
|
-
from
|
|
22
|
-
from
|
|
19
|
+
from ..core.base.message_param import ToolResultPart
|
|
20
|
+
from ..core.base.types import FinishReason
|
|
21
|
+
from ._response_metaclass import _ResponseMetaclass
|
|
22
|
+
from .tool import Tool
|
|
23
23
|
|
|
24
24
|
_ResponseT = TypeVar("_ResponseT")
|
|
25
25
|
|
|
@@ -4,9 +4,9 @@ from __future__ import annotations
|
|
|
4
4
|
|
|
5
5
|
from typing import Any, Generic, TypeVar
|
|
6
6
|
|
|
7
|
-
from
|
|
8
|
-
from
|
|
9
|
-
from
|
|
7
|
+
from ..core.base.call_response_chunk import BaseCallResponseChunk
|
|
8
|
+
from ..core.base.types import CostMetadata, FinishReason, Usage
|
|
9
|
+
from ._response_metaclass import _ResponseMetaclass
|
|
10
10
|
|
|
11
11
|
_ChunkT = TypeVar("_ChunkT")
|
|
12
12
|
|
|
@@ -40,6 +40,9 @@ class CallResponseChunk(
|
|
|
40
40
|
"__pydantic_private__",
|
|
41
41
|
"__class_getitem__",
|
|
42
42
|
"_properties",
|
|
43
|
+
"cost_metadata",
|
|
44
|
+
"finish_reasons",
|
|
45
|
+
"usage",
|
|
43
46
|
} | set(object.__getattribute__(self, "_properties"))
|
|
44
47
|
|
|
45
48
|
if name in special_names:
|
|
@@ -58,3 +61,9 @@ class CallResponseChunk(
|
|
|
58
61
|
@property
|
|
59
62
|
def usage(self) -> Usage | None:
|
|
60
63
|
return self._response.common_usage
|
|
64
|
+
|
|
65
|
+
@property
|
|
66
|
+
def cost_metadata(self) -> CostMetadata:
|
|
67
|
+
"""Get metadata required for cost calculation."""
|
|
68
|
+
|
|
69
|
+
return self._response.cost_metadata
|
mirascope/llm/llm_call.py
CHANGED
|
@@ -9,27 +9,25 @@ from typing import Any, ParamSpec, TypeVar, cast, get_args
|
|
|
9
9
|
|
|
10
10
|
from pydantic import BaseModel
|
|
11
11
|
|
|
12
|
-
from
|
|
13
|
-
from
|
|
12
|
+
from ..core import BaseTool
|
|
13
|
+
from ..core.base import (
|
|
14
14
|
BaseCallResponse,
|
|
15
15
|
BaseCallResponseChunk,
|
|
16
16
|
BaseStream,
|
|
17
17
|
BaseType,
|
|
18
18
|
CommonCallParams,
|
|
19
19
|
)
|
|
20
|
-
from
|
|
21
|
-
from mirascope.llm.call_response import CallResponse
|
|
22
|
-
from mirascope.llm.stream import Stream
|
|
23
|
-
|
|
20
|
+
from ..core.base._utils import fn_is_async
|
|
24
21
|
from ..core.base.stream_config import StreamConfig
|
|
22
|
+
from ..core.base.types import LocalProvider, Provider
|
|
25
23
|
from ._protocols import (
|
|
26
24
|
AsyncLLMFunctionDecorator,
|
|
27
25
|
CallDecorator,
|
|
28
26
|
LLMFunctionDecorator,
|
|
29
|
-
LocalProvider,
|
|
30
|
-
Provider,
|
|
31
27
|
SyncLLMFunctionDecorator,
|
|
32
28
|
)
|
|
29
|
+
from .call_response import CallResponse
|
|
30
|
+
from .stream import Stream
|
|
33
31
|
|
|
34
32
|
_P = ParamSpec("_P")
|
|
35
33
|
_R = TypeVar("_R")
|
|
@@ -53,7 +51,7 @@ def _get_local_provider_call(
|
|
|
53
51
|
client: Any | None, # noqa: ANN401
|
|
54
52
|
) -> tuple[Callable, Any | None]:
|
|
55
53
|
if provider == "ollama":
|
|
56
|
-
from
|
|
54
|
+
from ..core.openai import openai_call
|
|
57
55
|
|
|
58
56
|
if client:
|
|
59
57
|
return openai_call, client
|
|
@@ -62,7 +60,7 @@ def _get_local_provider_call(
|
|
|
62
60
|
client = OpenAI(api_key="ollama", base_url="http://localhost:11434/v1")
|
|
63
61
|
return openai_call, client
|
|
64
62
|
else: # provider == "vllm"
|
|
65
|
-
from
|
|
63
|
+
from ..core.openai import openai_call
|
|
66
64
|
|
|
67
65
|
if client:
|
|
68
66
|
return openai_call, client
|
|
@@ -75,51 +73,51 @@ def _get_local_provider_call(
|
|
|
75
73
|
def _get_provider_call(provider: Provider) -> Callable:
|
|
76
74
|
"""Returns the provider-specific call decorator based on the provider name."""
|
|
77
75
|
if provider == "anthropic":
|
|
78
|
-
from
|
|
76
|
+
from ..core.anthropic import anthropic_call
|
|
79
77
|
|
|
80
78
|
return anthropic_call
|
|
81
79
|
elif provider == "azure":
|
|
82
|
-
from
|
|
80
|
+
from ..core.azure import azure_call
|
|
83
81
|
|
|
84
82
|
return azure_call
|
|
85
83
|
elif provider == "bedrock":
|
|
86
|
-
from
|
|
84
|
+
from ..core.bedrock import bedrock_call
|
|
87
85
|
|
|
88
86
|
return bedrock_call
|
|
89
87
|
elif provider == "cohere":
|
|
90
|
-
from
|
|
88
|
+
from ..core.cohere import cohere_call
|
|
91
89
|
|
|
92
90
|
return cohere_call
|
|
93
91
|
elif provider == "gemini":
|
|
94
|
-
from
|
|
92
|
+
from ..core.gemini import gemini_call
|
|
95
93
|
|
|
96
94
|
return gemini_call
|
|
97
95
|
elif provider == "google":
|
|
98
|
-
from
|
|
96
|
+
from ..core.google import google_call
|
|
99
97
|
|
|
100
98
|
return google_call
|
|
101
99
|
elif provider == "groq":
|
|
102
|
-
from
|
|
100
|
+
from ..core.groq import groq_call
|
|
103
101
|
|
|
104
102
|
return groq_call
|
|
105
103
|
elif provider == "litellm":
|
|
106
|
-
from
|
|
104
|
+
from ..core.litellm import litellm_call
|
|
107
105
|
|
|
108
106
|
return litellm_call
|
|
109
107
|
elif provider == "mistral":
|
|
110
|
-
from
|
|
108
|
+
from ..core.mistral import mistral_call
|
|
111
109
|
|
|
112
110
|
return mistral_call
|
|
113
111
|
elif provider == "openai":
|
|
114
|
-
from
|
|
112
|
+
from ..core.openai import openai_call
|
|
115
113
|
|
|
116
114
|
return openai_call
|
|
117
115
|
elif provider == "vertex":
|
|
118
|
-
from
|
|
116
|
+
from ..core.vertex import vertex_call
|
|
119
117
|
|
|
120
118
|
return vertex_call
|
|
121
119
|
elif provider == "xai":
|
|
122
|
-
from
|
|
120
|
+
from ..core.xai import xai_call
|
|
123
121
|
|
|
124
122
|
return xai_call
|
|
125
123
|
raise ValueError(f"Unsupported provider: {provider}")
|
|
@@ -264,7 +262,7 @@ template.
|
|
|
264
262
|
Example:
|
|
265
263
|
|
|
266
264
|
```python
|
|
267
|
-
from
|
|
265
|
+
from ..llm import call
|
|
268
266
|
|
|
269
267
|
|
|
270
268
|
@call(provider="openai", model="gpt-4o-mini")
|
mirascope/llm/llm_override.py
CHANGED
|
@@ -5,22 +5,22 @@ from __future__ import annotations
|
|
|
5
5
|
from collections.abc import Callable
|
|
6
6
|
from typing import TYPE_CHECKING, Any, Literal, ParamSpec, TypeVar, overload
|
|
7
7
|
|
|
8
|
-
from
|
|
9
|
-
from
|
|
10
|
-
from
|
|
8
|
+
from ..core.base import CommonCallParams
|
|
9
|
+
from ..core.base.types import LocalProvider, Provider
|
|
10
|
+
from .llm_call import _call
|
|
11
11
|
|
|
12
12
|
if TYPE_CHECKING:
|
|
13
|
-
from
|
|
14
|
-
from
|
|
15
|
-
from
|
|
16
|
-
from
|
|
17
|
-
from
|
|
18
|
-
from
|
|
19
|
-
from
|
|
20
|
-
from
|
|
21
|
-
from
|
|
22
|
-
from
|
|
23
|
-
from
|
|
13
|
+
from ..core.anthropic import AnthropicCallParams
|
|
14
|
+
from ..core.azure import AzureCallParams
|
|
15
|
+
from ..core.bedrock import BedrockCallParams
|
|
16
|
+
from ..core.cohere import CohereCallParams
|
|
17
|
+
from ..core.gemini import GeminiCallParams
|
|
18
|
+
from ..core.google import GoogleCallParams
|
|
19
|
+
from ..core.groq import GroqCallParams
|
|
20
|
+
from ..core.litellm import LiteLLMCallParams
|
|
21
|
+
from ..core.mistral import MistralCallParams
|
|
22
|
+
from ..core.openai import OpenAICallParams
|
|
23
|
+
from ..core.vertex import VertexCallParams
|
|
24
24
|
else:
|
|
25
25
|
AnthropicCallParams = AzureCallParams = BedrockCallParams = CohereCallParams = (
|
|
26
26
|
GeminiCallParams
|
|
@@ -42,6 +42,8 @@ def override(
|
|
|
42
42
|
call_params: CommonCallParams | AnthropicCallParams | None = None,
|
|
43
43
|
client: Any = None, # noqa: ANN401
|
|
44
44
|
) -> Callable[_P, _R]: ...
|
|
45
|
+
|
|
46
|
+
|
|
45
47
|
@overload
|
|
46
48
|
def override(
|
|
47
49
|
provider_agnostic_call: Callable[_P, _R],
|
|
@@ -51,6 +53,8 @@ def override(
|
|
|
51
53
|
call_params: CommonCallParams | AzureCallParams | None = None,
|
|
52
54
|
client: Any = None, # noqa: ANN401
|
|
53
55
|
) -> Callable[_P, _R]: ...
|
|
56
|
+
|
|
57
|
+
|
|
54
58
|
@overload
|
|
55
59
|
def override(
|
|
56
60
|
provider_agnostic_call: Callable[_P, _R],
|
|
@@ -60,6 +64,8 @@ def override(
|
|
|
60
64
|
call_params: CommonCallParams | BedrockCallParams | None = None,
|
|
61
65
|
client: Any = None, # noqa: ANN401
|
|
62
66
|
) -> Callable[_P, _R]: ...
|
|
67
|
+
|
|
68
|
+
|
|
63
69
|
@overload
|
|
64
70
|
def override(
|
|
65
71
|
provider_agnostic_call: Callable[_P, _R],
|
|
@@ -69,6 +75,8 @@ def override(
|
|
|
69
75
|
call_params: CommonCallParams | CohereCallParams | None = None,
|
|
70
76
|
client: Any = None, # noqa: ANN401
|
|
71
77
|
) -> Callable[_P, _R]: ...
|
|
78
|
+
|
|
79
|
+
|
|
72
80
|
@overload
|
|
73
81
|
def override(
|
|
74
82
|
provider_agnostic_call: Callable[_P, _R],
|
|
@@ -78,6 +86,8 @@ def override(
|
|
|
78
86
|
call_params: CommonCallParams | GeminiCallParams | None = None,
|
|
79
87
|
client: Any = None, # noqa: ANN401
|
|
80
88
|
) -> Callable[_P, _R]: ...
|
|
89
|
+
|
|
90
|
+
|
|
81
91
|
@overload
|
|
82
92
|
def override(
|
|
83
93
|
provider_agnostic_call: Callable[_P, _R],
|
|
@@ -98,6 +108,8 @@ def override(
|
|
|
98
108
|
call_params: CommonCallParams | GroqCallParams | None = None,
|
|
99
109
|
client: Any = None, # noqa: ANN401
|
|
100
110
|
) -> Callable[_P, _R]: ...
|
|
111
|
+
|
|
112
|
+
|
|
101
113
|
@overload
|
|
102
114
|
def override(
|
|
103
115
|
provider_agnostic_call: Callable[_P, _R],
|
|
@@ -107,6 +119,8 @@ def override(
|
|
|
107
119
|
call_params: CommonCallParams | MistralCallParams | None = None,
|
|
108
120
|
client: Any = None, # noqa: ANN401
|
|
109
121
|
) -> Callable[_P, _R]: ...
|
|
122
|
+
|
|
123
|
+
|
|
110
124
|
@overload
|
|
111
125
|
def override(
|
|
112
126
|
provider_agnostic_call: Callable[_P, _R],
|
|
@@ -116,6 +130,8 @@ def override(
|
|
|
116
130
|
call_params: CommonCallParams | OpenAICallParams | None = None,
|
|
117
131
|
client: Any = None, # noqa: ANN401
|
|
118
132
|
) -> Callable[_P, _R]: ...
|
|
133
|
+
|
|
134
|
+
|
|
119
135
|
@overload
|
|
120
136
|
def override(
|
|
121
137
|
provider_agnostic_call: Callable[_P, _R],
|
|
@@ -125,6 +141,8 @@ def override(
|
|
|
125
141
|
call_params: CommonCallParams | LiteLLMCallParams | None = None,
|
|
126
142
|
client: Any = None, # noqa: ANN401
|
|
127
143
|
) -> Callable[_P, _R]: ...
|
|
144
|
+
|
|
145
|
+
|
|
128
146
|
@overload
|
|
129
147
|
def override(
|
|
130
148
|
provider_agnostic_call: Callable[_P, _R],
|
|
@@ -134,6 +152,8 @@ def override(
|
|
|
134
152
|
call_params: CommonCallParams | VertexCallParams | None = None,
|
|
135
153
|
client: Any = None, # noqa: ANN401
|
|
136
154
|
) -> Callable[_P, _R]: ...
|
|
155
|
+
|
|
156
|
+
|
|
137
157
|
@overload
|
|
138
158
|
def override(
|
|
139
159
|
provider_agnostic_call: Callable[_P, _R],
|
|
@@ -143,6 +163,8 @@ def override(
|
|
|
143
163
|
call_params: CommonCallParams | None = None,
|
|
144
164
|
client: Any = None, # noqa: ANN401
|
|
145
165
|
) -> Callable[_P, _R]: ...
|
|
166
|
+
|
|
167
|
+
|
|
146
168
|
@overload
|
|
147
169
|
def override(
|
|
148
170
|
provider_agnostic_call: Callable[_P, _R],
|
|
@@ -157,7 +179,7 @@ def override(
|
|
|
157
179
|
def override(
|
|
158
180
|
provider_agnostic_call: Callable[_P, _R],
|
|
159
181
|
*,
|
|
160
|
-
provider: Provider | None = None,
|
|
182
|
+
provider: Provider | LocalProvider | None = None,
|
|
161
183
|
model: str | None = None,
|
|
162
184
|
call_params: CommonCallParams
|
|
163
185
|
| AnthropicCallParams
|
|
@@ -185,20 +207,27 @@ def override(
|
|
|
185
207
|
Returns:
|
|
186
208
|
The overridden function.
|
|
187
209
|
"""
|
|
188
|
-
|
|
189
|
-
if provider is not None and not model and call_params is None and client is None:
|
|
210
|
+
if (provider and not model) or (model and not provider):
|
|
190
211
|
raise ValueError(
|
|
191
|
-
"
|
|
212
|
+
"Provider and model must both be overridden if either is overridden."
|
|
192
213
|
)
|
|
193
214
|
|
|
215
|
+
original_provider = provider_agnostic_call._original_provider # pyright: ignore [reportFunctionMemberAccess]
|
|
216
|
+
original_args = provider_agnostic_call._original_args # pyright: ignore [reportFunctionMemberAccess]
|
|
217
|
+
|
|
218
|
+
# Note: if switching providers, we will always use `client` since `original_client`
|
|
219
|
+
# would be from a different provider and fail.
|
|
220
|
+
if provider and provider == original_provider:
|
|
221
|
+
client = client or original_args["client"]
|
|
222
|
+
|
|
194
223
|
return _call( # pyright: ignore [reportReturnType]
|
|
195
|
-
provider=provider or
|
|
196
|
-
model=model or
|
|
197
|
-
stream=
|
|
198
|
-
tools=
|
|
199
|
-
response_model=
|
|
200
|
-
output_parser=
|
|
201
|
-
json_mode=
|
|
202
|
-
client=client
|
|
203
|
-
call_params=call_params or
|
|
224
|
+
provider=provider or original_provider,
|
|
225
|
+
model=model or original_args["model"],
|
|
226
|
+
stream=original_args["stream"],
|
|
227
|
+
tools=original_args["tools"],
|
|
228
|
+
response_model=original_args["response_model"],
|
|
229
|
+
output_parser=original_args["output_parser"],
|
|
230
|
+
json_mode=original_args["json_mode"],
|
|
231
|
+
client=client,
|
|
232
|
+
call_params=call_params or original_args["call_params"],
|
|
204
233
|
)(provider_agnostic_call._original_fn) # pyright: ignore [reportFunctionMemberAccess]
|
mirascope/llm/stream.py
CHANGED
|
@@ -5,7 +5,7 @@ from __future__ import annotations
|
|
|
5
5
|
from collections.abc import AsyncGenerator, Generator
|
|
6
6
|
from typing import Any, Generic, TypeVar
|
|
7
7
|
|
|
8
|
-
from
|
|
8
|
+
from ..core.base import (
|
|
9
9
|
BaseCallParams,
|
|
10
10
|
BaseCallResponse,
|
|
11
11
|
BaseCallResponseChunk,
|
|
@@ -13,12 +13,12 @@ from mirascope.core.base import (
|
|
|
13
13
|
BaseMessageParam,
|
|
14
14
|
BaseTool,
|
|
15
15
|
)
|
|
16
|
-
from
|
|
17
|
-
from
|
|
18
|
-
from
|
|
19
|
-
from
|
|
20
|
-
from
|
|
21
|
-
from
|
|
16
|
+
from ..core.base.call_response import JsonableType
|
|
17
|
+
from ..core.base.stream import BaseStream
|
|
18
|
+
from ..core.base.types import FinishReason
|
|
19
|
+
from ..llm.call_response import CallResponse
|
|
20
|
+
from ..llm.call_response_chunk import CallResponseChunk
|
|
21
|
+
from .tool import Tool
|
|
22
22
|
|
|
23
23
|
_BaseCallResponseT = TypeVar("_BaseCallResponseT", bound=BaseCallResponse)
|
|
24
24
|
_BaseCallResponseChunkT = TypeVar(
|
mirascope/llm/tool.py
CHANGED
mirascope/retries/fallback.py
CHANGED