mirascope 1.20.0__py3-none-any.whl → 1.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -14,6 +14,7 @@ from .base import (
14
14
  BaseTool,
15
15
  BaseToolKit,
16
16
  CacheControlPart,
17
+ CostMetadata,
17
18
  DocumentPart,
18
19
  FromCallArgs,
19
20
  ImagePart,
@@ -74,6 +75,7 @@ __all__ = [
74
75
  "BaseTool",
75
76
  "BaseToolKit",
76
77
  "CacheControlPart",
78
+ "CostMetadata",
77
79
  "DocumentPart",
78
80
  "FromCallArgs",
79
81
  "ImagePart",
@@ -20,9 +20,7 @@ from ..base.types import CostMetadata
20
20
  from ._utils._convert_finish_reason_to_common_finish_reasons import (
21
21
  _convert_finish_reasons_to_common_finish_reasons,
22
22
  )
23
- from ._utils._message_param_converter import (
24
- AnthropicMessageParamConverter,
25
- )
23
+ from ._utils._message_param_converter import AnthropicMessageParamConverter
26
24
  from .call_params import AnthropicCallParams
27
25
  from .dynamic_config import AnthropicDynamicConfig, AsyncAnthropicDynamicConfig
28
26
  from .tool import AnthropicTool
@@ -37,6 +35,7 @@ class AnthropicCallResponse(
37
35
  MessageParam,
38
36
  AnthropicCallParams,
39
37
  MessageParam,
38
+ AnthropicMessageParamConverter,
40
39
  ]
41
40
  ):
42
41
  """A convenience wrapper around the Anthropic `Message` response.
@@ -62,6 +61,10 @@ class AnthropicCallResponse(
62
61
  ```
63
62
  """
64
63
 
64
+ _message_converter: type[AnthropicMessageParamConverter] = (
65
+ AnthropicMessageParamConverter
66
+ )
67
+
65
68
  _provider = "anthropic"
66
69
 
67
70
  @computed_field
@@ -38,6 +38,7 @@ class AzureCallResponse(
38
38
  ChatRequestMessage,
39
39
  AzureCallParams,
40
40
  UserMessage,
41
+ AzureMessageParamConverter,
41
42
  ]
42
43
  ):
43
44
  """A convenience wrapper around the Azure `ChatCompletion` response.
@@ -64,6 +65,7 @@ class AzureCallResponse(
64
65
  """
65
66
 
66
67
  response: SkipValidation[ChatCompletions]
68
+ _message_converter: type[AzureMessageParamConverter] = AzureMessageParamConverter
67
69
 
68
70
  _provider = "azure"
69
71
 
@@ -30,7 +30,14 @@ from .stream import BaseStream
30
30
  from .structured_stream import BaseStructuredStream
31
31
  from .tool import BaseTool, GenerateJsonSchemaNoTitles, ToolConfig
32
32
  from .toolkit import BaseToolKit, toolkit_tool
33
- from .types import AudioSegment, JsonableType, LocalProvider, Provider, Usage
33
+ from .types import (
34
+ AudioSegment,
35
+ CostMetadata,
36
+ JsonableType,
37
+ LocalProvider,
38
+ Provider,
39
+ Usage,
40
+ )
34
41
 
35
42
  __all__ = [
36
43
  "AudioPart",
@@ -50,6 +57,7 @@ __all__ = [
50
57
  "BaseType",
51
58
  "CacheControlPart",
52
59
  "CommonCallParams",
60
+ "CostMetadata",
53
61
  "DocumentPart",
54
62
  "FromCallArgs",
55
63
  "GenerateJsonSchemaNoTitles",
@@ -187,7 +187,7 @@ def create_factory( # noqa: ANN202
187
187
  start_time = datetime.datetime.now().timestamp() * 1000
188
188
  response = await create(stream=False, **call_kwargs)
189
189
  end_time = datetime.datetime.now().timestamp() * 1000
190
- output = TCallResponse(
190
+ output = TCallResponse( # pyright: ignore [reportCallIssue]
191
191
  metadata=get_metadata(fn, dynamic_config),
192
192
  response=response,
193
193
  tool_types=tool_types, # pyright: ignore [reportArgumentType]
@@ -231,7 +231,7 @@ def create_factory( # noqa: ANN202
231
231
  start_time = datetime.datetime.now().timestamp() * 1000
232
232
  response = create(stream=False, **call_kwargs)
233
233
  end_time = datetime.datetime.now().timestamp() * 1000
234
- output = TCallResponse(
234
+ output = TCallResponse( # pyright: ignore [reportCallIssue]
235
235
  metadata=get_metadata(fn, dynamic_config),
236
236
  response=response,
237
237
  tool_types=tool_types, # pyright: ignore [reportArgumentType]
@@ -1,5 +1,6 @@
1
1
  """Internal Utilities."""
2
2
 
3
+ from ._base_message_param_converter import BaseMessageParamConverter
3
4
  from ._base_type import BaseType, is_base_type
4
5
  from ._convert_base_model_to_base_tool import convert_base_model_to_base_tool
5
6
  from ._convert_base_type_to_base_tool import convert_base_type_to_base_tool
@@ -46,8 +47,10 @@ from ._setup_extract_tool import setup_extract_tool
46
47
  __all__ = [
47
48
  "DEFAULT_TOOL_DOCSTRING",
48
49
  "AsyncCreateFn",
50
+ "BaseMessageParamConverter",
49
51
  "BaseType",
50
52
  "CalculateCost",
53
+ "CallDecorator",
51
54
  "CreateFn",
52
55
  "GetJsonOutput",
53
56
  "HandleStream",
@@ -3,7 +3,7 @@
3
3
  from abc import ABC, abstractmethod
4
4
  from typing import Any
5
5
 
6
- from mirascope.core import BaseMessageParam
6
+ from mirascope.core.base.message_param import BaseMessageParam
7
7
 
8
8
 
9
9
  class BaseMessageParamConverter(ABC):
@@ -19,7 +19,7 @@ from pydantic import (
19
19
  )
20
20
 
21
21
  from ..costs import calculate_cost
22
- from ._utils import BaseType, get_common_usage
22
+ from ._utils import BaseMessageParamConverter, BaseType, get_common_usage
23
23
  from .call_kwargs import BaseCallKwargs
24
24
  from .call_params import BaseCallParams
25
25
  from .dynamic_config import BaseDynamicConfig
@@ -40,6 +40,9 @@ _ToolMessageParamT = TypeVar("_ToolMessageParamT", bound=Any)
40
40
  _CallParamsT = TypeVar("_CallParamsT", bound=BaseCallParams)
41
41
  _UserMessageParamT = TypeVar("_UserMessageParamT", bound=Any)
42
42
  _BaseCallResponseT = TypeVar("_BaseCallResponseT", bound="BaseCallResponse")
43
+ _BaseMessageParamConverterT = TypeVar(
44
+ "_BaseMessageParamConverterT", bound=BaseMessageParamConverter
45
+ )
43
46
 
44
47
 
45
48
  def transform_tool_outputs(
@@ -97,6 +100,7 @@ class BaseCallResponse(
97
100
  _MessageParamT,
98
101
  _CallParamsT,
99
102
  _UserMessageParamT,
103
+ _BaseMessageParamConverterT,
100
104
  ],
101
105
  ABC,
102
106
  ):
@@ -131,6 +135,7 @@ class BaseCallResponse(
131
135
  start_time: float
132
136
  end_time: float
133
137
 
138
+ _message_converter: type[_BaseMessageParamConverterT]
134
139
  _provider: ClassVar[str] = "NO PROVIDER"
135
140
  _model: str = "NO MODEL"
136
141
 
@@ -313,6 +318,11 @@ class BaseCallResponse(
313
318
  """Provider-agnostic user message param."""
314
319
  ...
315
320
 
321
+ @property
322
+ def common_messages(self) -> list[BaseMessageParam]:
323
+ """Provider-agnostic list of messages."""
324
+ return self._message_converter.from_provider(self.messages)
325
+
316
326
  @property
317
327
  def common_tools(self) -> list[Tool] | None:
318
328
  """Provider-agnostic tools."""
@@ -51,6 +51,7 @@ class BedrockCallResponse(
51
51
  InternalBedrockMessageParam,
52
52
  BedrockCallParams,
53
53
  UserMessageTypeDef,
54
+ BedrockMessageParamConverter,
54
55
  ]
55
56
  ):
56
57
  """A convenience wrapper around the Bedrock `ChatCompletion` response.
@@ -78,6 +79,9 @@ class BedrockCallResponse(
78
79
  """
79
80
 
80
81
  response: SkipValidation[SyncConverseResponseTypeDef | AsyncConverseResponseTypeDef]
82
+ _message_converter: type[BedrockMessageParamConverter] = (
83
+ BedrockMessageParamConverter
84
+ )
81
85
 
82
86
  _provider = "bedrock"
83
87
 
@@ -35,6 +35,7 @@ class CohereCallResponse(
35
35
  SkipValidation[ChatMessage],
36
36
  CohereCallParams,
37
37
  SkipValidation[ChatMessage],
38
+ CohereMessageParamConverter,
38
39
  ]
39
40
  ):
40
41
  """A convenience wrapper around the Cohere `ChatCompletion` response.
@@ -60,6 +61,8 @@ class CohereCallResponse(
60
61
  ```
61
62
  """
62
63
 
64
+ _message_converter: type[CohereMessageParamConverter] = CohereMessageParamConverter
65
+
63
66
  _provider = "cohere"
64
67
 
65
68
  @computed_field
@@ -36,6 +36,7 @@ class GeminiCallResponse(
36
36
  ContentsType,
37
37
  GeminiCallParams,
38
38
  ContentDict,
39
+ GeminiMessageParamConverter,
39
40
  ]
40
41
  ):
41
42
  """A convenience wrapper around the Gemini API response.
@@ -62,6 +63,8 @@ class GeminiCallResponse(
62
63
  ```
63
64
  """
64
65
 
66
+ _message_converter: type[GeminiMessageParamConverter] = GeminiMessageParamConverter
67
+
65
68
  _provider = "gemini"
66
69
 
67
70
  @computed_field
@@ -40,6 +40,7 @@ class GoogleCallResponse(
40
40
  ContentListUnion | ContentListUnionDict,
41
41
  GoogleCallParams,
42
42
  ContentDict,
43
+ GoogleMessageParamConverter,
43
44
  ]
44
45
  ):
45
46
  """A convenience wrapper around the Google API response.
@@ -66,6 +67,8 @@ class GoogleCallResponse(
66
67
  ```
67
68
  """
68
69
 
70
+ _message_converter: type[GoogleMessageParamConverter] = GoogleMessageParamConverter
71
+
69
72
  _provider = "google"
70
73
 
71
74
  @computed_field
@@ -70,17 +70,48 @@ class GoogleTool(BaseTool):
70
70
  fn["parameters"] = model_schema
71
71
 
72
72
  if "parameters" in fn:
73
+ # Resolve $defs and $ref
73
74
  if "$defs" in fn["parameters"]:
74
- raise ValueError(
75
- "Unfortunately Google's Google API cannot handle nested structures "
76
- "with $defs."
77
- )
75
+ defs = fn["parameters"].pop("$defs")
76
+
77
+ def resolve_refs(schema: dict[str, Any]) -> dict[str, Any]:
78
+ """Recursively resolve $ref references using the $defs dictionary."""
79
+ # If this is a reference, resolve it
80
+ if "$ref" in schema:
81
+ ref = schema["$ref"]
82
+ if ref.startswith("#/$defs/"):
83
+ ref_key = ref.replace("#/$defs/", "")
84
+ if ref_key in defs:
85
+ # Merge the definition with the current schema (excluding $ref)
86
+ resolved = {
87
+ **{k: v for k, v in schema.items() if k != "$ref"},
88
+ **resolve_refs(defs[ref_key]),
89
+ }
90
+ return resolved
91
+
92
+ # Process all other keys recursively
93
+ result = {}
94
+ for key, value in schema.items():
95
+ if isinstance(value, dict):
96
+ result[key] = resolve_refs(value)
97
+ elif isinstance(value, list):
98
+ result[key] = [
99
+ resolve_refs(item) if isinstance(item, dict) else item
100
+ for item in value
101
+ ]
102
+ else:
103
+ result[key] = value
104
+ return result
105
+
106
+ # Resolve all references in the parameters
107
+ fn["parameters"] = resolve_refs(fn["parameters"])
78
108
 
79
109
  def handle_enum_schema(prop_schema: dict[str, Any]) -> dict[str, Any]:
80
110
  if "enum" in prop_schema:
81
111
  prop_schema["format"] = "enum"
82
112
  return prop_schema
83
113
 
114
+ # Process properties after resolving references
84
115
  fn["parameters"]["properties"] = {
85
116
  prop: {
86
117
  key: value
@@ -35,6 +35,7 @@ class GroqCallResponse(
35
35
  ChatCompletionMessageParam,
36
36
  GroqCallParams,
37
37
  ChatCompletionUserMessageParam,
38
+ GroqMessageParamConverter,
38
39
  ]
39
40
  ):
40
41
  """A convenience wrapper around the Groq `ChatCompletion` response.
@@ -59,6 +60,8 @@ class GroqCallResponse(
59
60
  ```
60
61
  """
61
62
 
63
+ _message_converter: type[GroqMessageParamConverter] = GroqMessageParamConverter
64
+
62
65
  _provider = "groq"
63
66
 
64
67
  @computed_field
@@ -38,6 +38,7 @@ class MistralCallResponse(
38
38
  AssistantMessage | SystemMessage | ToolMessage | UserMessage,
39
39
  MistralCallParams,
40
40
  UserMessage,
41
+ MistralMessageParamConverter,
41
42
  ]
42
43
  ):
43
44
  """A convenience wrapper around the Mistral `ChatCompletion` response.
@@ -62,6 +63,10 @@ class MistralCallResponse(
62
63
  ```
63
64
  """
64
65
 
66
+ _message_converter: type[MistralMessageParamConverter] = (
67
+ MistralMessageParamConverter
68
+ )
69
+
65
70
  _provider = "mistral"
66
71
 
67
72
  @property
@@ -58,6 +58,7 @@ class OpenAICallResponse(
58
58
  ChatCompletionMessageParam,
59
59
  OpenAICallParams,
60
60
  ChatCompletionUserMessageParam,
61
+ OpenAIMessageParamConverter,
61
62
  ]
62
63
  ):
63
64
  """A convenience wrapper around the OpenAI `ChatCompletion` response.
@@ -84,6 +85,7 @@ class OpenAICallResponse(
84
85
  """
85
86
 
86
87
  response: SkipValidation[ChatCompletion]
88
+ _message_converter: type[OpenAIMessageParamConverter] = OpenAIMessageParamConverter
87
89
 
88
90
  _provider = "openai"
89
91
 
@@ -30,6 +30,7 @@ class VertexCallResponse(
30
30
  Content,
31
31
  VertexCallParams,
32
32
  Content,
33
+ VertexMessageParamConverter,
33
34
  ]
34
35
  ):
35
36
  """A convenience wrapper around the Vertex AI `GenerateContentResponse`.
@@ -55,6 +56,8 @@ class VertexCallResponse(
55
56
  ```
56
57
  """
57
58
 
59
+ _message_converter: type[VertexMessageParamConverter] = VertexMessageParamConverter
60
+
58
61
  _provider = "vertex"
59
62
 
60
63
  @computed_field
mirascope/llm/__init__.py CHANGED
@@ -1,13 +1,18 @@
1
- from ..core import LocalProvider, Provider, calculate_cost
1
+ from ..core import CostMetadata, LocalProvider, Provider, calculate_cost
2
+ from ._call import call
3
+ from ._context import context
4
+ from ._override import override
2
5
  from .call_response import CallResponse
3
- from .llm_call import call
4
- from .llm_override import override
6
+ from .stream import Stream
5
7
 
6
8
  __all__ = [
7
9
  "CallResponse",
10
+ "CostMetadata",
8
11
  "LocalProvider",
9
12
  "Provider",
13
+ "Stream",
10
14
  "calculate_cost",
11
15
  "call",
16
+ "context",
12
17
  "override",
13
18
  ]
@@ -20,6 +20,7 @@ from ..core.base import (
20
20
  from ..core.base._utils import fn_is_async
21
21
  from ..core.base.stream_config import StreamConfig
22
22
  from ..core.base.types import LocalProvider, Provider
23
+ from ._context import CallArgs, apply_context_overrides_to_call_args
23
24
  from ._protocols import (
24
25
  AsyncLLMFunctionDecorator,
25
26
  CallDecorator,
@@ -193,13 +194,9 @@ def _call(
193
194
  ]
194
195
  ):
195
196
  """Decorator for defining a function that calls a language model."""
196
- if provider in get_args(LocalProvider):
197
- provider_call, client = _get_local_provider_call(
198
- cast(LocalProvider, provider), client
199
- )
200
- else:
201
- provider_call = _get_provider_call(cast(Provider, provider))
202
- _original_args = {
197
+ # Store original call args that will be used for each function call
198
+ original_call_args: CallArgs = {
199
+ "provider": provider,
203
200
  "model": model,
204
201
  "stream": stream,
205
202
  "tools": tools,
@@ -214,36 +211,112 @@ def _call(
214
211
  fn: Callable[_P, _R | Awaitable[_R]],
215
212
  ) -> Callable[
216
213
  _P,
217
- CallResponse | Stream | Awaitable[CallResponse | Stream],
214
+ CallResponse
215
+ | Stream
216
+ | _ResponseModelT
217
+ | _ParsedOutputT
218
+ | (_ResponseModelT | CallResponse)
219
+ | Awaitable[CallResponse]
220
+ | Awaitable[Stream]
221
+ | Awaitable[_ResponseModelT]
222
+ | Awaitable[_ParsedOutputT]
223
+ | Awaitable[(_ResponseModelT | CallResponse)],
218
224
  ]:
219
- decorated = provider_call(**_original_args)(fn)
225
+ if fn_is_async(fn):
220
226
 
221
- if fn_is_async(decorated):
222
-
223
- @wraps(decorated)
227
+ @wraps(fn)
224
228
  async def inner_async(
225
229
  *args: _P.args, **kwargs: _P.kwargs
226
- ) -> CallResponse | Stream:
230
+ ) -> (
231
+ CallResponse
232
+ | Stream
233
+ | _ResponseModelT
234
+ | _ParsedOutputT
235
+ | (_ResponseModelT | CallResponse)
236
+ ):
237
+ # Apply any context overrides to the original call args
238
+ effective_call_args = apply_context_overrides_to_call_args(
239
+ original_call_args
240
+ )
241
+
242
+ # Get the appropriate provider call function with the possibly overridden provider
243
+ effective_provider = effective_call_args["provider"]
244
+ effective_client = effective_call_args["client"]
245
+
246
+ if effective_provider in get_args(LocalProvider):
247
+ provider_call, effective_client = _get_local_provider_call(
248
+ cast(LocalProvider, effective_provider), effective_client
249
+ )
250
+ effective_call_args["client"] = effective_client
251
+ else:
252
+ provider_call = _get_provider_call(
253
+ cast(Provider, effective_provider)
254
+ )
255
+
256
+ # Use the provider-specific call function with overridden args
257
+ call_kwargs = dict(effective_call_args)
258
+ del call_kwargs[
259
+ "provider"
260
+ ] # Remove provider as it's not a parameter to provider_call
261
+
262
+ # Get decorated function using provider_call
263
+ decorated = provider_call(**call_kwargs)(fn)
264
+
265
+ # Call the decorated function and wrap the result
227
266
  result = await decorated(*args, **kwargs)
228
267
  return _wrap_result(result)
229
268
 
230
- inner_async._original_args = _original_args # pyright: ignore [reportAttributeAccessIssue]
231
- inner_async._original_provider_call = provider_call # pyright: ignore [reportAttributeAccessIssue]
269
+ inner_async._original_call_args = original_call_args # pyright: ignore [reportAttributeAccessIssue]
232
270
  inner_async._original_fn = fn # pyright: ignore [reportAttributeAccessIssue]
233
- inner_async._original_provider = provider # pyright: ignore [reportAttributeAccessIssue]
234
271
 
235
272
  return inner_async
236
273
  else:
237
274
 
238
- @wraps(decorated)
239
- def inner(*args: _P.args, **kwargs: _P.kwargs) -> CallResponse | Stream:
275
+ @wraps(fn)
276
+ def inner(
277
+ *args: _P.args, **kwargs: _P.kwargs
278
+ ) -> (
279
+ CallResponse
280
+ | Stream
281
+ | _ResponseModelT
282
+ | _ParsedOutputT
283
+ | (_ResponseModelT | CallResponse)
284
+ ):
285
+ # Apply any context overrides to the original call args
286
+ effective_call_args = apply_context_overrides_to_call_args(
287
+ original_call_args
288
+ )
289
+
290
+ # Get the appropriate provider call function with the possibly overridden provider
291
+ effective_provider = effective_call_args["provider"]
292
+ effective_client = effective_call_args["client"]
293
+
294
+ if effective_provider in get_args(LocalProvider):
295
+ provider_call, effective_client = _get_local_provider_call(
296
+ cast(LocalProvider, effective_provider), effective_client
297
+ )
298
+ effective_call_args["client"] = effective_client
299
+ else:
300
+ provider_call = _get_provider_call(
301
+ cast(Provider, effective_provider)
302
+ )
303
+
304
+ # Use the provider-specific call function with overridden args
305
+ call_kwargs = dict(effective_call_args)
306
+ del call_kwargs[
307
+ "provider"
308
+ ] # Remove provider as it's not a parameter to provider_call
309
+
310
+ # Get decorated function using provider_call
311
+ decorated = provider_call(**call_kwargs)(fn)
312
+
313
+ # Call the decorated function and wrap the result
240
314
  result = decorated(*args, **kwargs)
241
315
  return _wrap_result(result)
242
316
 
243
- inner._original_args = _original_args # pyright: ignore [reportAttributeAccessIssue]
244
- inner._original_provider_call = provider_call # pyright: ignore [reportAttributeAccessIssue]
317
+ inner._original_call_args = original_call_args # pyright: ignore [reportAttributeAccessIssue]
245
318
  inner._original_fn = fn # pyright: ignore [reportAttributeAccessIssue]
246
- inner._original_provider = provider # pyright: ignore [reportAttributeAccessIssue]
319
+
247
320
  return inner
248
321
 
249
322
  return wrapper # pyright: ignore [reportReturnType]