mirascope 2.0.1__py3-none-any.whl → 2.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mirascope/_stubs.py +39 -18
- mirascope/_utils.py +34 -0
- mirascope/api/_generated/__init__.py +4 -0
- mirascope/api/_generated/organization_invitations/client.py +2 -2
- mirascope/api/_generated/organization_invitations/raw_client.py +2 -2
- mirascope/api/_generated/project_memberships/__init__.py +4 -0
- mirascope/api/_generated/project_memberships/client.py +91 -0
- mirascope/api/_generated/project_memberships/raw_client.py +239 -0
- mirascope/api/_generated/project_memberships/types/__init__.py +4 -0
- mirascope/api/_generated/project_memberships/types/project_memberships_get_response.py +33 -0
- mirascope/api/_generated/project_memberships/types/project_memberships_get_response_role.py +7 -0
- mirascope/api/_generated/reference.md +73 -1
- mirascope/llm/__init__.py +19 -0
- mirascope/llm/calls/calls.py +28 -21
- mirascope/llm/calls/decorator.py +17 -24
- mirascope/llm/formatting/__init__.py +2 -2
- mirascope/llm/formatting/format.py +2 -4
- mirascope/llm/formatting/types.py +19 -2
- mirascope/llm/models/models.py +66 -146
- mirascope/llm/prompts/decorator.py +5 -16
- mirascope/llm/prompts/prompts.py +35 -38
- mirascope/llm/providers/anthropic/_utils/beta_decode.py +22 -7
- mirascope/llm/providers/anthropic/_utils/beta_encode.py +22 -16
- mirascope/llm/providers/anthropic/_utils/decode.py +45 -7
- mirascope/llm/providers/anthropic/_utils/encode.py +28 -15
- mirascope/llm/providers/anthropic/beta_provider.py +33 -69
- mirascope/llm/providers/anthropic/provider.py +52 -91
- mirascope/llm/providers/base/_utils.py +4 -9
- mirascope/llm/providers/base/base_provider.py +89 -205
- mirascope/llm/providers/google/_utils/decode.py +51 -1
- mirascope/llm/providers/google/_utils/encode.py +38 -21
- mirascope/llm/providers/google/provider.py +33 -69
- mirascope/llm/providers/mirascope/provider.py +25 -61
- mirascope/llm/providers/mlx/encoding/base.py +3 -6
- mirascope/llm/providers/mlx/encoding/transformers.py +4 -8
- mirascope/llm/providers/mlx/mlx.py +9 -21
- mirascope/llm/providers/mlx/provider.py +33 -69
- mirascope/llm/providers/openai/completions/_utils/encode.py +39 -20
- mirascope/llm/providers/openai/completions/base_provider.py +34 -75
- mirascope/llm/providers/openai/provider.py +25 -61
- mirascope/llm/providers/openai/responses/_utils/decode.py +31 -2
- mirascope/llm/providers/openai/responses/_utils/encode.py +32 -17
- mirascope/llm/providers/openai/responses/provider.py +34 -75
- mirascope/llm/responses/__init__.py +2 -1
- mirascope/llm/responses/base_stream_response.py +4 -0
- mirascope/llm/responses/response.py +8 -12
- mirascope/llm/responses/stream_response.py +8 -12
- mirascope/llm/responses/usage.py +44 -0
- mirascope/llm/tools/__init__.py +24 -0
- mirascope/llm/tools/provider_tools.py +18 -0
- mirascope/llm/tools/tool_schema.py +11 -4
- mirascope/llm/tools/toolkit.py +24 -6
- mirascope/llm/tools/types.py +112 -0
- mirascope/llm/tools/web_search_tool.py +32 -0
- mirascope/ops/__init__.py +19 -1
- mirascope/ops/_internal/closure.py +4 -1
- mirascope/ops/_internal/exporters/exporters.py +13 -46
- mirascope/ops/_internal/exporters/utils.py +37 -0
- mirascope/ops/_internal/instrumentation/__init__.py +20 -0
- mirascope/ops/_internal/instrumentation/llm/common.py +19 -49
- mirascope/ops/_internal/instrumentation/llm/model.py +61 -82
- mirascope/ops/_internal/instrumentation/llm/serialize.py +36 -12
- mirascope/ops/_internal/instrumentation/providers/__init__.py +29 -0
- mirascope/ops/_internal/instrumentation/providers/anthropic.py +78 -0
- mirascope/ops/_internal/instrumentation/providers/base.py +179 -0
- mirascope/ops/_internal/instrumentation/providers/google_genai.py +85 -0
- mirascope/ops/_internal/instrumentation/providers/openai.py +82 -0
- mirascope/ops/_internal/traced_calls.py +14 -0
- mirascope/ops/_internal/traced_functions.py +7 -2
- mirascope/ops/_internal/utils.py +12 -4
- mirascope/ops/_internal/versioned_functions.py +1 -1
- {mirascope-2.0.1.dist-info → mirascope-2.1.0.dist-info}/METADATA +96 -68
- {mirascope-2.0.1.dist-info → mirascope-2.1.0.dist-info}/RECORD +75 -64
- {mirascope-2.0.1.dist-info → mirascope-2.1.0.dist-info}/WHEEL +0 -0
- {mirascope-2.0.1.dist-info → mirascope-2.1.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING
|
|
|
8
8
|
from typing_extensions import Unpack
|
|
9
9
|
|
|
10
10
|
from ...context import Context, DepsT
|
|
11
|
-
from ...formatting import
|
|
11
|
+
from ...formatting import FormatSpec, FormattableT
|
|
12
12
|
from ...messages import Message
|
|
13
13
|
from ...responses import (
|
|
14
14
|
AsyncContextResponse,
|
|
@@ -21,13 +21,9 @@ from ...responses import (
|
|
|
21
21
|
StreamResponse,
|
|
22
22
|
)
|
|
23
23
|
from ...tools import (
|
|
24
|
-
AsyncContextTool,
|
|
25
24
|
AsyncContextToolkit,
|
|
26
|
-
AsyncTool,
|
|
27
25
|
AsyncToolkit,
|
|
28
|
-
ContextTool,
|
|
29
26
|
ContextToolkit,
|
|
30
|
-
Tool,
|
|
31
27
|
Toolkit,
|
|
32
28
|
)
|
|
33
29
|
from ..base import BaseProvider, Provider
|
|
@@ -156,11 +152,8 @@ class MirascopeProvider(BaseProvider[None]):
|
|
|
156
152
|
*,
|
|
157
153
|
model_id: str,
|
|
158
154
|
messages: Sequence[Message],
|
|
159
|
-
|
|
160
|
-
format:
|
|
161
|
-
| Format[FormattableT]
|
|
162
|
-
| OutputParser[FormattableT]
|
|
163
|
-
| None = None,
|
|
155
|
+
toolkit: Toolkit,
|
|
156
|
+
format: FormatSpec[FormattableT] | None = None,
|
|
164
157
|
**params: Unpack[Params],
|
|
165
158
|
) -> Response | Response[FormattableT]:
|
|
166
159
|
"""Generate an `llm.Response` by calling through the Mirascope Router."""
|
|
@@ -168,7 +161,7 @@ class MirascopeProvider(BaseProvider[None]):
|
|
|
168
161
|
return provider.call(
|
|
169
162
|
model_id=model_id,
|
|
170
163
|
messages=messages,
|
|
171
|
-
|
|
164
|
+
toolkit=toolkit,
|
|
172
165
|
format=format,
|
|
173
166
|
**params,
|
|
174
167
|
)
|
|
@@ -179,13 +172,8 @@ class MirascopeProvider(BaseProvider[None]):
|
|
|
179
172
|
ctx: Context[DepsT],
|
|
180
173
|
model_id: str,
|
|
181
174
|
messages: Sequence[Message],
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
| None = None,
|
|
185
|
-
format: type[FormattableT]
|
|
186
|
-
| Format[FormattableT]
|
|
187
|
-
| OutputParser[FormattableT]
|
|
188
|
-
| None = None,
|
|
175
|
+
toolkit: ContextToolkit[DepsT],
|
|
176
|
+
format: FormatSpec[FormattableT] | None = None,
|
|
189
177
|
**params: Unpack[Params],
|
|
190
178
|
) -> ContextResponse[DepsT, None] | ContextResponse[DepsT, FormattableT]:
|
|
191
179
|
"""Generate an `llm.ContextResponse` by calling through the Mirascope Router."""
|
|
@@ -194,7 +182,7 @@ class MirascopeProvider(BaseProvider[None]):
|
|
|
194
182
|
ctx=ctx,
|
|
195
183
|
model_id=model_id,
|
|
196
184
|
messages=messages,
|
|
197
|
-
|
|
185
|
+
toolkit=toolkit,
|
|
198
186
|
format=format,
|
|
199
187
|
**params,
|
|
200
188
|
)
|
|
@@ -204,11 +192,8 @@ class MirascopeProvider(BaseProvider[None]):
|
|
|
204
192
|
*,
|
|
205
193
|
model_id: str,
|
|
206
194
|
messages: Sequence[Message],
|
|
207
|
-
|
|
208
|
-
format:
|
|
209
|
-
| Format[FormattableT]
|
|
210
|
-
| OutputParser[FormattableT]
|
|
211
|
-
| None = None,
|
|
195
|
+
toolkit: AsyncToolkit,
|
|
196
|
+
format: FormatSpec[FormattableT] | None = None,
|
|
212
197
|
**params: Unpack[Params],
|
|
213
198
|
) -> AsyncResponse | AsyncResponse[FormattableT]:
|
|
214
199
|
"""Generate an `llm.AsyncResponse` by calling through the Mirascope Router."""
|
|
@@ -216,7 +201,7 @@ class MirascopeProvider(BaseProvider[None]):
|
|
|
216
201
|
return await provider.call_async(
|
|
217
202
|
model_id=model_id,
|
|
218
203
|
messages=messages,
|
|
219
|
-
|
|
204
|
+
toolkit=toolkit,
|
|
220
205
|
format=format,
|
|
221
206
|
**params,
|
|
222
207
|
)
|
|
@@ -227,13 +212,8 @@ class MirascopeProvider(BaseProvider[None]):
|
|
|
227
212
|
ctx: Context[DepsT],
|
|
228
213
|
model_id: str,
|
|
229
214
|
messages: Sequence[Message],
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
| None = None,
|
|
233
|
-
format: type[FormattableT]
|
|
234
|
-
| Format[FormattableT]
|
|
235
|
-
| OutputParser[FormattableT]
|
|
236
|
-
| None = None,
|
|
215
|
+
toolkit: AsyncContextToolkit[DepsT],
|
|
216
|
+
format: FormatSpec[FormattableT] | None = None,
|
|
237
217
|
**params: Unpack[Params],
|
|
238
218
|
) -> AsyncContextResponse[DepsT, None] | AsyncContextResponse[DepsT, FormattableT]:
|
|
239
219
|
"""Generate an `llm.AsyncContextResponse` by calling through the Mirascope Router."""
|
|
@@ -242,7 +222,7 @@ class MirascopeProvider(BaseProvider[None]):
|
|
|
242
222
|
ctx=ctx,
|
|
243
223
|
model_id=model_id,
|
|
244
224
|
messages=messages,
|
|
245
|
-
|
|
225
|
+
toolkit=toolkit,
|
|
246
226
|
format=format,
|
|
247
227
|
**params,
|
|
248
228
|
)
|
|
@@ -252,11 +232,8 @@ class MirascopeProvider(BaseProvider[None]):
|
|
|
252
232
|
*,
|
|
253
233
|
model_id: str,
|
|
254
234
|
messages: Sequence[Message],
|
|
255
|
-
|
|
256
|
-
format:
|
|
257
|
-
| Format[FormattableT]
|
|
258
|
-
| OutputParser[FormattableT]
|
|
259
|
-
| None = None,
|
|
235
|
+
toolkit: Toolkit,
|
|
236
|
+
format: FormatSpec[FormattableT] | None = None,
|
|
260
237
|
**params: Unpack[Params],
|
|
261
238
|
) -> StreamResponse | StreamResponse[FormattableT]:
|
|
262
239
|
"""Stream an `llm.StreamResponse` by calling through the Mirascope Router."""
|
|
@@ -264,7 +241,7 @@ class MirascopeProvider(BaseProvider[None]):
|
|
|
264
241
|
return provider.stream(
|
|
265
242
|
model_id=model_id,
|
|
266
243
|
messages=messages,
|
|
267
|
-
|
|
244
|
+
toolkit=toolkit,
|
|
268
245
|
format=format,
|
|
269
246
|
**params,
|
|
270
247
|
)
|
|
@@ -275,13 +252,8 @@ class MirascopeProvider(BaseProvider[None]):
|
|
|
275
252
|
ctx: Context[DepsT],
|
|
276
253
|
model_id: str,
|
|
277
254
|
messages: Sequence[Message],
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
| None = None,
|
|
281
|
-
format: type[FormattableT]
|
|
282
|
-
| Format[FormattableT]
|
|
283
|
-
| OutputParser[FormattableT]
|
|
284
|
-
| None = None,
|
|
255
|
+
toolkit: ContextToolkit[DepsT],
|
|
256
|
+
format: FormatSpec[FormattableT] | None = None,
|
|
285
257
|
**params: Unpack[Params],
|
|
286
258
|
) -> (
|
|
287
259
|
ContextStreamResponse[DepsT, None] | ContextStreamResponse[DepsT, FormattableT]
|
|
@@ -292,7 +264,7 @@ class MirascopeProvider(BaseProvider[None]):
|
|
|
292
264
|
ctx=ctx,
|
|
293
265
|
model_id=model_id,
|
|
294
266
|
messages=messages,
|
|
295
|
-
|
|
267
|
+
toolkit=toolkit,
|
|
296
268
|
format=format,
|
|
297
269
|
**params,
|
|
298
270
|
)
|
|
@@ -302,11 +274,8 @@ class MirascopeProvider(BaseProvider[None]):
|
|
|
302
274
|
*,
|
|
303
275
|
model_id: str,
|
|
304
276
|
messages: Sequence[Message],
|
|
305
|
-
|
|
306
|
-
format:
|
|
307
|
-
| Format[FormattableT]
|
|
308
|
-
| OutputParser[FormattableT]
|
|
309
|
-
| None = None,
|
|
277
|
+
toolkit: AsyncToolkit,
|
|
278
|
+
format: FormatSpec[FormattableT] | None = None,
|
|
310
279
|
**params: Unpack[Params],
|
|
311
280
|
) -> AsyncStreamResponse | AsyncStreamResponse[FormattableT]:
|
|
312
281
|
"""Stream an `llm.AsyncStreamResponse` by calling through the Mirascope Router."""
|
|
@@ -314,7 +283,7 @@ class MirascopeProvider(BaseProvider[None]):
|
|
|
314
283
|
return await provider.stream_async(
|
|
315
284
|
model_id=model_id,
|
|
316
285
|
messages=messages,
|
|
317
|
-
|
|
286
|
+
toolkit=toolkit,
|
|
318
287
|
format=format,
|
|
319
288
|
**params,
|
|
320
289
|
)
|
|
@@ -325,13 +294,8 @@ class MirascopeProvider(BaseProvider[None]):
|
|
|
325
294
|
ctx: Context[DepsT],
|
|
326
295
|
model_id: str,
|
|
327
296
|
messages: Sequence[Message],
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
| None = None,
|
|
331
|
-
format: type[FormattableT]
|
|
332
|
-
| Format[FormattableT]
|
|
333
|
-
| OutputParser[FormattableT]
|
|
334
|
-
| None = None,
|
|
297
|
+
toolkit: AsyncContextToolkit[DepsT],
|
|
298
|
+
format: FormatSpec[FormattableT] | None = None,
|
|
335
299
|
**params: Unpack[Params],
|
|
336
300
|
) -> (
|
|
337
301
|
AsyncContextStreamResponse[DepsT, None]
|
|
@@ -343,7 +307,7 @@ class MirascopeProvider(BaseProvider[None]):
|
|
|
343
307
|
ctx=ctx,
|
|
344
308
|
model_id=model_id,
|
|
345
309
|
messages=messages,
|
|
346
|
-
|
|
310
|
+
toolkit=toolkit,
|
|
347
311
|
format=format,
|
|
348
312
|
**params,
|
|
349
313
|
)
|
|
@@ -6,7 +6,7 @@ from typing import TypeAlias
|
|
|
6
6
|
|
|
7
7
|
from mlx_lm.generate import GenerationResponse
|
|
8
8
|
|
|
9
|
-
from ....formatting import Format,
|
|
9
|
+
from ....formatting import Format, FormatSpec, FormattableT
|
|
10
10
|
from ....messages import AssistantContent, Message
|
|
11
11
|
from ....responses import ChunkIterator
|
|
12
12
|
from ....tools import AnyToolSchema, BaseToolkit
|
|
@@ -21,11 +21,8 @@ class BaseEncoder(abc.ABC):
|
|
|
21
21
|
def encode_request(
|
|
22
22
|
self,
|
|
23
23
|
messages: Sequence[Message],
|
|
24
|
-
tools:
|
|
25
|
-
format:
|
|
26
|
-
| Format[FormattableT]
|
|
27
|
-
| OutputParser[FormattableT]
|
|
28
|
-
| None,
|
|
24
|
+
tools: BaseToolkit[AnyToolSchema],
|
|
25
|
+
format: FormatSpec[FormattableT] | None,
|
|
29
26
|
) -> tuple[Sequence[Message], Format[FormattableT] | None, TokenIds]:
|
|
30
27
|
"""Encode the request messages into a format suitable for the model.
|
|
31
28
|
|
|
@@ -8,7 +8,7 @@ from mlx_lm.generate import GenerationResponse
|
|
|
8
8
|
from transformers import PreTrainedTokenizer
|
|
9
9
|
|
|
10
10
|
from ....content import ContentPart, TextChunk, TextEndChunk, TextStartChunk
|
|
11
|
-
from ....formatting import Format,
|
|
11
|
+
from ....formatting import Format, FormatSpec, FormattableT
|
|
12
12
|
from ....messages import AssistantContent, Message
|
|
13
13
|
from ....responses import (
|
|
14
14
|
ChunkIterator,
|
|
@@ -80,15 +80,11 @@ class TransformersEncoder(BaseEncoder):
|
|
|
80
80
|
def encode_request(
|
|
81
81
|
self,
|
|
82
82
|
messages: Sequence[Message],
|
|
83
|
-
tools:
|
|
84
|
-
format:
|
|
85
|
-
| Format[FormattableT]
|
|
86
|
-
| OutputParser[FormattableT]
|
|
87
|
-
| None,
|
|
83
|
+
tools: BaseToolkit[AnyToolSchema],
|
|
84
|
+
format: FormatSpec[FormattableT] | None,
|
|
88
85
|
) -> tuple[Sequence[Message], Format[FormattableT] | None, TokenIds]:
|
|
89
86
|
"""Encode a request into a format suitable for the model."""
|
|
90
|
-
|
|
91
|
-
if len(tool_schemas) > 0:
|
|
87
|
+
if tools.tools:
|
|
92
88
|
raise NotImplementedError("Tool usage is not supported.")
|
|
93
89
|
if format is not None:
|
|
94
90
|
raise NotImplementedError("Formatting is not supported.")
|
|
@@ -13,7 +13,7 @@ from mlx_lm import stream_generate # type: ignore[reportPrivateImportUsage]
|
|
|
13
13
|
from mlx_lm.generate import GenerationResponse
|
|
14
14
|
from transformers import PreTrainedTokenizer
|
|
15
15
|
|
|
16
|
-
from ...formatting import Format,
|
|
16
|
+
from ...formatting import Format, FormatSpec, FormattableT
|
|
17
17
|
from ...messages import AssistantMessage, Message, assistant
|
|
18
18
|
from ...responses import AsyncChunkIterator, ChunkIterator, StreamResponseChunk
|
|
19
19
|
from ...tools import AnyToolSchema, BaseToolkit
|
|
@@ -137,11 +137,8 @@ class MLX:
|
|
|
137
137
|
def stream(
|
|
138
138
|
self,
|
|
139
139
|
messages: Sequence[Message],
|
|
140
|
-
tools:
|
|
141
|
-
format:
|
|
142
|
-
| Format[FormattableT]
|
|
143
|
-
| OutputParser[FormattableT]
|
|
144
|
-
| None,
|
|
140
|
+
tools: BaseToolkit[AnyToolSchema],
|
|
141
|
+
format: FormatSpec[FormattableT] | None,
|
|
145
142
|
params: Params,
|
|
146
143
|
) -> tuple[Sequence[Message], Format[FormattableT] | None, ChunkIterator]:
|
|
147
144
|
"""Stream response chunks synchronously.
|
|
@@ -163,11 +160,8 @@ class MLX:
|
|
|
163
160
|
async def stream_async(
|
|
164
161
|
self,
|
|
165
162
|
messages: Sequence[Message],
|
|
166
|
-
tools:
|
|
167
|
-
format:
|
|
168
|
-
| Format[FormattableT]
|
|
169
|
-
| OutputParser[FormattableT]
|
|
170
|
-
| None,
|
|
163
|
+
tools: BaseToolkit[AnyToolSchema],
|
|
164
|
+
format: FormatSpec[FormattableT] | None,
|
|
171
165
|
params: Params,
|
|
172
166
|
) -> tuple[Sequence[Message], Format[FormattableT] | None, AsyncChunkIterator]:
|
|
173
167
|
"""Stream response chunks asynchronously.
|
|
@@ -190,11 +184,8 @@ class MLX:
|
|
|
190
184
|
def generate(
|
|
191
185
|
self,
|
|
192
186
|
messages: Sequence[Message],
|
|
193
|
-
tools:
|
|
194
|
-
format:
|
|
195
|
-
| Format[FormattableT]
|
|
196
|
-
| OutputParser[FormattableT]
|
|
197
|
-
| None,
|
|
187
|
+
tools: BaseToolkit[AnyToolSchema],
|
|
188
|
+
format: FormatSpec[FormattableT] | None,
|
|
198
189
|
params: Params,
|
|
199
190
|
) -> tuple[
|
|
200
191
|
Sequence[Message],
|
|
@@ -229,11 +220,8 @@ class MLX:
|
|
|
229
220
|
async def generate_async(
|
|
230
221
|
self,
|
|
231
222
|
messages: Sequence[Message],
|
|
232
|
-
tools:
|
|
233
|
-
format:
|
|
234
|
-
| Format[FormattableT]
|
|
235
|
-
| OutputParser[FormattableT]
|
|
236
|
-
| None,
|
|
223
|
+
tools: BaseToolkit[AnyToolSchema],
|
|
224
|
+
format: FormatSpec[FormattableT] | None,
|
|
237
225
|
params: Params,
|
|
238
226
|
) -> tuple[
|
|
239
227
|
Sequence[Message],
|
|
@@ -10,7 +10,7 @@ from mlx_lm import load as mlx_load
|
|
|
10
10
|
from transformers import PreTrainedTokenizer
|
|
11
11
|
|
|
12
12
|
from ...context import Context, DepsT
|
|
13
|
-
from ...formatting import
|
|
13
|
+
from ...formatting import FormatSpec, FormattableT
|
|
14
14
|
from ...messages import Message
|
|
15
15
|
from ...responses import (
|
|
16
16
|
AsyncContextResponse,
|
|
@@ -23,13 +23,9 @@ from ...responses import (
|
|
|
23
23
|
StreamResponse,
|
|
24
24
|
)
|
|
25
25
|
from ...tools import (
|
|
26
|
-
AsyncContextTool,
|
|
27
26
|
AsyncContextToolkit,
|
|
28
|
-
AsyncTool,
|
|
29
27
|
AsyncToolkit,
|
|
30
|
-
ContextTool,
|
|
31
28
|
ContextToolkit,
|
|
32
|
-
Tool,
|
|
33
29
|
Toolkit,
|
|
34
30
|
)
|
|
35
31
|
from ..base import BaseProvider
|
|
@@ -89,11 +85,8 @@ class MLXProvider(BaseProvider[None]):
|
|
|
89
85
|
*,
|
|
90
86
|
model_id: MLXModelId,
|
|
91
87
|
messages: Sequence[Message],
|
|
92
|
-
|
|
93
|
-
format:
|
|
94
|
-
| Format[FormattableT]
|
|
95
|
-
| OutputParser[FormattableT]
|
|
96
|
-
| None = None,
|
|
88
|
+
toolkit: Toolkit,
|
|
89
|
+
format: FormatSpec[FormattableT] | None = None,
|
|
97
90
|
**params: Unpack[Params],
|
|
98
91
|
) -> Response | Response[FormattableT]:
|
|
99
92
|
"""Generate an `llm.Response` using MLX model.
|
|
@@ -111,7 +104,7 @@ class MLXProvider(BaseProvider[None]):
|
|
|
111
104
|
mlx = _get_mlx(model_id)
|
|
112
105
|
|
|
113
106
|
input_messages, format, assistant_message, response = mlx.generate(
|
|
114
|
-
messages,
|
|
107
|
+
messages, toolkit, format, params
|
|
115
108
|
)
|
|
116
109
|
|
|
117
110
|
return Response(
|
|
@@ -120,7 +113,7 @@ class MLXProvider(BaseProvider[None]):
|
|
|
120
113
|
model_id=model_id,
|
|
121
114
|
provider_model_name=model_id,
|
|
122
115
|
params=params,
|
|
123
|
-
tools=
|
|
116
|
+
tools=toolkit,
|
|
124
117
|
input_messages=input_messages,
|
|
125
118
|
assistant_message=assistant_message,
|
|
126
119
|
finish_reason=_utils.extract_finish_reason(response),
|
|
@@ -134,13 +127,8 @@ class MLXProvider(BaseProvider[None]):
|
|
|
134
127
|
ctx: Context[DepsT],
|
|
135
128
|
model_id: MLXModelId,
|
|
136
129
|
messages: Sequence[Message],
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
| None = None,
|
|
140
|
-
format: type[FormattableT]
|
|
141
|
-
| Format[FormattableT]
|
|
142
|
-
| OutputParser[FormattableT]
|
|
143
|
-
| None = None,
|
|
130
|
+
toolkit: ContextToolkit[DepsT],
|
|
131
|
+
format: FormatSpec[FormattableT] | None = None,
|
|
144
132
|
**params: Unpack[Params],
|
|
145
133
|
) -> ContextResponse[DepsT, None] | ContextResponse[DepsT, FormattableT]:
|
|
146
134
|
"""Generate an `llm.ContextResponse` using MLX model.
|
|
@@ -159,7 +147,7 @@ class MLXProvider(BaseProvider[None]):
|
|
|
159
147
|
mlx = _get_mlx(model_id)
|
|
160
148
|
|
|
161
149
|
input_messages, format, assistant_message, response = mlx.generate(
|
|
162
|
-
messages,
|
|
150
|
+
messages, toolkit, format, params
|
|
163
151
|
)
|
|
164
152
|
|
|
165
153
|
return ContextResponse(
|
|
@@ -168,7 +156,7 @@ class MLXProvider(BaseProvider[None]):
|
|
|
168
156
|
model_id=model_id,
|
|
169
157
|
provider_model_name=model_id,
|
|
170
158
|
params=params,
|
|
171
|
-
tools=
|
|
159
|
+
tools=toolkit,
|
|
172
160
|
input_messages=input_messages,
|
|
173
161
|
assistant_message=assistant_message,
|
|
174
162
|
finish_reason=_utils.extract_finish_reason(response),
|
|
@@ -181,11 +169,8 @@ class MLXProvider(BaseProvider[None]):
|
|
|
181
169
|
*,
|
|
182
170
|
model_id: MLXModelId,
|
|
183
171
|
messages: Sequence[Message],
|
|
184
|
-
|
|
185
|
-
format:
|
|
186
|
-
| Format[FormattableT]
|
|
187
|
-
| OutputParser[FormattableT]
|
|
188
|
-
| None = None,
|
|
172
|
+
toolkit: AsyncToolkit,
|
|
173
|
+
format: FormatSpec[FormattableT] | None = None,
|
|
189
174
|
**params: Unpack[Params],
|
|
190
175
|
) -> AsyncResponse | AsyncResponse[FormattableT]:
|
|
191
176
|
"""Generate an `llm.AsyncResponse` using MLX model by asynchronously calloing
|
|
@@ -208,7 +193,7 @@ class MLXProvider(BaseProvider[None]):
|
|
|
208
193
|
format,
|
|
209
194
|
assistant_message,
|
|
210
195
|
response,
|
|
211
|
-
) = await mlx.generate_async(messages,
|
|
196
|
+
) = await mlx.generate_async(messages, toolkit, format, params)
|
|
212
197
|
|
|
213
198
|
return AsyncResponse(
|
|
214
199
|
raw=response,
|
|
@@ -216,7 +201,7 @@ class MLXProvider(BaseProvider[None]):
|
|
|
216
201
|
model_id=model_id,
|
|
217
202
|
provider_model_name=model_id,
|
|
218
203
|
params=params,
|
|
219
|
-
tools=
|
|
204
|
+
tools=toolkit,
|
|
220
205
|
input_messages=input_messages,
|
|
221
206
|
assistant_message=assistant_message,
|
|
222
207
|
finish_reason=_utils.extract_finish_reason(response),
|
|
@@ -230,13 +215,8 @@ class MLXProvider(BaseProvider[None]):
|
|
|
230
215
|
ctx: Context[DepsT],
|
|
231
216
|
model_id: MLXModelId,
|
|
232
217
|
messages: Sequence[Message],
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
| None = None,
|
|
236
|
-
format: type[FormattableT]
|
|
237
|
-
| Format[FormattableT]
|
|
238
|
-
| OutputParser[FormattableT]
|
|
239
|
-
| None = None,
|
|
218
|
+
toolkit: AsyncContextToolkit[DepsT],
|
|
219
|
+
format: FormatSpec[FormattableT] | None = None,
|
|
240
220
|
**params: Unpack[Params],
|
|
241
221
|
) -> AsyncContextResponse[DepsT, None] | AsyncContextResponse[DepsT, FormattableT]:
|
|
242
222
|
"""Generate an `llm.AsyncResponse` using MLX model by asynchronously calloing
|
|
@@ -260,7 +240,7 @@ class MLXProvider(BaseProvider[None]):
|
|
|
260
240
|
format,
|
|
261
241
|
assistant_message,
|
|
262
242
|
response,
|
|
263
|
-
) = await mlx.generate_async(messages,
|
|
243
|
+
) = await mlx.generate_async(messages, toolkit, format, params)
|
|
264
244
|
|
|
265
245
|
return AsyncContextResponse(
|
|
266
246
|
raw=response,
|
|
@@ -268,7 +248,7 @@ class MLXProvider(BaseProvider[None]):
|
|
|
268
248
|
model_id=model_id,
|
|
269
249
|
provider_model_name=model_id,
|
|
270
250
|
params=params,
|
|
271
|
-
tools=
|
|
251
|
+
tools=toolkit,
|
|
272
252
|
input_messages=input_messages,
|
|
273
253
|
assistant_message=assistant_message,
|
|
274
254
|
finish_reason=_utils.extract_finish_reason(response),
|
|
@@ -281,11 +261,8 @@ class MLXProvider(BaseProvider[None]):
|
|
|
281
261
|
*,
|
|
282
262
|
model_id: MLXModelId,
|
|
283
263
|
messages: Sequence[Message],
|
|
284
|
-
|
|
285
|
-
format:
|
|
286
|
-
| Format[FormattableT]
|
|
287
|
-
| OutputParser[FormattableT]
|
|
288
|
-
| None = None,
|
|
264
|
+
toolkit: Toolkit,
|
|
265
|
+
format: FormatSpec[FormattableT] | None = None,
|
|
289
266
|
**params: Unpack[Params],
|
|
290
267
|
) -> StreamResponse | StreamResponse[FormattableT]:
|
|
291
268
|
"""Generate an `llm.StreamResponse` by synchronously streaming from MLX model output.
|
|
@@ -303,7 +280,7 @@ class MLXProvider(BaseProvider[None]):
|
|
|
303
280
|
mlx = _get_mlx(model_id)
|
|
304
281
|
|
|
305
282
|
input_messages, format, chunk_iterator = mlx.stream(
|
|
306
|
-
messages,
|
|
283
|
+
messages, toolkit, format, params
|
|
307
284
|
)
|
|
308
285
|
|
|
309
286
|
return StreamResponse(
|
|
@@ -311,7 +288,7 @@ class MLXProvider(BaseProvider[None]):
|
|
|
311
288
|
model_id=model_id,
|
|
312
289
|
provider_model_name=model_id,
|
|
313
290
|
params=params,
|
|
314
|
-
tools=
|
|
291
|
+
tools=toolkit,
|
|
315
292
|
input_messages=input_messages,
|
|
316
293
|
chunk_iterator=chunk_iterator,
|
|
317
294
|
format=format,
|
|
@@ -323,13 +300,8 @@ class MLXProvider(BaseProvider[None]):
|
|
|
323
300
|
ctx: Context[DepsT],
|
|
324
301
|
model_id: MLXModelId,
|
|
325
302
|
messages: Sequence[Message],
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
| None = None,
|
|
329
|
-
format: type[FormattableT]
|
|
330
|
-
| Format[FormattableT]
|
|
331
|
-
| OutputParser[FormattableT]
|
|
332
|
-
| None = None,
|
|
303
|
+
toolkit: ContextToolkit[DepsT],
|
|
304
|
+
format: FormatSpec[FormattableT] | None = None,
|
|
333
305
|
**params: Unpack[Params],
|
|
334
306
|
) -> ContextStreamResponse[DepsT] | ContextStreamResponse[DepsT, FormattableT]:
|
|
335
307
|
"""Generate an `llm.ContextStreamResponse` by synchronously streaming from MLX model output.
|
|
@@ -348,7 +320,7 @@ class MLXProvider(BaseProvider[None]):
|
|
|
348
320
|
mlx = _get_mlx(model_id)
|
|
349
321
|
|
|
350
322
|
input_messages, format, chunk_iterator = mlx.stream(
|
|
351
|
-
messages,
|
|
323
|
+
messages, toolkit, format, params
|
|
352
324
|
)
|
|
353
325
|
|
|
354
326
|
return ContextStreamResponse(
|
|
@@ -356,7 +328,7 @@ class MLXProvider(BaseProvider[None]):
|
|
|
356
328
|
model_id=model_id,
|
|
357
329
|
provider_model_name=model_id,
|
|
358
330
|
params=params,
|
|
359
|
-
tools=
|
|
331
|
+
tools=toolkit,
|
|
360
332
|
input_messages=input_messages,
|
|
361
333
|
chunk_iterator=chunk_iterator,
|
|
362
334
|
format=format,
|
|
@@ -367,11 +339,8 @@ class MLXProvider(BaseProvider[None]):
|
|
|
367
339
|
*,
|
|
368
340
|
model_id: MLXModelId,
|
|
369
341
|
messages: Sequence[Message],
|
|
370
|
-
|
|
371
|
-
format:
|
|
372
|
-
| Format[FormattableT]
|
|
373
|
-
| OutputParser[FormattableT]
|
|
374
|
-
| None = None,
|
|
342
|
+
toolkit: AsyncToolkit,
|
|
343
|
+
format: FormatSpec[FormattableT] | None = None,
|
|
375
344
|
**params: Unpack[Params],
|
|
376
345
|
) -> AsyncStreamResponse | AsyncStreamResponse[FormattableT]:
|
|
377
346
|
"""Generate an `llm.AsyncStreamResponse` by asynchronously streaming from MLX model output.
|
|
@@ -389,7 +358,7 @@ class MLXProvider(BaseProvider[None]):
|
|
|
389
358
|
mlx = _get_mlx(model_id)
|
|
390
359
|
|
|
391
360
|
input_messages, format, chunk_iterator = await mlx.stream_async(
|
|
392
|
-
messages,
|
|
361
|
+
messages, toolkit, format, params
|
|
393
362
|
)
|
|
394
363
|
|
|
395
364
|
return AsyncStreamResponse(
|
|
@@ -397,7 +366,7 @@ class MLXProvider(BaseProvider[None]):
|
|
|
397
366
|
model_id=model_id,
|
|
398
367
|
provider_model_name=model_id,
|
|
399
368
|
params=params,
|
|
400
|
-
tools=
|
|
369
|
+
tools=toolkit,
|
|
401
370
|
input_messages=input_messages,
|
|
402
371
|
chunk_iterator=chunk_iterator,
|
|
403
372
|
format=format,
|
|
@@ -409,13 +378,8 @@ class MLXProvider(BaseProvider[None]):
|
|
|
409
378
|
ctx: Context[DepsT],
|
|
410
379
|
model_id: MLXModelId,
|
|
411
380
|
messages: Sequence[Message],
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
| None = None,
|
|
415
|
-
format: type[FormattableT]
|
|
416
|
-
| Format[FormattableT]
|
|
417
|
-
| OutputParser[FormattableT]
|
|
418
|
-
| None = None,
|
|
381
|
+
toolkit: AsyncContextToolkit[DepsT],
|
|
382
|
+
format: FormatSpec[FormattableT] | None = None,
|
|
419
383
|
**params: Unpack[Params],
|
|
420
384
|
) -> (
|
|
421
385
|
AsyncContextStreamResponse[DepsT]
|
|
@@ -437,7 +401,7 @@ class MLXProvider(BaseProvider[None]):
|
|
|
437
401
|
mlx = _get_mlx(model_id)
|
|
438
402
|
|
|
439
403
|
input_messages, format, chunk_iterator = await mlx.stream_async(
|
|
440
|
-
messages,
|
|
404
|
+
messages, toolkit, format, params
|
|
441
405
|
)
|
|
442
406
|
|
|
443
407
|
return AsyncContextStreamResponse(
|
|
@@ -445,7 +409,7 @@ class MLXProvider(BaseProvider[None]):
|
|
|
445
409
|
model_id=model_id,
|
|
446
410
|
provider_model_name=model_id,
|
|
447
411
|
params=params,
|
|
448
|
-
tools=
|
|
412
|
+
tools=toolkit,
|
|
449
413
|
input_messages=input_messages,
|
|
450
414
|
chunk_iterator=chunk_iterator,
|
|
451
415
|
format=format,
|