mirascope 1.18.2__py3-none-any.whl → 1.18.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mirascope/__init__.py +20 -1
- mirascope/beta/openai/__init__.py +1 -1
- mirascope/beta/openai/realtime/__init__.py +1 -1
- mirascope/beta/openai/realtime/tool.py +1 -1
- mirascope/beta/rag/__init__.py +2 -2
- mirascope/beta/rag/base/__init__.py +2 -2
- mirascope/beta/rag/weaviate/__init__.py +1 -1
- mirascope/core/__init__.py +29 -6
- mirascope/core/anthropic/__init__.py +3 -3
- mirascope/core/anthropic/_utils/_calculate_cost.py +114 -47
- mirascope/core/anthropic/call_response.py +9 -3
- mirascope/core/anthropic/call_response_chunk.py +7 -0
- mirascope/core/anthropic/stream.py +3 -1
- mirascope/core/azure/__init__.py +2 -2
- mirascope/core/azure/_utils/_calculate_cost.py +4 -1
- mirascope/core/azure/call_response.py +9 -3
- mirascope/core/azure/call_response_chunk.py +5 -0
- mirascope/core/azure/stream.py +3 -1
- mirascope/core/base/__init__.py +11 -9
- mirascope/core/base/_utils/__init__.py +10 -10
- mirascope/core/base/_utils/_get_common_usage.py +8 -4
- mirascope/core/base/_utils/_get_create_fn_or_async_create_fn.py +2 -2
- mirascope/core/base/_utils/_protocols.py +9 -8
- mirascope/core/base/call_response.py +22 -22
- mirascope/core/base/call_response_chunk.py +12 -1
- mirascope/core/base/stream.py +24 -21
- mirascope/core/base/tool.py +7 -5
- mirascope/core/base/types.py +22 -5
- mirascope/core/bedrock/__init__.py +3 -3
- mirascope/core/bedrock/_utils/_calculate_cost.py +4 -1
- mirascope/core/bedrock/call_response.py +8 -3
- mirascope/core/bedrock/call_response_chunk.py +5 -0
- mirascope/core/bedrock/stream.py +3 -1
- mirascope/core/cohere/__init__.py +2 -2
- mirascope/core/cohere/_utils/_calculate_cost.py +4 -3
- mirascope/core/cohere/call_response.py +9 -3
- mirascope/core/cohere/call_response_chunk.py +5 -0
- mirascope/core/cohere/stream.py +3 -1
- mirascope/core/gemini/__init__.py +2 -2
- mirascope/core/gemini/_utils/_calculate_cost.py +4 -1
- mirascope/core/gemini/_utils/_convert_message_params.py +1 -1
- mirascope/core/gemini/call_response.py +9 -3
- mirascope/core/gemini/call_response_chunk.py +5 -0
- mirascope/core/gemini/stream.py +3 -1
- mirascope/core/google/__init__.py +2 -2
- mirascope/core/google/_utils/_calculate_cost.py +141 -14
- mirascope/core/google/_utils/_convert_message_params.py +120 -115
- mirascope/core/google/_utils/_message_param_converter.py +34 -33
- mirascope/core/google/_utils/_validate_media_type.py +34 -0
- mirascope/core/google/call_response.py +38 -10
- mirascope/core/google/call_response_chunk.py +17 -9
- mirascope/core/google/stream.py +20 -2
- mirascope/core/groq/__init__.py +2 -2
- mirascope/core/groq/_utils/_calculate_cost.py +12 -11
- mirascope/core/groq/call_response.py +9 -3
- mirascope/core/groq/call_response_chunk.py +5 -0
- mirascope/core/groq/stream.py +3 -1
- mirascope/core/litellm/__init__.py +1 -1
- mirascope/core/litellm/_utils/_setup_call.py +7 -3
- mirascope/core/mistral/__init__.py +2 -2
- mirascope/core/mistral/_utils/_calculate_cost.py +10 -9
- mirascope/core/mistral/call_response.py +9 -3
- mirascope/core/mistral/call_response_chunk.py +5 -0
- mirascope/core/mistral/stream.py +3 -1
- mirascope/core/openai/__init__.py +2 -2
- mirascope/core/openai/_utils/_calculate_cost.py +78 -37
- mirascope/core/openai/call_params.py +13 -0
- mirascope/core/openai/call_response.py +14 -3
- mirascope/core/openai/call_response_chunk.py +12 -0
- mirascope/core/openai/stream.py +6 -4
- mirascope/core/vertex/__init__.py +1 -1
- mirascope/core/vertex/_utils/_calculate_cost.py +1 -0
- mirascope/core/vertex/_utils/_convert_message_params.py +1 -1
- mirascope/core/vertex/call_response.py +9 -3
- mirascope/core/vertex/call_response_chunk.py +5 -0
- mirascope/core/vertex/stream.py +3 -1
- mirascope/integrations/_middleware_factory.py +6 -6
- mirascope/integrations/logfire/_utils.py +1 -1
- mirascope/llm/__init__.py +3 -1
- mirascope/llm/_protocols.py +5 -5
- mirascope/llm/call_response.py +16 -9
- mirascope/llm/llm_call.py +53 -25
- mirascope/llm/stream.py +43 -31
- mirascope/retries/__init__.py +1 -1
- mirascope/tools/__init__.py +2 -2
- {mirascope-1.18.2.dist-info → mirascope-1.18.4.dist-info}/METADATA +2 -2
- {mirascope-1.18.2.dist-info → mirascope-1.18.4.dist-info}/RECORD +89 -88
- {mirascope-1.18.2.dist-info → mirascope-1.18.4.dist-info}/WHEEL +0 -0
- {mirascope-1.18.2.dist-info → mirascope-1.18.4.dist-info}/licenses/LICENSE +0 -0
|
@@ -43,21 +43,26 @@ from ._setup_call import setup_call
|
|
|
43
43
|
from ._setup_extract_tool import setup_extract_tool
|
|
44
44
|
|
|
45
45
|
__all__ = [
|
|
46
|
+
"DEFAULT_TOOL_DOCSTRING",
|
|
46
47
|
"AsyncCreateFn",
|
|
47
|
-
"SameSyncAndAsyncClientSetupCall",
|
|
48
48
|
"BaseType",
|
|
49
49
|
"CalculateCost",
|
|
50
|
+
"CreateFn",
|
|
51
|
+
"GetJsonOutput",
|
|
52
|
+
"HandleStream",
|
|
53
|
+
"HandleStreamAsync",
|
|
54
|
+
"LLMFunctionDecorator",
|
|
55
|
+
"MessagesDecorator",
|
|
56
|
+
"SameSyncAndAsyncClientSetupCall",
|
|
57
|
+
"SetupCall",
|
|
50
58
|
"convert_base_model_to_base_tool",
|
|
51
59
|
"convert_base_type_to_base_tool",
|
|
52
60
|
"convert_function_to_base_tool",
|
|
53
|
-
"CreateFn",
|
|
54
|
-
"DEFAULT_TOOL_DOCSTRING",
|
|
55
61
|
"extract_tool_return",
|
|
56
62
|
"fn_is_async",
|
|
57
63
|
"format_template",
|
|
58
|
-
"GetJsonOutput",
|
|
59
|
-
"get_audio_type",
|
|
60
64
|
"get_async_create_fn",
|
|
65
|
+
"get_audio_type",
|
|
61
66
|
"get_common_usage",
|
|
62
67
|
"get_create_fn",
|
|
63
68
|
"get_document_type",
|
|
@@ -70,18 +75,13 @@ __all__ = [
|
|
|
70
75
|
"get_template_values",
|
|
71
76
|
"get_template_variables",
|
|
72
77
|
"get_unsupported_tool_config_keys",
|
|
73
|
-
"HandleStream",
|
|
74
|
-
"HandleStreamAsync",
|
|
75
78
|
"is_base_type",
|
|
76
79
|
"is_prompt_template",
|
|
77
80
|
"json_mode_content",
|
|
78
|
-
"LLMFunctionDecorator",
|
|
79
|
-
"MessagesDecorator",
|
|
80
81
|
"messages_decorator",
|
|
81
82
|
"parse_content_template",
|
|
82
83
|
"parse_prompt_messages",
|
|
83
84
|
"pil_image_to_bytes",
|
|
84
|
-
"SetupCall",
|
|
85
85
|
"setup_call",
|
|
86
86
|
"setup_extract_tool",
|
|
87
87
|
]
|
|
@@ -2,15 +2,19 @@ from mirascope.core.base.types import Usage
|
|
|
2
2
|
|
|
3
3
|
|
|
4
4
|
def get_common_usage(
|
|
5
|
-
input_tokens: int | float | None,
|
|
5
|
+
input_tokens: int | float | None,
|
|
6
|
+
cached_tokens: int | float | None,
|
|
7
|
+
output_tokens: int | float | None,
|
|
6
8
|
) -> Usage | None:
|
|
7
9
|
"""Get common usage from input and output tokens."""
|
|
8
|
-
if input_tokens is None and output_tokens is None:
|
|
10
|
+
if input_tokens is None and cached_tokens is None and output_tokens is None:
|
|
9
11
|
return None
|
|
10
12
|
input_tokens = int(input_tokens or 0)
|
|
13
|
+
cached_tokens = int(cached_tokens or 0)
|
|
11
14
|
output_tokens = int(output_tokens or 0)
|
|
12
15
|
return Usage(
|
|
13
|
-
|
|
14
|
-
|
|
16
|
+
input_tokens=input_tokens,
|
|
17
|
+
cached_tokens=cached_tokens,
|
|
18
|
+
output_tokens=output_tokens,
|
|
15
19
|
total_tokens=input_tokens + output_tokens,
|
|
16
20
|
)
|
|
@@ -53,7 +53,7 @@ def get_async_create_fn(
|
|
|
53
53
|
async_generator_func: _AsyncGeneratorFunc[_StreamedResponse] | None = None,
|
|
54
54
|
) -> AsyncCreateFn[_NonStreamedResponse, _StreamedResponse]:
|
|
55
55
|
@overload
|
|
56
|
-
def create_or_stream(
|
|
56
|
+
def create_or_stream( # pyright: ignore[reportOverlappingOverload]
|
|
57
57
|
*,
|
|
58
58
|
stream: Literal[True] | StreamConfig = True,
|
|
59
59
|
**kwargs: Any, # noqa: ANN401
|
|
@@ -100,7 +100,7 @@ def get_create_fn(
|
|
|
100
100
|
sync_generator_func: _SyncGeneratorFunc[_StreamedResponse] | None = None,
|
|
101
101
|
) -> CreateFn[_NonStreamedResponse, _StreamedResponse]:
|
|
102
102
|
@overload
|
|
103
|
-
def create_or_stream(
|
|
103
|
+
def create_or_stream( # pyright: ignore[reportOverlappingOverload]
|
|
104
104
|
*,
|
|
105
105
|
stream: Literal[True] | StreamConfig = True,
|
|
106
106
|
**kwargs: Any, # noqa: ANN401
|
|
@@ -129,7 +129,7 @@ class LLMFunctionDecorator(
|
|
|
129
129
|
|
|
130
130
|
class AsyncCreateFn(Protocol[_ResponseT, _ResponseChunkT]):
|
|
131
131
|
@overload
|
|
132
|
-
def __call__(
|
|
132
|
+
def __call__( # pyright: ignore[reportOverlappingOverload]
|
|
133
133
|
self,
|
|
134
134
|
*,
|
|
135
135
|
stream: Literal[False] = False,
|
|
@@ -153,7 +153,7 @@ class AsyncCreateFn(Protocol[_ResponseT, _ResponseChunkT]):
|
|
|
153
153
|
|
|
154
154
|
class CreateFn(Protocol[_ResponseT, _ResponseChunkT]):
|
|
155
155
|
@overload
|
|
156
|
-
def __call__(
|
|
156
|
+
def __call__( # pyright: ignore[reportOverlappingOverload]
|
|
157
157
|
self,
|
|
158
158
|
*,
|
|
159
159
|
stream: Literal[False] = False,
|
|
@@ -371,6 +371,7 @@ class CalculateCost(Protocol):
|
|
|
371
371
|
def __call__(
|
|
372
372
|
self,
|
|
373
373
|
input_tokens: int | float | None,
|
|
374
|
+
cached_tokens: int | float | None,
|
|
374
375
|
output_tokens: int | float | None,
|
|
375
376
|
model: str,
|
|
376
377
|
) -> float | None: ... # pragma: no cover
|
|
@@ -390,7 +391,7 @@ class CallDecorator(
|
|
|
390
391
|
],
|
|
391
392
|
):
|
|
392
393
|
@overload
|
|
393
|
-
def __call__(
|
|
394
|
+
def __call__( # pyright: ignore[reportOverlappingOverload]
|
|
394
395
|
self,
|
|
395
396
|
model: str,
|
|
396
397
|
*,
|
|
@@ -409,7 +410,7 @@ class CallDecorator(
|
|
|
409
410
|
]: ...
|
|
410
411
|
|
|
411
412
|
@overload
|
|
412
|
-
def __call__(
|
|
413
|
+
def __call__( # pyright: ignore[reportOverlappingOverload]
|
|
413
414
|
self,
|
|
414
415
|
model: str,
|
|
415
416
|
*,
|
|
@@ -437,7 +438,7 @@ class CallDecorator(
|
|
|
437
438
|
) -> SyncLLMFunctionDecorator[_BaseDynamicConfigT, _BaseCallResponseT]: ...
|
|
438
439
|
|
|
439
440
|
@overload
|
|
440
|
-
def __call__(
|
|
441
|
+
def __call__( # pyright: ignore[reportOverlappingOverload]
|
|
441
442
|
self,
|
|
442
443
|
model: str,
|
|
443
444
|
*,
|
|
@@ -576,7 +577,7 @@ class CallDecorator(
|
|
|
576
577
|
) -> NoReturn: ...
|
|
577
578
|
|
|
578
579
|
@overload
|
|
579
|
-
def __call__(
|
|
580
|
+
def __call__( # pyright: ignore[reportOverlappingOverload]
|
|
580
581
|
self,
|
|
581
582
|
model: str,
|
|
582
583
|
*,
|
|
@@ -664,7 +665,7 @@ class CallDecorator(
|
|
|
664
665
|
) -> SyncLLMFunctionDecorator[_BaseDynamicConfigT, _ParsedOutputT]: ...
|
|
665
666
|
|
|
666
667
|
@overload
|
|
667
|
-
def __call__(
|
|
668
|
+
def __call__( # pyright: ignore[reportOverlappingOverload]
|
|
668
669
|
self,
|
|
669
670
|
model: str,
|
|
670
671
|
*,
|
|
@@ -774,7 +775,7 @@ class CallDecorator(
|
|
|
774
775
|
####
|
|
775
776
|
|
|
776
777
|
@overload
|
|
777
|
-
def __call__(
|
|
778
|
+
def __call__( # pyright: ignore[reportOverlappingOverload]
|
|
778
779
|
self,
|
|
779
780
|
model: str,
|
|
780
781
|
*,
|
|
@@ -7,7 +7,7 @@ import json
|
|
|
7
7
|
from abc import ABC, abstractmethod
|
|
8
8
|
from collections.abc import Callable
|
|
9
9
|
from functools import cached_property, wraps
|
|
10
|
-
from typing import TYPE_CHECKING, Any, ClassVar, Generic,
|
|
10
|
+
from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar
|
|
11
11
|
|
|
12
12
|
from pydantic import (
|
|
13
13
|
BaseModel,
|
|
@@ -24,7 +24,7 @@ from .call_params import BaseCallParams
|
|
|
24
24
|
from .dynamic_config import BaseDynamicConfig
|
|
25
25
|
from .metadata import Metadata
|
|
26
26
|
from .tool import BaseTool
|
|
27
|
-
from .types import FinishReason, Usage
|
|
27
|
+
from .types import FinishReason, JsonableType, Usage
|
|
28
28
|
|
|
29
29
|
if TYPE_CHECKING:
|
|
30
30
|
from ...llm.tool import Tool
|
|
@@ -35,36 +35,26 @@ _BaseToolT = TypeVar("_BaseToolT", bound=BaseTool)
|
|
|
35
35
|
_ToolSchemaT = TypeVar("_ToolSchemaT")
|
|
36
36
|
_BaseDynamicConfigT = TypeVar("_BaseDynamicConfigT", bound=BaseDynamicConfig)
|
|
37
37
|
_MessageParamT = TypeVar("_MessageParamT", bound=Any)
|
|
38
|
+
_ToolMessageParamT = TypeVar("_ToolMessageParamT", bound=Any)
|
|
38
39
|
_CallParamsT = TypeVar("_CallParamsT", bound=BaseCallParams)
|
|
39
40
|
_UserMessageParamT = TypeVar("_UserMessageParamT", bound=Any)
|
|
40
41
|
_BaseCallResponseT = TypeVar("_BaseCallResponseT", bound="BaseCallResponse")
|
|
41
42
|
|
|
42
43
|
|
|
43
|
-
JsonableType: TypeAlias = (
|
|
44
|
-
str
|
|
45
|
-
| int
|
|
46
|
-
| float
|
|
47
|
-
| bool
|
|
48
|
-
| bytes
|
|
49
|
-
| list["JsonableType"]
|
|
50
|
-
| set["JsonableType"]
|
|
51
|
-
| tuple["JsonableType", ...]
|
|
52
|
-
| dict[str, "JsonableType"]
|
|
53
|
-
| BaseModel
|
|
54
|
-
)
|
|
55
|
-
|
|
56
|
-
|
|
57
44
|
def transform_tool_outputs(
|
|
58
|
-
fn: Callable[
|
|
45
|
+
fn: Callable[
|
|
46
|
+
[type[_BaseCallResponseT], list[tuple[_BaseToolT, str]]],
|
|
47
|
+
list[_ToolMessageParamT],
|
|
48
|
+
],
|
|
59
49
|
) -> Callable[
|
|
60
50
|
[type[_BaseCallResponseT], list[tuple[_BaseToolT, JsonableType]]],
|
|
61
|
-
list[
|
|
51
|
+
list[_ToolMessageParamT],
|
|
62
52
|
]:
|
|
63
53
|
@wraps(fn)
|
|
64
54
|
def wrapper(
|
|
65
55
|
cls: type[_BaseCallResponseT],
|
|
66
56
|
tools_and_outputs: list[tuple[_BaseToolT, JsonableType]],
|
|
67
|
-
) -> list[
|
|
57
|
+
) -> list[_ToolMessageParamT]:
|
|
68
58
|
def recursive_serializer(value: JsonableType) -> BaseType:
|
|
69
59
|
if isinstance(value, str):
|
|
70
60
|
return value
|
|
@@ -212,6 +202,16 @@ class BaseCallResponse(
|
|
|
212
202
|
"""
|
|
213
203
|
...
|
|
214
204
|
|
|
205
|
+
@computed_field
|
|
206
|
+
@property
|
|
207
|
+
@abstractmethod
|
|
208
|
+
def cached_tokens(self) -> int | float | None:
|
|
209
|
+
"""Should return the number of cached tokens.
|
|
210
|
+
|
|
211
|
+
If there is no cached_tokens, this method must return None.
|
|
212
|
+
"""
|
|
213
|
+
...
|
|
214
|
+
|
|
215
215
|
@computed_field
|
|
216
216
|
@property
|
|
217
217
|
@abstractmethod
|
|
@@ -239,14 +239,12 @@ class BaseCallResponse(
|
|
|
239
239
|
"""Returns the assistant's response as a message parameter."""
|
|
240
240
|
...
|
|
241
241
|
|
|
242
|
-
@computed_field
|
|
243
242
|
@cached_property
|
|
244
243
|
@abstractmethod
|
|
245
244
|
def tools(self) -> list[_BaseToolT] | None:
|
|
246
245
|
"""Returns the tools for the 0th choice message."""
|
|
247
246
|
...
|
|
248
247
|
|
|
249
|
-
@computed_field
|
|
250
248
|
@cached_property
|
|
251
249
|
@abstractmethod
|
|
252
250
|
def tool(self) -> _BaseToolT | None:
|
|
@@ -297,4 +295,6 @@ class BaseCallResponse(
|
|
|
297
295
|
@property
|
|
298
296
|
def common_usage(self) -> Usage | None:
|
|
299
297
|
"""Provider-agnostic usage info."""
|
|
300
|
-
return get_common_usage(
|
|
298
|
+
return get_common_usage(
|
|
299
|
+
self.input_tokens, self.cached_tokens, self.output_tokens
|
|
300
|
+
)
|
|
@@ -84,6 +84,15 @@ class BaseCallResponseChunk(BaseModel, Generic[_ChunkT, _FinishReasonT], ABC):
|
|
|
84
84
|
"""
|
|
85
85
|
...
|
|
86
86
|
|
|
87
|
+
@property
|
|
88
|
+
@abstractmethod
|
|
89
|
+
def cached_tokens(self) -> int | float | None:
|
|
90
|
+
"""Should return the number of cached tokens.
|
|
91
|
+
|
|
92
|
+
If there is no cached_tokens, this method must return None.
|
|
93
|
+
"""
|
|
94
|
+
...
|
|
95
|
+
|
|
87
96
|
@property
|
|
88
97
|
@abstractmethod
|
|
89
98
|
def output_tokens(self) -> int | float | None:
|
|
@@ -102,4 +111,6 @@ class BaseCallResponseChunk(BaseModel, Generic[_ChunkT, _FinishReasonT], ABC):
|
|
|
102
111
|
@property
|
|
103
112
|
def common_usage(self) -> Usage | None:
|
|
104
113
|
"""Provider-agnostic usage info."""
|
|
105
|
-
return get_common_usage(
|
|
114
|
+
return get_common_usage(
|
|
115
|
+
self.input_tokens, self.cached_tokens, self.output_tokens
|
|
116
|
+
)
|
mirascope/core/base/stream.py
CHANGED
|
@@ -92,6 +92,7 @@ class BaseStream(
|
|
|
92
92
|
user_message_param: _UserMessageParamT | None = None
|
|
93
93
|
message_param: _AssistantMessageParamT
|
|
94
94
|
input_tokens: int | float | None = None
|
|
95
|
+
cached_tokens: int | float | None = None
|
|
95
96
|
output_tokens: int | float | None = None
|
|
96
97
|
id: str | None = None
|
|
97
98
|
finish_reasons: list[_FinishReason] | None = None
|
|
@@ -138,9 +139,9 @@ class BaseStream(
|
|
|
138
139
|
self,
|
|
139
140
|
) -> Generator[tuple[_BaseCallResponseChunkT, _BaseToolT | None], None, None]:
|
|
140
141
|
"""Iterator over the stream and stores useful information."""
|
|
141
|
-
assert isinstance(
|
|
142
|
-
|
|
143
|
-
)
|
|
142
|
+
assert isinstance(self.stream, Generator), (
|
|
143
|
+
"Stream must be a generator for __iter__"
|
|
144
|
+
)
|
|
144
145
|
self.content, tool_calls = "", []
|
|
145
146
|
self.start_time = datetime.datetime.now().timestamp() * 1000
|
|
146
147
|
for chunk, tool in self.stream:
|
|
@@ -161,12 +162,12 @@ class BaseStream(
|
|
|
161
162
|
"""Iterates over the stream and stores useful information."""
|
|
162
163
|
self.content = ""
|
|
163
164
|
|
|
164
|
-
async def generator() ->
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
assert isinstance(
|
|
168
|
-
|
|
169
|
-
)
|
|
165
|
+
async def generator() -> AsyncGenerator[
|
|
166
|
+
tuple[_BaseCallResponseChunkT, _BaseToolT | None], None
|
|
167
|
+
]:
|
|
168
|
+
assert isinstance(self.stream, AsyncGenerator), (
|
|
169
|
+
"Stream must be an async generator for __aiter__"
|
|
170
|
+
)
|
|
170
171
|
tool_calls = []
|
|
171
172
|
async for chunk, tool in self.stream:
|
|
172
173
|
self._update_properties(chunk)
|
|
@@ -190,6 +191,12 @@ class BaseStream(
|
|
|
190
191
|
if not self.input_tokens
|
|
191
192
|
else self.input_tokens + chunk.input_tokens
|
|
192
193
|
)
|
|
194
|
+
if chunk.cached_tokens is not None:
|
|
195
|
+
self.cached_tokens = (
|
|
196
|
+
chunk.cached_tokens
|
|
197
|
+
if not self.cached_tokens
|
|
198
|
+
else self.cached_tokens + chunk.cached_tokens
|
|
199
|
+
)
|
|
193
200
|
if chunk.output_tokens is not None:
|
|
194
201
|
self.output_tokens = (
|
|
195
202
|
chunk.output_tokens
|
|
@@ -373,11 +380,9 @@ def stream_factory( # noqa: ANN201
|
|
|
373
380
|
stream=True,
|
|
374
381
|
)
|
|
375
382
|
|
|
376
|
-
async def generator() ->
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
]
|
|
380
|
-
):
|
|
383
|
+
async def generator() -> AsyncGenerator[
|
|
384
|
+
tuple[_BaseCallResponseChunkT, _BaseToolT | None], None
|
|
385
|
+
]:
|
|
381
386
|
async for chunk, tool in handle_stream_async(
|
|
382
387
|
await create(stream=True, **call_kwargs),
|
|
383
388
|
tool_types,
|
|
@@ -422,13 +427,11 @@ def stream_factory( # noqa: ANN201
|
|
|
422
427
|
stream=True,
|
|
423
428
|
)
|
|
424
429
|
|
|
425
|
-
def generator() ->
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
]
|
|
431
|
-
):
|
|
430
|
+
def generator() -> Generator[
|
|
431
|
+
tuple[_BaseCallResponseChunkT, _BaseToolT | None],
|
|
432
|
+
None,
|
|
433
|
+
None,
|
|
434
|
+
]:
|
|
432
435
|
yield from handle_stream(
|
|
433
436
|
create(stream=True, **call_kwargs),
|
|
434
437
|
tool_types,
|
mirascope/core/base/tool.py
CHANGED
|
@@ -36,17 +36,19 @@ class ToolConfig(TypedDict, total=False):
|
|
|
36
36
|
class GenerateJsonSchemaNoTitles(GenerateJsonSchema):
|
|
37
37
|
_openai_strict: ClassVar[bool] = False
|
|
38
38
|
|
|
39
|
-
def _remove_title(self, obj: Any) -> Any: # noqa: ANN401
|
|
39
|
+
def _remove_title(self, key: str | None, obj: Any) -> Any: # noqa: ANN401
|
|
40
40
|
if isinstance(obj, dict):
|
|
41
41
|
if self._openai_strict and "type" in obj and obj["type"] == "object":
|
|
42
42
|
obj["additionalProperties"] = False
|
|
43
43
|
if "type" in obj or "$ref" in obj or "properties" in obj:
|
|
44
|
-
obj.pop("title", None)
|
|
44
|
+
title = obj.pop("title", None)
|
|
45
|
+
if key and title and key.lower() != title.lower():
|
|
46
|
+
obj["title"] = title
|
|
45
47
|
|
|
46
48
|
for key, value in list(obj.items()):
|
|
47
|
-
obj[key] = self._remove_title(value)
|
|
49
|
+
obj[key] = self._remove_title(key, value)
|
|
48
50
|
elif isinstance(obj, list):
|
|
49
|
-
return [self._remove_title(item) for item in obj]
|
|
51
|
+
return [self._remove_title(None, item) for item in obj]
|
|
50
52
|
|
|
51
53
|
return obj
|
|
52
54
|
|
|
@@ -56,7 +58,7 @@ class GenerateJsonSchemaNoTitles(GenerateJsonSchema):
|
|
|
56
58
|
json_schema = super().generate(schema, mode=mode)
|
|
57
59
|
json_schema.pop("title", None)
|
|
58
60
|
json_schema.pop("description", None)
|
|
59
|
-
json_schema = self._remove_title(json_schema)
|
|
61
|
+
json_schema = self._remove_title(None, json_schema)
|
|
60
62
|
return json_schema
|
|
61
63
|
|
|
62
64
|
|
mirascope/core/base/types.py
CHANGED
|
@@ -43,11 +43,28 @@ FinishReason: TypeAlias = Literal["stop", "length", "tool_calls", "content_filte
|
|
|
43
43
|
|
|
44
44
|
|
|
45
45
|
class Usage(BaseModel):
|
|
46
|
-
|
|
47
|
-
"""Number of tokens in the generated completion."""
|
|
48
|
-
|
|
49
|
-
prompt_tokens: int = 0
|
|
46
|
+
input_tokens: int
|
|
50
47
|
"""Number of tokens in the prompt."""
|
|
51
48
|
|
|
52
|
-
|
|
49
|
+
cached_tokens: int
|
|
50
|
+
"""Number of tokens used that were previously cached (and thus cheaper)."""
|
|
51
|
+
|
|
52
|
+
output_tokens: int
|
|
53
|
+
"""Number of tokens in the generated output."""
|
|
54
|
+
|
|
55
|
+
total_tokens: int
|
|
53
56
|
"""Total number of tokens used in the request (prompt + completion)."""
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
JsonableType: TypeAlias = (
|
|
60
|
+
str
|
|
61
|
+
| int
|
|
62
|
+
| float
|
|
63
|
+
| bool
|
|
64
|
+
| bytes
|
|
65
|
+
| list["JsonableType"]
|
|
66
|
+
| set["JsonableType"]
|
|
67
|
+
| tuple["JsonableType", ...]
|
|
68
|
+
| dict[str, "JsonableType"]
|
|
69
|
+
| BaseModel
|
|
70
|
+
)
|
|
@@ -18,9 +18,8 @@ from .tool import BedrockTool, BedrockToolConfig
|
|
|
18
18
|
BedrockMessageParam: TypeAlias = InternalBedrockMessageParam | BaseMessageParam
|
|
19
19
|
|
|
20
20
|
__all__ = [
|
|
21
|
-
"AsyncBedrockDynamicConfig",
|
|
22
21
|
"AssistantMessageTypeDef",
|
|
23
|
-
"
|
|
22
|
+
"AsyncBedrockDynamicConfig",
|
|
24
23
|
"BedrockCallParams",
|
|
25
24
|
"BedrockCallResponse",
|
|
26
25
|
"BedrockCallResponseChunk",
|
|
@@ -29,6 +28,7 @@ __all__ = [
|
|
|
29
28
|
"BedrockStream",
|
|
30
29
|
"BedrockTool",
|
|
31
30
|
"BedrockToolConfig",
|
|
32
|
-
"call",
|
|
33
31
|
"UserMessageTypeDef",
|
|
32
|
+
"bedrock_call",
|
|
33
|
+
"call",
|
|
34
34
|
]
|
|
@@ -2,7 +2,10 @@
|
|
|
2
2
|
|
|
3
3
|
|
|
4
4
|
def calculate_cost(
|
|
5
|
-
input_tokens: int | float | None,
|
|
5
|
+
input_tokens: int | float | None,
|
|
6
|
+
cached_tokens: int | float | None,
|
|
7
|
+
output_tokens: int | float | None,
|
|
8
|
+
model: str,
|
|
6
9
|
) -> float | None:
|
|
7
10
|
"""Calculate the cost of a completion using the Bedrock API."""
|
|
8
11
|
# NOTE: We are currently investigating a dynamic approach to determine costs
|
|
@@ -129,6 +129,11 @@ class BedrockCallResponse(
|
|
|
129
129
|
"""Returns the number of input tokens."""
|
|
130
130
|
return self.usage["inputTokens"] if self.usage else None
|
|
131
131
|
|
|
132
|
+
@property
|
|
133
|
+
def cached_tokens(self) -> int | None:
|
|
134
|
+
"""Returns the number of cached tokens."""
|
|
135
|
+
return None
|
|
136
|
+
|
|
132
137
|
@computed_field
|
|
133
138
|
@property
|
|
134
139
|
def output_tokens(self) -> int | None:
|
|
@@ -139,7 +144,9 @@ class BedrockCallResponse(
|
|
|
139
144
|
@property
|
|
140
145
|
def cost(self) -> float | None:
|
|
141
146
|
"""Returns the cost of the call."""
|
|
142
|
-
return calculate_cost(
|
|
147
|
+
return calculate_cost(
|
|
148
|
+
self.input_tokens, self.cached_tokens, self.output_tokens, self.model
|
|
149
|
+
)
|
|
143
150
|
|
|
144
151
|
@computed_field
|
|
145
152
|
@cached_property
|
|
@@ -150,7 +157,6 @@ class BedrockCallResponse(
|
|
|
150
157
|
return AssistantMessageTypeDef(role="assistant", content=[])
|
|
151
158
|
return AssistantMessageTypeDef(role="assistant", content=message["content"])
|
|
152
159
|
|
|
153
|
-
@computed_field
|
|
154
160
|
@cached_property
|
|
155
161
|
def tools(self) -> list[BedrockTool] | None:
|
|
156
162
|
"""Returns any available tool calls as their `BedrockTool` definition.
|
|
@@ -180,7 +186,6 @@ class BedrockCallResponse(
|
|
|
180
186
|
|
|
181
187
|
return extracted_tools
|
|
182
188
|
|
|
183
|
-
@computed_field
|
|
184
189
|
@cached_property
|
|
185
190
|
def tool(self) -> BedrockTool | None:
|
|
186
191
|
"""Returns the 0th tool for the 0th choice message.
|
|
@@ -86,6 +86,11 @@ class BedrockCallResponseChunk(
|
|
|
86
86
|
return self.usage["inputTokens"]
|
|
87
87
|
return None
|
|
88
88
|
|
|
89
|
+
@property
|
|
90
|
+
def cached_tokens(self) -> int | None:
|
|
91
|
+
"""Returns the number of cached tokens."""
|
|
92
|
+
return None
|
|
93
|
+
|
|
89
94
|
@property
|
|
90
95
|
def output_tokens(self) -> int | None:
|
|
91
96
|
"""Returns the number of output tokens."""
|
mirascope/core/bedrock/stream.py
CHANGED
|
@@ -94,7 +94,9 @@ class BedrockStream(
|
|
|
94
94
|
@property
|
|
95
95
|
def cost(self) -> float | None:
|
|
96
96
|
"""Returns the cost of the call."""
|
|
97
|
-
return calculate_cost(
|
|
97
|
+
return calculate_cost(
|
|
98
|
+
self.input_tokens, self.cached_tokens, self.output_tokens, self.model
|
|
99
|
+
)
|
|
98
100
|
|
|
99
101
|
def _construct_message_param(
|
|
100
102
|
self,
|
|
@@ -18,13 +18,13 @@ CohereMessageParam: TypeAlias = ChatMessage | ToolResult | BaseMessageParam
|
|
|
18
18
|
|
|
19
19
|
__all__ = [
|
|
20
20
|
"AsyncCohereDynamicConfig",
|
|
21
|
-
"call",
|
|
22
|
-
"CohereDynamicConfig",
|
|
23
21
|
"CohereCallParams",
|
|
24
22
|
"CohereCallResponse",
|
|
25
23
|
"CohereCallResponseChunk",
|
|
24
|
+
"CohereDynamicConfig",
|
|
26
25
|
"CohereMessageParam",
|
|
27
26
|
"CohereStream",
|
|
28
27
|
"CohereTool",
|
|
28
|
+
"call",
|
|
29
29
|
"cohere_call",
|
|
30
30
|
]
|
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
|
|
4
4
|
def calculate_cost(
|
|
5
5
|
input_tokens: int | float | None,
|
|
6
|
+
cached_tokens: int | float | None,
|
|
6
7
|
output_tokens: int | float | None,
|
|
7
8
|
model: str = "command-r-plus",
|
|
8
9
|
) -> float | None:
|
|
@@ -10,9 +11,9 @@ def calculate_cost(
|
|
|
10
11
|
|
|
11
12
|
https://cohere.com/pricing
|
|
12
13
|
|
|
13
|
-
Model Input Output
|
|
14
|
-
command-r $0.5 / 1M tokens
|
|
15
|
-
command-r-plus $3 / 1M tokens
|
|
14
|
+
Model Input Cached Output
|
|
15
|
+
command-r $0.5 / 1M tokens $1.5 / 1M tokens
|
|
16
|
+
command-r-plus $3 / 1M tokens $15 / 1M tokens
|
|
16
17
|
"""
|
|
17
18
|
pricing = {
|
|
18
19
|
"command-r": {
|
|
@@ -105,6 +105,12 @@ class CohereCallResponse(
|
|
|
105
105
|
return self.usage.input_tokens
|
|
106
106
|
return None
|
|
107
107
|
|
|
108
|
+
@computed_field
|
|
109
|
+
@property
|
|
110
|
+
def cached_tokens(self) -> float | None:
|
|
111
|
+
"""Returns the number of cached tokens."""
|
|
112
|
+
return None
|
|
113
|
+
|
|
108
114
|
@computed_field
|
|
109
115
|
@property
|
|
110
116
|
def output_tokens(self) -> float | None:
|
|
@@ -117,7 +123,9 @@ class CohereCallResponse(
|
|
|
117
123
|
@property
|
|
118
124
|
def cost(self) -> float | None:
|
|
119
125
|
"""Returns the cost of the response."""
|
|
120
|
-
return calculate_cost(
|
|
126
|
+
return calculate_cost(
|
|
127
|
+
self.input_tokens, self.cached_tokens, self.output_tokens, self.model
|
|
128
|
+
)
|
|
121
129
|
|
|
122
130
|
@computed_field
|
|
123
131
|
@cached_property
|
|
@@ -129,7 +137,6 @@ class CohereCallResponse(
|
|
|
129
137
|
role="assistant", # pyright: ignore [reportCallIssue]
|
|
130
138
|
)
|
|
131
139
|
|
|
132
|
-
@computed_field
|
|
133
140
|
@cached_property
|
|
134
141
|
def tools(self) -> list[CohereTool] | None:
|
|
135
142
|
"""Returns the tools for the 0th choice message.
|
|
@@ -147,7 +154,6 @@ class CohereCallResponse(
|
|
|
147
154
|
break
|
|
148
155
|
return extracted_tools
|
|
149
156
|
|
|
150
|
-
@computed_field
|
|
151
157
|
@cached_property
|
|
152
158
|
def tool(self) -> CohereTool | None:
|
|
153
159
|
"""Returns the 0th tool for the 0th choice message.
|
|
@@ -100,6 +100,11 @@ class CohereCallResponseChunk(
|
|
|
100
100
|
return self.usage.input_tokens
|
|
101
101
|
return None
|
|
102
102
|
|
|
103
|
+
@property
|
|
104
|
+
def cached_tokens(self) -> float | None:
|
|
105
|
+
"""Returns the number of cached tokens."""
|
|
106
|
+
return None
|
|
107
|
+
|
|
103
108
|
@property
|
|
104
109
|
def output_tokens(self) -> float | None:
|
|
105
110
|
"""Returns the number of output tokens."""
|
mirascope/core/cohere/stream.py
CHANGED
|
@@ -62,7 +62,9 @@ class CohereStream(
|
|
|
62
62
|
@property
|
|
63
63
|
def cost(self) -> float | None:
|
|
64
64
|
"""Returns the cost of the call."""
|
|
65
|
-
return calculate_cost(
|
|
65
|
+
return calculate_cost(
|
|
66
|
+
self.input_tokens, self.cached_tokens, self.output_tokens, self.model
|
|
67
|
+
)
|
|
66
68
|
|
|
67
69
|
def _construct_message_param(
|
|
68
70
|
self, tool_calls: list[ToolCall] | None = None, content: str | None = None
|
|
@@ -28,13 +28,13 @@ warnings.warn(
|
|
|
28
28
|
)
|
|
29
29
|
|
|
30
30
|
__all__ = [
|
|
31
|
-
"call",
|
|
32
|
-
"GeminiDynamicConfig",
|
|
33
31
|
"GeminiCallParams",
|
|
34
32
|
"GeminiCallResponse",
|
|
35
33
|
"GeminiCallResponseChunk",
|
|
34
|
+
"GeminiDynamicConfig",
|
|
36
35
|
"GeminiMessageParam",
|
|
37
36
|
"GeminiStream",
|
|
38
37
|
"GeminiTool",
|
|
38
|
+
"call",
|
|
39
39
|
"gemini_call",
|
|
40
40
|
]
|