mirascope 2.0.2__py3-none-any.whl → 2.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. mirascope/_stubs.py +39 -18
  2. mirascope/api/_generated/__init__.py +4 -0
  3. mirascope/api/_generated/project_memberships/__init__.py +4 -0
  4. mirascope/api/_generated/project_memberships/client.py +91 -0
  5. mirascope/api/_generated/project_memberships/raw_client.py +239 -0
  6. mirascope/api/_generated/project_memberships/types/__init__.py +4 -0
  7. mirascope/api/_generated/project_memberships/types/project_memberships_get_response.py +33 -0
  8. mirascope/api/_generated/project_memberships/types/project_memberships_get_response_role.py +7 -0
  9. mirascope/api/_generated/reference.md +72 -0
  10. mirascope/llm/__init__.py +19 -0
  11. mirascope/llm/calls/decorator.py +17 -24
  12. mirascope/llm/formatting/__init__.py +2 -2
  13. mirascope/llm/formatting/format.py +2 -4
  14. mirascope/llm/formatting/types.py +19 -2
  15. mirascope/llm/models/models.py +66 -146
  16. mirascope/llm/prompts/decorator.py +5 -16
  17. mirascope/llm/prompts/prompts.py +5 -13
  18. mirascope/llm/providers/anthropic/_utils/beta_decode.py +22 -7
  19. mirascope/llm/providers/anthropic/_utils/beta_encode.py +22 -16
  20. mirascope/llm/providers/anthropic/_utils/decode.py +45 -7
  21. mirascope/llm/providers/anthropic/_utils/encode.py +28 -15
  22. mirascope/llm/providers/anthropic/beta_provider.py +33 -69
  23. mirascope/llm/providers/anthropic/provider.py +52 -91
  24. mirascope/llm/providers/base/_utils.py +4 -9
  25. mirascope/llm/providers/base/base_provider.py +89 -205
  26. mirascope/llm/providers/google/_utils/decode.py +51 -1
  27. mirascope/llm/providers/google/_utils/encode.py +38 -21
  28. mirascope/llm/providers/google/provider.py +33 -69
  29. mirascope/llm/providers/mirascope/provider.py +25 -61
  30. mirascope/llm/providers/mlx/encoding/base.py +3 -6
  31. mirascope/llm/providers/mlx/encoding/transformers.py +4 -8
  32. mirascope/llm/providers/mlx/mlx.py +9 -21
  33. mirascope/llm/providers/mlx/provider.py +33 -69
  34. mirascope/llm/providers/openai/completions/_utils/encode.py +39 -20
  35. mirascope/llm/providers/openai/completions/base_provider.py +34 -75
  36. mirascope/llm/providers/openai/provider.py +25 -61
  37. mirascope/llm/providers/openai/responses/_utils/decode.py +31 -2
  38. mirascope/llm/providers/openai/responses/_utils/encode.py +32 -17
  39. mirascope/llm/providers/openai/responses/provider.py +34 -75
  40. mirascope/llm/responses/__init__.py +2 -1
  41. mirascope/llm/responses/base_stream_response.py +4 -0
  42. mirascope/llm/responses/response.py +8 -12
  43. mirascope/llm/responses/stream_response.py +8 -12
  44. mirascope/llm/responses/usage.py +44 -0
  45. mirascope/llm/tools/__init__.py +24 -0
  46. mirascope/llm/tools/provider_tools.py +18 -0
  47. mirascope/llm/tools/tool_schema.py +4 -2
  48. mirascope/llm/tools/toolkit.py +24 -6
  49. mirascope/llm/tools/types.py +112 -0
  50. mirascope/llm/tools/web_search_tool.py +32 -0
  51. mirascope/ops/__init__.py +19 -1
  52. mirascope/ops/_internal/instrumentation/__init__.py +20 -0
  53. mirascope/ops/_internal/instrumentation/llm/common.py +19 -49
  54. mirascope/ops/_internal/instrumentation/llm/model.py +61 -82
  55. mirascope/ops/_internal/instrumentation/llm/serialize.py +36 -12
  56. mirascope/ops/_internal/instrumentation/providers/__init__.py +29 -0
  57. mirascope/ops/_internal/instrumentation/providers/anthropic.py +78 -0
  58. mirascope/ops/_internal/instrumentation/providers/base.py +179 -0
  59. mirascope/ops/_internal/instrumentation/providers/google_genai.py +85 -0
  60. mirascope/ops/_internal/instrumentation/providers/openai.py +82 -0
  61. {mirascope-2.0.2.dist-info → mirascope-2.1.0.dist-info}/METADATA +96 -68
  62. {mirascope-2.0.2.dist-info → mirascope-2.1.0.dist-info}/RECORD +64 -54
  63. {mirascope-2.0.2.dist-info → mirascope-2.1.0.dist-info}/WHEEL +0 -0
  64. {mirascope-2.0.2.dist-info → mirascope-2.1.0.dist-info}/licenses/LICENSE +0 -0
@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING
8
8
  from typing_extensions import Unpack
9
9
 
10
10
  from ...context import Context, DepsT
11
- from ...formatting import Format, FormattableT, OutputParser
11
+ from ...formatting import FormatSpec, FormattableT
12
12
  from ...messages import Message
13
13
  from ...responses import (
14
14
  AsyncContextResponse,
@@ -21,13 +21,9 @@ from ...responses import (
21
21
  StreamResponse,
22
22
  )
23
23
  from ...tools import (
24
- AsyncContextTool,
25
24
  AsyncContextToolkit,
26
- AsyncTool,
27
25
  AsyncToolkit,
28
- ContextTool,
29
26
  ContextToolkit,
30
- Tool,
31
27
  Toolkit,
32
28
  )
33
29
  from ..base import BaseProvider, Provider
@@ -156,11 +152,8 @@ class MirascopeProvider(BaseProvider[None]):
156
152
  *,
157
153
  model_id: str,
158
154
  messages: Sequence[Message],
159
- tools: Sequence[Tool] | Toolkit | None = None,
160
- format: type[FormattableT]
161
- | Format[FormattableT]
162
- | OutputParser[FormattableT]
163
- | None = None,
155
+ toolkit: Toolkit,
156
+ format: FormatSpec[FormattableT] | None = None,
164
157
  **params: Unpack[Params],
165
158
  ) -> Response | Response[FormattableT]:
166
159
  """Generate an `llm.Response` by calling through the Mirascope Router."""
@@ -168,7 +161,7 @@ class MirascopeProvider(BaseProvider[None]):
168
161
  return provider.call(
169
162
  model_id=model_id,
170
163
  messages=messages,
171
- tools=tools,
164
+ toolkit=toolkit,
172
165
  format=format,
173
166
  **params,
174
167
  )
@@ -179,13 +172,8 @@ class MirascopeProvider(BaseProvider[None]):
179
172
  ctx: Context[DepsT],
180
173
  model_id: str,
181
174
  messages: Sequence[Message],
182
- tools: Sequence[Tool | ContextTool[DepsT]]
183
- | ContextToolkit[DepsT]
184
- | None = None,
185
- format: type[FormattableT]
186
- | Format[FormattableT]
187
- | OutputParser[FormattableT]
188
- | None = None,
175
+ toolkit: ContextToolkit[DepsT],
176
+ format: FormatSpec[FormattableT] | None = None,
189
177
  **params: Unpack[Params],
190
178
  ) -> ContextResponse[DepsT, None] | ContextResponse[DepsT, FormattableT]:
191
179
  """Generate an `llm.ContextResponse` by calling through the Mirascope Router."""
@@ -194,7 +182,7 @@ class MirascopeProvider(BaseProvider[None]):
194
182
  ctx=ctx,
195
183
  model_id=model_id,
196
184
  messages=messages,
197
- tools=tools,
185
+ toolkit=toolkit,
198
186
  format=format,
199
187
  **params,
200
188
  )
@@ -204,11 +192,8 @@ class MirascopeProvider(BaseProvider[None]):
204
192
  *,
205
193
  model_id: str,
206
194
  messages: Sequence[Message],
207
- tools: Sequence[AsyncTool] | AsyncToolkit | None = None,
208
- format: type[FormattableT]
209
- | Format[FormattableT]
210
- | OutputParser[FormattableT]
211
- | None = None,
195
+ toolkit: AsyncToolkit,
196
+ format: FormatSpec[FormattableT] | None = None,
212
197
  **params: Unpack[Params],
213
198
  ) -> AsyncResponse | AsyncResponse[FormattableT]:
214
199
  """Generate an `llm.AsyncResponse` by calling through the Mirascope Router."""
@@ -216,7 +201,7 @@ class MirascopeProvider(BaseProvider[None]):
216
201
  return await provider.call_async(
217
202
  model_id=model_id,
218
203
  messages=messages,
219
- tools=tools,
204
+ toolkit=toolkit,
220
205
  format=format,
221
206
  **params,
222
207
  )
@@ -227,13 +212,8 @@ class MirascopeProvider(BaseProvider[None]):
227
212
  ctx: Context[DepsT],
228
213
  model_id: str,
229
214
  messages: Sequence[Message],
230
- tools: Sequence[AsyncTool | AsyncContextTool[DepsT]]
231
- | AsyncContextToolkit[DepsT]
232
- | None = None,
233
- format: type[FormattableT]
234
- | Format[FormattableT]
235
- | OutputParser[FormattableT]
236
- | None = None,
215
+ toolkit: AsyncContextToolkit[DepsT],
216
+ format: FormatSpec[FormattableT] | None = None,
237
217
  **params: Unpack[Params],
238
218
  ) -> AsyncContextResponse[DepsT, None] | AsyncContextResponse[DepsT, FormattableT]:
239
219
  """Generate an `llm.AsyncContextResponse` by calling through the Mirascope Router."""
@@ -242,7 +222,7 @@ class MirascopeProvider(BaseProvider[None]):
242
222
  ctx=ctx,
243
223
  model_id=model_id,
244
224
  messages=messages,
245
- tools=tools,
225
+ toolkit=toolkit,
246
226
  format=format,
247
227
  **params,
248
228
  )
@@ -252,11 +232,8 @@ class MirascopeProvider(BaseProvider[None]):
252
232
  *,
253
233
  model_id: str,
254
234
  messages: Sequence[Message],
255
- tools: Sequence[Tool] | Toolkit | None = None,
256
- format: type[FormattableT]
257
- | Format[FormattableT]
258
- | OutputParser[FormattableT]
259
- | None = None,
235
+ toolkit: Toolkit,
236
+ format: FormatSpec[FormattableT] | None = None,
260
237
  **params: Unpack[Params],
261
238
  ) -> StreamResponse | StreamResponse[FormattableT]:
262
239
  """Stream an `llm.StreamResponse` by calling through the Mirascope Router."""
@@ -264,7 +241,7 @@ class MirascopeProvider(BaseProvider[None]):
264
241
  return provider.stream(
265
242
  model_id=model_id,
266
243
  messages=messages,
267
- tools=tools,
244
+ toolkit=toolkit,
268
245
  format=format,
269
246
  **params,
270
247
  )
@@ -275,13 +252,8 @@ class MirascopeProvider(BaseProvider[None]):
275
252
  ctx: Context[DepsT],
276
253
  model_id: str,
277
254
  messages: Sequence[Message],
278
- tools: Sequence[Tool | ContextTool[DepsT]]
279
- | ContextToolkit[DepsT]
280
- | None = None,
281
- format: type[FormattableT]
282
- | Format[FormattableT]
283
- | OutputParser[FormattableT]
284
- | None = None,
255
+ toolkit: ContextToolkit[DepsT],
256
+ format: FormatSpec[FormattableT] | None = None,
285
257
  **params: Unpack[Params],
286
258
  ) -> (
287
259
  ContextStreamResponse[DepsT, None] | ContextStreamResponse[DepsT, FormattableT]
@@ -292,7 +264,7 @@ class MirascopeProvider(BaseProvider[None]):
292
264
  ctx=ctx,
293
265
  model_id=model_id,
294
266
  messages=messages,
295
- tools=tools,
267
+ toolkit=toolkit,
296
268
  format=format,
297
269
  **params,
298
270
  )
@@ -302,11 +274,8 @@ class MirascopeProvider(BaseProvider[None]):
302
274
  *,
303
275
  model_id: str,
304
276
  messages: Sequence[Message],
305
- tools: Sequence[AsyncTool] | AsyncToolkit | None = None,
306
- format: type[FormattableT]
307
- | Format[FormattableT]
308
- | OutputParser[FormattableT]
309
- | None = None,
277
+ toolkit: AsyncToolkit,
278
+ format: FormatSpec[FormattableT] | None = None,
310
279
  **params: Unpack[Params],
311
280
  ) -> AsyncStreamResponse | AsyncStreamResponse[FormattableT]:
312
281
  """Stream an `llm.AsyncStreamResponse` by calling through the Mirascope Router."""
@@ -314,7 +283,7 @@ class MirascopeProvider(BaseProvider[None]):
314
283
  return await provider.stream_async(
315
284
  model_id=model_id,
316
285
  messages=messages,
317
- tools=tools,
286
+ toolkit=toolkit,
318
287
  format=format,
319
288
  **params,
320
289
  )
@@ -325,13 +294,8 @@ class MirascopeProvider(BaseProvider[None]):
325
294
  ctx: Context[DepsT],
326
295
  model_id: str,
327
296
  messages: Sequence[Message],
328
- tools: Sequence[AsyncTool | AsyncContextTool[DepsT]]
329
- | AsyncContextToolkit[DepsT]
330
- | None = None,
331
- format: type[FormattableT]
332
- | Format[FormattableT]
333
- | OutputParser[FormattableT]
334
- | None = None,
297
+ toolkit: AsyncContextToolkit[DepsT],
298
+ format: FormatSpec[FormattableT] | None = None,
335
299
  **params: Unpack[Params],
336
300
  ) -> (
337
301
  AsyncContextStreamResponse[DepsT, None]
@@ -343,7 +307,7 @@ class MirascopeProvider(BaseProvider[None]):
343
307
  ctx=ctx,
344
308
  model_id=model_id,
345
309
  messages=messages,
346
- tools=tools,
310
+ toolkit=toolkit,
347
311
  format=format,
348
312
  **params,
349
313
  )
@@ -6,7 +6,7 @@ from typing import TypeAlias
6
6
 
7
7
  from mlx_lm.generate import GenerationResponse
8
8
 
9
- from ....formatting import Format, FormattableT, OutputParser
9
+ from ....formatting import Format, FormatSpec, FormattableT
10
10
  from ....messages import AssistantContent, Message
11
11
  from ....responses import ChunkIterator
12
12
  from ....tools import AnyToolSchema, BaseToolkit
@@ -21,11 +21,8 @@ class BaseEncoder(abc.ABC):
21
21
  def encode_request(
22
22
  self,
23
23
  messages: Sequence[Message],
24
- tools: Sequence[AnyToolSchema] | BaseToolkit[AnyToolSchema] | None,
25
- format: type[FormattableT]
26
- | Format[FormattableT]
27
- | OutputParser[FormattableT]
28
- | None,
24
+ tools: BaseToolkit[AnyToolSchema],
25
+ format: FormatSpec[FormattableT] | None,
29
26
  ) -> tuple[Sequence[Message], Format[FormattableT] | None, TokenIds]:
30
27
  """Encode the request messages into a format suitable for the model.
31
28
 
@@ -8,7 +8,7 @@ from mlx_lm.generate import GenerationResponse
8
8
  from transformers import PreTrainedTokenizer
9
9
 
10
10
  from ....content import ContentPart, TextChunk, TextEndChunk, TextStartChunk
11
- from ....formatting import Format, FormattableT, OutputParser
11
+ from ....formatting import Format, FormatSpec, FormattableT
12
12
  from ....messages import AssistantContent, Message
13
13
  from ....responses import (
14
14
  ChunkIterator,
@@ -80,15 +80,11 @@ class TransformersEncoder(BaseEncoder):
80
80
  def encode_request(
81
81
  self,
82
82
  messages: Sequence[Message],
83
- tools: Sequence[AnyToolSchema] | BaseToolkit[AnyToolSchema] | None,
84
- format: type[FormattableT]
85
- | Format[FormattableT]
86
- | OutputParser[FormattableT]
87
- | None,
83
+ tools: BaseToolkit[AnyToolSchema],
84
+ format: FormatSpec[FormattableT] | None,
88
85
  ) -> tuple[Sequence[Message], Format[FormattableT] | None, TokenIds]:
89
86
  """Encode a request into a format suitable for the model."""
90
- tool_schemas = tools.tools if isinstance(tools, BaseToolkit) else tools or []
91
- if len(tool_schemas) > 0:
87
+ if tools.tools:
92
88
  raise NotImplementedError("Tool usage is not supported.")
93
89
  if format is not None:
94
90
  raise NotImplementedError("Formatting is not supported.")
@@ -13,7 +13,7 @@ from mlx_lm import stream_generate # type: ignore[reportPrivateImportUsage]
13
13
  from mlx_lm.generate import GenerationResponse
14
14
  from transformers import PreTrainedTokenizer
15
15
 
16
- from ...formatting import Format, FormattableT, OutputParser
16
+ from ...formatting import Format, FormatSpec, FormattableT
17
17
  from ...messages import AssistantMessage, Message, assistant
18
18
  from ...responses import AsyncChunkIterator, ChunkIterator, StreamResponseChunk
19
19
  from ...tools import AnyToolSchema, BaseToolkit
@@ -137,11 +137,8 @@ class MLX:
137
137
  def stream(
138
138
  self,
139
139
  messages: Sequence[Message],
140
- tools: Sequence[AnyToolSchema] | BaseToolkit[AnyToolSchema] | None,
141
- format: type[FormattableT]
142
- | Format[FormattableT]
143
- | OutputParser[FormattableT]
144
- | None,
140
+ tools: BaseToolkit[AnyToolSchema],
141
+ format: FormatSpec[FormattableT] | None,
145
142
  params: Params,
146
143
  ) -> tuple[Sequence[Message], Format[FormattableT] | None, ChunkIterator]:
147
144
  """Stream response chunks synchronously.
@@ -163,11 +160,8 @@ class MLX:
163
160
  async def stream_async(
164
161
  self,
165
162
  messages: Sequence[Message],
166
- tools: Sequence[AnyToolSchema] | BaseToolkit[AnyToolSchema] | None,
167
- format: type[FormattableT]
168
- | Format[FormattableT]
169
- | OutputParser[FormattableT]
170
- | None,
163
+ tools: BaseToolkit[AnyToolSchema],
164
+ format: FormatSpec[FormattableT] | None,
171
165
  params: Params,
172
166
  ) -> tuple[Sequence[Message], Format[FormattableT] | None, AsyncChunkIterator]:
173
167
  """Stream response chunks asynchronously.
@@ -190,11 +184,8 @@ class MLX:
190
184
  def generate(
191
185
  self,
192
186
  messages: Sequence[Message],
193
- tools: Sequence[AnyToolSchema] | BaseToolkit[AnyToolSchema] | None,
194
- format: type[FormattableT]
195
- | Format[FormattableT]
196
- | OutputParser[FormattableT]
197
- | None,
187
+ tools: BaseToolkit[AnyToolSchema],
188
+ format: FormatSpec[FormattableT] | None,
198
189
  params: Params,
199
190
  ) -> tuple[
200
191
  Sequence[Message],
@@ -229,11 +220,8 @@ class MLX:
229
220
  async def generate_async(
230
221
  self,
231
222
  messages: Sequence[Message],
232
- tools: Sequence[AnyToolSchema] | BaseToolkit[AnyToolSchema] | None,
233
- format: type[FormattableT]
234
- | Format[FormattableT]
235
- | OutputParser[FormattableT]
236
- | None,
223
+ tools: BaseToolkit[AnyToolSchema],
224
+ format: FormatSpec[FormattableT] | None,
237
225
  params: Params,
238
226
  ) -> tuple[
239
227
  Sequence[Message],
@@ -10,7 +10,7 @@ from mlx_lm import load as mlx_load
10
10
  from transformers import PreTrainedTokenizer
11
11
 
12
12
  from ...context import Context, DepsT
13
- from ...formatting import Format, FormattableT, OutputParser
13
+ from ...formatting import FormatSpec, FormattableT
14
14
  from ...messages import Message
15
15
  from ...responses import (
16
16
  AsyncContextResponse,
@@ -23,13 +23,9 @@ from ...responses import (
23
23
  StreamResponse,
24
24
  )
25
25
  from ...tools import (
26
- AsyncContextTool,
27
26
  AsyncContextToolkit,
28
- AsyncTool,
29
27
  AsyncToolkit,
30
- ContextTool,
31
28
  ContextToolkit,
32
- Tool,
33
29
  Toolkit,
34
30
  )
35
31
  from ..base import BaseProvider
@@ -89,11 +85,8 @@ class MLXProvider(BaseProvider[None]):
89
85
  *,
90
86
  model_id: MLXModelId,
91
87
  messages: Sequence[Message],
92
- tools: Sequence[Tool] | Toolkit | None = None,
93
- format: type[FormattableT]
94
- | Format[FormattableT]
95
- | OutputParser[FormattableT]
96
- | None = None,
88
+ toolkit: Toolkit,
89
+ format: FormatSpec[FormattableT] | None = None,
97
90
  **params: Unpack[Params],
98
91
  ) -> Response | Response[FormattableT]:
99
92
  """Generate an `llm.Response` using MLX model.
@@ -111,7 +104,7 @@ class MLXProvider(BaseProvider[None]):
111
104
  mlx = _get_mlx(model_id)
112
105
 
113
106
  input_messages, format, assistant_message, response = mlx.generate(
114
- messages, tools, format, params
107
+ messages, toolkit, format, params
115
108
  )
116
109
 
117
110
  return Response(
@@ -120,7 +113,7 @@ class MLXProvider(BaseProvider[None]):
120
113
  model_id=model_id,
121
114
  provider_model_name=model_id,
122
115
  params=params,
123
- tools=tools,
116
+ tools=toolkit,
124
117
  input_messages=input_messages,
125
118
  assistant_message=assistant_message,
126
119
  finish_reason=_utils.extract_finish_reason(response),
@@ -134,13 +127,8 @@ class MLXProvider(BaseProvider[None]):
134
127
  ctx: Context[DepsT],
135
128
  model_id: MLXModelId,
136
129
  messages: Sequence[Message],
137
- tools: Sequence[Tool | ContextTool[DepsT]]
138
- | ContextToolkit[DepsT]
139
- | None = None,
140
- format: type[FormattableT]
141
- | Format[FormattableT]
142
- | OutputParser[FormattableT]
143
- | None = None,
130
+ toolkit: ContextToolkit[DepsT],
131
+ format: FormatSpec[FormattableT] | None = None,
144
132
  **params: Unpack[Params],
145
133
  ) -> ContextResponse[DepsT, None] | ContextResponse[DepsT, FormattableT]:
146
134
  """Generate an `llm.ContextResponse` using MLX model.
@@ -159,7 +147,7 @@ class MLXProvider(BaseProvider[None]):
159
147
  mlx = _get_mlx(model_id)
160
148
 
161
149
  input_messages, format, assistant_message, response = mlx.generate(
162
- messages, tools, format, params
150
+ messages, toolkit, format, params
163
151
  )
164
152
 
165
153
  return ContextResponse(
@@ -168,7 +156,7 @@ class MLXProvider(BaseProvider[None]):
168
156
  model_id=model_id,
169
157
  provider_model_name=model_id,
170
158
  params=params,
171
- tools=tools,
159
+ tools=toolkit,
172
160
  input_messages=input_messages,
173
161
  assistant_message=assistant_message,
174
162
  finish_reason=_utils.extract_finish_reason(response),
@@ -181,11 +169,8 @@ class MLXProvider(BaseProvider[None]):
181
169
  *,
182
170
  model_id: MLXModelId,
183
171
  messages: Sequence[Message],
184
- tools: Sequence[AsyncTool] | AsyncToolkit | None = None,
185
- format: type[FormattableT]
186
- | Format[FormattableT]
187
- | OutputParser[FormattableT]
188
- | None = None,
172
+ toolkit: AsyncToolkit,
173
+ format: FormatSpec[FormattableT] | None = None,
189
174
  **params: Unpack[Params],
190
175
  ) -> AsyncResponse | AsyncResponse[FormattableT]:
191
176
  """Generate an `llm.AsyncResponse` using MLX model by asynchronously calloing
@@ -208,7 +193,7 @@ class MLXProvider(BaseProvider[None]):
208
193
  format,
209
194
  assistant_message,
210
195
  response,
211
- ) = await mlx.generate_async(messages, tools, format, params)
196
+ ) = await mlx.generate_async(messages, toolkit, format, params)
212
197
 
213
198
  return AsyncResponse(
214
199
  raw=response,
@@ -216,7 +201,7 @@ class MLXProvider(BaseProvider[None]):
216
201
  model_id=model_id,
217
202
  provider_model_name=model_id,
218
203
  params=params,
219
- tools=tools,
204
+ tools=toolkit,
220
205
  input_messages=input_messages,
221
206
  assistant_message=assistant_message,
222
207
  finish_reason=_utils.extract_finish_reason(response),
@@ -230,13 +215,8 @@ class MLXProvider(BaseProvider[None]):
230
215
  ctx: Context[DepsT],
231
216
  model_id: MLXModelId,
232
217
  messages: Sequence[Message],
233
- tools: Sequence[AsyncTool | AsyncContextTool[DepsT]]
234
- | AsyncContextToolkit[DepsT]
235
- | None = None,
236
- format: type[FormattableT]
237
- | Format[FormattableT]
238
- | OutputParser[FormattableT]
239
- | None = None,
218
+ toolkit: AsyncContextToolkit[DepsT],
219
+ format: FormatSpec[FormattableT] | None = None,
240
220
  **params: Unpack[Params],
241
221
  ) -> AsyncContextResponse[DepsT, None] | AsyncContextResponse[DepsT, FormattableT]:
242
222
  """Generate an `llm.AsyncResponse` using MLX model by asynchronously calloing
@@ -260,7 +240,7 @@ class MLXProvider(BaseProvider[None]):
260
240
  format,
261
241
  assistant_message,
262
242
  response,
263
- ) = await mlx.generate_async(messages, tools, format, params)
243
+ ) = await mlx.generate_async(messages, toolkit, format, params)
264
244
 
265
245
  return AsyncContextResponse(
266
246
  raw=response,
@@ -268,7 +248,7 @@ class MLXProvider(BaseProvider[None]):
268
248
  model_id=model_id,
269
249
  provider_model_name=model_id,
270
250
  params=params,
271
- tools=tools,
251
+ tools=toolkit,
272
252
  input_messages=input_messages,
273
253
  assistant_message=assistant_message,
274
254
  finish_reason=_utils.extract_finish_reason(response),
@@ -281,11 +261,8 @@ class MLXProvider(BaseProvider[None]):
281
261
  *,
282
262
  model_id: MLXModelId,
283
263
  messages: Sequence[Message],
284
- tools: Sequence[Tool] | Toolkit | None = None,
285
- format: type[FormattableT]
286
- | Format[FormattableT]
287
- | OutputParser[FormattableT]
288
- | None = None,
264
+ toolkit: Toolkit,
265
+ format: FormatSpec[FormattableT] | None = None,
289
266
  **params: Unpack[Params],
290
267
  ) -> StreamResponse | StreamResponse[FormattableT]:
291
268
  """Generate an `llm.StreamResponse` by synchronously streaming from MLX model output.
@@ -303,7 +280,7 @@ class MLXProvider(BaseProvider[None]):
303
280
  mlx = _get_mlx(model_id)
304
281
 
305
282
  input_messages, format, chunk_iterator = mlx.stream(
306
- messages, tools, format, params
283
+ messages, toolkit, format, params
307
284
  )
308
285
 
309
286
  return StreamResponse(
@@ -311,7 +288,7 @@ class MLXProvider(BaseProvider[None]):
311
288
  model_id=model_id,
312
289
  provider_model_name=model_id,
313
290
  params=params,
314
- tools=tools,
291
+ tools=toolkit,
315
292
  input_messages=input_messages,
316
293
  chunk_iterator=chunk_iterator,
317
294
  format=format,
@@ -323,13 +300,8 @@ class MLXProvider(BaseProvider[None]):
323
300
  ctx: Context[DepsT],
324
301
  model_id: MLXModelId,
325
302
  messages: Sequence[Message],
326
- tools: Sequence[Tool | ContextTool[DepsT]]
327
- | ContextToolkit[DepsT]
328
- | None = None,
329
- format: type[FormattableT]
330
- | Format[FormattableT]
331
- | OutputParser[FormattableT]
332
- | None = None,
303
+ toolkit: ContextToolkit[DepsT],
304
+ format: FormatSpec[FormattableT] | None = None,
333
305
  **params: Unpack[Params],
334
306
  ) -> ContextStreamResponse[DepsT] | ContextStreamResponse[DepsT, FormattableT]:
335
307
  """Generate an `llm.ContextStreamResponse` by synchronously streaming from MLX model output.
@@ -348,7 +320,7 @@ class MLXProvider(BaseProvider[None]):
348
320
  mlx = _get_mlx(model_id)
349
321
 
350
322
  input_messages, format, chunk_iterator = mlx.stream(
351
- messages, tools, format, params
323
+ messages, toolkit, format, params
352
324
  )
353
325
 
354
326
  return ContextStreamResponse(
@@ -356,7 +328,7 @@ class MLXProvider(BaseProvider[None]):
356
328
  model_id=model_id,
357
329
  provider_model_name=model_id,
358
330
  params=params,
359
- tools=tools,
331
+ tools=toolkit,
360
332
  input_messages=input_messages,
361
333
  chunk_iterator=chunk_iterator,
362
334
  format=format,
@@ -367,11 +339,8 @@ class MLXProvider(BaseProvider[None]):
367
339
  *,
368
340
  model_id: MLXModelId,
369
341
  messages: Sequence[Message],
370
- tools: Sequence[AsyncTool] | AsyncToolkit | None = None,
371
- format: type[FormattableT]
372
- | Format[FormattableT]
373
- | OutputParser[FormattableT]
374
- | None = None,
342
+ toolkit: AsyncToolkit,
343
+ format: FormatSpec[FormattableT] | None = None,
375
344
  **params: Unpack[Params],
376
345
  ) -> AsyncStreamResponse | AsyncStreamResponse[FormattableT]:
377
346
  """Generate an `llm.AsyncStreamResponse` by asynchronously streaming from MLX model output.
@@ -389,7 +358,7 @@ class MLXProvider(BaseProvider[None]):
389
358
  mlx = _get_mlx(model_id)
390
359
 
391
360
  input_messages, format, chunk_iterator = await mlx.stream_async(
392
- messages, tools, format, params
361
+ messages, toolkit, format, params
393
362
  )
394
363
 
395
364
  return AsyncStreamResponse(
@@ -397,7 +366,7 @@ class MLXProvider(BaseProvider[None]):
397
366
  model_id=model_id,
398
367
  provider_model_name=model_id,
399
368
  params=params,
400
- tools=tools,
369
+ tools=toolkit,
401
370
  input_messages=input_messages,
402
371
  chunk_iterator=chunk_iterator,
403
372
  format=format,
@@ -409,13 +378,8 @@ class MLXProvider(BaseProvider[None]):
409
378
  ctx: Context[DepsT],
410
379
  model_id: MLXModelId,
411
380
  messages: Sequence[Message],
412
- tools: Sequence[AsyncTool | AsyncContextTool[DepsT]]
413
- | AsyncContextToolkit[DepsT]
414
- | None = None,
415
- format: type[FormattableT]
416
- | Format[FormattableT]
417
- | OutputParser[FormattableT]
418
- | None = None,
381
+ toolkit: AsyncContextToolkit[DepsT],
382
+ format: FormatSpec[FormattableT] | None = None,
419
383
  **params: Unpack[Params],
420
384
  ) -> (
421
385
  AsyncContextStreamResponse[DepsT]
@@ -437,7 +401,7 @@ class MLXProvider(BaseProvider[None]):
437
401
  mlx = _get_mlx(model_id)
438
402
 
439
403
  input_messages, format, chunk_iterator = await mlx.stream_async(
440
- messages, tools, format, params
404
+ messages, toolkit, format, params
441
405
  )
442
406
 
443
407
  return AsyncContextStreamResponse(
@@ -445,7 +409,7 @@ class MLXProvider(BaseProvider[None]):
445
409
  model_id=model_id,
446
410
  provider_model_name=model_id,
447
411
  params=params,
448
- tools=tools,
412
+ tools=toolkit,
449
413
  input_messages=input_messages,
450
414
  chunk_iterator=chunk_iterator,
451
415
  format=format,