mirascope 1.20.1__py3-none-any.whl → 1.21.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -20,9 +20,7 @@ from ..base.types import CostMetadata
20
20
  from ._utils._convert_finish_reason_to_common_finish_reasons import (
21
21
  _convert_finish_reasons_to_common_finish_reasons,
22
22
  )
23
- from ._utils._message_param_converter import (
24
- AnthropicMessageParamConverter,
25
- )
23
+ from ._utils._message_param_converter import AnthropicMessageParamConverter
26
24
  from .call_params import AnthropicCallParams
27
25
  from .dynamic_config import AnthropicDynamicConfig, AsyncAnthropicDynamicConfig
28
26
  from .tool import AnthropicTool
@@ -37,6 +35,7 @@ class AnthropicCallResponse(
37
35
  MessageParam,
38
36
  AnthropicCallParams,
39
37
  MessageParam,
38
+ AnthropicMessageParamConverter,
40
39
  ]
41
40
  ):
42
41
  """A convenience wrapper around the Anthropic `Message` response.
@@ -62,6 +61,10 @@ class AnthropicCallResponse(
62
61
  ```
63
62
  """
64
63
 
64
+ _message_converter: type[AnthropicMessageParamConverter] = (
65
+ AnthropicMessageParamConverter
66
+ )
67
+
65
68
  _provider = "anthropic"
66
69
 
67
70
  @computed_field
@@ -38,6 +38,7 @@ class AzureCallResponse(
38
38
  ChatRequestMessage,
39
39
  AzureCallParams,
40
40
  UserMessage,
41
+ AzureMessageParamConverter,
41
42
  ]
42
43
  ):
43
44
  """A convenience wrapper around the Azure `ChatCompletion` response.
@@ -64,6 +65,7 @@ class AzureCallResponse(
64
65
  """
65
66
 
66
67
  response: SkipValidation[ChatCompletions]
68
+ _message_converter: type[AzureMessageParamConverter] = AzureMessageParamConverter
67
69
 
68
70
  _provider = "azure"
69
71
 
@@ -187,7 +187,7 @@ def create_factory( # noqa: ANN202
187
187
  start_time = datetime.datetime.now().timestamp() * 1000
188
188
  response = await create(stream=False, **call_kwargs)
189
189
  end_time = datetime.datetime.now().timestamp() * 1000
190
- output = TCallResponse(
190
+ output = TCallResponse( # pyright: ignore [reportCallIssue]
191
191
  metadata=get_metadata(fn, dynamic_config),
192
192
  response=response,
193
193
  tool_types=tool_types, # pyright: ignore [reportArgumentType]
@@ -231,7 +231,7 @@ def create_factory( # noqa: ANN202
231
231
  start_time = datetime.datetime.now().timestamp() * 1000
232
232
  response = create(stream=False, **call_kwargs)
233
233
  end_time = datetime.datetime.now().timestamp() * 1000
234
- output = TCallResponse(
234
+ output = TCallResponse( # pyright: ignore [reportCallIssue]
235
235
  metadata=get_metadata(fn, dynamic_config),
236
236
  response=response,
237
237
  tool_types=tool_types, # pyright: ignore [reportArgumentType]
@@ -1,5 +1,6 @@
1
1
  """Internal Utilities."""
2
2
 
3
+ from ._base_message_param_converter import BaseMessageParamConverter
3
4
  from ._base_type import BaseType, is_base_type
4
5
  from ._convert_base_model_to_base_tool import convert_base_model_to_base_tool
5
6
  from ._convert_base_type_to_base_tool import convert_base_type_to_base_tool
@@ -46,8 +47,10 @@ from ._setup_extract_tool import setup_extract_tool
46
47
  __all__ = [
47
48
  "DEFAULT_TOOL_DOCSTRING",
48
49
  "AsyncCreateFn",
50
+ "BaseMessageParamConverter",
49
51
  "BaseType",
50
52
  "CalculateCost",
53
+ "CallDecorator",
51
54
  "CreateFn",
52
55
  "GetJsonOutput",
53
56
  "HandleStream",
@@ -3,7 +3,7 @@
3
3
  from abc import ABC, abstractmethod
4
4
  from typing import Any
5
5
 
6
- from mirascope.core import BaseMessageParam
6
+ from mirascope.core.base.message_param import BaseMessageParam
7
7
 
8
8
 
9
9
  class BaseMessageParamConverter(ABC):
@@ -19,7 +19,7 @@ from pydantic import (
19
19
  )
20
20
 
21
21
  from ..costs import calculate_cost
22
- from ._utils import BaseType, get_common_usage
22
+ from ._utils import BaseMessageParamConverter, BaseType, get_common_usage
23
23
  from .call_kwargs import BaseCallKwargs
24
24
  from .call_params import BaseCallParams
25
25
  from .dynamic_config import BaseDynamicConfig
@@ -40,6 +40,9 @@ _ToolMessageParamT = TypeVar("_ToolMessageParamT", bound=Any)
40
40
  _CallParamsT = TypeVar("_CallParamsT", bound=BaseCallParams)
41
41
  _UserMessageParamT = TypeVar("_UserMessageParamT", bound=Any)
42
42
  _BaseCallResponseT = TypeVar("_BaseCallResponseT", bound="BaseCallResponse")
43
+ _BaseMessageParamConverterT = TypeVar(
44
+ "_BaseMessageParamConverterT", bound=BaseMessageParamConverter
45
+ )
43
46
 
44
47
 
45
48
  def transform_tool_outputs(
@@ -97,6 +100,7 @@ class BaseCallResponse(
97
100
  _MessageParamT,
98
101
  _CallParamsT,
99
102
  _UserMessageParamT,
103
+ _BaseMessageParamConverterT,
100
104
  ],
101
105
  ABC,
102
106
  ):
@@ -131,6 +135,7 @@ class BaseCallResponse(
131
135
  start_time: float
132
136
  end_time: float
133
137
 
138
+ _message_converter: type[_BaseMessageParamConverterT]
134
139
  _provider: ClassVar[str] = "NO PROVIDER"
135
140
  _model: str = "NO MODEL"
136
141
 
@@ -313,6 +318,11 @@ class BaseCallResponse(
313
318
  """Provider-agnostic user message param."""
314
319
  ...
315
320
 
321
+ @property
322
+ def common_messages(self) -> list[BaseMessageParam]:
323
+ """Provider-agnostic list of messages."""
324
+ return self._message_converter.from_provider(self.messages)
325
+
316
326
  @property
317
327
  def common_tools(self) -> list[Tool] | None:
318
328
  """Provider-agnostic tools."""
@@ -51,6 +51,7 @@ class BedrockCallResponse(
51
51
  InternalBedrockMessageParam,
52
52
  BedrockCallParams,
53
53
  UserMessageTypeDef,
54
+ BedrockMessageParamConverter,
54
55
  ]
55
56
  ):
56
57
  """A convenience wrapper around the Bedrock `ChatCompletion` response.
@@ -78,6 +79,9 @@ class BedrockCallResponse(
78
79
  """
79
80
 
80
81
  response: SkipValidation[SyncConverseResponseTypeDef | AsyncConverseResponseTypeDef]
82
+ _message_converter: type[BedrockMessageParamConverter] = (
83
+ BedrockMessageParamConverter
84
+ )
81
85
 
82
86
  _provider = "bedrock"
83
87
 
@@ -35,6 +35,7 @@ class CohereCallResponse(
35
35
  SkipValidation[ChatMessage],
36
36
  CohereCallParams,
37
37
  SkipValidation[ChatMessage],
38
+ CohereMessageParamConverter,
38
39
  ]
39
40
  ):
40
41
  """A convenience wrapper around the Cohere `ChatCompletion` response.
@@ -60,6 +61,8 @@ class CohereCallResponse(
60
61
  ```
61
62
  """
62
63
 
64
+ _message_converter: type[CohereMessageParamConverter] = CohereMessageParamConverter
65
+
63
66
  _provider = "cohere"
64
67
 
65
68
  @computed_field
@@ -36,6 +36,7 @@ class GeminiCallResponse(
36
36
  ContentsType,
37
37
  GeminiCallParams,
38
38
  ContentDict,
39
+ GeminiMessageParamConverter,
39
40
  ]
40
41
  ):
41
42
  """A convenience wrapper around the Gemini API response.
@@ -62,6 +63,8 @@ class GeminiCallResponse(
62
63
  ```
63
64
  """
64
65
 
66
+ _message_converter: type[GeminiMessageParamConverter] = GeminiMessageParamConverter
67
+
65
68
  _provider = "gemini"
66
69
 
67
70
  @computed_field
@@ -40,6 +40,7 @@ class GoogleCallResponse(
40
40
  ContentListUnion | ContentListUnionDict,
41
41
  GoogleCallParams,
42
42
  ContentDict,
43
+ GoogleMessageParamConverter,
43
44
  ]
44
45
  ):
45
46
  """A convenience wrapper around the Google API response.
@@ -66,6 +67,8 @@ class GoogleCallResponse(
66
67
  ```
67
68
  """
68
69
 
70
+ _message_converter: type[GoogleMessageParamConverter] = GoogleMessageParamConverter
71
+
69
72
  _provider = "google"
70
73
 
71
74
  @computed_field
@@ -70,17 +70,48 @@ class GoogleTool(BaseTool):
70
70
  fn["parameters"] = model_schema
71
71
 
72
72
  if "parameters" in fn:
73
+ # Resolve $defs and $ref
73
74
  if "$defs" in fn["parameters"]:
74
- raise ValueError(
75
- "Unfortunately Google's Google API cannot handle nested structures "
76
- "with $defs."
77
- )
75
+ defs = fn["parameters"].pop("$defs")
76
+
77
+ def resolve_refs(schema: dict[str, Any]) -> dict[str, Any]:
78
+ """Recursively resolve $ref references using the $defs dictionary."""
79
+ # If this is a reference, resolve it
80
+ if "$ref" in schema:
81
+ ref = schema["$ref"]
82
+ if ref.startswith("#/$defs/"):
83
+ ref_key = ref.replace("#/$defs/", "")
84
+ if ref_key in defs:
85
+ # Merge the definition with the current schema (excluding $ref)
86
+ resolved = {
87
+ **{k: v for k, v in schema.items() if k != "$ref"},
88
+ **resolve_refs(defs[ref_key]),
89
+ }
90
+ return resolved
91
+
92
+ # Process all other keys recursively
93
+ result = {}
94
+ for key, value in schema.items():
95
+ if isinstance(value, dict):
96
+ result[key] = resolve_refs(value)
97
+ elif isinstance(value, list):
98
+ result[key] = [
99
+ resolve_refs(item) if isinstance(item, dict) else item
100
+ for item in value
101
+ ]
102
+ else:
103
+ result[key] = value
104
+ return result
105
+
106
+ # Resolve all references in the parameters
107
+ fn["parameters"] = resolve_refs(fn["parameters"])
78
108
 
79
109
  def handle_enum_schema(prop_schema: dict[str, Any]) -> dict[str, Any]:
80
110
  if "enum" in prop_schema:
81
111
  prop_schema["format"] = "enum"
82
112
  return prop_schema
83
113
 
114
+ # Process properties after resolving references
84
115
  fn["parameters"]["properties"] = {
85
116
  prop: {
86
117
  key: value
@@ -35,6 +35,7 @@ class GroqCallResponse(
35
35
  ChatCompletionMessageParam,
36
36
  GroqCallParams,
37
37
  ChatCompletionUserMessageParam,
38
+ GroqMessageParamConverter,
38
39
  ]
39
40
  ):
40
41
  """A convenience wrapper around the Groq `ChatCompletion` response.
@@ -59,6 +60,8 @@ class GroqCallResponse(
59
60
  ```
60
61
  """
61
62
 
63
+ _message_converter: type[GroqMessageParamConverter] = GroqMessageParamConverter
64
+
62
65
  _provider = "groq"
63
66
 
64
67
  @computed_field
@@ -38,6 +38,7 @@ class MistralCallResponse(
38
38
  AssistantMessage | SystemMessage | ToolMessage | UserMessage,
39
39
  MistralCallParams,
40
40
  UserMessage,
41
+ MistralMessageParamConverter,
41
42
  ]
42
43
  ):
43
44
  """A convenience wrapper around the Mistral `ChatCompletion` response.
@@ -62,6 +63,10 @@ class MistralCallResponse(
62
63
  ```
63
64
  """
64
65
 
66
+ _message_converter: type[MistralMessageParamConverter] = (
67
+ MistralMessageParamConverter
68
+ )
69
+
65
70
  _provider = "mistral"
66
71
 
67
72
  @property
@@ -58,6 +58,7 @@ class OpenAICallResponse(
58
58
  ChatCompletionMessageParam,
59
59
  OpenAICallParams,
60
60
  ChatCompletionUserMessageParam,
61
+ OpenAIMessageParamConverter,
61
62
  ]
62
63
  ):
63
64
  """A convenience wrapper around the OpenAI `ChatCompletion` response.
@@ -84,6 +85,7 @@ class OpenAICallResponse(
84
85
  """
85
86
 
86
87
  response: SkipValidation[ChatCompletion]
88
+ _message_converter: type[OpenAIMessageParamConverter] = OpenAIMessageParamConverter
87
89
 
88
90
  _provider = "openai"
89
91
 
@@ -30,6 +30,7 @@ class VertexCallResponse(
30
30
  Content,
31
31
  VertexCallParams,
32
32
  Content,
33
+ VertexMessageParamConverter,
33
34
  ]
34
35
  ):
35
36
  """A convenience wrapper around the Vertex AI `GenerateContentResponse`.
@@ -55,6 +56,8 @@ class VertexCallResponse(
55
56
  ```
56
57
  """
57
58
 
59
+ _message_converter: type[VertexMessageParamConverter] = VertexMessageParamConverter
60
+
58
61
  _provider = "vertex"
59
62
 
60
63
  @computed_field
mirascope/llm/__init__.py CHANGED
@@ -1,14 +1,18 @@
1
1
  from ..core import CostMetadata, LocalProvider, Provider, calculate_cost
2
+ from ._call import call
3
+ from ._context import context
4
+ from ._override import override
2
5
  from .call_response import CallResponse
3
- from .llm_call import call
4
- from .llm_override import override
6
+ from .stream import Stream
5
7
 
6
8
  __all__ = [
7
9
  "CallResponse",
8
10
  "CostMetadata",
9
11
  "LocalProvider",
10
12
  "Provider",
13
+ "Stream",
11
14
  "calculate_cost",
12
15
  "call",
16
+ "context",
13
17
  "override",
14
18
  ]
@@ -20,6 +20,7 @@ from ..core.base import (
20
20
  from ..core.base._utils import fn_is_async
21
21
  from ..core.base.stream_config import StreamConfig
22
22
  from ..core.base.types import LocalProvider, Provider
23
+ from ._context import CallArgs, apply_context_overrides_to_call_args
23
24
  from ._protocols import (
24
25
  AsyncLLMFunctionDecorator,
25
26
  CallDecorator,
@@ -49,24 +50,36 @@ _ResultT = TypeVar("_ResultT")
49
50
  def _get_local_provider_call(
50
51
  provider: LocalProvider,
51
52
  client: Any | None, # noqa: ANN401
53
+ is_async: bool,
52
54
  ) -> tuple[Callable, Any | None]:
53
55
  if provider == "ollama":
54
56
  from ..core.openai import openai_call
55
57
 
56
58
  if client:
57
59
  return openai_call, client
58
- from openai import OpenAI
60
+ if is_async:
61
+ from openai import AsyncOpenAI
59
62
 
60
- client = OpenAI(api_key="ollama", base_url="http://localhost:11434/v1")
63
+ client = AsyncOpenAI(api_key="ollama", base_url="http://localhost:11434/v1")
64
+ else:
65
+ from openai import OpenAI
66
+
67
+ client = OpenAI(api_key="ollama", base_url="http://localhost:11434/v1")
61
68
  return openai_call, client
62
69
  else: # provider == "vllm"
63
70
  from ..core.openai import openai_call
64
71
 
65
72
  if client:
66
73
  return openai_call, client
67
- from openai import OpenAI
68
74
 
69
- client = OpenAI(api_key="ollama", base_url="http://localhost:8000/v1")
75
+ if is_async:
76
+ from openai import AsyncOpenAI
77
+
78
+ client = AsyncOpenAI(api_key="ollama", base_url="http://localhost:8000/v1")
79
+ else:
80
+ from openai import OpenAI
81
+
82
+ client = OpenAI(api_key="ollama", base_url="http://localhost:8000/v1")
70
83
  return openai_call, client
71
84
 
72
85
 
@@ -193,13 +206,9 @@ def _call(
193
206
  ]
194
207
  ):
195
208
  """Decorator for defining a function that calls a language model."""
196
- if provider in get_args(LocalProvider):
197
- provider_call, client = _get_local_provider_call(
198
- cast(LocalProvider, provider), client
199
- )
200
- else:
201
- provider_call = _get_provider_call(cast(Provider, provider))
202
- _original_args = {
209
+ # Store original call args that will be used for each function call
210
+ original_call_args: CallArgs = {
211
+ "provider": provider,
203
212
  "model": model,
204
213
  "stream": stream,
205
214
  "tools": tools,
@@ -214,36 +223,116 @@ def _call(
214
223
  fn: Callable[_P, _R | Awaitable[_R]],
215
224
  ) -> Callable[
216
225
  _P,
217
- CallResponse | Stream | Awaitable[CallResponse | Stream],
226
+ CallResponse
227
+ | Stream
228
+ | _ResponseModelT
229
+ | _ParsedOutputT
230
+ | (_ResponseModelT | CallResponse)
231
+ | Awaitable[CallResponse]
232
+ | Awaitable[Stream]
233
+ | Awaitable[_ResponseModelT]
234
+ | Awaitable[_ParsedOutputT]
235
+ | Awaitable[(_ResponseModelT | CallResponse)],
218
236
  ]:
219
- decorated = provider_call(**_original_args)(fn)
220
-
221
- if fn_is_async(decorated):
237
+ if fn_is_async(fn):
222
238
 
223
- @wraps(decorated)
239
+ @wraps(fn)
224
240
  async def inner_async(
225
241
  *args: _P.args, **kwargs: _P.kwargs
226
- ) -> CallResponse | Stream:
242
+ ) -> (
243
+ CallResponse
244
+ | Stream
245
+ | _ResponseModelT
246
+ | _ParsedOutputT
247
+ | (_ResponseModelT | CallResponse)
248
+ ):
249
+ # Apply any context overrides to the original call args
250
+ effective_call_args = apply_context_overrides_to_call_args(
251
+ original_call_args
252
+ )
253
+
254
+ # Get the appropriate provider call function with the possibly overridden provider
255
+ effective_provider = effective_call_args["provider"]
256
+ effective_client = effective_call_args["client"]
257
+
258
+ if effective_provider in get_args(LocalProvider):
259
+ provider_call, effective_client = _get_local_provider_call(
260
+ cast(LocalProvider, effective_provider),
261
+ effective_client,
262
+ True,
263
+ )
264
+ effective_call_args["client"] = effective_client
265
+ else:
266
+ provider_call = _get_provider_call(
267
+ cast(Provider, effective_provider)
268
+ )
269
+
270
+ # Use the provider-specific call function with overridden args
271
+ call_kwargs = dict(effective_call_args)
272
+ del call_kwargs[
273
+ "provider"
274
+ ] # Remove provider as it's not a parameter to provider_call
275
+
276
+ # Get decorated function using provider_call
277
+ decorated = provider_call(**call_kwargs)(fn)
278
+
279
+ # Call the decorated function and wrap the result
227
280
  result = await decorated(*args, **kwargs)
228
281
  return _wrap_result(result)
229
282
 
230
- inner_async._original_args = _original_args # pyright: ignore [reportAttributeAccessIssue]
231
- inner_async._original_provider_call = provider_call # pyright: ignore [reportAttributeAccessIssue]
283
+ inner_async._original_call_args = original_call_args # pyright: ignore [reportAttributeAccessIssue]
232
284
  inner_async._original_fn = fn # pyright: ignore [reportAttributeAccessIssue]
233
- inner_async._original_provider = provider # pyright: ignore [reportAttributeAccessIssue]
234
285
 
235
286
  return inner_async
236
287
  else:
237
288
 
238
- @wraps(decorated)
239
- def inner(*args: _P.args, **kwargs: _P.kwargs) -> CallResponse | Stream:
289
+ @wraps(fn)
290
+ def inner(
291
+ *args: _P.args, **kwargs: _P.kwargs
292
+ ) -> (
293
+ CallResponse
294
+ | Stream
295
+ | _ResponseModelT
296
+ | _ParsedOutputT
297
+ | (_ResponseModelT | CallResponse)
298
+ ):
299
+ # Apply any context overrides to the original call args
300
+ effective_call_args = apply_context_overrides_to_call_args(
301
+ original_call_args
302
+ )
303
+
304
+ # Get the appropriate provider call function with the possibly overridden provider
305
+ effective_provider = effective_call_args["provider"]
306
+ effective_client = effective_call_args["client"]
307
+
308
+ if effective_provider in get_args(LocalProvider):
309
+ provider_call, effective_client = _get_local_provider_call(
310
+ cast(LocalProvider, effective_provider),
311
+ effective_client,
312
+ False,
313
+ )
314
+ effective_call_args["client"] = effective_client
315
+ else:
316
+ provider_call = _get_provider_call(
317
+ cast(Provider, effective_provider)
318
+ )
319
+
320
+ # Use the provider-specific call function with overridden args
321
+ call_kwargs = dict(effective_call_args)
322
+ del call_kwargs[
323
+ "provider"
324
+ ] # Remove provider as it's not a parameter to provider_call
325
+
326
+ # Get decorated function using provider_call
327
+ decorated = provider_call(**call_kwargs)(fn)
328
+
329
+ # Call the decorated function and wrap the result
240
330
  result = decorated(*args, **kwargs)
241
331
  return _wrap_result(result)
242
332
 
243
- inner._original_args = _original_args # pyright: ignore [reportAttributeAccessIssue]
244
- inner._original_provider_call = provider_call # pyright: ignore [reportAttributeAccessIssue]
333
+ inner._original_call_args = original_call_args # pyright: ignore [reportAttributeAccessIssue]
245
334
  inner._original_fn = fn # pyright: ignore [reportAttributeAccessIssue]
246
- inner._original_provider = provider # pyright: ignore [reportAttributeAccessIssue]
335
+
247
336
  return inner
248
337
 
249
338
  return wrapper # pyright: ignore [reportReturnType]