mirascope 1.18.2__py3-none-any.whl → 1.18.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. mirascope/__init__.py +20 -1
  2. mirascope/beta/openai/__init__.py +1 -1
  3. mirascope/beta/openai/realtime/__init__.py +1 -1
  4. mirascope/beta/openai/realtime/tool.py +1 -1
  5. mirascope/beta/rag/__init__.py +2 -2
  6. mirascope/beta/rag/base/__init__.py +2 -2
  7. mirascope/beta/rag/weaviate/__init__.py +1 -1
  8. mirascope/core/__init__.py +29 -6
  9. mirascope/core/anthropic/__init__.py +3 -3
  10. mirascope/core/anthropic/_utils/_calculate_cost.py +114 -47
  11. mirascope/core/anthropic/call_response.py +9 -3
  12. mirascope/core/anthropic/call_response_chunk.py +7 -0
  13. mirascope/core/anthropic/stream.py +3 -1
  14. mirascope/core/azure/__init__.py +2 -2
  15. mirascope/core/azure/_utils/_calculate_cost.py +4 -1
  16. mirascope/core/azure/call_response.py +9 -3
  17. mirascope/core/azure/call_response_chunk.py +5 -0
  18. mirascope/core/azure/stream.py +3 -1
  19. mirascope/core/base/__init__.py +11 -9
  20. mirascope/core/base/_utils/__init__.py +10 -10
  21. mirascope/core/base/_utils/_get_common_usage.py +8 -4
  22. mirascope/core/base/_utils/_get_create_fn_or_async_create_fn.py +2 -2
  23. mirascope/core/base/_utils/_protocols.py +9 -8
  24. mirascope/core/base/call_response.py +22 -22
  25. mirascope/core/base/call_response_chunk.py +12 -1
  26. mirascope/core/base/stream.py +24 -21
  27. mirascope/core/base/tool.py +7 -5
  28. mirascope/core/base/types.py +22 -5
  29. mirascope/core/bedrock/__init__.py +3 -3
  30. mirascope/core/bedrock/_utils/_calculate_cost.py +4 -1
  31. mirascope/core/bedrock/call_response.py +8 -3
  32. mirascope/core/bedrock/call_response_chunk.py +5 -0
  33. mirascope/core/bedrock/stream.py +3 -1
  34. mirascope/core/cohere/__init__.py +2 -2
  35. mirascope/core/cohere/_utils/_calculate_cost.py +4 -3
  36. mirascope/core/cohere/call_response.py +9 -3
  37. mirascope/core/cohere/call_response_chunk.py +5 -0
  38. mirascope/core/cohere/stream.py +3 -1
  39. mirascope/core/gemini/__init__.py +2 -2
  40. mirascope/core/gemini/_utils/_calculate_cost.py +4 -1
  41. mirascope/core/gemini/_utils/_convert_message_params.py +1 -1
  42. mirascope/core/gemini/call_response.py +9 -3
  43. mirascope/core/gemini/call_response_chunk.py +5 -0
  44. mirascope/core/gemini/stream.py +3 -1
  45. mirascope/core/google/__init__.py +2 -2
  46. mirascope/core/google/_utils/_calculate_cost.py +141 -14
  47. mirascope/core/google/_utils/_convert_message_params.py +120 -115
  48. mirascope/core/google/_utils/_message_param_converter.py +34 -33
  49. mirascope/core/google/_utils/_validate_media_type.py +34 -0
  50. mirascope/core/google/call_response.py +38 -10
  51. mirascope/core/google/call_response_chunk.py +17 -9
  52. mirascope/core/google/stream.py +20 -2
  53. mirascope/core/groq/__init__.py +2 -2
  54. mirascope/core/groq/_utils/_calculate_cost.py +12 -11
  55. mirascope/core/groq/call_response.py +9 -3
  56. mirascope/core/groq/call_response_chunk.py +5 -0
  57. mirascope/core/groq/stream.py +3 -1
  58. mirascope/core/litellm/__init__.py +1 -1
  59. mirascope/core/litellm/_utils/_setup_call.py +7 -3
  60. mirascope/core/mistral/__init__.py +2 -2
  61. mirascope/core/mistral/_utils/_calculate_cost.py +10 -9
  62. mirascope/core/mistral/call_response.py +9 -3
  63. mirascope/core/mistral/call_response_chunk.py +5 -0
  64. mirascope/core/mistral/stream.py +3 -1
  65. mirascope/core/openai/__init__.py +2 -2
  66. mirascope/core/openai/_utils/_calculate_cost.py +78 -37
  67. mirascope/core/openai/call_params.py +13 -0
  68. mirascope/core/openai/call_response.py +14 -3
  69. mirascope/core/openai/call_response_chunk.py +12 -0
  70. mirascope/core/openai/stream.py +6 -4
  71. mirascope/core/vertex/__init__.py +1 -1
  72. mirascope/core/vertex/_utils/_calculate_cost.py +1 -0
  73. mirascope/core/vertex/_utils/_convert_message_params.py +1 -1
  74. mirascope/core/vertex/call_response.py +9 -3
  75. mirascope/core/vertex/call_response_chunk.py +5 -0
  76. mirascope/core/vertex/stream.py +3 -1
  77. mirascope/integrations/_middleware_factory.py +6 -6
  78. mirascope/integrations/logfire/_utils.py +1 -1
  79. mirascope/llm/__init__.py +3 -1
  80. mirascope/llm/_protocols.py +5 -5
  81. mirascope/llm/call_response.py +16 -9
  82. mirascope/llm/llm_call.py +53 -25
  83. mirascope/llm/stream.py +43 -31
  84. mirascope/retries/__init__.py +1 -1
  85. mirascope/tools/__init__.py +2 -2
  86. {mirascope-1.18.2.dist-info → mirascope-1.18.4.dist-info}/METADATA +2 -2
  87. {mirascope-1.18.2.dist-info → mirascope-1.18.4.dist-info}/RECORD +89 -88
  88. {mirascope-1.18.2.dist-info → mirascope-1.18.4.dist-info}/WHEEL +0 -0
  89. {mirascope-1.18.2.dist-info → mirascope-1.18.4.dist-info}/licenses/LICENSE +0 -0
@@ -43,21 +43,26 @@ from ._setup_call import setup_call
43
43
  from ._setup_extract_tool import setup_extract_tool
44
44
 
45
45
  __all__ = [
46
+ "DEFAULT_TOOL_DOCSTRING",
46
47
  "AsyncCreateFn",
47
- "SameSyncAndAsyncClientSetupCall",
48
48
  "BaseType",
49
49
  "CalculateCost",
50
+ "CreateFn",
51
+ "GetJsonOutput",
52
+ "HandleStream",
53
+ "HandleStreamAsync",
54
+ "LLMFunctionDecorator",
55
+ "MessagesDecorator",
56
+ "SameSyncAndAsyncClientSetupCall",
57
+ "SetupCall",
50
58
  "convert_base_model_to_base_tool",
51
59
  "convert_base_type_to_base_tool",
52
60
  "convert_function_to_base_tool",
53
- "CreateFn",
54
- "DEFAULT_TOOL_DOCSTRING",
55
61
  "extract_tool_return",
56
62
  "fn_is_async",
57
63
  "format_template",
58
- "GetJsonOutput",
59
- "get_audio_type",
60
64
  "get_async_create_fn",
65
+ "get_audio_type",
61
66
  "get_common_usage",
62
67
  "get_create_fn",
63
68
  "get_document_type",
@@ -70,18 +75,13 @@ __all__ = [
70
75
  "get_template_values",
71
76
  "get_template_variables",
72
77
  "get_unsupported_tool_config_keys",
73
- "HandleStream",
74
- "HandleStreamAsync",
75
78
  "is_base_type",
76
79
  "is_prompt_template",
77
80
  "json_mode_content",
78
- "LLMFunctionDecorator",
79
- "MessagesDecorator",
80
81
  "messages_decorator",
81
82
  "parse_content_template",
82
83
  "parse_prompt_messages",
83
84
  "pil_image_to_bytes",
84
- "SetupCall",
85
85
  "setup_call",
86
86
  "setup_extract_tool",
87
87
  ]
@@ -2,15 +2,19 @@ from mirascope.core.base.types import Usage
2
2
 
3
3
 
4
4
  def get_common_usage(
5
- input_tokens: int | float | None, output_tokens: int | float | None
5
+ input_tokens: int | float | None,
6
+ cached_tokens: int | float | None,
7
+ output_tokens: int | float | None,
6
8
  ) -> Usage | None:
7
9
  """Get common usage from input and output tokens."""
8
- if input_tokens is None and output_tokens is None:
10
+ if input_tokens is None and cached_tokens is None and output_tokens is None:
9
11
  return None
10
12
  input_tokens = int(input_tokens or 0)
13
+ cached_tokens = int(cached_tokens or 0)
11
14
  output_tokens = int(output_tokens or 0)
12
15
  return Usage(
13
- prompt_tokens=input_tokens,
14
- completion_tokens=output_tokens,
16
+ input_tokens=input_tokens,
17
+ cached_tokens=cached_tokens,
18
+ output_tokens=output_tokens,
15
19
  total_tokens=input_tokens + output_tokens,
16
20
  )
@@ -53,7 +53,7 @@ def get_async_create_fn(
53
53
  async_generator_func: _AsyncGeneratorFunc[_StreamedResponse] | None = None,
54
54
  ) -> AsyncCreateFn[_NonStreamedResponse, _StreamedResponse]:
55
55
  @overload
56
- def create_or_stream(
56
+ def create_or_stream( # pyright: ignore[reportOverlappingOverload]
57
57
  *,
58
58
  stream: Literal[True] | StreamConfig = True,
59
59
  **kwargs: Any, # noqa: ANN401
@@ -100,7 +100,7 @@ def get_create_fn(
100
100
  sync_generator_func: _SyncGeneratorFunc[_StreamedResponse] | None = None,
101
101
  ) -> CreateFn[_NonStreamedResponse, _StreamedResponse]:
102
102
  @overload
103
- def create_or_stream(
103
+ def create_or_stream( # pyright: ignore[reportOverlappingOverload]
104
104
  *,
105
105
  stream: Literal[True] | StreamConfig = True,
106
106
  **kwargs: Any, # noqa: ANN401
@@ -129,7 +129,7 @@ class LLMFunctionDecorator(
129
129
 
130
130
  class AsyncCreateFn(Protocol[_ResponseT, _ResponseChunkT]):
131
131
  @overload
132
- def __call__(
132
+ def __call__( # pyright: ignore[reportOverlappingOverload]
133
133
  self,
134
134
  *,
135
135
  stream: Literal[False] = False,
@@ -153,7 +153,7 @@ class AsyncCreateFn(Protocol[_ResponseT, _ResponseChunkT]):
153
153
 
154
154
  class CreateFn(Protocol[_ResponseT, _ResponseChunkT]):
155
155
  @overload
156
- def __call__(
156
+ def __call__( # pyright: ignore[reportOverlappingOverload]
157
157
  self,
158
158
  *,
159
159
  stream: Literal[False] = False,
@@ -371,6 +371,7 @@ class CalculateCost(Protocol):
371
371
  def __call__(
372
372
  self,
373
373
  input_tokens: int | float | None,
374
+ cached_tokens: int | float | None,
374
375
  output_tokens: int | float | None,
375
376
  model: str,
376
377
  ) -> float | None: ... # pragma: no cover
@@ -390,7 +391,7 @@ class CallDecorator(
390
391
  ],
391
392
  ):
392
393
  @overload
393
- def __call__(
394
+ def __call__( # pyright: ignore[reportOverlappingOverload]
394
395
  self,
395
396
  model: str,
396
397
  *,
@@ -409,7 +410,7 @@ class CallDecorator(
409
410
  ]: ...
410
411
 
411
412
  @overload
412
- def __call__(
413
+ def __call__( # pyright: ignore[reportOverlappingOverload]
413
414
  self,
414
415
  model: str,
415
416
  *,
@@ -437,7 +438,7 @@ class CallDecorator(
437
438
  ) -> SyncLLMFunctionDecorator[_BaseDynamicConfigT, _BaseCallResponseT]: ...
438
439
 
439
440
  @overload
440
- def __call__(
441
+ def __call__( # pyright: ignore[reportOverlappingOverload]
441
442
  self,
442
443
  model: str,
443
444
  *,
@@ -576,7 +577,7 @@ class CallDecorator(
576
577
  ) -> NoReturn: ...
577
578
 
578
579
  @overload
579
- def __call__(
580
+ def __call__( # pyright: ignore[reportOverlappingOverload]
580
581
  self,
581
582
  model: str,
582
583
  *,
@@ -664,7 +665,7 @@ class CallDecorator(
664
665
  ) -> SyncLLMFunctionDecorator[_BaseDynamicConfigT, _ParsedOutputT]: ...
665
666
 
666
667
  @overload
667
- def __call__(
668
+ def __call__( # pyright: ignore[reportOverlappingOverload]
668
669
  self,
669
670
  model: str,
670
671
  *,
@@ -774,7 +775,7 @@ class CallDecorator(
774
775
  ####
775
776
 
776
777
  @overload
777
- def __call__(
778
+ def __call__( # pyright: ignore[reportOverlappingOverload]
778
779
  self,
779
780
  model: str,
780
781
  *,
@@ -7,7 +7,7 @@ import json
7
7
  from abc import ABC, abstractmethod
8
8
  from collections.abc import Callable
9
9
  from functools import cached_property, wraps
10
- from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeAlias, TypeVar
10
+ from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar
11
11
 
12
12
  from pydantic import (
13
13
  BaseModel,
@@ -24,7 +24,7 @@ from .call_params import BaseCallParams
24
24
  from .dynamic_config import BaseDynamicConfig
25
25
  from .metadata import Metadata
26
26
  from .tool import BaseTool
27
- from .types import FinishReason, Usage
27
+ from .types import FinishReason, JsonableType, Usage
28
28
 
29
29
  if TYPE_CHECKING:
30
30
  from ...llm.tool import Tool
@@ -35,36 +35,26 @@ _BaseToolT = TypeVar("_BaseToolT", bound=BaseTool)
35
35
  _ToolSchemaT = TypeVar("_ToolSchemaT")
36
36
  _BaseDynamicConfigT = TypeVar("_BaseDynamicConfigT", bound=BaseDynamicConfig)
37
37
  _MessageParamT = TypeVar("_MessageParamT", bound=Any)
38
+ _ToolMessageParamT = TypeVar("_ToolMessageParamT", bound=Any)
38
39
  _CallParamsT = TypeVar("_CallParamsT", bound=BaseCallParams)
39
40
  _UserMessageParamT = TypeVar("_UserMessageParamT", bound=Any)
40
41
  _BaseCallResponseT = TypeVar("_BaseCallResponseT", bound="BaseCallResponse")
41
42
 
42
43
 
43
- JsonableType: TypeAlias = (
44
- str
45
- | int
46
- | float
47
- | bool
48
- | bytes
49
- | list["JsonableType"]
50
- | set["JsonableType"]
51
- | tuple["JsonableType", ...]
52
- | dict[str, "JsonableType"]
53
- | BaseModel
54
- )
55
-
56
-
57
44
  def transform_tool_outputs(
58
- fn: Callable[[type[_BaseCallResponseT], list[tuple[_BaseToolT, str]]], list[Any]],
45
+ fn: Callable[
46
+ [type[_BaseCallResponseT], list[tuple[_BaseToolT, str]]],
47
+ list[_ToolMessageParamT],
48
+ ],
59
49
  ) -> Callable[
60
50
  [type[_BaseCallResponseT], list[tuple[_BaseToolT, JsonableType]]],
61
- list[Any],
51
+ list[_ToolMessageParamT],
62
52
  ]:
63
53
  @wraps(fn)
64
54
  def wrapper(
65
55
  cls: type[_BaseCallResponseT],
66
56
  tools_and_outputs: list[tuple[_BaseToolT, JsonableType]],
67
- ) -> list[Any]:
57
+ ) -> list[_ToolMessageParamT]:
68
58
  def recursive_serializer(value: JsonableType) -> BaseType:
69
59
  if isinstance(value, str):
70
60
  return value
@@ -212,6 +202,16 @@ class BaseCallResponse(
212
202
  """
213
203
  ...
214
204
 
205
+ @computed_field
206
+ @property
207
+ @abstractmethod
208
+ def cached_tokens(self) -> int | float | None:
209
+ """Should return the number of cached tokens.
210
+
211
+ If there is no cached_tokens, this method must return None.
212
+ """
213
+ ...
214
+
215
215
  @computed_field
216
216
  @property
217
217
  @abstractmethod
@@ -239,14 +239,12 @@ class BaseCallResponse(
239
239
  """Returns the assistant's response as a message parameter."""
240
240
  ...
241
241
 
242
- @computed_field
243
242
  @cached_property
244
243
  @abstractmethod
245
244
  def tools(self) -> list[_BaseToolT] | None:
246
245
  """Returns the tools for the 0th choice message."""
247
246
  ...
248
247
 
249
- @computed_field
250
248
  @cached_property
251
249
  @abstractmethod
252
250
  def tool(self) -> _BaseToolT | None:
@@ -297,4 +295,6 @@ class BaseCallResponse(
297
295
  @property
298
296
  def common_usage(self) -> Usage | None:
299
297
  """Provider-agnostic usage info."""
300
- return get_common_usage(self.input_tokens, self.output_tokens)
298
+ return get_common_usage(
299
+ self.input_tokens, self.cached_tokens, self.output_tokens
300
+ )
@@ -84,6 +84,15 @@ class BaseCallResponseChunk(BaseModel, Generic[_ChunkT, _FinishReasonT], ABC):
84
84
  """
85
85
  ...
86
86
 
87
+ @property
88
+ @abstractmethod
89
+ def cached_tokens(self) -> int | float | None:
90
+ """Should return the number of cached tokens.
91
+
92
+ If there is no cached_tokens, this method must return None.
93
+ """
94
+ ...
95
+
87
96
  @property
88
97
  @abstractmethod
89
98
  def output_tokens(self) -> int | float | None:
@@ -102,4 +111,6 @@ class BaseCallResponseChunk(BaseModel, Generic[_ChunkT, _FinishReasonT], ABC):
102
111
  @property
103
112
  def common_usage(self) -> Usage | None:
104
113
  """Provider-agnostic usage info."""
105
- return get_common_usage(self.input_tokens, self.output_tokens)
114
+ return get_common_usage(
115
+ self.input_tokens, self.cached_tokens, self.output_tokens
116
+ )
@@ -92,6 +92,7 @@ class BaseStream(
92
92
  user_message_param: _UserMessageParamT | None = None
93
93
  message_param: _AssistantMessageParamT
94
94
  input_tokens: int | float | None = None
95
+ cached_tokens: int | float | None = None
95
96
  output_tokens: int | float | None = None
96
97
  id: str | None = None
97
98
  finish_reasons: list[_FinishReason] | None = None
@@ -138,9 +139,9 @@ class BaseStream(
138
139
  self,
139
140
  ) -> Generator[tuple[_BaseCallResponseChunkT, _BaseToolT | None], None, None]:
140
141
  """Iterator over the stream and stores useful information."""
141
- assert isinstance(
142
- self.stream, Generator
143
- ), "Stream must be a generator for __iter__"
142
+ assert isinstance(self.stream, Generator), (
143
+ "Stream must be a generator for __iter__"
144
+ )
144
145
  self.content, tool_calls = "", []
145
146
  self.start_time = datetime.datetime.now().timestamp() * 1000
146
147
  for chunk, tool in self.stream:
@@ -161,12 +162,12 @@ class BaseStream(
161
162
  """Iterates over the stream and stores useful information."""
162
163
  self.content = ""
163
164
 
164
- async def generator() -> (
165
- AsyncGenerator[tuple[_BaseCallResponseChunkT, _BaseToolT | None], None]
166
- ):
167
- assert isinstance(
168
- self.stream, AsyncGenerator
169
- ), "Stream must be an async generator for __aiter__"
165
+ async def generator() -> AsyncGenerator[
166
+ tuple[_BaseCallResponseChunkT, _BaseToolT | None], None
167
+ ]:
168
+ assert isinstance(self.stream, AsyncGenerator), (
169
+ "Stream must be an async generator for __aiter__"
170
+ )
170
171
  tool_calls = []
171
172
  async for chunk, tool in self.stream:
172
173
  self._update_properties(chunk)
@@ -190,6 +191,12 @@ class BaseStream(
190
191
  if not self.input_tokens
191
192
  else self.input_tokens + chunk.input_tokens
192
193
  )
194
+ if chunk.cached_tokens is not None:
195
+ self.cached_tokens = (
196
+ chunk.cached_tokens
197
+ if not self.cached_tokens
198
+ else self.cached_tokens + chunk.cached_tokens
199
+ )
193
200
  if chunk.output_tokens is not None:
194
201
  self.output_tokens = (
195
202
  chunk.output_tokens
@@ -373,11 +380,9 @@ def stream_factory( # noqa: ANN201
373
380
  stream=True,
374
381
  )
375
382
 
376
- async def generator() -> (
377
- AsyncGenerator[
378
- tuple[_BaseCallResponseChunkT, _BaseToolT | None], None
379
- ]
380
- ):
383
+ async def generator() -> AsyncGenerator[
384
+ tuple[_BaseCallResponseChunkT, _BaseToolT | None], None
385
+ ]:
381
386
  async for chunk, tool in handle_stream_async(
382
387
  await create(stream=True, **call_kwargs),
383
388
  tool_types,
@@ -422,13 +427,11 @@ def stream_factory( # noqa: ANN201
422
427
  stream=True,
423
428
  )
424
429
 
425
- def generator() -> (
426
- Generator[
427
- tuple[_BaseCallResponseChunkT, _BaseToolT | None],
428
- None,
429
- None,
430
- ]
431
- ):
430
+ def generator() -> Generator[
431
+ tuple[_BaseCallResponseChunkT, _BaseToolT | None],
432
+ None,
433
+ None,
434
+ ]:
432
435
  yield from handle_stream(
433
436
  create(stream=True, **call_kwargs),
434
437
  tool_types,
@@ -36,17 +36,19 @@ class ToolConfig(TypedDict, total=False):
36
36
  class GenerateJsonSchemaNoTitles(GenerateJsonSchema):
37
37
  _openai_strict: ClassVar[bool] = False
38
38
 
39
- def _remove_title(self, obj: Any) -> Any: # noqa: ANN401
39
+ def _remove_title(self, key: str | None, obj: Any) -> Any: # noqa: ANN401
40
40
  if isinstance(obj, dict):
41
41
  if self._openai_strict and "type" in obj and obj["type"] == "object":
42
42
  obj["additionalProperties"] = False
43
43
  if "type" in obj or "$ref" in obj or "properties" in obj:
44
- obj.pop("title", None)
44
+ title = obj.pop("title", None)
45
+ if key and title and key.lower() != title.lower():
46
+ obj["title"] = title
45
47
 
46
48
  for key, value in list(obj.items()):
47
- obj[key] = self._remove_title(value)
49
+ obj[key] = self._remove_title(key, value)
48
50
  elif isinstance(obj, list):
49
- return [self._remove_title(item) for item in obj]
51
+ return [self._remove_title(None, item) for item in obj]
50
52
 
51
53
  return obj
52
54
 
@@ -56,7 +58,7 @@ class GenerateJsonSchemaNoTitles(GenerateJsonSchema):
56
58
  json_schema = super().generate(schema, mode=mode)
57
59
  json_schema.pop("title", None)
58
60
  json_schema.pop("description", None)
59
- json_schema = self._remove_title(json_schema)
61
+ json_schema = self._remove_title(None, json_schema)
60
62
  return json_schema
61
63
 
62
64
 
@@ -43,11 +43,28 @@ FinishReason: TypeAlias = Literal["stop", "length", "tool_calls", "content_filte
43
43
 
44
44
 
45
45
  class Usage(BaseModel):
46
- completion_tokens: int = 0
47
- """Number of tokens in the generated completion."""
48
-
49
- prompt_tokens: int = 0
46
+ input_tokens: int
50
47
  """Number of tokens in the prompt."""
51
48
 
52
- total_tokens: int = 0
49
+ cached_tokens: int
50
+ """Number of tokens used that were previously cached (and thus cheaper)."""
51
+
52
+ output_tokens: int
53
+ """Number of tokens in the generated output."""
54
+
55
+ total_tokens: int
53
56
  """Total number of tokens used in the request (prompt + completion)."""
57
+
58
+
59
+ JsonableType: TypeAlias = (
60
+ str
61
+ | int
62
+ | float
63
+ | bool
64
+ | bytes
65
+ | list["JsonableType"]
66
+ | set["JsonableType"]
67
+ | tuple["JsonableType", ...]
68
+ | dict[str, "JsonableType"]
69
+ | BaseModel
70
+ )
@@ -18,9 +18,8 @@ from .tool import BedrockTool, BedrockToolConfig
18
18
  BedrockMessageParam: TypeAlias = InternalBedrockMessageParam | BaseMessageParam
19
19
 
20
20
  __all__ = [
21
- "AsyncBedrockDynamicConfig",
22
21
  "AssistantMessageTypeDef",
23
- "bedrock_call",
22
+ "AsyncBedrockDynamicConfig",
24
23
  "BedrockCallParams",
25
24
  "BedrockCallResponse",
26
25
  "BedrockCallResponseChunk",
@@ -29,6 +28,7 @@ __all__ = [
29
28
  "BedrockStream",
30
29
  "BedrockTool",
31
30
  "BedrockToolConfig",
32
- "call",
33
31
  "UserMessageTypeDef",
32
+ "bedrock_call",
33
+ "call",
34
34
  ]
@@ -2,7 +2,10 @@
2
2
 
3
3
 
4
4
  def calculate_cost(
5
- input_tokens: int | float | None, output_tokens: int | float | None, model: str
5
+ input_tokens: int | float | None,
6
+ cached_tokens: int | float | None,
7
+ output_tokens: int | float | None,
8
+ model: str,
6
9
  ) -> float | None:
7
10
  """Calculate the cost of a completion using the Bedrock API."""
8
11
  # NOTE: We are currently investigating a dynamic approach to determine costs
@@ -129,6 +129,11 @@ class BedrockCallResponse(
129
129
  """Returns the number of input tokens."""
130
130
  return self.usage["inputTokens"] if self.usage else None
131
131
 
132
+ @property
133
+ def cached_tokens(self) -> int | None:
134
+ """Returns the number of cached tokens."""
135
+ return None
136
+
132
137
  @computed_field
133
138
  @property
134
139
  def output_tokens(self) -> int | None:
@@ -139,7 +144,9 @@ class BedrockCallResponse(
139
144
  @property
140
145
  def cost(self) -> float | None:
141
146
  """Returns the cost of the call."""
142
- return calculate_cost(self.input_tokens, self.output_tokens, self.model)
147
+ return calculate_cost(
148
+ self.input_tokens, self.cached_tokens, self.output_tokens, self.model
149
+ )
143
150
 
144
151
  @computed_field
145
152
  @cached_property
@@ -150,7 +157,6 @@ class BedrockCallResponse(
150
157
  return AssistantMessageTypeDef(role="assistant", content=[])
151
158
  return AssistantMessageTypeDef(role="assistant", content=message["content"])
152
159
 
153
- @computed_field
154
160
  @cached_property
155
161
  def tools(self) -> list[BedrockTool] | None:
156
162
  """Returns any available tool calls as their `BedrockTool` definition.
@@ -180,7 +186,6 @@ class BedrockCallResponse(
180
186
 
181
187
  return extracted_tools
182
188
 
183
- @computed_field
184
189
  @cached_property
185
190
  def tool(self) -> BedrockTool | None:
186
191
  """Returns the 0th tool for the 0th choice message.
@@ -86,6 +86,11 @@ class BedrockCallResponseChunk(
86
86
  return self.usage["inputTokens"]
87
87
  return None
88
88
 
89
+ @property
90
+ def cached_tokens(self) -> int | None:
91
+ """Returns the number of cached tokens."""
92
+ return None
93
+
89
94
  @property
90
95
  def output_tokens(self) -> int | None:
91
96
  """Returns the number of output tokens."""
@@ -94,7 +94,9 @@ class BedrockStream(
94
94
  @property
95
95
  def cost(self) -> float | None:
96
96
  """Returns the cost of the call."""
97
- return calculate_cost(self.input_tokens, self.output_tokens, self.model)
97
+ return calculate_cost(
98
+ self.input_tokens, self.cached_tokens, self.output_tokens, self.model
99
+ )
98
100
 
99
101
  def _construct_message_param(
100
102
  self,
@@ -18,13 +18,13 @@ CohereMessageParam: TypeAlias = ChatMessage | ToolResult | BaseMessageParam
18
18
 
19
19
  __all__ = [
20
20
  "AsyncCohereDynamicConfig",
21
- "call",
22
- "CohereDynamicConfig",
23
21
  "CohereCallParams",
24
22
  "CohereCallResponse",
25
23
  "CohereCallResponseChunk",
24
+ "CohereDynamicConfig",
26
25
  "CohereMessageParam",
27
26
  "CohereStream",
28
27
  "CohereTool",
28
+ "call",
29
29
  "cohere_call",
30
30
  ]
@@ -3,6 +3,7 @@
3
3
 
4
4
  def calculate_cost(
5
5
  input_tokens: int | float | None,
6
+ cached_tokens: int | float | None,
6
7
  output_tokens: int | float | None,
7
8
  model: str = "command-r-plus",
8
9
  ) -> float | None:
@@ -10,9 +11,9 @@ def calculate_cost(
10
11
 
11
12
  https://cohere.com/pricing
12
13
 
13
- Model Input Output
14
- command-r $0.5 / 1M tokens $1.5 / 1M tokens
15
- command-r-plus $3 / 1M tokens $15 / 1M tokens
14
+ Model Input Cached Output
15
+ command-r $0.5 / 1M tokens $1.5 / 1M tokens
16
+ command-r-plus $3 / 1M tokens $15 / 1M tokens
16
17
  """
17
18
  pricing = {
18
19
  "command-r": {
@@ -105,6 +105,12 @@ class CohereCallResponse(
105
105
  return self.usage.input_tokens
106
106
  return None
107
107
 
108
+ @computed_field
109
+ @property
110
+ def cached_tokens(self) -> float | None:
111
+ """Returns the number of cached tokens."""
112
+ return None
113
+
108
114
  @computed_field
109
115
  @property
110
116
  def output_tokens(self) -> float | None:
@@ -117,7 +123,9 @@ class CohereCallResponse(
117
123
  @property
118
124
  def cost(self) -> float | None:
119
125
  """Returns the cost of the response."""
120
- return calculate_cost(self.input_tokens, self.output_tokens, self.model)
126
+ return calculate_cost(
127
+ self.input_tokens, self.cached_tokens, self.output_tokens, self.model
128
+ )
121
129
 
122
130
  @computed_field
123
131
  @cached_property
@@ -129,7 +137,6 @@ class CohereCallResponse(
129
137
  role="assistant", # pyright: ignore [reportCallIssue]
130
138
  )
131
139
 
132
- @computed_field
133
140
  @cached_property
134
141
  def tools(self) -> list[CohereTool] | None:
135
142
  """Returns the tools for the 0th choice message.
@@ -147,7 +154,6 @@ class CohereCallResponse(
147
154
  break
148
155
  return extracted_tools
149
156
 
150
- @computed_field
151
157
  @cached_property
152
158
  def tool(self) -> CohereTool | None:
153
159
  """Returns the 0th tool for the 0th choice message.
@@ -100,6 +100,11 @@ class CohereCallResponseChunk(
100
100
  return self.usage.input_tokens
101
101
  return None
102
102
 
103
+ @property
104
+ def cached_tokens(self) -> float | None:
105
+ """Returns the number of cached tokens."""
106
+ return None
107
+
103
108
  @property
104
109
  def output_tokens(self) -> float | None:
105
110
  """Returns the number of output tokens."""
@@ -62,7 +62,9 @@ class CohereStream(
62
62
  @property
63
63
  def cost(self) -> float | None:
64
64
  """Returns the cost of the call."""
65
- return calculate_cost(self.input_tokens, self.output_tokens, self.model)
65
+ return calculate_cost(
66
+ self.input_tokens, self.cached_tokens, self.output_tokens, self.model
67
+ )
66
68
 
67
69
  def _construct_message_param(
68
70
  self, tool_calls: list[ToolCall] | None = None, content: str | None = None
@@ -28,13 +28,13 @@ warnings.warn(
28
28
  )
29
29
 
30
30
  __all__ = [
31
- "call",
32
- "GeminiDynamicConfig",
33
31
  "GeminiCallParams",
34
32
  "GeminiCallResponse",
35
33
  "GeminiCallResponseChunk",
34
+ "GeminiDynamicConfig",
36
35
  "GeminiMessageParam",
37
36
  "GeminiStream",
38
37
  "GeminiTool",
38
+ "call",
39
39
  "gemini_call",
40
40
  ]