mirascope 1.18.3__py3-none-any.whl → 1.19.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mirascope/__init__.py +20 -2
- mirascope/beta/openai/__init__.py +1 -1
- mirascope/beta/openai/realtime/__init__.py +1 -1
- mirascope/beta/openai/realtime/tool.py +1 -1
- mirascope/beta/rag/__init__.py +2 -2
- mirascope/beta/rag/base/__init__.py +2 -2
- mirascope/beta/rag/weaviate/__init__.py +1 -1
- mirascope/core/__init__.py +26 -8
- mirascope/core/anthropic/__init__.py +3 -3
- mirascope/core/anthropic/_utils/_calculate_cost.py +114 -47
- mirascope/core/anthropic/call_response.py +9 -1
- mirascope/core/anthropic/call_response_chunk.py +7 -0
- mirascope/core/anthropic/stream.py +3 -1
- mirascope/core/azure/__init__.py +2 -2
- mirascope/core/azure/_utils/_calculate_cost.py +4 -1
- mirascope/core/azure/call_response.py +9 -1
- mirascope/core/azure/call_response_chunk.py +5 -0
- mirascope/core/azure/stream.py +3 -1
- mirascope/core/base/__init__.py +11 -9
- mirascope/core/base/_utils/__init__.py +10 -10
- mirascope/core/base/_utils/_get_common_usage.py +8 -4
- mirascope/core/base/_utils/_get_create_fn_or_async_create_fn.py +2 -2
- mirascope/core/base/_utils/_protocols.py +9 -8
- mirascope/core/base/call_response.py +22 -20
- mirascope/core/base/call_response_chunk.py +12 -1
- mirascope/core/base/stream.py +24 -21
- mirascope/core/base/tool.py +7 -5
- mirascope/core/base/types.py +22 -5
- mirascope/core/bedrock/__init__.py +3 -3
- mirascope/core/bedrock/_utils/_calculate_cost.py +4 -1
- mirascope/core/bedrock/call_response.py +8 -1
- mirascope/core/bedrock/call_response_chunk.py +5 -0
- mirascope/core/bedrock/stream.py +3 -1
- mirascope/core/cohere/__init__.py +2 -2
- mirascope/core/cohere/_utils/_calculate_cost.py +4 -3
- mirascope/core/cohere/call_response.py +9 -1
- mirascope/core/cohere/call_response_chunk.py +5 -0
- mirascope/core/cohere/stream.py +3 -1
- mirascope/core/gemini/__init__.py +2 -2
- mirascope/core/gemini/_utils/_calculate_cost.py +4 -1
- mirascope/core/gemini/_utils/_convert_message_params.py +1 -1
- mirascope/core/gemini/call_response.py +9 -1
- mirascope/core/gemini/call_response_chunk.py +5 -0
- mirascope/core/gemini/stream.py +3 -1
- mirascope/core/google/__init__.py +2 -2
- mirascope/core/google/_utils/_calculate_cost.py +141 -14
- mirascope/core/google/_utils/_convert_message_params.py +23 -51
- mirascope/core/google/_utils/_message_param_converter.py +34 -33
- mirascope/core/google/_utils/_validate_media_type.py +34 -0
- mirascope/core/google/call_response.py +26 -4
- mirascope/core/google/call_response_chunk.py +17 -9
- mirascope/core/google/stream.py +20 -2
- mirascope/core/groq/__init__.py +2 -2
- mirascope/core/groq/_utils/_calculate_cost.py +12 -11
- mirascope/core/groq/call_response.py +9 -1
- mirascope/core/groq/call_response_chunk.py +5 -0
- mirascope/core/groq/stream.py +3 -1
- mirascope/core/litellm/__init__.py +1 -1
- mirascope/core/litellm/_utils/_setup_call.py +7 -3
- mirascope/core/mistral/__init__.py +2 -2
- mirascope/core/mistral/_utils/_calculate_cost.py +10 -9
- mirascope/core/mistral/call_response.py +9 -1
- mirascope/core/mistral/call_response_chunk.py +5 -0
- mirascope/core/mistral/stream.py +3 -1
- mirascope/core/openai/__init__.py +2 -2
- mirascope/core/openai/_utils/_calculate_cost.py +78 -37
- mirascope/core/openai/call_params.py +13 -0
- mirascope/core/openai/call_response.py +14 -1
- mirascope/core/openai/call_response_chunk.py +12 -0
- mirascope/core/openai/stream.py +6 -4
- mirascope/core/vertex/__init__.py +1 -1
- mirascope/core/vertex/_utils/_calculate_cost.py +1 -0
- mirascope/core/vertex/_utils/_convert_message_params.py +1 -1
- mirascope/core/vertex/call_response.py +9 -1
- mirascope/core/vertex/call_response_chunk.py +5 -0
- mirascope/core/vertex/stream.py +3 -1
- mirascope/core/xai/__init__.py +28 -0
- mirascope/core/xai/_call.py +67 -0
- mirascope/core/xai/_utils/__init__.py +6 -0
- mirascope/core/xai/_utils/_calculate_cost.py +104 -0
- mirascope/core/xai/_utils/_setup_call.py +113 -0
- mirascope/core/xai/call_params.py +10 -0
- mirascope/core/xai/call_response.py +27 -0
- mirascope/core/xai/call_response_chunk.py +14 -0
- mirascope/core/xai/dynamic_config.py +8 -0
- mirascope/core/xai/py.typed +0 -0
- mirascope/core/xai/stream.py +57 -0
- mirascope/core/xai/tool.py +13 -0
- mirascope/integrations/_middleware_factory.py +6 -6
- mirascope/integrations/logfire/_utils.py +1 -1
- mirascope/llm/__init__.py +2 -2
- mirascope/llm/_protocols.py +34 -28
- mirascope/llm/call_response.py +16 -7
- mirascope/llm/llm_call.py +50 -46
- mirascope/llm/stream.py +43 -31
- mirascope/retries/__init__.py +1 -1
- mirascope/tools/__init__.py +2 -2
- {mirascope-1.18.3.dist-info → mirascope-1.19.0.dist-info}/METADATA +3 -1
- {mirascope-1.18.3.dist-info → mirascope-1.19.0.dist-info}/RECORD +101 -88
- {mirascope-1.18.3.dist-info → mirascope-1.19.0.dist-info}/WHEEL +0 -0
- {mirascope-1.18.3.dist-info → mirascope-1.19.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
"""Calculate the cost of a Grok API call."""
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def calculate_cost(
|
|
5
|
+
input_tokens: int | float | None,
|
|
6
|
+
cached_tokens: int | float | None,
|
|
7
|
+
output_tokens: int | float | None,
|
|
8
|
+
model: str,
|
|
9
|
+
) -> float | None:
|
|
10
|
+
"""Calculate the cost of an xAI Grok API call.
|
|
11
|
+
|
|
12
|
+
https://docs.x.ai
|
|
13
|
+
|
|
14
|
+
Pricing (per 1M tokens):
|
|
15
|
+
|
|
16
|
+
Model Input Cached Output
|
|
17
|
+
grok-3 $3.50 $0.875 $10.50
|
|
18
|
+
grok-3-latest $3.50 $0.875 $10.50
|
|
19
|
+
grok-2 $2.00 $0.50 $6.00
|
|
20
|
+
grok-2-latest $2.00 $0.50 $6.00
|
|
21
|
+
grok-2-1212 $2.00 $0.50 $6.00
|
|
22
|
+
grok-2-mini $0.33 $0.083 $1.00
|
|
23
|
+
grok-2-vision-1212 $2.00 $0.50 $6.00
|
|
24
|
+
grok-vision-beta $5.00 $1.25 $15.00
|
|
25
|
+
grok-beta $5.00 $1.25 $15.00
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
input_tokens: Number of input tokens
|
|
29
|
+
cached_tokens: Number of cached tokens
|
|
30
|
+
output_tokens: Number of output tokens
|
|
31
|
+
model: Model name to use for pricing calculation
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
Total cost in USD or None if invalid input
|
|
35
|
+
"""
|
|
36
|
+
pricing = {
|
|
37
|
+
"grok-3": {
|
|
38
|
+
"prompt": 0.000_003_5,
|
|
39
|
+
"cached": 0.000_000_875,
|
|
40
|
+
"completion": 0.000_010_5,
|
|
41
|
+
},
|
|
42
|
+
"grok-3-latest": {
|
|
43
|
+
"prompt": 0.000_003_5,
|
|
44
|
+
"cached": 0.000_000_875,
|
|
45
|
+
"completion": 0.000_010_5,
|
|
46
|
+
},
|
|
47
|
+
"grok-2": {
|
|
48
|
+
"prompt": 0.000_002,
|
|
49
|
+
"cached": 0.000_000_5,
|
|
50
|
+
"completion": 0.000_006,
|
|
51
|
+
},
|
|
52
|
+
"grok-latest": {
|
|
53
|
+
"prompt": 0.000_002,
|
|
54
|
+
"cached": 0.000_000_5,
|
|
55
|
+
"completion": 0.000_006,
|
|
56
|
+
},
|
|
57
|
+
"grok-2-1212": {
|
|
58
|
+
"prompt": 0.000_002,
|
|
59
|
+
"cached": 0.000_000_5,
|
|
60
|
+
"completion": 0.000_006,
|
|
61
|
+
},
|
|
62
|
+
"grok-2-mini": {
|
|
63
|
+
"prompt": 0.000_000_33,
|
|
64
|
+
"cached": 0.000_000_083,
|
|
65
|
+
"completion": 0.000_001,
|
|
66
|
+
},
|
|
67
|
+
"grok-2-vision-1212": {
|
|
68
|
+
"prompt": 0.000_002,
|
|
69
|
+
"cached": 0.000_000_5,
|
|
70
|
+
"completion": 0.000_006,
|
|
71
|
+
},
|
|
72
|
+
"grok-vision-beta": {
|
|
73
|
+
"prompt": 0.000_005,
|
|
74
|
+
"cached": 0.000_001_25,
|
|
75
|
+
"completion": 0.000_015,
|
|
76
|
+
},
|
|
77
|
+
"grok-beta": {
|
|
78
|
+
"prompt": 0.000_005,
|
|
79
|
+
"cached": 0.000_001_25,
|
|
80
|
+
"completion": 0.000_015,
|
|
81
|
+
},
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
if input_tokens is None or output_tokens is None:
|
|
85
|
+
return None
|
|
86
|
+
|
|
87
|
+
if cached_tokens is None:
|
|
88
|
+
cached_tokens = 0
|
|
89
|
+
|
|
90
|
+
try:
|
|
91
|
+
model_pricing = pricing[model]
|
|
92
|
+
except KeyError:
|
|
93
|
+
return None
|
|
94
|
+
|
|
95
|
+
prompt_price = model_pricing["prompt"]
|
|
96
|
+
cached_price = model_pricing["cached"]
|
|
97
|
+
completion_price = model_pricing["completion"]
|
|
98
|
+
|
|
99
|
+
prompt_cost = input_tokens * prompt_price
|
|
100
|
+
cached_cost = cached_tokens * cached_price
|
|
101
|
+
completion_cost = output_tokens * completion_price
|
|
102
|
+
total_cost = prompt_cost + cached_cost + completion_cost
|
|
103
|
+
|
|
104
|
+
return total_cost
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
"""This module contains the setup_call function for OpenAI tools."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from collections.abc import Awaitable, Callable
|
|
5
|
+
from typing import Any, overload
|
|
6
|
+
|
|
7
|
+
from openai import AsyncOpenAI, OpenAI
|
|
8
|
+
from openai.types.chat import (
|
|
9
|
+
ChatCompletion,
|
|
10
|
+
ChatCompletionChunk,
|
|
11
|
+
ChatCompletionMessageParam,
|
|
12
|
+
)
|
|
13
|
+
from pydantic import BaseModel
|
|
14
|
+
|
|
15
|
+
from ...base import BaseTool
|
|
16
|
+
from ...base._utils import AsyncCreateFn, CreateFn, fn_is_async
|
|
17
|
+
from ...base.call_params import CommonCallParams
|
|
18
|
+
from ...base.stream_config import StreamConfig
|
|
19
|
+
from ...openai import (
|
|
20
|
+
AsyncOpenAIDynamicConfig,
|
|
21
|
+
OpenAICallParams,
|
|
22
|
+
OpenAIDynamicConfig,
|
|
23
|
+
OpenAITool,
|
|
24
|
+
)
|
|
25
|
+
from ...openai._call_kwargs import OpenAICallKwargs
|
|
26
|
+
from ...openai._utils import setup_call as setup_call_openai
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@overload
|
|
30
|
+
def setup_call(
|
|
31
|
+
*,
|
|
32
|
+
model: str,
|
|
33
|
+
client: AsyncOpenAI | None,
|
|
34
|
+
fn: Callable[..., Awaitable[AsyncOpenAIDynamicConfig]],
|
|
35
|
+
fn_args: dict[str, Any],
|
|
36
|
+
dynamic_config: AsyncOpenAIDynamicConfig,
|
|
37
|
+
tools: list[type[BaseTool] | Callable] | None,
|
|
38
|
+
json_mode: bool,
|
|
39
|
+
call_params: OpenAICallParams | CommonCallParams,
|
|
40
|
+
response_model: type[BaseModel] | None,
|
|
41
|
+
stream: bool | StreamConfig,
|
|
42
|
+
) -> tuple[
|
|
43
|
+
AsyncCreateFn[ChatCompletion, ChatCompletionChunk],
|
|
44
|
+
str | None,
|
|
45
|
+
list[ChatCompletionMessageParam],
|
|
46
|
+
list[type[OpenAITool]] | None,
|
|
47
|
+
OpenAICallKwargs,
|
|
48
|
+
]: ...
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@overload
|
|
52
|
+
def setup_call(
|
|
53
|
+
*,
|
|
54
|
+
model: str,
|
|
55
|
+
client: OpenAI | None,
|
|
56
|
+
fn: Callable[..., OpenAIDynamicConfig],
|
|
57
|
+
fn_args: dict[str, Any],
|
|
58
|
+
dynamic_config: OpenAIDynamicConfig,
|
|
59
|
+
tools: list[type[BaseTool] | Callable] | None,
|
|
60
|
+
json_mode: bool,
|
|
61
|
+
call_params: OpenAICallParams | CommonCallParams,
|
|
62
|
+
response_model: type[BaseModel] | None,
|
|
63
|
+
stream: bool | StreamConfig,
|
|
64
|
+
) -> tuple[
|
|
65
|
+
CreateFn[ChatCompletion, ChatCompletionChunk],
|
|
66
|
+
str | None,
|
|
67
|
+
list[ChatCompletionMessageParam],
|
|
68
|
+
list[type[OpenAITool]] | None,
|
|
69
|
+
OpenAICallKwargs,
|
|
70
|
+
]: ...
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def setup_call(
|
|
74
|
+
*,
|
|
75
|
+
model: str,
|
|
76
|
+
client: OpenAI | AsyncOpenAI | None,
|
|
77
|
+
fn: Callable[..., OpenAIDynamicConfig]
|
|
78
|
+
| Callable[..., Awaitable[AsyncOpenAIDynamicConfig]],
|
|
79
|
+
fn_args: dict[str, Any],
|
|
80
|
+
dynamic_config: OpenAIDynamicConfig | AsyncOpenAIDynamicConfig,
|
|
81
|
+
tools: list[type[BaseTool] | Callable] | None,
|
|
82
|
+
json_mode: bool,
|
|
83
|
+
call_params: OpenAICallParams | CommonCallParams,
|
|
84
|
+
response_model: type[BaseModel] | None,
|
|
85
|
+
stream: bool | StreamConfig,
|
|
86
|
+
) -> tuple[
|
|
87
|
+
CreateFn[ChatCompletion, ChatCompletionChunk]
|
|
88
|
+
| AsyncCreateFn[ChatCompletion, ChatCompletionChunk],
|
|
89
|
+
str | None,
|
|
90
|
+
list[ChatCompletionMessageParam],
|
|
91
|
+
list[type[OpenAITool]] | None,
|
|
92
|
+
OpenAICallKwargs,
|
|
93
|
+
]:
|
|
94
|
+
if not client:
|
|
95
|
+
api_key = os.environ.get("XAI_API_KEY")
|
|
96
|
+
client = (
|
|
97
|
+
AsyncOpenAI(api_key=api_key, base_url="https://api.x.ai/v1")
|
|
98
|
+
if fn_is_async(fn)
|
|
99
|
+
else OpenAI(api_key=api_key, base_url="https://api.x.ai/v1")
|
|
100
|
+
)
|
|
101
|
+
create, prompt_template, messages, tool_types, call_kwargs = setup_call_openai(
|
|
102
|
+
model=model, # pyright: ignore [reportCallIssue]
|
|
103
|
+
client=client, # pyright: ignore [reportArgumentType]
|
|
104
|
+
fn=fn, # pyright: ignore [reportArgumentType]
|
|
105
|
+
fn_args=fn_args, # pyright: ignore [reportArgumentType]
|
|
106
|
+
dynamic_config=dynamic_config,
|
|
107
|
+
tools=tools,
|
|
108
|
+
json_mode=json_mode,
|
|
109
|
+
call_params=call_params,
|
|
110
|
+
response_model=response_model,
|
|
111
|
+
stream=stream,
|
|
112
|
+
)
|
|
113
|
+
return create, prompt_template, messages, tool_types, call_kwargs
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
"""usage docs: learn/calls.md#provider-specific-parameters"""
|
|
2
|
+
|
|
3
|
+
from ..openai import OpenAICallParams
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class XAICallParams(OpenAICallParams):
|
|
7
|
+
"""A simple wrapper around `OpenAICallParams.`
|
|
8
|
+
|
|
9
|
+
Since xAI supports the OpenAI spec, we change nothing here.
|
|
10
|
+
"""
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
"""This module contains the `XAICallResponse` class.
|
|
2
|
+
|
|
3
|
+
usage docs: learn/calls.md#handling-responses
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from pydantic import computed_field
|
|
7
|
+
|
|
8
|
+
from ..openai import OpenAICallResponse
|
|
9
|
+
from ._utils import calculate_cost
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class XAICallResponse(OpenAICallResponse):
|
|
13
|
+
"""A simpler wrapper around `OpenAICallResponse`.
|
|
14
|
+
|
|
15
|
+
Everything is the same except the `cost` property, which has been updated to use
|
|
16
|
+
xAI's cost calculations so that cost tracking works for non-OpenAI models.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
_provider = "xai"
|
|
20
|
+
|
|
21
|
+
@computed_field
|
|
22
|
+
@property
|
|
23
|
+
def cost(self) -> float | None:
|
|
24
|
+
"""Returns the cost of the call."""
|
|
25
|
+
return calculate_cost(
|
|
26
|
+
self.input_tokens, self.cached_tokens, self.output_tokens, self.model
|
|
27
|
+
)
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
"""This module contains the `XAICallResponseChunk` class.
|
|
2
|
+
|
|
3
|
+
usage docs: learn/streams.md#handling-streamed-responses
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from ..openai import OpenAICallResponseChunk
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class XAICallResponseChunk(OpenAICallResponseChunk):
|
|
10
|
+
"""A simpler wrapper around `OpenAICallResponse`.
|
|
11
|
+
|
|
12
|
+
Everything is the same except the `cost` property, which has been updated to use
|
|
13
|
+
xAI's cost calculations so that cost tracking works for non-OpenAI models.
|
|
14
|
+
"""
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
"""This module defines the function return type for functions as LLM calls."""
|
|
2
|
+
|
|
3
|
+
from typing import TypeAlias
|
|
4
|
+
|
|
5
|
+
from ..openai import AsyncOpenAIDynamicConfig, OpenAIDynamicConfig
|
|
6
|
+
|
|
7
|
+
AsyncXAIDynamicConfig: TypeAlias = AsyncOpenAIDynamicConfig
|
|
8
|
+
XAIDynamicConfig: TypeAlias = OpenAIDynamicConfig
|
|
File without changes
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
"""The `XAIStream` class for convenience around streaming xAI LLM calls.
|
|
2
|
+
|
|
3
|
+
usage docs: learn/streams.md
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from collections.abc import AsyncGenerator, Generator
|
|
7
|
+
|
|
8
|
+
from ..openai import OpenAIStream
|
|
9
|
+
from .call_response import XAICallResponse
|
|
10
|
+
from .call_response_chunk import XAICallResponseChunk
|
|
11
|
+
from .tool import XAITool
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class XAIStream(OpenAIStream):
|
|
15
|
+
"""A simple wrapper around `OpenAIStream`.
|
|
16
|
+
|
|
17
|
+
Everything is the same except updates to the `construct_call_response` method and
|
|
18
|
+
the `cost` property so that cost is properly calculated using xAI's cost
|
|
19
|
+
calculation method. This ensures cost calculation works for non-OpenAI models.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
_provider = "xai"
|
|
23
|
+
|
|
24
|
+
def __iter__(
|
|
25
|
+
self,
|
|
26
|
+
) -> Generator[tuple[XAICallResponseChunk, XAITool | None], None, None]:
|
|
27
|
+
yield from super().__iter__() # pyright: ignore [reportReturnType]
|
|
28
|
+
|
|
29
|
+
def __aiter__(
|
|
30
|
+
self,
|
|
31
|
+
) -> AsyncGenerator[tuple[XAICallResponseChunk, XAITool | None], None]:
|
|
32
|
+
return super().__aiter__() # pyright: ignore [reportReturnType] # pragma: no cover
|
|
33
|
+
|
|
34
|
+
@property
|
|
35
|
+
def cost(self) -> float | None:
|
|
36
|
+
"""Returns the cost of the call."""
|
|
37
|
+
response = self.construct_call_response()
|
|
38
|
+
return response.cost
|
|
39
|
+
|
|
40
|
+
def construct_call_response(self) -> XAICallResponse:
|
|
41
|
+
openai_call_response = super().construct_call_response()
|
|
42
|
+
response = XAICallResponse(
|
|
43
|
+
metadata=openai_call_response.metadata,
|
|
44
|
+
response=openai_call_response.response,
|
|
45
|
+
tool_types=openai_call_response.tool_types,
|
|
46
|
+
prompt_template=openai_call_response.prompt_template,
|
|
47
|
+
fn_args=openai_call_response.fn_args,
|
|
48
|
+
dynamic_config=openai_call_response.dynamic_config,
|
|
49
|
+
messages=openai_call_response.messages,
|
|
50
|
+
call_params=openai_call_response.call_params,
|
|
51
|
+
call_kwargs=openai_call_response.call_kwargs,
|
|
52
|
+
user_message_param=openai_call_response.user_message_param,
|
|
53
|
+
start_time=openai_call_response.start_time,
|
|
54
|
+
end_time=openai_call_response.end_time,
|
|
55
|
+
)
|
|
56
|
+
response._model = self.model
|
|
57
|
+
return response
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
"""The `XAITool` class for easy tool usage with xAI LLM calls.
|
|
2
|
+
|
|
3
|
+
usage docs: learn/tools.md
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from ..openai import OpenAITool
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class XAITool(OpenAITool):
|
|
10
|
+
"""A simple wrapper around `OpenAITool`.
|
|
11
|
+
|
|
12
|
+
Since xAI supports the OpenAI spec, we change nothing here.
|
|
13
|
+
"""
|
|
@@ -175,9 +175,9 @@ def middleware_factory(
|
|
|
175
175
|
def new_stream_aiter(
|
|
176
176
|
self: Any, # noqa: ANN401
|
|
177
177
|
) -> AsyncGenerator[tuple[Any, Any | None], Any]: # noqa: ANN401
|
|
178
|
-
async def generator() ->
|
|
179
|
-
|
|
180
|
-
|
|
178
|
+
async def generator() -> AsyncGenerator[
|
|
179
|
+
tuple[Any, Any | None], Any
|
|
180
|
+
]:
|
|
181
181
|
try:
|
|
182
182
|
async for chunk, tool in original_aiter():
|
|
183
183
|
yield chunk, tool
|
|
@@ -226,9 +226,9 @@ def middleware_factory(
|
|
|
226
226
|
def new_aiter(
|
|
227
227
|
self: Any, # noqa: ANN401
|
|
228
228
|
) -> AsyncGenerator[tuple[Any, Any | None], Any]: # noqa: ANN401
|
|
229
|
-
async def generator() ->
|
|
230
|
-
|
|
231
|
-
|
|
229
|
+
async def generator() -> AsyncGenerator[
|
|
230
|
+
tuple[Any, Any | None], Any
|
|
231
|
+
]:
|
|
232
232
|
try:
|
|
233
233
|
async for chunk in original_aiter():
|
|
234
234
|
yield chunk
|
|
@@ -23,7 +23,7 @@ def custom_context_manager(
|
|
|
23
23
|
) -> Generator[logfire.LogfireSpan, Any, None]:
|
|
24
24
|
metadata: Metadata = _utils.get_metadata(fn, None)
|
|
25
25
|
tags = metadata.get("tags", [])
|
|
26
|
-
with logfire.with_settings(custom_scope_suffix="mirascope", tags=list(tags)).span(
|
|
26
|
+
with logfire.with_settings(custom_scope_suffix="mirascope", tags=list(tags)).span( # pyright: ignore[reportGeneralTypeIssues]
|
|
27
27
|
fn.__name__
|
|
28
28
|
) as logfire_span:
|
|
29
29
|
yield logfire_span
|
mirascope/llm/__init__.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
from ._protocols import Provider
|
|
1
|
+
from ._protocols import LocalProvider, Provider
|
|
2
2
|
from .call_response import CallResponse
|
|
3
3
|
from .llm_call import call
|
|
4
4
|
from .llm_override import override
|
|
5
5
|
|
|
6
|
-
__all__ = ["
|
|
6
|
+
__all__ = ["CallResponse", "LocalProvider", "Provider", "call", "override"]
|
mirascope/llm/_protocols.py
CHANGED
|
@@ -73,6 +73,12 @@ Provider: TypeAlias = Literal[
|
|
|
73
73
|
"mistral",
|
|
74
74
|
"openai",
|
|
75
75
|
"vertex",
|
|
76
|
+
"xai",
|
|
77
|
+
]
|
|
78
|
+
|
|
79
|
+
LocalProvider: TypeAlias = Literal[
|
|
80
|
+
"ollama",
|
|
81
|
+
"vllm",
|
|
76
82
|
]
|
|
77
83
|
|
|
78
84
|
|
|
@@ -90,9 +96,9 @@ class _CallDecorator(
|
|
|
90
96
|
],
|
|
91
97
|
):
|
|
92
98
|
@overload
|
|
93
|
-
def __call__(
|
|
99
|
+
def __call__( # pyright: ignore[reportOverlappingOverload]
|
|
94
100
|
self,
|
|
95
|
-
provider: Provider,
|
|
101
|
+
provider: Provider | LocalProvider,
|
|
96
102
|
model: str,
|
|
97
103
|
*,
|
|
98
104
|
stream: Literal[False] = False,
|
|
@@ -110,9 +116,9 @@ class _CallDecorator(
|
|
|
110
116
|
]: ...
|
|
111
117
|
|
|
112
118
|
@overload
|
|
113
|
-
def __call__(
|
|
119
|
+
def __call__( # pyright: ignore[reportOverlappingOverload]
|
|
114
120
|
self,
|
|
115
|
-
provider: Provider,
|
|
121
|
+
provider: Provider | LocalProvider,
|
|
116
122
|
model: str,
|
|
117
123
|
*,
|
|
118
124
|
stream: Literal[False] = False,
|
|
@@ -127,7 +133,7 @@ class _CallDecorator(
|
|
|
127
133
|
@overload
|
|
128
134
|
def __call__(
|
|
129
135
|
self,
|
|
130
|
-
provider: Provider,
|
|
136
|
+
provider: Provider | LocalProvider,
|
|
131
137
|
model: str,
|
|
132
138
|
*,
|
|
133
139
|
stream: Literal[False] = False,
|
|
@@ -140,9 +146,9 @@ class _CallDecorator(
|
|
|
140
146
|
) -> SyncLLMFunctionDecorator[_BaseDynamicConfigT, _BaseCallResponseT]: ...
|
|
141
147
|
|
|
142
148
|
@overload
|
|
143
|
-
def __call__(
|
|
149
|
+
def __call__( # pyright: ignore[reportOverlappingOverload]
|
|
144
150
|
self,
|
|
145
|
-
provider: Provider,
|
|
151
|
+
provider: Provider | LocalProvider,
|
|
146
152
|
model: str,
|
|
147
153
|
*,
|
|
148
154
|
stream: Literal[False] = False,
|
|
@@ -159,7 +165,7 @@ class _CallDecorator(
|
|
|
159
165
|
@overload
|
|
160
166
|
def __call__(
|
|
161
167
|
self,
|
|
162
|
-
provider: Provider,
|
|
168
|
+
provider: Provider | LocalProvider,
|
|
163
169
|
model: str,
|
|
164
170
|
*,
|
|
165
171
|
stream: Literal[False] = False,
|
|
@@ -174,7 +180,7 @@ class _CallDecorator(
|
|
|
174
180
|
@overload
|
|
175
181
|
def __call__(
|
|
176
182
|
self,
|
|
177
|
-
provider: Provider,
|
|
183
|
+
provider: Provider | LocalProvider,
|
|
178
184
|
model: str,
|
|
179
185
|
*,
|
|
180
186
|
stream: Literal[False] = False,
|
|
@@ -189,7 +195,7 @@ class _CallDecorator(
|
|
|
189
195
|
@overload
|
|
190
196
|
def __call__(
|
|
191
197
|
self,
|
|
192
|
-
provider: Provider,
|
|
198
|
+
provider: Provider | LocalProvider,
|
|
193
199
|
model: str,
|
|
194
200
|
*,
|
|
195
201
|
stream: Literal[False] = False,
|
|
@@ -207,7 +213,7 @@ class _CallDecorator(
|
|
|
207
213
|
@overload
|
|
208
214
|
def __call__(
|
|
209
215
|
self,
|
|
210
|
-
provider: Provider,
|
|
216
|
+
provider: Provider | LocalProvider,
|
|
211
217
|
model: str,
|
|
212
218
|
*,
|
|
213
219
|
stream: Literal[True] | StreamConfig = True,
|
|
@@ -224,7 +230,7 @@ class _CallDecorator(
|
|
|
224
230
|
@overload
|
|
225
231
|
def __call__(
|
|
226
232
|
self,
|
|
227
|
-
provider: Provider,
|
|
233
|
+
provider: Provider | LocalProvider,
|
|
228
234
|
model: str,
|
|
229
235
|
*,
|
|
230
236
|
stream: Literal[True] | StreamConfig = True,
|
|
@@ -239,7 +245,7 @@ class _CallDecorator(
|
|
|
239
245
|
@overload
|
|
240
246
|
def __call__(
|
|
241
247
|
self,
|
|
242
|
-
provider: Provider,
|
|
248
|
+
provider: Provider | LocalProvider,
|
|
243
249
|
model: str,
|
|
244
250
|
*,
|
|
245
251
|
stream: Literal[True] | StreamConfig = True,
|
|
@@ -254,7 +260,7 @@ class _CallDecorator(
|
|
|
254
260
|
@overload
|
|
255
261
|
def __call__(
|
|
256
262
|
self,
|
|
257
|
-
provider: Provider,
|
|
263
|
+
provider: Provider | LocalProvider,
|
|
258
264
|
model: str,
|
|
259
265
|
*,
|
|
260
266
|
stream: Literal[True] | StreamConfig = True,
|
|
@@ -272,7 +278,7 @@ class _CallDecorator(
|
|
|
272
278
|
@overload
|
|
273
279
|
def __call__(
|
|
274
280
|
self,
|
|
275
|
-
provider: Provider,
|
|
281
|
+
provider: Provider | LocalProvider,
|
|
276
282
|
model: str,
|
|
277
283
|
*,
|
|
278
284
|
stream: Literal[True] | StreamConfig = True,
|
|
@@ -288,9 +294,9 @@ class _CallDecorator(
|
|
|
288
294
|
) -> NoReturn: ...
|
|
289
295
|
|
|
290
296
|
@overload
|
|
291
|
-
def __call__(
|
|
297
|
+
def __call__( # pyright: ignore[reportOverlappingOverload]
|
|
292
298
|
self,
|
|
293
|
-
provider: Provider,
|
|
299
|
+
provider: Provider | LocalProvider,
|
|
294
300
|
model: str,
|
|
295
301
|
*,
|
|
296
302
|
stream: Literal[False] = False,
|
|
@@ -307,7 +313,7 @@ class _CallDecorator(
|
|
|
307
313
|
@overload
|
|
308
314
|
def __call__(
|
|
309
315
|
self,
|
|
310
|
-
provider: Provider,
|
|
316
|
+
provider: Provider | LocalProvider,
|
|
311
317
|
model: str,
|
|
312
318
|
*,
|
|
313
319
|
stream: Literal[False] = False,
|
|
@@ -322,7 +328,7 @@ class _CallDecorator(
|
|
|
322
328
|
@overload
|
|
323
329
|
def __call__(
|
|
324
330
|
self,
|
|
325
|
-
provider: Provider,
|
|
331
|
+
provider: Provider | LocalProvider,
|
|
326
332
|
model: str,
|
|
327
333
|
*,
|
|
328
334
|
stream: Literal[False] = False,
|
|
@@ -337,7 +343,7 @@ class _CallDecorator(
|
|
|
337
343
|
@overload
|
|
338
344
|
def __call__(
|
|
339
345
|
self,
|
|
340
|
-
provider: Provider,
|
|
346
|
+
provider: Provider | LocalProvider,
|
|
341
347
|
model: str,
|
|
342
348
|
*,
|
|
343
349
|
stream: Literal[False] = False,
|
|
@@ -354,7 +360,7 @@ class _CallDecorator(
|
|
|
354
360
|
@overload
|
|
355
361
|
def __call__(
|
|
356
362
|
self,
|
|
357
|
-
provider: Provider,
|
|
363
|
+
provider: Provider | LocalProvider,
|
|
358
364
|
model: str,
|
|
359
365
|
*,
|
|
360
366
|
stream: Literal[False] = False,
|
|
@@ -369,7 +375,7 @@ class _CallDecorator(
|
|
|
369
375
|
@overload
|
|
370
376
|
def __call__(
|
|
371
377
|
self,
|
|
372
|
-
provider: Provider,
|
|
378
|
+
provider: Provider | LocalProvider,
|
|
373
379
|
model: str,
|
|
374
380
|
*,
|
|
375
381
|
stream: Literal[False] = False,
|
|
@@ -382,9 +388,9 @@ class _CallDecorator(
|
|
|
382
388
|
) -> SyncLLMFunctionDecorator[_BaseDynamicConfigT, _ParsedOutputT]: ...
|
|
383
389
|
|
|
384
390
|
@overload
|
|
385
|
-
def __call__(
|
|
391
|
+
def __call__( # pyright: ignore[reportOverlappingOverload]
|
|
386
392
|
self,
|
|
387
|
-
provider: Provider,
|
|
393
|
+
provider: Provider | LocalProvider,
|
|
388
394
|
model: str,
|
|
389
395
|
*,
|
|
390
396
|
stream: Literal[True] | StreamConfig,
|
|
@@ -404,7 +410,7 @@ class _CallDecorator(
|
|
|
404
410
|
@overload
|
|
405
411
|
def __call__(
|
|
406
412
|
self,
|
|
407
|
-
provider: Provider,
|
|
413
|
+
provider: Provider | LocalProvider,
|
|
408
414
|
model: str,
|
|
409
415
|
*,
|
|
410
416
|
stream: Literal[True] | StreamConfig,
|
|
@@ -421,7 +427,7 @@ class _CallDecorator(
|
|
|
421
427
|
@overload
|
|
422
428
|
def __call__(
|
|
423
429
|
self,
|
|
424
|
-
provider: Provider,
|
|
430
|
+
provider: Provider | LocalProvider,
|
|
425
431
|
model: str,
|
|
426
432
|
*,
|
|
427
433
|
stream: Literal[True] | StreamConfig,
|
|
@@ -436,7 +442,7 @@ class _CallDecorator(
|
|
|
436
442
|
@overload
|
|
437
443
|
def __call__(
|
|
438
444
|
self,
|
|
439
|
-
provider: Provider,
|
|
445
|
+
provider: Provider | LocalProvider,
|
|
440
446
|
model: str,
|
|
441
447
|
*,
|
|
442
448
|
stream: Literal[True] | StreamConfig,
|
|
@@ -456,7 +462,7 @@ class _CallDecorator(
|
|
|
456
462
|
|
|
457
463
|
def __call__(
|
|
458
464
|
self,
|
|
459
|
-
provider: Provider,
|
|
465
|
+
provider: Provider | LocalProvider,
|
|
460
466
|
model: str,
|
|
461
467
|
*,
|
|
462
468
|
stream: bool | StreamConfig = False,
|