mirascope 2.0.0a6__py3-none-any.whl → 2.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mirascope/_utils.py +34 -0
- mirascope/api/_generated/__init__.py +186 -5
- mirascope/api/_generated/annotations/client.py +38 -6
- mirascope/api/_generated/annotations/raw_client.py +366 -47
- mirascope/api/_generated/annotations/types/annotations_create_response.py +19 -6
- mirascope/api/_generated/annotations/types/annotations_get_response.py +19 -6
- mirascope/api/_generated/annotations/types/annotations_list_response_annotations_item.py +22 -7
- mirascope/api/_generated/annotations/types/annotations_update_response.py +19 -6
- mirascope/api/_generated/api_keys/__init__.py +12 -2
- mirascope/api/_generated/api_keys/client.py +107 -6
- mirascope/api/_generated/api_keys/raw_client.py +486 -38
- mirascope/api/_generated/api_keys/types/__init__.py +7 -1
- mirascope/api/_generated/api_keys/types/api_keys_list_all_for_org_response_item.py +40 -0
- mirascope/api/_generated/client.py +36 -0
- mirascope/api/_generated/docs/raw_client.py +71 -9
- mirascope/api/_generated/environment.py +3 -3
- mirascope/api/_generated/environments/__init__.py +6 -0
- mirascope/api/_generated/environments/client.py +158 -9
- mirascope/api/_generated/environments/raw_client.py +620 -52
- mirascope/api/_generated/environments/types/__init__.py +10 -0
- mirascope/api/_generated/environments/types/environments_get_analytics_response.py +60 -0
- mirascope/api/_generated/environments/types/environments_get_analytics_response_top_functions_item.py +24 -0
- mirascope/api/_generated/{organizations/types/organizations_credits_response.py → environments/types/environments_get_analytics_response_top_models_item.py} +6 -3
- mirascope/api/_generated/errors/__init__.py +6 -0
- mirascope/api/_generated/errors/bad_request_error.py +5 -2
- mirascope/api/_generated/errors/conflict_error.py +5 -2
- mirascope/api/_generated/errors/payment_required_error.py +15 -0
- mirascope/api/_generated/errors/service_unavailable_error.py +14 -0
- mirascope/api/_generated/errors/too_many_requests_error.py +15 -0
- mirascope/api/_generated/functions/__init__.py +10 -0
- mirascope/api/_generated/functions/client.py +222 -8
- mirascope/api/_generated/functions/raw_client.py +975 -134
- mirascope/api/_generated/functions/types/__init__.py +28 -4
- mirascope/api/_generated/functions/types/functions_get_by_env_response.py +53 -0
- mirascope/api/_generated/functions/types/functions_get_by_env_response_dependencies_value.py +22 -0
- mirascope/api/_generated/functions/types/functions_list_by_env_response.py +25 -0
- mirascope/api/_generated/functions/types/functions_list_by_env_response_functions_item.py +56 -0
- mirascope/api/_generated/functions/types/functions_list_by_env_response_functions_item_dependencies_value.py +22 -0
- mirascope/api/_generated/health/raw_client.py +74 -10
- mirascope/api/_generated/organization_invitations/__init__.py +33 -0
- mirascope/api/_generated/organization_invitations/client.py +546 -0
- mirascope/api/_generated/organization_invitations/raw_client.py +1519 -0
- mirascope/api/_generated/organization_invitations/types/__init__.py +53 -0
- mirascope/api/_generated/organization_invitations/types/organization_invitations_accept_response.py +34 -0
- mirascope/api/_generated/organization_invitations/types/organization_invitations_accept_response_role.py +7 -0
- mirascope/api/_generated/organization_invitations/types/organization_invitations_create_request_role.py +7 -0
- mirascope/api/_generated/organization_invitations/types/organization_invitations_create_response.py +48 -0
- mirascope/api/_generated/organization_invitations/types/organization_invitations_create_response_role.py +7 -0
- mirascope/api/_generated/organization_invitations/types/organization_invitations_create_response_status.py +7 -0
- mirascope/api/_generated/organization_invitations/types/organization_invitations_get_response.py +48 -0
- mirascope/api/_generated/organization_invitations/types/organization_invitations_get_response_role.py +7 -0
- mirascope/api/_generated/organization_invitations/types/organization_invitations_get_response_status.py +7 -0
- mirascope/api/_generated/organization_invitations/types/organization_invitations_list_response_item.py +48 -0
- mirascope/api/_generated/organization_invitations/types/organization_invitations_list_response_item_role.py +7 -0
- mirascope/api/_generated/organization_invitations/types/organization_invitations_list_response_item_status.py +7 -0
- mirascope/api/_generated/organization_memberships/__init__.py +19 -0
- mirascope/api/_generated/organization_memberships/client.py +302 -0
- mirascope/api/_generated/organization_memberships/raw_client.py +736 -0
- mirascope/api/_generated/organization_memberships/types/__init__.py +27 -0
- mirascope/api/_generated/organization_memberships/types/organization_memberships_list_response_item.py +33 -0
- mirascope/api/_generated/organization_memberships/types/organization_memberships_list_response_item_role.py +7 -0
- mirascope/api/_generated/organization_memberships/types/organization_memberships_update_request_role.py +7 -0
- mirascope/api/_generated/organization_memberships/types/organization_memberships_update_response.py +31 -0
- mirascope/api/_generated/organization_memberships/types/organization_memberships_update_response_role.py +7 -0
- mirascope/api/_generated/organizations/__init__.py +26 -2
- mirascope/api/_generated/organizations/client.py +442 -20
- mirascope/api/_generated/organizations/raw_client.py +1763 -164
- mirascope/api/_generated/organizations/types/__init__.py +48 -2
- mirascope/api/_generated/organizations/types/organizations_create_payment_intent_response.py +24 -0
- mirascope/api/_generated/organizations/types/organizations_preview_subscription_change_request_target_plan.py +7 -0
- mirascope/api/_generated/organizations/types/organizations_preview_subscription_change_response.py +47 -0
- mirascope/api/_generated/organizations/types/organizations_preview_subscription_change_response_validation_errors_item.py +33 -0
- mirascope/api/_generated/organizations/types/organizations_preview_subscription_change_response_validation_errors_item_resource.py +7 -0
- mirascope/api/_generated/organizations/types/organizations_router_balance_response.py +24 -0
- mirascope/api/_generated/organizations/types/organizations_subscription_response.py +53 -0
- mirascope/api/_generated/organizations/types/organizations_subscription_response_current_plan.py +7 -0
- mirascope/api/_generated/organizations/types/organizations_subscription_response_payment_method.py +26 -0
- mirascope/api/_generated/organizations/types/organizations_subscription_response_scheduled_change.py +34 -0
- mirascope/api/_generated/organizations/types/organizations_subscription_response_scheduled_change_target_plan.py +7 -0
- mirascope/api/_generated/organizations/types/organizations_update_subscription_request_target_plan.py +7 -0
- mirascope/api/_generated/organizations/types/organizations_update_subscription_response.py +35 -0
- mirascope/api/_generated/project_memberships/__init__.py +25 -0
- mirascope/api/_generated/project_memberships/client.py +437 -0
- mirascope/api/_generated/project_memberships/raw_client.py +1039 -0
- mirascope/api/_generated/project_memberships/types/__init__.py +29 -0
- mirascope/api/_generated/project_memberships/types/project_memberships_create_request_role.py +7 -0
- mirascope/api/_generated/project_memberships/types/project_memberships_create_response.py +35 -0
- mirascope/api/_generated/project_memberships/types/project_memberships_create_response_role.py +7 -0
- mirascope/api/_generated/project_memberships/types/project_memberships_list_response_item.py +33 -0
- mirascope/api/_generated/project_memberships/types/project_memberships_list_response_item_role.py +7 -0
- mirascope/api/_generated/project_memberships/types/project_memberships_update_request_role.py +7 -0
- mirascope/api/_generated/project_memberships/types/project_memberships_update_response.py +35 -0
- mirascope/api/_generated/project_memberships/types/project_memberships_update_response_role.py +7 -0
- mirascope/api/_generated/projects/raw_client.py +415 -58
- mirascope/api/_generated/reference.md +2767 -397
- mirascope/api/_generated/tags/__init__.py +19 -0
- mirascope/api/_generated/tags/client.py +504 -0
- mirascope/api/_generated/tags/raw_client.py +1288 -0
- mirascope/api/_generated/tags/types/__init__.py +17 -0
- mirascope/api/_generated/tags/types/tags_create_response.py +41 -0
- mirascope/api/_generated/tags/types/tags_get_response.py +41 -0
- mirascope/api/_generated/tags/types/tags_list_response.py +23 -0
- mirascope/api/_generated/tags/types/tags_list_response_tags_item.py +41 -0
- mirascope/api/_generated/tags/types/tags_update_response.py +41 -0
- mirascope/api/_generated/token_cost/__init__.py +7 -0
- mirascope/api/_generated/token_cost/client.py +160 -0
- mirascope/api/_generated/token_cost/raw_client.py +264 -0
- mirascope/api/_generated/token_cost/types/__init__.py +8 -0
- mirascope/api/_generated/token_cost/types/token_cost_calculate_request_usage.py +54 -0
- mirascope/api/_generated/token_cost/types/token_cost_calculate_response.py +52 -0
- mirascope/api/_generated/traces/__init__.py +20 -0
- mirascope/api/_generated/traces/client.py +543 -0
- mirascope/api/_generated/traces/raw_client.py +1366 -96
- mirascope/api/_generated/traces/types/__init__.py +28 -0
- mirascope/api/_generated/traces/types/traces_get_analytics_summary_response.py +6 -0
- mirascope/api/_generated/traces/types/traces_get_trace_detail_by_env_response.py +33 -0
- mirascope/api/_generated/traces/types/traces_get_trace_detail_by_env_response_spans_item.py +88 -0
- mirascope/api/_generated/traces/types/traces_get_trace_detail_response_spans_item.py +0 -2
- mirascope/api/_generated/traces/types/traces_list_by_function_hash_response.py +25 -0
- mirascope/api/_generated/traces/types/traces_list_by_function_hash_response_traces_item.py +44 -0
- mirascope/api/_generated/traces/types/traces_search_by_env_request_attribute_filters_item.py +26 -0
- mirascope/api/_generated/traces/types/traces_search_by_env_request_attribute_filters_item_operator.py +7 -0
- mirascope/api/_generated/traces/types/traces_search_by_env_request_sort_by.py +7 -0
- mirascope/api/_generated/traces/types/traces_search_by_env_request_sort_order.py +7 -0
- mirascope/api/_generated/traces/types/traces_search_by_env_response.py +26 -0
- mirascope/api/_generated/traces/types/traces_search_by_env_response_spans_item.py +50 -0
- mirascope/api/_generated/traces/types/traces_search_response_spans_item.py +10 -1
- mirascope/api/_generated/types/__init__.py +32 -2
- mirascope/api/_generated/types/bad_request_error_body.py +50 -0
- mirascope/api/_generated/types/date.py +3 -0
- mirascope/api/_generated/types/immutable_resource_error.py +22 -0
- mirascope/api/_generated/types/internal_server_error_body.py +3 -3
- mirascope/api/_generated/types/plan_limit_exceeded_error.py +32 -0
- mirascope/api/_generated/types/plan_limit_exceeded_error_tag.py +7 -0
- mirascope/api/_generated/types/pricing_unavailable_error.py +23 -0
- mirascope/api/_generated/types/rate_limit_error.py +31 -0
- mirascope/api/_generated/types/rate_limit_error_tag.py +5 -0
- mirascope/api/_generated/types/service_unavailable_error_body.py +24 -0
- mirascope/api/_generated/types/service_unavailable_error_tag.py +7 -0
- mirascope/api/_generated/types/subscription_past_due_error.py +31 -0
- mirascope/api/_generated/types/subscription_past_due_error_tag.py +7 -0
- mirascope/api/settings.py +19 -1
- mirascope/llm/__init__.py +53 -10
- mirascope/llm/calls/__init__.py +2 -1
- mirascope/llm/calls/calls.py +29 -20
- mirascope/llm/calls/decorator.py +21 -7
- mirascope/llm/content/tool_output.py +22 -5
- mirascope/llm/exceptions.py +284 -71
- mirascope/llm/formatting/__init__.py +17 -0
- mirascope/llm/formatting/format.py +112 -35
- mirascope/llm/formatting/output_parser.py +178 -0
- mirascope/llm/formatting/partial.py +80 -7
- mirascope/llm/formatting/primitives.py +192 -0
- mirascope/llm/formatting/types.py +20 -8
- mirascope/llm/messages/__init__.py +3 -0
- mirascope/llm/messages/_utils.py +34 -0
- mirascope/llm/models/__init__.py +5 -0
- mirascope/llm/models/models.py +137 -69
- mirascope/llm/{providers/base → models}/params.py +7 -57
- mirascope/llm/models/thinking_config.py +61 -0
- mirascope/llm/prompts/_utils.py +0 -32
- mirascope/llm/prompts/decorator.py +16 -5
- mirascope/llm/prompts/prompts.py +160 -92
- mirascope/llm/providers/__init__.py +1 -4
- mirascope/llm/providers/anthropic/_utils/__init__.py +2 -0
- mirascope/llm/providers/anthropic/_utils/beta_decode.py +18 -9
- mirascope/llm/providers/anthropic/_utils/beta_encode.py +62 -13
- mirascope/llm/providers/anthropic/_utils/decode.py +18 -9
- mirascope/llm/providers/anthropic/_utils/encode.py +26 -7
- mirascope/llm/providers/anthropic/_utils/errors.py +2 -2
- mirascope/llm/providers/anthropic/beta_provider.py +64 -18
- mirascope/llm/providers/anthropic/provider.py +91 -33
- mirascope/llm/providers/base/__init__.py +0 -4
- mirascope/llm/providers/base/_utils.py +55 -6
- mirascope/llm/providers/base/base_provider.py +116 -37
- mirascope/llm/providers/google/_utils/__init__.py +2 -0
- mirascope/llm/providers/google/_utils/decode.py +20 -7
- mirascope/llm/providers/google/_utils/encode.py +26 -7
- mirascope/llm/providers/google/_utils/errors.py +3 -2
- mirascope/llm/providers/google/provider.py +64 -18
- mirascope/llm/providers/mirascope/_utils.py +13 -17
- mirascope/llm/providers/mirascope/provider.py +49 -18
- mirascope/llm/providers/mlx/_utils.py +7 -2
- mirascope/llm/providers/mlx/encoding/base.py +5 -2
- mirascope/llm/providers/mlx/encoding/transformers.py +5 -2
- mirascope/llm/providers/mlx/mlx.py +23 -6
- mirascope/llm/providers/mlx/provider.py +42 -13
- mirascope/llm/providers/openai/_utils/errors.py +2 -2
- mirascope/llm/providers/openai/completions/_utils/encode.py +20 -16
- mirascope/llm/providers/openai/completions/base_provider.py +40 -11
- mirascope/llm/providers/openai/provider.py +40 -10
- mirascope/llm/providers/openai/responses/_utils/__init__.py +2 -0
- mirascope/llm/providers/openai/responses/_utils/decode.py +19 -6
- mirascope/llm/providers/openai/responses/_utils/encode.py +22 -10
- mirascope/llm/providers/openai/responses/provider.py +56 -18
- mirascope/llm/providers/provider_registry.py +93 -19
- mirascope/llm/responses/__init__.py +6 -1
- mirascope/llm/responses/_utils.py +102 -12
- mirascope/llm/responses/base_response.py +5 -2
- mirascope/llm/responses/base_stream_response.py +115 -25
- mirascope/llm/responses/response.py +2 -1
- mirascope/llm/responses/root_response.py +89 -17
- mirascope/llm/responses/stream_response.py +6 -9
- mirascope/llm/tools/decorator.py +9 -4
- mirascope/llm/tools/tool_schema.py +17 -6
- mirascope/llm/tools/toolkit.py +35 -27
- mirascope/llm/tools/tools.py +45 -20
- mirascope/ops/__init__.py +4 -0
- mirascope/ops/_internal/closure.py +4 -1
- mirascope/ops/_internal/configuration.py +82 -31
- mirascope/ops/_internal/exporters/exporters.py +55 -35
- mirascope/ops/_internal/exporters/utils.py +37 -0
- mirascope/ops/_internal/instrumentation/llm/common.py +530 -0
- mirascope/ops/_internal/instrumentation/llm/cost.py +190 -0
- mirascope/ops/_internal/instrumentation/llm/encode.py +1 -1
- mirascope/ops/_internal/instrumentation/llm/llm.py +116 -1242
- mirascope/ops/_internal/instrumentation/llm/model.py +1798 -0
- mirascope/ops/_internal/instrumentation/llm/response.py +521 -0
- mirascope/ops/_internal/instrumentation/llm/serialize.py +300 -0
- mirascope/ops/_internal/protocols.py +83 -1
- mirascope/ops/_internal/traced_calls.py +18 -0
- mirascope/ops/_internal/traced_functions.py +125 -10
- mirascope/ops/_internal/tracing.py +78 -1
- mirascope/ops/_internal/utils.py +60 -4
- mirascope/ops/_internal/versioned_functions.py +1 -1
- {mirascope-2.0.0a6.dist-info → mirascope-2.0.2.dist-info}/METADATA +12 -11
- mirascope-2.0.2.dist-info/RECORD +424 -0
- {mirascope-2.0.0a6.dist-info → mirascope-2.0.2.dist-info}/licenses/LICENSE +1 -1
- mirascope-2.0.0a6.dist-info/RECORD +0 -316
- {mirascope-2.0.0a6.dist-info → mirascope-2.0.2.dist-info}/WHEEL +0 -0
|
@@ -1,13 +1,16 @@
|
|
|
1
1
|
"""Google provider implementation."""
|
|
2
2
|
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
3
5
|
from collections.abc import Sequence
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
4
7
|
from typing_extensions import Unpack
|
|
5
8
|
|
|
6
9
|
from google.genai import Client
|
|
7
10
|
from google.genai.types import HttpOptions
|
|
8
11
|
|
|
9
12
|
from ...context import Context, DepsT
|
|
10
|
-
from ...formatting import Format, FormattableT
|
|
13
|
+
from ...formatting import Format, FormattableT, OutputParser
|
|
11
14
|
from ...messages import Message
|
|
12
15
|
from ...responses import (
|
|
13
16
|
AsyncContextResponse,
|
|
@@ -29,10 +32,13 @@ from ...tools import (
|
|
|
29
32
|
Tool,
|
|
30
33
|
Toolkit,
|
|
31
34
|
)
|
|
32
|
-
from ..base import BaseProvider
|
|
35
|
+
from ..base import BaseProvider
|
|
33
36
|
from . import _utils
|
|
34
37
|
from .model_id import GoogleModelId, model_name
|
|
35
38
|
|
|
39
|
+
if TYPE_CHECKING:
|
|
40
|
+
from ...models import Params
|
|
41
|
+
|
|
36
42
|
|
|
37
43
|
class GoogleProvider(BaseProvider[Client]):
|
|
38
44
|
"""The client for the Google LLM model."""
|
|
@@ -64,7 +70,10 @@ class GoogleProvider(BaseProvider[Client]):
|
|
|
64
70
|
model_id: GoogleModelId,
|
|
65
71
|
messages: Sequence[Message],
|
|
66
72
|
tools: Sequence[Tool] | Toolkit | None = None,
|
|
67
|
-
format: type[FormattableT]
|
|
73
|
+
format: type[FormattableT]
|
|
74
|
+
| Format[FormattableT]
|
|
75
|
+
| OutputParser[FormattableT]
|
|
76
|
+
| None = None,
|
|
68
77
|
**params: Unpack[Params],
|
|
69
78
|
) -> Response | Response[FormattableT]:
|
|
70
79
|
"""Generate an `llm.Response` by synchronously calling the Google GenAI API.
|
|
@@ -88,8 +97,9 @@ class GoogleProvider(BaseProvider[Client]):
|
|
|
88
97
|
)
|
|
89
98
|
google_response = self.client.models.generate_content(**kwargs)
|
|
90
99
|
|
|
100
|
+
include_thoughts = _utils.get_include_thoughts(params)
|
|
91
101
|
assistant_message, finish_reason, usage = _utils.decode_response(
|
|
92
|
-
google_response, model_id
|
|
102
|
+
google_response, model_id, include_thoughts=include_thoughts
|
|
93
103
|
)
|
|
94
104
|
|
|
95
105
|
return Response(
|
|
@@ -115,7 +125,10 @@ class GoogleProvider(BaseProvider[Client]):
|
|
|
115
125
|
tools: Sequence[Tool | ContextTool[DepsT]]
|
|
116
126
|
| ContextToolkit[DepsT]
|
|
117
127
|
| None = None,
|
|
118
|
-
format: type[FormattableT]
|
|
128
|
+
format: type[FormattableT]
|
|
129
|
+
| Format[FormattableT]
|
|
130
|
+
| OutputParser[FormattableT]
|
|
131
|
+
| None = None,
|
|
119
132
|
**params: Unpack[Params],
|
|
120
133
|
) -> ContextResponse[DepsT, None] | ContextResponse[DepsT, FormattableT]:
|
|
121
134
|
"""Generate an `llm.ContextResponse` by synchronously calling the Google GenAI API.
|
|
@@ -140,8 +153,9 @@ class GoogleProvider(BaseProvider[Client]):
|
|
|
140
153
|
)
|
|
141
154
|
google_response = self.client.models.generate_content(**kwargs)
|
|
142
155
|
|
|
156
|
+
include_thoughts = _utils.get_include_thoughts(params)
|
|
143
157
|
assistant_message, finish_reason, usage = _utils.decode_response(
|
|
144
|
-
google_response, model_id
|
|
158
|
+
google_response, model_id, include_thoughts=include_thoughts
|
|
145
159
|
)
|
|
146
160
|
|
|
147
161
|
return ContextResponse(
|
|
@@ -164,7 +178,10 @@ class GoogleProvider(BaseProvider[Client]):
|
|
|
164
178
|
model_id: GoogleModelId,
|
|
165
179
|
messages: Sequence[Message],
|
|
166
180
|
tools: Sequence[AsyncTool] | AsyncToolkit | None = None,
|
|
167
|
-
format: type[FormattableT]
|
|
181
|
+
format: type[FormattableT]
|
|
182
|
+
| Format[FormattableT]
|
|
183
|
+
| OutputParser[FormattableT]
|
|
184
|
+
| None = None,
|
|
168
185
|
**params: Unpack[Params],
|
|
169
186
|
) -> AsyncResponse | AsyncResponse[FormattableT]:
|
|
170
187
|
"""Generate an `llm.AsyncResponse` by asynchronously calling the Google GenAI API.
|
|
@@ -188,8 +205,9 @@ class GoogleProvider(BaseProvider[Client]):
|
|
|
188
205
|
)
|
|
189
206
|
google_response = await self.client.aio.models.generate_content(**kwargs)
|
|
190
207
|
|
|
208
|
+
include_thoughts = _utils.get_include_thoughts(params)
|
|
191
209
|
assistant_message, finish_reason, usage = _utils.decode_response(
|
|
192
|
-
google_response, model_id
|
|
210
|
+
google_response, model_id, include_thoughts=include_thoughts
|
|
193
211
|
)
|
|
194
212
|
|
|
195
213
|
return AsyncResponse(
|
|
@@ -215,7 +233,10 @@ class GoogleProvider(BaseProvider[Client]):
|
|
|
215
233
|
tools: Sequence[AsyncTool | AsyncContextTool[DepsT]]
|
|
216
234
|
| AsyncContextToolkit[DepsT]
|
|
217
235
|
| None = None,
|
|
218
|
-
format: type[FormattableT]
|
|
236
|
+
format: type[FormattableT]
|
|
237
|
+
| Format[FormattableT]
|
|
238
|
+
| OutputParser[FormattableT]
|
|
239
|
+
| None = None,
|
|
219
240
|
**params: Unpack[Params],
|
|
220
241
|
) -> AsyncContextResponse[DepsT, None] | AsyncContextResponse[DepsT, FormattableT]:
|
|
221
242
|
"""Generate an `llm.AsyncContextResponse` by asynchronously calling the Google GenAI API.
|
|
@@ -240,8 +261,9 @@ class GoogleProvider(BaseProvider[Client]):
|
|
|
240
261
|
)
|
|
241
262
|
google_response = await self.client.aio.models.generate_content(**kwargs)
|
|
242
263
|
|
|
264
|
+
include_thoughts = _utils.get_include_thoughts(params)
|
|
243
265
|
assistant_message, finish_reason, usage = _utils.decode_response(
|
|
244
|
-
google_response, model_id
|
|
266
|
+
google_response, model_id, include_thoughts=include_thoughts
|
|
245
267
|
)
|
|
246
268
|
|
|
247
269
|
return AsyncContextResponse(
|
|
@@ -264,7 +286,10 @@ class GoogleProvider(BaseProvider[Client]):
|
|
|
264
286
|
model_id: GoogleModelId,
|
|
265
287
|
messages: Sequence[Message],
|
|
266
288
|
tools: Sequence[Tool] | Toolkit | None = None,
|
|
267
|
-
format: type[FormattableT]
|
|
289
|
+
format: type[FormattableT]
|
|
290
|
+
| Format[FormattableT]
|
|
291
|
+
| OutputParser[FormattableT]
|
|
292
|
+
| None = None,
|
|
268
293
|
**params: Unpack[Params],
|
|
269
294
|
) -> StreamResponse | StreamResponse[FormattableT]:
|
|
270
295
|
"""Generate an `llm.StreamResponse` by synchronously streaming from the Google GenAI API.
|
|
@@ -289,7 +314,10 @@ class GoogleProvider(BaseProvider[Client]):
|
|
|
289
314
|
|
|
290
315
|
google_stream = self.client.models.generate_content_stream(**kwargs)
|
|
291
316
|
|
|
292
|
-
|
|
317
|
+
include_thoughts = _utils.get_include_thoughts(params)
|
|
318
|
+
chunk_iterator = _utils.decode_stream(
|
|
319
|
+
google_stream, include_thoughts=include_thoughts
|
|
320
|
+
)
|
|
293
321
|
|
|
294
322
|
return StreamResponse(
|
|
295
323
|
provider_id="google",
|
|
@@ -311,7 +339,10 @@ class GoogleProvider(BaseProvider[Client]):
|
|
|
311
339
|
tools: Sequence[Tool | ContextTool[DepsT]]
|
|
312
340
|
| ContextToolkit[DepsT]
|
|
313
341
|
| None = None,
|
|
314
|
-
format: type[FormattableT]
|
|
342
|
+
format: type[FormattableT]
|
|
343
|
+
| Format[FormattableT]
|
|
344
|
+
| OutputParser[FormattableT]
|
|
345
|
+
| None = None,
|
|
315
346
|
**params: Unpack[Params],
|
|
316
347
|
) -> ContextStreamResponse[DepsT] | ContextStreamResponse[DepsT, FormattableT]:
|
|
317
348
|
"""Generate an `llm.ContextStreamResponse` by synchronously streaming from the Google GenAI API.
|
|
@@ -337,7 +368,10 @@ class GoogleProvider(BaseProvider[Client]):
|
|
|
337
368
|
|
|
338
369
|
google_stream = self.client.models.generate_content_stream(**kwargs)
|
|
339
370
|
|
|
340
|
-
|
|
371
|
+
include_thoughts = _utils.get_include_thoughts(params)
|
|
372
|
+
chunk_iterator = _utils.decode_stream(
|
|
373
|
+
google_stream, include_thoughts=include_thoughts
|
|
374
|
+
)
|
|
341
375
|
|
|
342
376
|
return ContextStreamResponse(
|
|
343
377
|
provider_id="google",
|
|
@@ -356,7 +390,10 @@ class GoogleProvider(BaseProvider[Client]):
|
|
|
356
390
|
model_id: GoogleModelId,
|
|
357
391
|
messages: Sequence[Message],
|
|
358
392
|
tools: Sequence[AsyncTool] | AsyncToolkit | None = None,
|
|
359
|
-
format: type[FormattableT]
|
|
393
|
+
format: type[FormattableT]
|
|
394
|
+
| Format[FormattableT]
|
|
395
|
+
| OutputParser[FormattableT]
|
|
396
|
+
| None = None,
|
|
360
397
|
**params: Unpack[Params],
|
|
361
398
|
) -> AsyncStreamResponse | AsyncStreamResponse[FormattableT]:
|
|
362
399
|
"""Generate an `llm.AsyncStreamResponse` by asynchronously streaming from the Google GenAI API.
|
|
@@ -381,7 +418,10 @@ class GoogleProvider(BaseProvider[Client]):
|
|
|
381
418
|
|
|
382
419
|
google_stream = await self.client.aio.models.generate_content_stream(**kwargs)
|
|
383
420
|
|
|
384
|
-
|
|
421
|
+
include_thoughts = _utils.get_include_thoughts(params)
|
|
422
|
+
chunk_iterator = _utils.decode_async_stream(
|
|
423
|
+
google_stream, include_thoughts=include_thoughts
|
|
424
|
+
)
|
|
385
425
|
|
|
386
426
|
return AsyncStreamResponse(
|
|
387
427
|
provider_id="google",
|
|
@@ -403,7 +443,10 @@ class GoogleProvider(BaseProvider[Client]):
|
|
|
403
443
|
tools: Sequence[AsyncTool | AsyncContextTool[DepsT]]
|
|
404
444
|
| AsyncContextToolkit[DepsT]
|
|
405
445
|
| None = None,
|
|
406
|
-
format: type[FormattableT]
|
|
446
|
+
format: type[FormattableT]
|
|
447
|
+
| Format[FormattableT]
|
|
448
|
+
| OutputParser[FormattableT]
|
|
449
|
+
| None = None,
|
|
407
450
|
**params: Unpack[Params],
|
|
408
451
|
) -> (
|
|
409
452
|
AsyncContextStreamResponse[DepsT]
|
|
@@ -432,7 +475,10 @@ class GoogleProvider(BaseProvider[Client]):
|
|
|
432
475
|
|
|
433
476
|
google_stream = await self.client.aio.models.generate_content_stream(**kwargs)
|
|
434
477
|
|
|
435
|
-
|
|
478
|
+
include_thoughts = _utils.get_include_thoughts(params)
|
|
479
|
+
chunk_iterator = _utils.decode_async_stream(
|
|
480
|
+
google_stream, include_thoughts=include_thoughts
|
|
481
|
+
)
|
|
436
482
|
|
|
437
483
|
return AsyncContextStreamResponse(
|
|
438
484
|
provider_id="google",
|
|
@@ -7,15 +7,15 @@ from ..base import Provider
|
|
|
7
7
|
from ..provider_id import ProviderId
|
|
8
8
|
|
|
9
9
|
|
|
10
|
-
def
|
|
11
|
-
"""Extract
|
|
10
|
+
def extract_model_scope(model_id: str) -> str | None:
|
|
11
|
+
"""Extract model scope from model ID.
|
|
12
12
|
|
|
13
13
|
Args:
|
|
14
|
-
model_id: Model identifier in the format "
|
|
14
|
+
model_id: Model identifier in the format "scope/model-name"
|
|
15
15
|
e.g., "openai/gpt-4", "anthropic/claude-3", "google/gemini-pro"
|
|
16
16
|
|
|
17
17
|
Returns:
|
|
18
|
-
The
|
|
18
|
+
The model scope (e.g., "openai", "anthropic", "google") or None if invalid format.
|
|
19
19
|
"""
|
|
20
20
|
if "/" not in model_id:
|
|
21
21
|
return None
|
|
@@ -29,12 +29,12 @@ def get_default_router_base_url() -> str:
|
|
|
29
29
|
The router base URL (without trailing provider path).
|
|
30
30
|
"""
|
|
31
31
|
return os.environ.get(
|
|
32
|
-
"MIRASCOPE_ROUTER_BASE_URL", "https://mirascope.com/router/
|
|
32
|
+
"MIRASCOPE_ROUTER_BASE_URL", "https://mirascope.com/router/v2"
|
|
33
33
|
)
|
|
34
34
|
|
|
35
35
|
|
|
36
36
|
def create_underlying_provider(
|
|
37
|
-
|
|
37
|
+
model_scope: str, api_key: str, router_base_url: str
|
|
38
38
|
) -> Provider:
|
|
39
39
|
"""Create and cache an underlying provider instance using provider_singleton.
|
|
40
40
|
|
|
@@ -42,10 +42,9 @@ def create_underlying_provider(
|
|
|
42
42
|
delegates to provider_singleton for caching and instantiation.
|
|
43
43
|
|
|
44
44
|
Args:
|
|
45
|
-
|
|
46
|
-
"openai:completions", "openai:responses")
|
|
45
|
+
model_scope: The model scope (e.g., "openai", "anthropic", "google")
|
|
47
46
|
api_key: The API key to use for authentication
|
|
48
|
-
router_base_url: The base router URL (e.g., "http://mirascope.com/router/
|
|
47
|
+
router_base_url: The base router URL (e.g., "http://mirascope.com/router/v2")
|
|
49
48
|
|
|
50
49
|
Returns:
|
|
51
50
|
A cached provider instance configured for the Mirascope Router.
|
|
@@ -53,17 +52,14 @@ def create_underlying_provider(
|
|
|
53
52
|
Raises:
|
|
54
53
|
ValueError: If the provider is unsupported.
|
|
55
54
|
"""
|
|
56
|
-
|
|
57
|
-
base_provider = provider_prefix.split(":")[0]
|
|
58
|
-
|
|
59
|
-
if base_provider not in ["anthropic", "google", "openai"]:
|
|
55
|
+
if model_scope not in ["anthropic", "google", "openai"]:
|
|
60
56
|
raise ValueError(
|
|
61
|
-
f"Unsupported provider: {
|
|
57
|
+
f"Unsupported provider: {model_scope}. "
|
|
62
58
|
f"Mirascope Router currently supports: anthropic, google, openai"
|
|
63
59
|
)
|
|
64
60
|
|
|
65
|
-
base_url = f"{router_base_url}/{
|
|
66
|
-
if
|
|
61
|
+
base_url = f"{router_base_url}/{model_scope}"
|
|
62
|
+
if model_scope == "openai": # OpenAI expects /v1, which their SDK doesn't add
|
|
67
63
|
base_url = f"{base_url}/v1"
|
|
68
64
|
|
|
69
65
|
# Lazy import to avoid circular dependencies
|
|
@@ -71,7 +67,7 @@ def create_underlying_provider(
|
|
|
71
67
|
|
|
72
68
|
# Use provider_singleton which provides caching
|
|
73
69
|
return provider_singleton(
|
|
74
|
-
cast(ProviderId,
|
|
70
|
+
cast(ProviderId, model_scope),
|
|
75
71
|
api_key=api_key,
|
|
76
72
|
base_url=base_url,
|
|
77
73
|
)
|
|
@@ -1,11 +1,14 @@
|
|
|
1
1
|
"""Mirascope Router provider that routes requests through the Mirascope Router API."""
|
|
2
2
|
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
3
5
|
import os
|
|
4
6
|
from collections.abc import Sequence
|
|
7
|
+
from typing import TYPE_CHECKING
|
|
5
8
|
from typing_extensions import Unpack
|
|
6
9
|
|
|
7
10
|
from ...context import Context, DepsT
|
|
8
|
-
from ...formatting import Format, FormattableT
|
|
11
|
+
from ...formatting import Format, FormattableT, OutputParser
|
|
9
12
|
from ...messages import Message
|
|
10
13
|
from ...responses import (
|
|
11
14
|
AsyncContextResponse,
|
|
@@ -27,9 +30,12 @@ from ...tools import (
|
|
|
27
30
|
Tool,
|
|
28
31
|
Toolkit,
|
|
29
32
|
)
|
|
30
|
-
from ..base import BaseProvider,
|
|
33
|
+
from ..base import BaseProvider, Provider
|
|
31
34
|
from . import _utils
|
|
32
35
|
|
|
36
|
+
if TYPE_CHECKING:
|
|
37
|
+
from ...models import Params
|
|
38
|
+
|
|
33
39
|
|
|
34
40
|
class MirascopeProvider(BaseProvider[None]):
|
|
35
41
|
"""Provider that routes LLM requests through the Mirascope Router API.
|
|
@@ -38,14 +44,15 @@ class MirascopeProvider(BaseProvider[None]):
|
|
|
38
44
|
(Anthropic, Google, OpenAI) with usage tracking and cost calculation.
|
|
39
45
|
|
|
40
46
|
This provider:
|
|
41
|
-
- Takes model IDs in the format "
|
|
47
|
+
- Takes model IDs in the format "scope/model-name" (e.g., "openai/gpt-4")
|
|
42
48
|
- Routes requests to the Mirascope Router endpoint
|
|
43
49
|
- Delegates to the appropriate underlying provider (Anthropic, Google, or OpenAI)
|
|
44
50
|
- Uses MIRASCOPE_API_KEY for authentication
|
|
45
51
|
|
|
46
52
|
Environment Variables:
|
|
47
53
|
MIRASCOPE_API_KEY: Required API key for Mirascope Router authentication
|
|
48
|
-
MIRASCOPE_ROUTER_BASE_URL: Optional base URL override
|
|
54
|
+
MIRASCOPE_ROUTER_BASE_URL: Optional base URL override
|
|
55
|
+
(default: https://mirascope.com/router/v2)
|
|
49
56
|
|
|
50
57
|
Example:
|
|
51
58
|
```python
|
|
@@ -90,7 +97,7 @@ class MirascopeProvider(BaseProvider[None]):
|
|
|
90
97
|
environment variable.
|
|
91
98
|
base_url: Optional base URL override for the Mirascope Router. If not
|
|
92
99
|
provided, reads from MIRASCOPE_ROUTER_BASE_URL environment variable
|
|
93
|
-
or defaults to https://mirascope.com/router/
|
|
100
|
+
or defaults to https://mirascope.com/router/v2
|
|
94
101
|
"""
|
|
95
102
|
api_key = api_key or os.environ.get("MIRASCOPE_API_KEY")
|
|
96
103
|
if not api_key:
|
|
@@ -122,7 +129,7 @@ class MirascopeProvider(BaseProvider[None]):
|
|
|
122
129
|
"""Get the underlying provider for a model ID.
|
|
123
130
|
|
|
124
131
|
Args:
|
|
125
|
-
model_id: Model identifier in format "
|
|
132
|
+
model_id: Model identifier in format "scope/model-name"
|
|
126
133
|
|
|
127
134
|
Returns:
|
|
128
135
|
The appropriate cached provider instance (Anthropic, Google, or OpenAI)
|
|
@@ -130,16 +137,16 @@ class MirascopeProvider(BaseProvider[None]):
|
|
|
130
137
|
Raises:
|
|
131
138
|
ValueError: If the model ID format is invalid or provider is unsupported
|
|
132
139
|
"""
|
|
133
|
-
|
|
134
|
-
if not
|
|
140
|
+
model_scope = _utils.extract_model_scope(model_id)
|
|
141
|
+
if not model_scope:
|
|
135
142
|
raise ValueError(
|
|
136
143
|
f"Invalid model ID format: {model_id}. "
|
|
137
|
-
f"Expected format '
|
|
144
|
+
f"Expected format 'scope/model-name' (e.g., 'openai/gpt-4')"
|
|
138
145
|
)
|
|
139
146
|
|
|
140
147
|
# Use the cached function to get/create the provider
|
|
141
148
|
return _utils.create_underlying_provider(
|
|
142
|
-
|
|
149
|
+
model_scope=model_scope,
|
|
143
150
|
api_key=self.api_key,
|
|
144
151
|
router_base_url=self.router_base_url,
|
|
145
152
|
)
|
|
@@ -150,7 +157,10 @@ class MirascopeProvider(BaseProvider[None]):
|
|
|
150
157
|
model_id: str,
|
|
151
158
|
messages: Sequence[Message],
|
|
152
159
|
tools: Sequence[Tool] | Toolkit | None = None,
|
|
153
|
-
format: type[FormattableT]
|
|
160
|
+
format: type[FormattableT]
|
|
161
|
+
| Format[FormattableT]
|
|
162
|
+
| OutputParser[FormattableT]
|
|
163
|
+
| None = None,
|
|
154
164
|
**params: Unpack[Params],
|
|
155
165
|
) -> Response | Response[FormattableT]:
|
|
156
166
|
"""Generate an `llm.Response` by calling through the Mirascope Router."""
|
|
@@ -172,7 +182,10 @@ class MirascopeProvider(BaseProvider[None]):
|
|
|
172
182
|
tools: Sequence[Tool | ContextTool[DepsT]]
|
|
173
183
|
| ContextToolkit[DepsT]
|
|
174
184
|
| None = None,
|
|
175
|
-
format: type[FormattableT]
|
|
185
|
+
format: type[FormattableT]
|
|
186
|
+
| Format[FormattableT]
|
|
187
|
+
| OutputParser[FormattableT]
|
|
188
|
+
| None = None,
|
|
176
189
|
**params: Unpack[Params],
|
|
177
190
|
) -> ContextResponse[DepsT, None] | ContextResponse[DepsT, FormattableT]:
|
|
178
191
|
"""Generate an `llm.ContextResponse` by calling through the Mirascope Router."""
|
|
@@ -192,7 +205,10 @@ class MirascopeProvider(BaseProvider[None]):
|
|
|
192
205
|
model_id: str,
|
|
193
206
|
messages: Sequence[Message],
|
|
194
207
|
tools: Sequence[AsyncTool] | AsyncToolkit | None = None,
|
|
195
|
-
format: type[FormattableT]
|
|
208
|
+
format: type[FormattableT]
|
|
209
|
+
| Format[FormattableT]
|
|
210
|
+
| OutputParser[FormattableT]
|
|
211
|
+
| None = None,
|
|
196
212
|
**params: Unpack[Params],
|
|
197
213
|
) -> AsyncResponse | AsyncResponse[FormattableT]:
|
|
198
214
|
"""Generate an `llm.AsyncResponse` by calling through the Mirascope Router."""
|
|
@@ -214,7 +230,10 @@ class MirascopeProvider(BaseProvider[None]):
|
|
|
214
230
|
tools: Sequence[AsyncTool | AsyncContextTool[DepsT]]
|
|
215
231
|
| AsyncContextToolkit[DepsT]
|
|
216
232
|
| None = None,
|
|
217
|
-
format: type[FormattableT]
|
|
233
|
+
format: type[FormattableT]
|
|
234
|
+
| Format[FormattableT]
|
|
235
|
+
| OutputParser[FormattableT]
|
|
236
|
+
| None = None,
|
|
218
237
|
**params: Unpack[Params],
|
|
219
238
|
) -> AsyncContextResponse[DepsT, None] | AsyncContextResponse[DepsT, FormattableT]:
|
|
220
239
|
"""Generate an `llm.AsyncContextResponse` by calling through the Mirascope Router."""
|
|
@@ -234,7 +253,10 @@ class MirascopeProvider(BaseProvider[None]):
|
|
|
234
253
|
model_id: str,
|
|
235
254
|
messages: Sequence[Message],
|
|
236
255
|
tools: Sequence[Tool] | Toolkit | None = None,
|
|
237
|
-
format: type[FormattableT]
|
|
256
|
+
format: type[FormattableT]
|
|
257
|
+
| Format[FormattableT]
|
|
258
|
+
| OutputParser[FormattableT]
|
|
259
|
+
| None = None,
|
|
238
260
|
**params: Unpack[Params],
|
|
239
261
|
) -> StreamResponse | StreamResponse[FormattableT]:
|
|
240
262
|
"""Stream an `llm.StreamResponse` by calling through the Mirascope Router."""
|
|
@@ -256,7 +278,10 @@ class MirascopeProvider(BaseProvider[None]):
|
|
|
256
278
|
tools: Sequence[Tool | ContextTool[DepsT]]
|
|
257
279
|
| ContextToolkit[DepsT]
|
|
258
280
|
| None = None,
|
|
259
|
-
format: type[FormattableT]
|
|
281
|
+
format: type[FormattableT]
|
|
282
|
+
| Format[FormattableT]
|
|
283
|
+
| OutputParser[FormattableT]
|
|
284
|
+
| None = None,
|
|
260
285
|
**params: Unpack[Params],
|
|
261
286
|
) -> (
|
|
262
287
|
ContextStreamResponse[DepsT, None] | ContextStreamResponse[DepsT, FormattableT]
|
|
@@ -278,7 +303,10 @@ class MirascopeProvider(BaseProvider[None]):
|
|
|
278
303
|
model_id: str,
|
|
279
304
|
messages: Sequence[Message],
|
|
280
305
|
tools: Sequence[AsyncTool] | AsyncToolkit | None = None,
|
|
281
|
-
format: type[FormattableT]
|
|
306
|
+
format: type[FormattableT]
|
|
307
|
+
| Format[FormattableT]
|
|
308
|
+
| OutputParser[FormattableT]
|
|
309
|
+
| None = None,
|
|
282
310
|
**params: Unpack[Params],
|
|
283
311
|
) -> AsyncStreamResponse | AsyncStreamResponse[FormattableT]:
|
|
284
312
|
"""Stream an `llm.AsyncStreamResponse` by calling through the Mirascope Router."""
|
|
@@ -300,7 +328,10 @@ class MirascopeProvider(BaseProvider[None]):
|
|
|
300
328
|
tools: Sequence[AsyncTool | AsyncContextTool[DepsT]]
|
|
301
329
|
| AsyncContextToolkit[DepsT]
|
|
302
330
|
| None = None,
|
|
303
|
-
format: type[FormattableT]
|
|
331
|
+
format: type[FormattableT]
|
|
332
|
+
| Format[FormattableT]
|
|
333
|
+
| OutputParser[FormattableT]
|
|
334
|
+
| None = None,
|
|
304
335
|
**params: Unpack[Params],
|
|
305
336
|
) -> (
|
|
306
337
|
AsyncContextStreamResponse[DepsT, None]
|
|
@@ -1,5 +1,7 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
from collections.abc import Callable
|
|
2
|
-
from typing import TypeAlias, TypedDict
|
|
4
|
+
from typing import TYPE_CHECKING, TypeAlias, TypedDict
|
|
3
5
|
|
|
4
6
|
import mlx.core as mx
|
|
5
7
|
from huggingface_hub.errors import LocalEntryNotFoundError
|
|
@@ -8,7 +10,10 @@ from mlx_lm.sample_utils import make_sampler
|
|
|
8
10
|
|
|
9
11
|
from ...exceptions import NotFoundError
|
|
10
12
|
from ...responses import FinishReason, Usage
|
|
11
|
-
from ..base import
|
|
13
|
+
from ..base import ProviderErrorMap, _utils as _base_utils
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from ...models import Params
|
|
12
17
|
|
|
13
18
|
Sampler: TypeAlias = Callable[[mx.array], mx.array]
|
|
14
19
|
|
|
@@ -6,7 +6,7 @@ from typing import TypeAlias
|
|
|
6
6
|
|
|
7
7
|
from mlx_lm.generate import GenerationResponse
|
|
8
8
|
|
|
9
|
-
from ....formatting import Format, FormattableT
|
|
9
|
+
from ....formatting import Format, FormattableT, OutputParser
|
|
10
10
|
from ....messages import AssistantContent, Message
|
|
11
11
|
from ....responses import ChunkIterator
|
|
12
12
|
from ....tools import AnyToolSchema, BaseToolkit
|
|
@@ -22,7 +22,10 @@ class BaseEncoder(abc.ABC):
|
|
|
22
22
|
self,
|
|
23
23
|
messages: Sequence[Message],
|
|
24
24
|
tools: Sequence[AnyToolSchema] | BaseToolkit[AnyToolSchema] | None,
|
|
25
|
-
format: type[FormattableT]
|
|
25
|
+
format: type[FormattableT]
|
|
26
|
+
| Format[FormattableT]
|
|
27
|
+
| OutputParser[FormattableT]
|
|
28
|
+
| None,
|
|
26
29
|
) -> tuple[Sequence[Message], Format[FormattableT] | None, TokenIds]:
|
|
27
30
|
"""Encode the request messages into a format suitable for the model.
|
|
28
31
|
|
|
@@ -8,7 +8,7 @@ from mlx_lm.generate import GenerationResponse
|
|
|
8
8
|
from transformers import PreTrainedTokenizer
|
|
9
9
|
|
|
10
10
|
from ....content import ContentPart, TextChunk, TextEndChunk, TextStartChunk
|
|
11
|
-
from ....formatting import Format, FormattableT
|
|
11
|
+
from ....formatting import Format, FormattableT, OutputParser
|
|
12
12
|
from ....messages import AssistantContent, Message
|
|
13
13
|
from ....responses import (
|
|
14
14
|
ChunkIterator,
|
|
@@ -81,7 +81,10 @@ class TransformersEncoder(BaseEncoder):
|
|
|
81
81
|
self,
|
|
82
82
|
messages: Sequence[Message],
|
|
83
83
|
tools: Sequence[AnyToolSchema] | BaseToolkit[AnyToolSchema] | None,
|
|
84
|
-
format: type[FormattableT]
|
|
84
|
+
format: type[FormattableT]
|
|
85
|
+
| Format[FormattableT]
|
|
86
|
+
| OutputParser[FormattableT]
|
|
87
|
+
| None,
|
|
85
88
|
) -> tuple[Sequence[Message], Format[FormattableT] | None, TokenIds]:
|
|
86
89
|
"""Encode a request into a format suitable for the model."""
|
|
87
90
|
tool_schemas = tools.tools if isinstance(tools, BaseToolkit) else tools or []
|
|
@@ -1,7 +1,10 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import asyncio
|
|
2
4
|
import threading
|
|
3
5
|
from collections.abc import Iterable, Sequence
|
|
4
6
|
from dataclasses import dataclass, field
|
|
7
|
+
from typing import TYPE_CHECKING
|
|
5
8
|
from typing_extensions import Unpack
|
|
6
9
|
|
|
7
10
|
import mlx.core as mx
|
|
@@ -10,15 +13,17 @@ from mlx_lm import stream_generate # type: ignore[reportPrivateImportUsage]
|
|
|
10
13
|
from mlx_lm.generate import GenerationResponse
|
|
11
14
|
from transformers import PreTrainedTokenizer
|
|
12
15
|
|
|
13
|
-
from ...formatting import Format, FormattableT
|
|
16
|
+
from ...formatting import Format, FormattableT, OutputParser
|
|
14
17
|
from ...messages import AssistantMessage, Message, assistant
|
|
15
18
|
from ...responses import AsyncChunkIterator, ChunkIterator, StreamResponseChunk
|
|
16
19
|
from ...tools import AnyToolSchema, BaseToolkit
|
|
17
|
-
from ..base import Params
|
|
18
20
|
from . import _utils
|
|
19
21
|
from .encoding import BaseEncoder, TokenIds
|
|
20
22
|
from .model_id import MLXModelId
|
|
21
23
|
|
|
24
|
+
if TYPE_CHECKING:
|
|
25
|
+
from ...models import Params
|
|
26
|
+
|
|
22
27
|
|
|
23
28
|
def _consume_sync_stream_into_queue(
|
|
24
29
|
generation_stream: ChunkIterator,
|
|
@@ -133,7 +138,10 @@ class MLX:
|
|
|
133
138
|
self,
|
|
134
139
|
messages: Sequence[Message],
|
|
135
140
|
tools: Sequence[AnyToolSchema] | BaseToolkit[AnyToolSchema] | None,
|
|
136
|
-
format: type[FormattableT]
|
|
141
|
+
format: type[FormattableT]
|
|
142
|
+
| Format[FormattableT]
|
|
143
|
+
| OutputParser[FormattableT]
|
|
144
|
+
| None,
|
|
137
145
|
params: Params,
|
|
138
146
|
) -> tuple[Sequence[Message], Format[FormattableT] | None, ChunkIterator]:
|
|
139
147
|
"""Stream response chunks synchronously.
|
|
@@ -156,7 +164,10 @@ class MLX:
|
|
|
156
164
|
self,
|
|
157
165
|
messages: Sequence[Message],
|
|
158
166
|
tools: Sequence[AnyToolSchema] | BaseToolkit[AnyToolSchema] | None,
|
|
159
|
-
format: type[FormattableT]
|
|
167
|
+
format: type[FormattableT]
|
|
168
|
+
| Format[FormattableT]
|
|
169
|
+
| OutputParser[FormattableT]
|
|
170
|
+
| None,
|
|
160
171
|
params: Params,
|
|
161
172
|
) -> tuple[Sequence[Message], Format[FormattableT] | None, AsyncChunkIterator]:
|
|
162
173
|
"""Stream response chunks asynchronously.
|
|
@@ -180,7 +191,10 @@ class MLX:
|
|
|
180
191
|
self,
|
|
181
192
|
messages: Sequence[Message],
|
|
182
193
|
tools: Sequence[AnyToolSchema] | BaseToolkit[AnyToolSchema] | None,
|
|
183
|
-
format: type[FormattableT]
|
|
194
|
+
format: type[FormattableT]
|
|
195
|
+
| Format[FormattableT]
|
|
196
|
+
| OutputParser[FormattableT]
|
|
197
|
+
| None,
|
|
184
198
|
params: Params,
|
|
185
199
|
) -> tuple[
|
|
186
200
|
Sequence[Message],
|
|
@@ -216,7 +230,10 @@ class MLX:
|
|
|
216
230
|
self,
|
|
217
231
|
messages: Sequence[Message],
|
|
218
232
|
tools: Sequence[AnyToolSchema] | BaseToolkit[AnyToolSchema] | None,
|
|
219
|
-
format: type[FormattableT]
|
|
233
|
+
format: type[FormattableT]
|
|
234
|
+
| Format[FormattableT]
|
|
235
|
+
| OutputParser[FormattableT]
|
|
236
|
+
| None,
|
|
220
237
|
params: Params,
|
|
221
238
|
) -> tuple[
|
|
222
239
|
Sequence[Message],
|