mirascope 2.0.0a6__py3-none-any.whl → 2.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mirascope/api/_generated/__init__.py +186 -5
- mirascope/api/_generated/annotations/client.py +38 -6
- mirascope/api/_generated/annotations/raw_client.py +366 -47
- mirascope/api/_generated/annotations/types/annotations_create_response.py +19 -6
- mirascope/api/_generated/annotations/types/annotations_get_response.py +19 -6
- mirascope/api/_generated/annotations/types/annotations_list_response_annotations_item.py +22 -7
- mirascope/api/_generated/annotations/types/annotations_update_response.py +19 -6
- mirascope/api/_generated/api_keys/__init__.py +12 -2
- mirascope/api/_generated/api_keys/client.py +107 -6
- mirascope/api/_generated/api_keys/raw_client.py +486 -38
- mirascope/api/_generated/api_keys/types/__init__.py +7 -1
- mirascope/api/_generated/api_keys/types/api_keys_list_all_for_org_response_item.py +40 -0
- mirascope/api/_generated/client.py +36 -0
- mirascope/api/_generated/docs/raw_client.py +71 -9
- mirascope/api/_generated/environment.py +3 -3
- mirascope/api/_generated/environments/__init__.py +6 -0
- mirascope/api/_generated/environments/client.py +158 -9
- mirascope/api/_generated/environments/raw_client.py +620 -52
- mirascope/api/_generated/environments/types/__init__.py +10 -0
- mirascope/api/_generated/environments/types/environments_get_analytics_response.py +60 -0
- mirascope/api/_generated/environments/types/environments_get_analytics_response_top_functions_item.py +24 -0
- mirascope/api/_generated/{organizations/types/organizations_credits_response.py → environments/types/environments_get_analytics_response_top_models_item.py} +6 -3
- mirascope/api/_generated/errors/__init__.py +6 -0
- mirascope/api/_generated/errors/bad_request_error.py +5 -2
- mirascope/api/_generated/errors/conflict_error.py +5 -2
- mirascope/api/_generated/errors/payment_required_error.py +15 -0
- mirascope/api/_generated/errors/service_unavailable_error.py +14 -0
- mirascope/api/_generated/errors/too_many_requests_error.py +15 -0
- mirascope/api/_generated/functions/__init__.py +10 -0
- mirascope/api/_generated/functions/client.py +222 -8
- mirascope/api/_generated/functions/raw_client.py +975 -134
- mirascope/api/_generated/functions/types/__init__.py +28 -4
- mirascope/api/_generated/functions/types/functions_get_by_env_response.py +53 -0
- mirascope/api/_generated/functions/types/functions_get_by_env_response_dependencies_value.py +22 -0
- mirascope/api/_generated/functions/types/functions_list_by_env_response.py +25 -0
- mirascope/api/_generated/functions/types/functions_list_by_env_response_functions_item.py +56 -0
- mirascope/api/_generated/functions/types/functions_list_by_env_response_functions_item_dependencies_value.py +22 -0
- mirascope/api/_generated/health/raw_client.py +74 -10
- mirascope/api/_generated/organization_invitations/__init__.py +33 -0
- mirascope/api/_generated/organization_invitations/client.py +546 -0
- mirascope/api/_generated/organization_invitations/raw_client.py +1519 -0
- mirascope/api/_generated/organization_invitations/types/__init__.py +53 -0
- mirascope/api/_generated/organization_invitations/types/organization_invitations_accept_response.py +34 -0
- mirascope/api/_generated/organization_invitations/types/organization_invitations_accept_response_role.py +7 -0
- mirascope/api/_generated/organization_invitations/types/organization_invitations_create_request_role.py +7 -0
- mirascope/api/_generated/organization_invitations/types/organization_invitations_create_response.py +48 -0
- mirascope/api/_generated/organization_invitations/types/organization_invitations_create_response_role.py +7 -0
- mirascope/api/_generated/organization_invitations/types/organization_invitations_create_response_status.py +7 -0
- mirascope/api/_generated/organization_invitations/types/organization_invitations_get_response.py +48 -0
- mirascope/api/_generated/organization_invitations/types/organization_invitations_get_response_role.py +7 -0
- mirascope/api/_generated/organization_invitations/types/organization_invitations_get_response_status.py +7 -0
- mirascope/api/_generated/organization_invitations/types/organization_invitations_list_response_item.py +48 -0
- mirascope/api/_generated/organization_invitations/types/organization_invitations_list_response_item_role.py +7 -0
- mirascope/api/_generated/organization_invitations/types/organization_invitations_list_response_item_status.py +7 -0
- mirascope/api/_generated/organization_memberships/__init__.py +19 -0
- mirascope/api/_generated/organization_memberships/client.py +302 -0
- mirascope/api/_generated/organization_memberships/raw_client.py +736 -0
- mirascope/api/_generated/organization_memberships/types/__init__.py +27 -0
- mirascope/api/_generated/organization_memberships/types/organization_memberships_list_response_item.py +33 -0
- mirascope/api/_generated/organization_memberships/types/organization_memberships_list_response_item_role.py +7 -0
- mirascope/api/_generated/organization_memberships/types/organization_memberships_update_request_role.py +7 -0
- mirascope/api/_generated/organization_memberships/types/organization_memberships_update_response.py +31 -0
- mirascope/api/_generated/organization_memberships/types/organization_memberships_update_response_role.py +7 -0
- mirascope/api/_generated/organizations/__init__.py +26 -2
- mirascope/api/_generated/organizations/client.py +442 -20
- mirascope/api/_generated/organizations/raw_client.py +1763 -164
- mirascope/api/_generated/organizations/types/__init__.py +48 -2
- mirascope/api/_generated/organizations/types/organizations_create_payment_intent_response.py +24 -0
- mirascope/api/_generated/organizations/types/organizations_preview_subscription_change_request_target_plan.py +7 -0
- mirascope/api/_generated/organizations/types/organizations_preview_subscription_change_response.py +47 -0
- mirascope/api/_generated/organizations/types/organizations_preview_subscription_change_response_validation_errors_item.py +33 -0
- mirascope/api/_generated/organizations/types/organizations_preview_subscription_change_response_validation_errors_item_resource.py +7 -0
- mirascope/api/_generated/organizations/types/organizations_router_balance_response.py +24 -0
- mirascope/api/_generated/organizations/types/organizations_subscription_response.py +53 -0
- mirascope/api/_generated/organizations/types/organizations_subscription_response_current_plan.py +7 -0
- mirascope/api/_generated/organizations/types/organizations_subscription_response_payment_method.py +26 -0
- mirascope/api/_generated/organizations/types/organizations_subscription_response_scheduled_change.py +34 -0
- mirascope/api/_generated/organizations/types/organizations_subscription_response_scheduled_change_target_plan.py +7 -0
- mirascope/api/_generated/organizations/types/organizations_update_subscription_request_target_plan.py +7 -0
- mirascope/api/_generated/organizations/types/organizations_update_subscription_response.py +35 -0
- mirascope/api/_generated/project_memberships/__init__.py +25 -0
- mirascope/api/_generated/project_memberships/client.py +437 -0
- mirascope/api/_generated/project_memberships/raw_client.py +1039 -0
- mirascope/api/_generated/project_memberships/types/__init__.py +29 -0
- mirascope/api/_generated/project_memberships/types/project_memberships_create_request_role.py +7 -0
- mirascope/api/_generated/project_memberships/types/project_memberships_create_response.py +35 -0
- mirascope/api/_generated/project_memberships/types/project_memberships_create_response_role.py +7 -0
- mirascope/api/_generated/project_memberships/types/project_memberships_list_response_item.py +33 -0
- mirascope/api/_generated/project_memberships/types/project_memberships_list_response_item_role.py +7 -0
- mirascope/api/_generated/project_memberships/types/project_memberships_update_request_role.py +7 -0
- mirascope/api/_generated/project_memberships/types/project_memberships_update_response.py +35 -0
- mirascope/api/_generated/project_memberships/types/project_memberships_update_response_role.py +7 -0
- mirascope/api/_generated/projects/raw_client.py +415 -58
- mirascope/api/_generated/reference.md +2767 -397
- mirascope/api/_generated/tags/__init__.py +19 -0
- mirascope/api/_generated/tags/client.py +504 -0
- mirascope/api/_generated/tags/raw_client.py +1288 -0
- mirascope/api/_generated/tags/types/__init__.py +17 -0
- mirascope/api/_generated/tags/types/tags_create_response.py +41 -0
- mirascope/api/_generated/tags/types/tags_get_response.py +41 -0
- mirascope/api/_generated/tags/types/tags_list_response.py +23 -0
- mirascope/api/_generated/tags/types/tags_list_response_tags_item.py +41 -0
- mirascope/api/_generated/tags/types/tags_update_response.py +41 -0
- mirascope/api/_generated/token_cost/__init__.py +7 -0
- mirascope/api/_generated/token_cost/client.py +160 -0
- mirascope/api/_generated/token_cost/raw_client.py +264 -0
- mirascope/api/_generated/token_cost/types/__init__.py +8 -0
- mirascope/api/_generated/token_cost/types/token_cost_calculate_request_usage.py +54 -0
- mirascope/api/_generated/token_cost/types/token_cost_calculate_response.py +52 -0
- mirascope/api/_generated/traces/__init__.py +20 -0
- mirascope/api/_generated/traces/client.py +543 -0
- mirascope/api/_generated/traces/raw_client.py +1366 -96
- mirascope/api/_generated/traces/types/__init__.py +28 -0
- mirascope/api/_generated/traces/types/traces_get_analytics_summary_response.py +6 -0
- mirascope/api/_generated/traces/types/traces_get_trace_detail_by_env_response.py +33 -0
- mirascope/api/_generated/traces/types/traces_get_trace_detail_by_env_response_spans_item.py +88 -0
- mirascope/api/_generated/traces/types/traces_get_trace_detail_response_spans_item.py +0 -2
- mirascope/api/_generated/traces/types/traces_list_by_function_hash_response.py +25 -0
- mirascope/api/_generated/traces/types/traces_list_by_function_hash_response_traces_item.py +44 -0
- mirascope/api/_generated/traces/types/traces_search_by_env_request_attribute_filters_item.py +26 -0
- mirascope/api/_generated/traces/types/traces_search_by_env_request_attribute_filters_item_operator.py +7 -0
- mirascope/api/_generated/traces/types/traces_search_by_env_request_sort_by.py +7 -0
- mirascope/api/_generated/traces/types/traces_search_by_env_request_sort_order.py +7 -0
- mirascope/api/_generated/traces/types/traces_search_by_env_response.py +26 -0
- mirascope/api/_generated/traces/types/traces_search_by_env_response_spans_item.py +50 -0
- mirascope/api/_generated/traces/types/traces_search_response_spans_item.py +10 -1
- mirascope/api/_generated/types/__init__.py +32 -2
- mirascope/api/_generated/types/bad_request_error_body.py +50 -0
- mirascope/api/_generated/types/date.py +3 -0
- mirascope/api/_generated/types/immutable_resource_error.py +22 -0
- mirascope/api/_generated/types/internal_server_error_body.py +3 -3
- mirascope/api/_generated/types/plan_limit_exceeded_error.py +32 -0
- mirascope/api/_generated/types/plan_limit_exceeded_error_tag.py +7 -0
- mirascope/api/_generated/types/pricing_unavailable_error.py +23 -0
- mirascope/api/_generated/types/rate_limit_error.py +31 -0
- mirascope/api/_generated/types/rate_limit_error_tag.py +5 -0
- mirascope/api/_generated/types/service_unavailable_error_body.py +24 -0
- mirascope/api/_generated/types/service_unavailable_error_tag.py +7 -0
- mirascope/api/_generated/types/subscription_past_due_error.py +31 -0
- mirascope/api/_generated/types/subscription_past_due_error_tag.py +7 -0
- mirascope/api/settings.py +19 -1
- mirascope/llm/__init__.py +53 -10
- mirascope/llm/calls/__init__.py +2 -1
- mirascope/llm/calls/calls.py +3 -1
- mirascope/llm/calls/decorator.py +21 -7
- mirascope/llm/content/tool_output.py +22 -5
- mirascope/llm/exceptions.py +284 -71
- mirascope/llm/formatting/__init__.py +17 -0
- mirascope/llm/formatting/format.py +112 -35
- mirascope/llm/formatting/output_parser.py +178 -0
- mirascope/llm/formatting/partial.py +80 -7
- mirascope/llm/formatting/primitives.py +192 -0
- mirascope/llm/formatting/types.py +20 -8
- mirascope/llm/messages/__init__.py +3 -0
- mirascope/llm/messages/_utils.py +34 -0
- mirascope/llm/models/__init__.py +5 -0
- mirascope/llm/models/models.py +137 -69
- mirascope/llm/{providers/base → models}/params.py +7 -57
- mirascope/llm/models/thinking_config.py +61 -0
- mirascope/llm/prompts/_utils.py +0 -32
- mirascope/llm/prompts/decorator.py +16 -5
- mirascope/llm/prompts/prompts.py +131 -68
- mirascope/llm/providers/__init__.py +1 -4
- mirascope/llm/providers/anthropic/_utils/__init__.py +2 -0
- mirascope/llm/providers/anthropic/_utils/beta_decode.py +18 -9
- mirascope/llm/providers/anthropic/_utils/beta_encode.py +62 -13
- mirascope/llm/providers/anthropic/_utils/decode.py +18 -9
- mirascope/llm/providers/anthropic/_utils/encode.py +26 -7
- mirascope/llm/providers/anthropic/_utils/errors.py +2 -2
- mirascope/llm/providers/anthropic/beta_provider.py +64 -18
- mirascope/llm/providers/anthropic/provider.py +91 -33
- mirascope/llm/providers/base/__init__.py +0 -4
- mirascope/llm/providers/base/_utils.py +55 -6
- mirascope/llm/providers/base/base_provider.py +116 -37
- mirascope/llm/providers/google/_utils/__init__.py +2 -0
- mirascope/llm/providers/google/_utils/decode.py +20 -7
- mirascope/llm/providers/google/_utils/encode.py +26 -7
- mirascope/llm/providers/google/_utils/errors.py +3 -2
- mirascope/llm/providers/google/provider.py +64 -18
- mirascope/llm/providers/mirascope/_utils.py +13 -17
- mirascope/llm/providers/mirascope/provider.py +49 -18
- mirascope/llm/providers/mlx/_utils.py +7 -2
- mirascope/llm/providers/mlx/encoding/base.py +5 -2
- mirascope/llm/providers/mlx/encoding/transformers.py +5 -2
- mirascope/llm/providers/mlx/mlx.py +23 -6
- mirascope/llm/providers/mlx/provider.py +42 -13
- mirascope/llm/providers/openai/_utils/errors.py +2 -2
- mirascope/llm/providers/openai/completions/_utils/encode.py +20 -16
- mirascope/llm/providers/openai/completions/base_provider.py +40 -11
- mirascope/llm/providers/openai/provider.py +40 -10
- mirascope/llm/providers/openai/responses/_utils/__init__.py +2 -0
- mirascope/llm/providers/openai/responses/_utils/decode.py +19 -6
- mirascope/llm/providers/openai/responses/_utils/encode.py +22 -10
- mirascope/llm/providers/openai/responses/provider.py +56 -18
- mirascope/llm/providers/provider_registry.py +93 -19
- mirascope/llm/responses/__init__.py +6 -1
- mirascope/llm/responses/_utils.py +102 -12
- mirascope/llm/responses/base_response.py +5 -2
- mirascope/llm/responses/base_stream_response.py +115 -25
- mirascope/llm/responses/response.py +2 -1
- mirascope/llm/responses/root_response.py +89 -17
- mirascope/llm/responses/stream_response.py +6 -9
- mirascope/llm/tools/decorator.py +9 -4
- mirascope/llm/tools/tool_schema.py +12 -6
- mirascope/llm/tools/toolkit.py +35 -27
- mirascope/llm/tools/tools.py +45 -20
- mirascope/ops/__init__.py +4 -0
- mirascope/ops/_internal/configuration.py +82 -31
- mirascope/ops/_internal/exporters/exporters.py +64 -11
- mirascope/ops/_internal/instrumentation/llm/common.py +530 -0
- mirascope/ops/_internal/instrumentation/llm/cost.py +190 -0
- mirascope/ops/_internal/instrumentation/llm/encode.py +1 -1
- mirascope/ops/_internal/instrumentation/llm/llm.py +116 -1242
- mirascope/ops/_internal/instrumentation/llm/model.py +1798 -0
- mirascope/ops/_internal/instrumentation/llm/response.py +521 -0
- mirascope/ops/_internal/instrumentation/llm/serialize.py +300 -0
- mirascope/ops/_internal/protocols.py +83 -1
- mirascope/ops/_internal/traced_calls.py +4 -0
- mirascope/ops/_internal/traced_functions.py +118 -8
- mirascope/ops/_internal/tracing.py +78 -1
- mirascope/ops/_internal/utils.py +52 -4
- {mirascope-2.0.0a6.dist-info → mirascope-2.0.1.dist-info}/METADATA +12 -11
- mirascope-2.0.1.dist-info/RECORD +423 -0
- {mirascope-2.0.0a6.dist-info → mirascope-2.0.1.dist-info}/licenses/LICENSE +1 -1
- mirascope-2.0.0a6.dist-info/RECORD +0 -316
- {mirascope-2.0.0a6.dist-info → mirascope-2.0.1.dist-info}/WHEEL +0 -0
|
@@ -1,6 +1,8 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
from collections.abc import Sequence
|
|
2
4
|
from functools import cache, lru_cache
|
|
3
|
-
from typing import cast
|
|
5
|
+
from typing import TYPE_CHECKING, cast
|
|
4
6
|
from typing_extensions import Unpack
|
|
5
7
|
|
|
6
8
|
import mlx.nn as nn
|
|
@@ -8,7 +10,7 @@ from mlx_lm import load as mlx_load
|
|
|
8
10
|
from transformers import PreTrainedTokenizer
|
|
9
11
|
|
|
10
12
|
from ...context import Context, DepsT
|
|
11
|
-
from ...formatting import Format, FormattableT
|
|
13
|
+
from ...formatting import Format, FormattableT, OutputParser
|
|
12
14
|
from ...messages import Message
|
|
13
15
|
from ...responses import (
|
|
14
16
|
AsyncContextResponse,
|
|
@@ -30,20 +32,23 @@ from ...tools import (
|
|
|
30
32
|
Tool,
|
|
31
33
|
Toolkit,
|
|
32
34
|
)
|
|
33
|
-
from ..base import BaseProvider
|
|
35
|
+
from ..base import BaseProvider
|
|
34
36
|
from . import _utils
|
|
35
37
|
from .encoding import TransformersEncoder
|
|
36
38
|
from .mlx import MLX
|
|
37
39
|
from .model_id import MLXModelId
|
|
38
40
|
|
|
41
|
+
if TYPE_CHECKING:
|
|
42
|
+
from ...models import Params
|
|
43
|
+
|
|
39
44
|
|
|
40
45
|
@cache
|
|
41
|
-
def _mlx_client_singleton() ->
|
|
46
|
+
def _mlx_client_singleton() -> MLXProvider:
|
|
42
47
|
"""Get or create the singleton MLX client instance."""
|
|
43
48
|
return MLXProvider()
|
|
44
49
|
|
|
45
50
|
|
|
46
|
-
def client() ->
|
|
51
|
+
def client() -> MLXProvider:
|
|
47
52
|
"""Get the MLX client singleton instance."""
|
|
48
53
|
return _mlx_client_singleton()
|
|
49
54
|
|
|
@@ -85,7 +90,10 @@ class MLXProvider(BaseProvider[None]):
|
|
|
85
90
|
model_id: MLXModelId,
|
|
86
91
|
messages: Sequence[Message],
|
|
87
92
|
tools: Sequence[Tool] | Toolkit | None = None,
|
|
88
|
-
format: type[FormattableT]
|
|
93
|
+
format: type[FormattableT]
|
|
94
|
+
| Format[FormattableT]
|
|
95
|
+
| OutputParser[FormattableT]
|
|
96
|
+
| None = None,
|
|
89
97
|
**params: Unpack[Params],
|
|
90
98
|
) -> Response | Response[FormattableT]:
|
|
91
99
|
"""Generate an `llm.Response` using MLX model.
|
|
@@ -129,7 +137,10 @@ class MLXProvider(BaseProvider[None]):
|
|
|
129
137
|
tools: Sequence[Tool | ContextTool[DepsT]]
|
|
130
138
|
| ContextToolkit[DepsT]
|
|
131
139
|
| None = None,
|
|
132
|
-
format: type[FormattableT]
|
|
140
|
+
format: type[FormattableT]
|
|
141
|
+
| Format[FormattableT]
|
|
142
|
+
| OutputParser[FormattableT]
|
|
143
|
+
| None = None,
|
|
133
144
|
**params: Unpack[Params],
|
|
134
145
|
) -> ContextResponse[DepsT, None] | ContextResponse[DepsT, FormattableT]:
|
|
135
146
|
"""Generate an `llm.ContextResponse` using MLX model.
|
|
@@ -171,7 +182,10 @@ class MLXProvider(BaseProvider[None]):
|
|
|
171
182
|
model_id: MLXModelId,
|
|
172
183
|
messages: Sequence[Message],
|
|
173
184
|
tools: Sequence[AsyncTool] | AsyncToolkit | None = None,
|
|
174
|
-
format: type[FormattableT]
|
|
185
|
+
format: type[FormattableT]
|
|
186
|
+
| Format[FormattableT]
|
|
187
|
+
| OutputParser[FormattableT]
|
|
188
|
+
| None = None,
|
|
175
189
|
**params: Unpack[Params],
|
|
176
190
|
) -> AsyncResponse | AsyncResponse[FormattableT]:
|
|
177
191
|
"""Generate an `llm.AsyncResponse` using MLX model by asynchronously calloing
|
|
@@ -219,7 +233,10 @@ class MLXProvider(BaseProvider[None]):
|
|
|
219
233
|
tools: Sequence[AsyncTool | AsyncContextTool[DepsT]]
|
|
220
234
|
| AsyncContextToolkit[DepsT]
|
|
221
235
|
| None = None,
|
|
222
|
-
format: type[FormattableT]
|
|
236
|
+
format: type[FormattableT]
|
|
237
|
+
| Format[FormattableT]
|
|
238
|
+
| OutputParser[FormattableT]
|
|
239
|
+
| None = None,
|
|
223
240
|
**params: Unpack[Params],
|
|
224
241
|
) -> AsyncContextResponse[DepsT, None] | AsyncContextResponse[DepsT, FormattableT]:
|
|
225
242
|
"""Generate an `llm.AsyncResponse` using MLX model by asynchronously calloing
|
|
@@ -265,7 +282,10 @@ class MLXProvider(BaseProvider[None]):
|
|
|
265
282
|
model_id: MLXModelId,
|
|
266
283
|
messages: Sequence[Message],
|
|
267
284
|
tools: Sequence[Tool] | Toolkit | None = None,
|
|
268
|
-
format: type[FormattableT]
|
|
285
|
+
format: type[FormattableT]
|
|
286
|
+
| Format[FormattableT]
|
|
287
|
+
| OutputParser[FormattableT]
|
|
288
|
+
| None = None,
|
|
269
289
|
**params: Unpack[Params],
|
|
270
290
|
) -> StreamResponse | StreamResponse[FormattableT]:
|
|
271
291
|
"""Generate an `llm.StreamResponse` by synchronously streaming from MLX model output.
|
|
@@ -306,7 +326,10 @@ class MLXProvider(BaseProvider[None]):
|
|
|
306
326
|
tools: Sequence[Tool | ContextTool[DepsT]]
|
|
307
327
|
| ContextToolkit[DepsT]
|
|
308
328
|
| None = None,
|
|
309
|
-
format: type[FormattableT]
|
|
329
|
+
format: type[FormattableT]
|
|
330
|
+
| Format[FormattableT]
|
|
331
|
+
| OutputParser[FormattableT]
|
|
332
|
+
| None = None,
|
|
310
333
|
**params: Unpack[Params],
|
|
311
334
|
) -> ContextStreamResponse[DepsT] | ContextStreamResponse[DepsT, FormattableT]:
|
|
312
335
|
"""Generate an `llm.ContextStreamResponse` by synchronously streaming from MLX model output.
|
|
@@ -345,7 +368,10 @@ class MLXProvider(BaseProvider[None]):
|
|
|
345
368
|
model_id: MLXModelId,
|
|
346
369
|
messages: Sequence[Message],
|
|
347
370
|
tools: Sequence[AsyncTool] | AsyncToolkit | None = None,
|
|
348
|
-
format: type[FormattableT]
|
|
371
|
+
format: type[FormattableT]
|
|
372
|
+
| Format[FormattableT]
|
|
373
|
+
| OutputParser[FormattableT]
|
|
374
|
+
| None = None,
|
|
349
375
|
**params: Unpack[Params],
|
|
350
376
|
) -> AsyncStreamResponse | AsyncStreamResponse[FormattableT]:
|
|
351
377
|
"""Generate an `llm.AsyncStreamResponse` by asynchronously streaming from MLX model output.
|
|
@@ -386,7 +412,10 @@ class MLXProvider(BaseProvider[None]):
|
|
|
386
412
|
tools: Sequence[AsyncTool | AsyncContextTool[DepsT]]
|
|
387
413
|
| AsyncContextToolkit[DepsT]
|
|
388
414
|
| None = None,
|
|
389
|
-
format: type[FormattableT]
|
|
415
|
+
format: type[FormattableT]
|
|
416
|
+
| Format[FormattableT]
|
|
417
|
+
| OutputParser[FormattableT]
|
|
418
|
+
| None = None,
|
|
390
419
|
**params: Unpack[Params],
|
|
391
420
|
) -> (
|
|
392
421
|
AsyncContextStreamResponse[DepsT]
|
|
@@ -16,12 +16,12 @@ from openai import (
|
|
|
16
16
|
)
|
|
17
17
|
|
|
18
18
|
from ....exceptions import (
|
|
19
|
-
APIError,
|
|
20
19
|
AuthenticationError,
|
|
21
20
|
BadRequestError,
|
|
22
21
|
ConnectionError,
|
|
23
22
|
NotFoundError,
|
|
24
23
|
PermissionError,
|
|
24
|
+
ProviderError,
|
|
25
25
|
RateLimitError,
|
|
26
26
|
ResponseValidationError,
|
|
27
27
|
ServerError,
|
|
@@ -42,5 +42,5 @@ OPENAI_ERROR_MAP: ProviderErrorMap = {
|
|
|
42
42
|
OpenAIAPITimeoutError: TimeoutError,
|
|
43
43
|
OpenAIAPIConnectionError: ConnectionError,
|
|
44
44
|
OpenAIAPIResponseValidationError: ResponseValidationError,
|
|
45
|
-
OpenAIError:
|
|
45
|
+
OpenAIError: ProviderError, # Catch-all for unknown OpenAI errors
|
|
46
46
|
}
|
|
@@ -1,25 +1,20 @@
|
|
|
1
1
|
"""OpenAI completions message encoding and request preparation."""
|
|
2
2
|
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
3
5
|
from collections.abc import Sequence
|
|
4
6
|
from functools import lru_cache
|
|
5
|
-
from typing import TypedDict, cast
|
|
7
|
+
from typing import TYPE_CHECKING, TypedDict, cast
|
|
6
8
|
|
|
7
9
|
from openai import Omit
|
|
8
10
|
from openai.types import chat as openai_types, shared_params as shared_openai_types
|
|
9
11
|
from openai.types.shared_params.response_format_json_schema import JSONSchema
|
|
10
12
|
|
|
11
|
-
from .....exceptions import
|
|
12
|
-
|
|
13
|
-
FormattingModeNotSupportedError,
|
|
14
|
-
)
|
|
15
|
-
from .....formatting import (
|
|
16
|
-
Format,
|
|
17
|
-
FormattableT,
|
|
18
|
-
resolve_format,
|
|
19
|
-
)
|
|
13
|
+
from .....exceptions import FeatureNotSupportedError
|
|
14
|
+
from .....formatting import Format, FormattableT, OutputParser, resolve_format
|
|
20
15
|
from .....messages import AssistantMessage, Message, UserMessage
|
|
21
16
|
from .....tools import FORMAT_TOOL_NAME, AnyToolSchema, BaseToolkit
|
|
22
|
-
from ....base import
|
|
17
|
+
from ....base import _utils as _base_utils
|
|
23
18
|
from ...model_id import OpenAIModelId, model_name
|
|
24
19
|
from ...model_info import (
|
|
25
20
|
MODELS_WITHOUT_AUDIO_SUPPORT,
|
|
@@ -27,6 +22,9 @@ from ...model_info import (
|
|
|
27
22
|
MODELS_WITHOUT_JSON_SCHEMA_SUPPORT,
|
|
28
23
|
)
|
|
29
24
|
|
|
25
|
+
if TYPE_CHECKING:
|
|
26
|
+
from .....models import Params
|
|
27
|
+
|
|
30
28
|
|
|
31
29
|
class ChatCompletionCreateKwargs(TypedDict, total=False):
|
|
32
30
|
"""Kwargs for OpenAI ChatCompletion.create method."""
|
|
@@ -129,7 +127,7 @@ def _encode_user_message(
|
|
|
129
127
|
result.append(
|
|
130
128
|
openai_types.ChatCompletionToolMessageParam(
|
|
131
129
|
role="tool",
|
|
132
|
-
content=str(part.
|
|
130
|
+
content=str(part.result),
|
|
133
131
|
tool_call_id=part.id,
|
|
134
132
|
)
|
|
135
133
|
)
|
|
@@ -236,13 +234,16 @@ def _convert_tool_to_tool_param(
|
|
|
236
234
|
schema_dict = tool.parameters.model_dump(by_alias=True, exclude_none=True)
|
|
237
235
|
schema_dict["type"] = "object"
|
|
238
236
|
_base_utils.ensure_additional_properties_false(schema_dict)
|
|
237
|
+
strict = True if tool.strict is None else tool.strict
|
|
238
|
+
if strict:
|
|
239
|
+
_base_utils.ensure_all_properties_required(schema_dict)
|
|
239
240
|
return openai_types.ChatCompletionToolParam(
|
|
240
241
|
type="function",
|
|
241
242
|
function={
|
|
242
243
|
"name": tool.name,
|
|
243
244
|
"description": tool.description,
|
|
244
245
|
"parameters": schema_dict,
|
|
245
|
-
"strict":
|
|
246
|
+
"strict": strict,
|
|
246
247
|
},
|
|
247
248
|
)
|
|
248
249
|
|
|
@@ -280,7 +281,10 @@ def encode_request(
|
|
|
280
281
|
model_id: OpenAIModelId,
|
|
281
282
|
messages: Sequence[Message],
|
|
282
283
|
tools: Sequence[AnyToolSchema] | BaseToolkit[AnyToolSchema] | None,
|
|
283
|
-
format: type[FormattableT]
|
|
284
|
+
format: type[FormattableT]
|
|
285
|
+
| Format[FormattableT]
|
|
286
|
+
| OutputParser[FormattableT]
|
|
287
|
+
| None,
|
|
284
288
|
params: Params,
|
|
285
289
|
) -> tuple[Sequence[Message], Format[FormattableT] | None, ChatCompletionCreateKwargs]:
|
|
286
290
|
"""Prepares a request for the `OpenAI.chat.completions.create` method."""
|
|
@@ -331,8 +335,8 @@ def encode_request(
|
|
|
331
335
|
if format is not None:
|
|
332
336
|
if format.mode == "strict":
|
|
333
337
|
if not model_supports_strict:
|
|
334
|
-
raise
|
|
335
|
-
|
|
338
|
+
raise FeatureNotSupportedError(
|
|
339
|
+
feature="formatting_mode:strict",
|
|
336
340
|
provider_id="openai",
|
|
337
341
|
model_id=model_id,
|
|
338
342
|
)
|
|
@@ -1,14 +1,16 @@
|
|
|
1
1
|
"""Base class for OpenAI Completions-compatible providers."""
|
|
2
2
|
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
3
5
|
import os
|
|
4
6
|
from collections.abc import Sequence
|
|
5
|
-
from typing import ClassVar
|
|
7
|
+
from typing import TYPE_CHECKING, ClassVar
|
|
6
8
|
from typing_extensions import Unpack
|
|
7
9
|
|
|
8
10
|
from openai import AsyncOpenAI, OpenAI
|
|
9
11
|
|
|
10
12
|
from ....context import Context, DepsT
|
|
11
|
-
from ....formatting import Format, FormattableT
|
|
13
|
+
from ....formatting import Format, FormattableT, OutputParser
|
|
12
14
|
from ....messages import Message
|
|
13
15
|
from ....responses import (
|
|
14
16
|
AsyncContextResponse,
|
|
@@ -30,11 +32,14 @@ from ....tools import (
|
|
|
30
32
|
Tool,
|
|
31
33
|
Toolkit,
|
|
32
34
|
)
|
|
33
|
-
from ...base import BaseProvider
|
|
35
|
+
from ...base import BaseProvider
|
|
34
36
|
from .. import _utils as _shared_utils
|
|
35
37
|
from ..model_id import model_name as openai_model_name
|
|
36
38
|
from . import _utils
|
|
37
39
|
|
|
40
|
+
if TYPE_CHECKING:
|
|
41
|
+
from ....models import Params
|
|
42
|
+
|
|
38
43
|
|
|
39
44
|
class BaseOpenAICompletionsProvider(BaseProvider[OpenAI]):
|
|
40
45
|
"""Base class for providers that use OpenAI Completions-compatible APIs."""
|
|
@@ -97,7 +102,10 @@ class BaseOpenAICompletionsProvider(BaseProvider[OpenAI]):
|
|
|
97
102
|
model_id: str,
|
|
98
103
|
messages: Sequence[Message],
|
|
99
104
|
tools: Sequence[Tool] | Toolkit | None = None,
|
|
100
|
-
format: type[FormattableT]
|
|
105
|
+
format: type[FormattableT]
|
|
106
|
+
| Format[FormattableT]
|
|
107
|
+
| OutputParser[FormattableT]
|
|
108
|
+
| None = None,
|
|
101
109
|
**params: Unpack[Params],
|
|
102
110
|
) -> Response | Response[FormattableT]:
|
|
103
111
|
"""Generate an `llm.Response` by synchronously calling the API.
|
|
@@ -152,7 +160,10 @@ class BaseOpenAICompletionsProvider(BaseProvider[OpenAI]):
|
|
|
152
160
|
tools: Sequence[Tool | ContextTool[DepsT]]
|
|
153
161
|
| ContextToolkit[DepsT]
|
|
154
162
|
| None = None,
|
|
155
|
-
format: type[FormattableT]
|
|
163
|
+
format: type[FormattableT]
|
|
164
|
+
| Format[FormattableT]
|
|
165
|
+
| OutputParser[FormattableT]
|
|
166
|
+
| None = None,
|
|
156
167
|
**params: Unpack[Params],
|
|
157
168
|
) -> ContextResponse[DepsT, None] | ContextResponse[DepsT, FormattableT]:
|
|
158
169
|
"""Generate an `llm.ContextResponse` by synchronously calling the API.
|
|
@@ -205,7 +216,10 @@ class BaseOpenAICompletionsProvider(BaseProvider[OpenAI]):
|
|
|
205
216
|
model_id: str,
|
|
206
217
|
messages: Sequence[Message],
|
|
207
218
|
tools: Sequence[AsyncTool] | AsyncToolkit | None = None,
|
|
208
|
-
format: type[FormattableT]
|
|
219
|
+
format: type[FormattableT]
|
|
220
|
+
| Format[FormattableT]
|
|
221
|
+
| OutputParser[FormattableT]
|
|
222
|
+
| None = None,
|
|
209
223
|
**params: Unpack[Params],
|
|
210
224
|
) -> AsyncResponse | AsyncResponse[FormattableT]:
|
|
211
225
|
"""Generate an `llm.AsyncResponse` by asynchronously calling the API.
|
|
@@ -260,7 +274,10 @@ class BaseOpenAICompletionsProvider(BaseProvider[OpenAI]):
|
|
|
260
274
|
tools: Sequence[AsyncTool | AsyncContextTool[DepsT]]
|
|
261
275
|
| AsyncContextToolkit[DepsT]
|
|
262
276
|
| None = None,
|
|
263
|
-
format: type[FormattableT]
|
|
277
|
+
format: type[FormattableT]
|
|
278
|
+
| Format[FormattableT]
|
|
279
|
+
| OutputParser[FormattableT]
|
|
280
|
+
| None = None,
|
|
264
281
|
**params: Unpack[Params],
|
|
265
282
|
) -> AsyncContextResponse[DepsT, None] | AsyncContextResponse[DepsT, FormattableT]:
|
|
266
283
|
"""Generate an `llm.AsyncContextResponse` by asynchronously calling the API.
|
|
@@ -313,7 +330,10 @@ class BaseOpenAICompletionsProvider(BaseProvider[OpenAI]):
|
|
|
313
330
|
model_id: str,
|
|
314
331
|
messages: Sequence[Message],
|
|
315
332
|
tools: Sequence[Tool] | Toolkit | None = None,
|
|
316
|
-
format: type[FormattableT]
|
|
333
|
+
format: type[FormattableT]
|
|
334
|
+
| Format[FormattableT]
|
|
335
|
+
| OutputParser[FormattableT]
|
|
336
|
+
| None = None,
|
|
317
337
|
**params: Unpack[Params],
|
|
318
338
|
) -> StreamResponse | StreamResponse[FormattableT]:
|
|
319
339
|
"""Generate an `llm.StreamResponse` by synchronously streaming from the API.
|
|
@@ -364,7 +384,10 @@ class BaseOpenAICompletionsProvider(BaseProvider[OpenAI]):
|
|
|
364
384
|
tools: Sequence[Tool | ContextTool[DepsT]]
|
|
365
385
|
| ContextToolkit[DepsT]
|
|
366
386
|
| None = None,
|
|
367
|
-
format: type[FormattableT]
|
|
387
|
+
format: type[FormattableT]
|
|
388
|
+
| Format[FormattableT]
|
|
389
|
+
| OutputParser[FormattableT]
|
|
390
|
+
| None = None,
|
|
368
391
|
**params: Unpack[Params],
|
|
369
392
|
) -> ContextStreamResponse[DepsT] | ContextStreamResponse[DepsT, FormattableT]:
|
|
370
393
|
"""Generate an `llm.ContextStreamResponse` by synchronously streaming from the API.
|
|
@@ -414,7 +437,10 @@ class BaseOpenAICompletionsProvider(BaseProvider[OpenAI]):
|
|
|
414
437
|
model_id: str,
|
|
415
438
|
messages: Sequence[Message],
|
|
416
439
|
tools: Sequence[AsyncTool] | AsyncToolkit | None = None,
|
|
417
|
-
format: type[FormattableT]
|
|
440
|
+
format: type[FormattableT]
|
|
441
|
+
| Format[FormattableT]
|
|
442
|
+
| OutputParser[FormattableT]
|
|
443
|
+
| None = None,
|
|
418
444
|
**params: Unpack[Params],
|
|
419
445
|
) -> AsyncStreamResponse | AsyncStreamResponse[FormattableT]:
|
|
420
446
|
"""Generate an `llm.AsyncStreamResponse` by asynchronously streaming from the API.
|
|
@@ -465,7 +491,10 @@ class BaseOpenAICompletionsProvider(BaseProvider[OpenAI]):
|
|
|
465
491
|
tools: Sequence[AsyncTool | AsyncContextTool[DepsT]]
|
|
466
492
|
| AsyncContextToolkit[DepsT]
|
|
467
493
|
| None = None,
|
|
468
|
-
format: type[FormattableT]
|
|
494
|
+
format: type[FormattableT]
|
|
495
|
+
| Format[FormattableT]
|
|
496
|
+
| OutputParser[FormattableT]
|
|
497
|
+
| None = None,
|
|
469
498
|
**params: Unpack[Params],
|
|
470
499
|
) -> (
|
|
471
500
|
AsyncContextStreamResponse[DepsT]
|
|
@@ -1,13 +1,16 @@
|
|
|
1
1
|
"""Unified OpenAI client implementation."""
|
|
2
2
|
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
3
5
|
from collections.abc import Sequence
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
4
7
|
from typing_extensions import Unpack
|
|
5
8
|
|
|
6
9
|
from openai import BadRequestError as OpenAIBadRequestError, OpenAI
|
|
7
10
|
|
|
8
11
|
from ...context import Context, DepsT
|
|
9
12
|
from ...exceptions import BadRequestError, NotFoundError
|
|
10
|
-
from ...formatting import Format, FormattableT
|
|
13
|
+
from ...formatting import Format, FormattableT, OutputParser
|
|
11
14
|
from ...messages import Message
|
|
12
15
|
from ...responses import (
|
|
13
16
|
AsyncContextResponse,
|
|
@@ -29,12 +32,15 @@ from ...tools import (
|
|
|
29
32
|
Tool,
|
|
30
33
|
Toolkit,
|
|
31
34
|
)
|
|
32
|
-
from ..base import BaseProvider
|
|
35
|
+
from ..base import BaseProvider
|
|
33
36
|
from . import _utils
|
|
34
37
|
from .completions import OpenAICompletionsProvider
|
|
35
38
|
from .model_id import OPENAI_KNOWN_MODELS, OpenAIModelId
|
|
36
39
|
from .responses import OpenAIResponsesProvider
|
|
37
40
|
|
|
41
|
+
if TYPE_CHECKING:
|
|
42
|
+
from ...models import Params
|
|
43
|
+
|
|
38
44
|
|
|
39
45
|
def _has_audio_content(messages: Sequence[Message]) -> bool:
|
|
40
46
|
"""Returns whether a sequence of messages contains any audio content."""
|
|
@@ -157,7 +163,10 @@ class OpenAIProvider(BaseProvider[OpenAI]):
|
|
|
157
163
|
model_id: OpenAIModelId,
|
|
158
164
|
messages: Sequence[Message],
|
|
159
165
|
tools: Sequence[Tool] | Toolkit | None = None,
|
|
160
|
-
format: type[FormattableT]
|
|
166
|
+
format: type[FormattableT]
|
|
167
|
+
| Format[FormattableT]
|
|
168
|
+
| OutputParser[FormattableT]
|
|
169
|
+
| None = None,
|
|
161
170
|
**params: Unpack[Params],
|
|
162
171
|
) -> Response | Response[FormattableT]:
|
|
163
172
|
"""Generate an `llm.Response` by synchronously calling the OpenAI API.
|
|
@@ -190,7 +199,10 @@ class OpenAIProvider(BaseProvider[OpenAI]):
|
|
|
190
199
|
tools: Sequence[Tool | ContextTool[DepsT]]
|
|
191
200
|
| ContextToolkit[DepsT]
|
|
192
201
|
| None = None,
|
|
193
|
-
format: type[FormattableT]
|
|
202
|
+
format: type[FormattableT]
|
|
203
|
+
| Format[FormattableT]
|
|
204
|
+
| OutputParser[FormattableT]
|
|
205
|
+
| None = None,
|
|
194
206
|
**params: Unpack[Params],
|
|
195
207
|
) -> ContextResponse[DepsT, None] | ContextResponse[DepsT, FormattableT]:
|
|
196
208
|
"""Generate an `llm.ContextResponse` by synchronously calling the OpenAI API.
|
|
@@ -222,7 +234,10 @@ class OpenAIProvider(BaseProvider[OpenAI]):
|
|
|
222
234
|
model_id: OpenAIModelId,
|
|
223
235
|
messages: Sequence[Message],
|
|
224
236
|
tools: Sequence[AsyncTool] | AsyncToolkit | None = None,
|
|
225
|
-
format: type[FormattableT]
|
|
237
|
+
format: type[FormattableT]
|
|
238
|
+
| Format[FormattableT]
|
|
239
|
+
| OutputParser[FormattableT]
|
|
240
|
+
| None = None,
|
|
226
241
|
**params: Unpack[Params],
|
|
227
242
|
) -> AsyncResponse | AsyncResponse[FormattableT]:
|
|
228
243
|
"""Generate an `llm.AsyncResponse` by asynchronously calling the OpenAI API.
|
|
@@ -254,7 +269,10 @@ class OpenAIProvider(BaseProvider[OpenAI]):
|
|
|
254
269
|
tools: Sequence[AsyncTool | AsyncContextTool[DepsT]]
|
|
255
270
|
| AsyncContextToolkit[DepsT]
|
|
256
271
|
| None = None,
|
|
257
|
-
format: type[FormattableT]
|
|
272
|
+
format: type[FormattableT]
|
|
273
|
+
| Format[FormattableT]
|
|
274
|
+
| OutputParser[FormattableT]
|
|
275
|
+
| None = None,
|
|
258
276
|
**params: Unpack[Params],
|
|
259
277
|
) -> AsyncContextResponse[DepsT, None] | AsyncContextResponse[DepsT, FormattableT]:
|
|
260
278
|
"""Generate an `llm.AsyncContextResponse` by asynchronously calling the OpenAI API.
|
|
@@ -285,7 +303,10 @@ class OpenAIProvider(BaseProvider[OpenAI]):
|
|
|
285
303
|
model_id: OpenAIModelId,
|
|
286
304
|
messages: Sequence[Message],
|
|
287
305
|
tools: Sequence[Tool] | Toolkit | None = None,
|
|
288
|
-
format: type[FormattableT]
|
|
306
|
+
format: type[FormattableT]
|
|
307
|
+
| Format[FormattableT]
|
|
308
|
+
| OutputParser[FormattableT]
|
|
309
|
+
| None = None,
|
|
289
310
|
**params: Unpack[Params],
|
|
290
311
|
) -> StreamResponse | StreamResponse[FormattableT]:
|
|
291
312
|
"""Generate an `llm.StreamResponse` by synchronously streaming from the OpenAI API.
|
|
@@ -318,7 +339,10 @@ class OpenAIProvider(BaseProvider[OpenAI]):
|
|
|
318
339
|
tools: Sequence[Tool | ContextTool[DepsT]]
|
|
319
340
|
| ContextToolkit[DepsT]
|
|
320
341
|
| None = None,
|
|
321
|
-
format: type[FormattableT]
|
|
342
|
+
format: type[FormattableT]
|
|
343
|
+
| Format[FormattableT]
|
|
344
|
+
| OutputParser[FormattableT]
|
|
345
|
+
| None = None,
|
|
322
346
|
**params: Unpack[Params],
|
|
323
347
|
) -> ContextStreamResponse[DepsT] | ContextStreamResponse[DepsT, FormattableT]:
|
|
324
348
|
"""Generate an `llm.ContextStreamResponse` by synchronously streaming from the OpenAI API.
|
|
@@ -350,7 +374,10 @@ class OpenAIProvider(BaseProvider[OpenAI]):
|
|
|
350
374
|
model_id: OpenAIModelId,
|
|
351
375
|
messages: Sequence[Message],
|
|
352
376
|
tools: Sequence[AsyncTool] | AsyncToolkit | None = None,
|
|
353
|
-
format: type[FormattableT]
|
|
377
|
+
format: type[FormattableT]
|
|
378
|
+
| Format[FormattableT]
|
|
379
|
+
| OutputParser[FormattableT]
|
|
380
|
+
| None = None,
|
|
354
381
|
**params: Unpack[Params],
|
|
355
382
|
) -> AsyncStreamResponse | AsyncStreamResponse[FormattableT]:
|
|
356
383
|
"""Generate an `llm.AsyncStreamResponse` by asynchronously streaming from the OpenAI API.
|
|
@@ -382,7 +409,10 @@ class OpenAIProvider(BaseProvider[OpenAI]):
|
|
|
382
409
|
tools: Sequence[AsyncTool | AsyncContextTool[DepsT]]
|
|
383
410
|
| AsyncContextToolkit[DepsT]
|
|
384
411
|
| None = None,
|
|
385
|
-
format: type[FormattableT]
|
|
412
|
+
format: type[FormattableT]
|
|
413
|
+
| Format[FormattableT]
|
|
414
|
+
| OutputParser[FormattableT]
|
|
415
|
+
| None = None,
|
|
386
416
|
**params: Unpack[Params],
|
|
387
417
|
) -> (
|
|
388
418
|
AsyncContextStreamResponse[DepsT]
|
|
@@ -78,6 +78,8 @@ def decode_response(
|
|
|
78
78
|
response: openai_types.Response,
|
|
79
79
|
model_id: OpenAIModelId,
|
|
80
80
|
provider_id: str,
|
|
81
|
+
*,
|
|
82
|
+
include_thoughts: bool,
|
|
81
83
|
) -> tuple[AssistantMessage, FinishReason | None, Usage | None]:
|
|
82
84
|
"""Convert OpenAI Responses Response to mirascope AssistantMessage and usage."""
|
|
83
85
|
parts: list[AssistantContentPart] = []
|
|
@@ -114,6 +116,9 @@ def decode_response(
|
|
|
114
116
|
else:
|
|
115
117
|
raise NotImplementedError(f"Unsupported output item: {output_item.type}")
|
|
116
118
|
|
|
119
|
+
if not include_thoughts:
|
|
120
|
+
parts = [part for part in parts if part.type != "thought"]
|
|
121
|
+
|
|
117
122
|
if refused:
|
|
118
123
|
finish_reason = FinishReason.REFUSAL
|
|
119
124
|
elif details := response.incomplete_details:
|
|
@@ -136,9 +141,10 @@ def decode_response(
|
|
|
136
141
|
class _OpenAIResponsesChunkProcessor:
|
|
137
142
|
"""Processes OpenAI Responses streaming events and maintains state across chunks."""
|
|
138
143
|
|
|
139
|
-
def __init__(self) -> None:
|
|
144
|
+
def __init__(self, *, include_thoughts: bool) -> None:
|
|
140
145
|
self.current_content_type: Literal["text", "tool_call", "thought"] | None = None
|
|
141
146
|
self.refusal_encountered = False
|
|
147
|
+
self.include_thoughts = include_thoughts
|
|
142
148
|
|
|
143
149
|
def process_chunk(self, event: ResponseStreamEvent) -> ChunkIterator:
|
|
144
150
|
"""Process a single OpenAI Responses stream event and yield the appropriate content chunks."""
|
|
@@ -182,14 +188,17 @@ class _OpenAIResponsesChunkProcessor:
|
|
|
182
188
|
or event.type == "response.reasoning_summary_text.delta"
|
|
183
189
|
):
|
|
184
190
|
if not self.current_content_type:
|
|
185
|
-
|
|
191
|
+
if self.include_thoughts:
|
|
192
|
+
yield ThoughtStartChunk()
|
|
186
193
|
self.current_content_type = "thought"
|
|
187
|
-
|
|
194
|
+
if self.include_thoughts:
|
|
195
|
+
yield ThoughtChunk(delta=event.delta)
|
|
188
196
|
elif (
|
|
189
197
|
event.type == "response.reasoning_summary_text.done"
|
|
190
198
|
or event.type == "response.reasoning_text.done"
|
|
191
199
|
):
|
|
192
|
-
|
|
200
|
+
if self.include_thoughts:
|
|
201
|
+
yield ThoughtEndChunk()
|
|
193
202
|
self.current_content_type = None
|
|
194
203
|
elif event.type == "response.incomplete":
|
|
195
204
|
details = event.response.incomplete_details
|
|
@@ -230,18 +239,22 @@ class _OpenAIResponsesChunkProcessor:
|
|
|
230
239
|
|
|
231
240
|
def decode_stream(
|
|
232
241
|
openai_stream: Stream[ResponseStreamEvent],
|
|
242
|
+
*,
|
|
243
|
+
include_thoughts: bool,
|
|
233
244
|
) -> ChunkIterator:
|
|
234
245
|
"""Returns a ChunkIterator converted from an OpenAI Stream[ResponseStreamEvent]"""
|
|
235
|
-
processor = _OpenAIResponsesChunkProcessor()
|
|
246
|
+
processor = _OpenAIResponsesChunkProcessor(include_thoughts=include_thoughts)
|
|
236
247
|
for event in openai_stream:
|
|
237
248
|
yield from processor.process_chunk(event)
|
|
238
249
|
|
|
239
250
|
|
|
240
251
|
async def decode_async_stream(
|
|
241
252
|
openai_stream: AsyncStream[ResponseStreamEvent],
|
|
253
|
+
*,
|
|
254
|
+
include_thoughts: bool,
|
|
242
255
|
) -> AsyncChunkIterator:
|
|
243
256
|
"""Returns an AsyncChunkIterator converted from an OpenAI AsyncStream[ResponseStreamEvent]"""
|
|
244
|
-
processor = _OpenAIResponsesChunkProcessor()
|
|
257
|
+
processor = _OpenAIResponsesChunkProcessor(include_thoughts=include_thoughts)
|
|
245
258
|
async for event in openai_stream:
|
|
246
259
|
for item in processor.process_chunk(event):
|
|
247
260
|
yield item
|