mirascope 1.20.0__py3-none-any.whl → 1.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,381 @@
1
+ """This module contains the context managers for LLM API calls."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import threading
6
+ from collections.abc import Callable
7
+ from dataclasses import dataclass
8
+ from enum import Enum
9
+ from types import TracebackType
10
+ from typing import TYPE_CHECKING, Any, Literal, TypeVar, cast, overload
11
+
12
+ from pydantic import BaseModel
13
+ from typing_extensions import TypedDict
14
+
15
+ from ..core.base import BaseTool, BaseType, CommonCallParams
16
+ from ..core.base.stream_config import StreamConfig
17
+ from ..core.base.types import LocalProvider, Provider
18
+
19
+ _ResponseModelT = TypeVar("_ResponseModelT", bound=BaseModel | BaseType | Enum)
20
+
21
+ if TYPE_CHECKING:
22
+ from ..core.anthropic import AnthropicCallParams
23
+ from ..core.azure import AzureCallParams
24
+ from ..core.bedrock import BedrockCallParams
25
+ from ..core.cohere import CohereCallParams
26
+ from ..core.gemini import GeminiCallParams
27
+ from ..core.google import GoogleCallParams
28
+ from ..core.groq import GroqCallParams
29
+ from ..core.litellm import LiteLLMCallParams
30
+ from ..core.mistral import MistralCallParams
31
+ from ..core.openai import OpenAICallParams
32
+ from ..core.vertex import VertexCallParams
33
+ from ..core.xai import XAICallParams
34
+ else:
35
+ AnthropicCallParams = AzureCallParams = BedrockCallParams = CohereCallParams = (
36
+ GeminiCallParams
37
+ ) = GoogleCallParams = GroqCallParams = LiteLLMCallParams = MistralCallParams = (
38
+ OpenAICallParams
39
+ ) = VertexCallParams = XAICallParams = None
40
+
41
+
42
+ class CallArgs(TypedDict):
43
+ """TypedDict for call arguments."""
44
+
45
+ provider: Provider | LocalProvider
46
+ model: str
47
+ stream: bool | StreamConfig
48
+ tools: list[type[BaseTool] | Callable] | None
49
+ response_model: type[BaseModel] | type[BaseType] | type[Enum] | None
50
+ output_parser: Callable | None
51
+ json_mode: bool
52
+ client: Any | None
53
+ call_params: CommonCallParams | Any | None
54
+
55
+
56
+ # We use a thread-local variable to store the current context, so that it's thread-safe
57
+ _current_context_local = threading.local()
58
+
59
+
60
+ @dataclass
61
+ class LLMContext:
62
+ """Context for LLM API calls.
63
+
64
+ This class is used to store the context for LLM API calls, including both
65
+ setting overrides (provider, model, client, call_params) and
66
+ structural overrides (stream, tools, response_model, etc.).
67
+ """
68
+
69
+ provider: Provider | LocalProvider | None = None
70
+ model: str | None = None
71
+ stream: bool | StreamConfig | None = None
72
+ tools: list[type[BaseTool] | Callable] | None = None
73
+ response_model: type[BaseModel] | type[BaseType] | type[Enum] | None = None
74
+ output_parser: Callable | None = None
75
+ json_mode: bool | None = None
76
+ client: Any | None = None
77
+ call_params: CommonCallParams | Any | None = None
78
+
79
+ def __enter__(self) -> LLMContext:
80
+ _current_context_local.context = self
81
+ return self
82
+
83
+ def __exit__(
84
+ self,
85
+ exc_type: type[BaseException] | None,
86
+ exc_val: BaseException | None,
87
+ exc_tb: TracebackType | None,
88
+ ) -> Literal[False]:
89
+ if hasattr(_current_context_local, "context"):
90
+ del _current_context_local.context
91
+ return False # Don't suppress exceptions
92
+
93
+
94
+ def get_current_context() -> LLMContext | None:
95
+ """Get the current context for LLM API calls.
96
+
97
+ Returns:
98
+ The current context, or None if there is no context.
99
+ """
100
+ if hasattr(_current_context_local, "context"):
101
+ return cast(LLMContext, _current_context_local.context)
102
+ return None
103
+
104
+
105
+ def _context(
106
+ *,
107
+ provider: Provider | LocalProvider | None,
108
+ model: str | None,
109
+ stream: bool | StreamConfig | None = None,
110
+ tools: list[type[BaseTool] | Callable] | None = None,
111
+ response_model: type[_ResponseModelT] | None = None,
112
+ output_parser: Callable | None = None,
113
+ json_mode: bool | None = None,
114
+ client: Any | None = None, # noqa: ANN401
115
+ call_params: CommonCallParams | Any | None = None, # noqa: ANN401
116
+ ) -> LLMContext:
117
+ """Context manager for synchronous LLM API calls.
118
+
119
+ This is an internal method that allows both setting and structural overrides
120
+ for synchronous functions.
121
+
122
+ Unfortunately we have not yet identified a way to properly type hint this because
123
+ providing no structural overrides means the return type is that of the original
124
+ function. Of course, the `apply` method could pass through the return type, but
125
+ we do not have a way to know whether it should be passthrough or not.
126
+
127
+ For now, we use `_context` simply to implement `override` fully. The public facing
128
+ `context` method only allows setting overrides.
129
+
130
+ Args:
131
+ provider: The provider to use for the LLM API call.
132
+ model: The model to use for the LLM API call.
133
+ stream: Whether to stream the response.
134
+ tools: The tools to use for the LLM API call.
135
+ response_model: The response model for the LLM API call.
136
+ output_parser: The output parser for the LLM API call.
137
+ json_mode: Whether to use JSON mode.
138
+ client: The client to use for the LLM API call.
139
+ call_params: The call parameters for the LLM API call.
140
+
141
+ Yields:
142
+ The context object that can be used to apply the context to a function.
143
+ """
144
+ old_context: LLMContext | None = getattr(_current_context_local, "context", None)
145
+ if not old_context:
146
+ return LLMContext(
147
+ provider=provider,
148
+ model=model,
149
+ stream=stream,
150
+ tools=tools,
151
+ response_model=response_model,
152
+ output_parser=output_parser,
153
+ json_mode=json_mode,
154
+ client=client,
155
+ call_params=call_params,
156
+ )
157
+ else:
158
+ # Ensure we properly set nested context settings. For example, we need to make
159
+ # sure that calling override on an overridden function applies the context to
160
+ # the overridden function's already overridden settings.
161
+ return LLMContext(
162
+ provider=provider or old_context.provider,
163
+ model=model or old_context.model,
164
+ stream=stream or old_context.stream,
165
+ tools=tools or old_context.tools,
166
+ response_model=response_model or old_context.response_model,
167
+ output_parser=output_parser or old_context.output_parser,
168
+ json_mode=json_mode or old_context.json_mode,
169
+ client=client or old_context.client,
170
+ call_params=call_params or old_context.call_params,
171
+ )
172
+
173
+
174
+ def apply_context_overrides_to_call_args(call_args: CallArgs) -> CallArgs:
175
+ """Apply any active context overrides to the call arguments.
176
+
177
+ Args:
178
+ call_args: The original call arguments.
179
+
180
+ Returns:
181
+ The call arguments with any context overrides applied.
182
+ """
183
+ context = get_current_context()
184
+ if not context:
185
+ return call_args
186
+
187
+ # Create a new dict with the original args
188
+ overridden_args = CallArgs(call_args)
189
+
190
+ # If any structural overrides are set, we have to force all others to take their
191
+ # default values so the type hints match.
192
+ if context.stream or context.response_model or context.output_parser:
193
+ overridden_args["stream"] = False
194
+ overridden_args["response_model"] = None
195
+ overridden_args["output_parser"] = None
196
+ if context.response_model:
197
+ overridden_args["tools"] = None
198
+
199
+ # Apply context overrides
200
+ if context.provider is not None:
201
+ overridden_args["provider"] = context.provider
202
+ if context.model is not None:
203
+ overridden_args["model"] = context.model
204
+ if context.stream is not None:
205
+ overridden_args["stream"] = context.stream
206
+ if context.tools is not None:
207
+ overridden_args["tools"] = context.tools
208
+ if context.response_model is not None:
209
+ overridden_args["response_model"] = context.response_model
210
+ if context.output_parser is not None:
211
+ overridden_args["output_parser"] = context.output_parser
212
+ if context.json_mode is not None:
213
+ overridden_args["json_mode"] = context.json_mode
214
+ if context.client is not None:
215
+ overridden_args["client"] = context.client
216
+ if context.call_params is not None:
217
+ overridden_args["call_params"] = context.call_params
218
+
219
+ return overridden_args
220
+
221
+
222
+ @overload
223
+ def context(
224
+ *,
225
+ provider: Literal["anthropic"],
226
+ model: str,
227
+ client: Any = None, # noqa: ANN401
228
+ call_params: CommonCallParams | AnthropicCallParams | None = None, # noqa: ANN401
229
+ ) -> LLMContext: ...
230
+
231
+
232
+ @overload
233
+ def context(
234
+ *,
235
+ provider: Literal["azure"],
236
+ model: str,
237
+ client: Any = None, # noqa: ANN401
238
+ call_params: CommonCallParams | AzureCallParams | None = None, # noqa: ANN401
239
+ ) -> LLMContext: ...
240
+
241
+
242
+ @overload
243
+ def context(
244
+ *,
245
+ provider: Literal["bedrock"],
246
+ model: str,
247
+ client: Any = None, # noqa: ANN401
248
+ call_params: CommonCallParams | BedrockCallParams | None = None, # noqa: ANN401
249
+ ) -> LLMContext: ...
250
+
251
+
252
+ @overload
253
+ def context(
254
+ *,
255
+ provider: Literal["cohere"],
256
+ model: str,
257
+ client: Any = None, # noqa: ANN401
258
+ call_params: CommonCallParams | CohereCallParams | None = None, # noqa: ANN401
259
+ ) -> LLMContext: ...
260
+
261
+
262
+ @overload
263
+ def context(
264
+ *,
265
+ provider: Literal["gemini"],
266
+ model: str,
267
+ client: Any = None, # noqa: ANN401
268
+ call_params: CommonCallParams | GeminiCallParams | None = None, # noqa: ANN401
269
+ ) -> LLMContext: ...
270
+
271
+
272
+ @overload
273
+ def context(
274
+ *,
275
+ provider: Literal["google"],
276
+ model: str,
277
+ client: Any = None, # noqa: ANN401
278
+ call_params: CommonCallParams | GoogleCallParams | None = None, # noqa: ANN401
279
+ ) -> LLMContext: ...
280
+
281
+
282
+ @overload
283
+ def context(
284
+ *,
285
+ provider: Literal["groq"],
286
+ model: str,
287
+ client: Any = None, # noqa: ANN401
288
+ call_params: CommonCallParams | GroqCallParams | None = None, # noqa: ANN401
289
+ ) -> LLMContext: ...
290
+
291
+
292
+ @overload
293
+ def context(
294
+ *,
295
+ provider: Literal["litellm"],
296
+ model: str,
297
+ client: Any = None, # noqa: ANN401
298
+ call_params: CommonCallParams | LiteLLMCallParams | None = None, # noqa: ANN401
299
+ ) -> LLMContext: ...
300
+
301
+
302
+ @overload
303
+ def context(
304
+ *,
305
+ provider: Literal["mistral"],
306
+ model: str,
307
+ client: Any = None, # noqa: ANN401
308
+ call_params: CommonCallParams | MistralCallParams | None = None, # noqa: ANN401
309
+ ) -> LLMContext: ...
310
+
311
+
312
+ @overload
313
+ def context(
314
+ *,
315
+ provider: Literal["openai"],
316
+ model: str,
317
+ client: Any = None, # noqa: ANN401
318
+ call_params: CommonCallParams | OpenAICallParams | None = None, # noqa: ANN401
319
+ ) -> LLMContext: ...
320
+
321
+
322
+ @overload
323
+ def context(
324
+ *,
325
+ provider: Literal["vertex"],
326
+ model: str,
327
+ client: Any = None, # noqa: ANN401
328
+ call_params: CommonCallParams | VertexCallParams | None = None, # noqa: ANN401
329
+ ) -> LLMContext: ...
330
+
331
+
332
+ @overload
333
+ def context(
334
+ *,
335
+ provider: Literal["xai"],
336
+ model: str,
337
+ client: Any = None, # noqa: ANN401
338
+ call_params: CommonCallParams | XAICallParams | None = None, # noqa: ANN401
339
+ ) -> LLMContext: ...
340
+
341
+
342
+ def context(
343
+ *,
344
+ provider: Provider | LocalProvider,
345
+ model: str,
346
+ client: Any | None = None,
347
+ call_params: CommonCallParams | Any | None = None, # noqa: ANN401
348
+ ) -> LLMContext:
349
+ """Context manager for LLM API calls.
350
+
351
+ This method only allows setting overrides (provider, model, client, call_params)
352
+ and does not allow structural overrides (stream, tools, response_model, etc.).
353
+
354
+ Example:
355
+ ```python
356
+ @llm.call(provider="openai", model="gpt-4o-mini")
357
+ def recommend_book(genre: str) -> str:
358
+ return f"Recommend a {genre} book"
359
+
360
+ # Override the model for a specific call
361
+ with llm.context(provider="anthropic", model="claude-3-5-sonnet-20240620") as ctx:
362
+ response = recommend_book("fantasy") # Uses claude-3-5-sonnet
363
+ ```
364
+
365
+ Args:
366
+ provider: The provider to use for the LLM API call.
367
+ model: The model to use for the LLM API call.
368
+ client: The client to use for the LLM API call.
369
+ call_params: The call parameters for the LLM API call.
370
+
371
+ Yields:
372
+ The context object.
373
+ """
374
+ if (provider and not model) or (model and not provider):
375
+ raise ValueError(
376
+ "Provider and model must both be specified if either is specified."
377
+ )
378
+
379
+ return _context(
380
+ provider=provider, model=model, client=client, call_params=call_params
381
+ )