mirascope 2.0.0a6__py3-none-any.whl → 2.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mirascope/_utils.py +34 -0
- mirascope/api/_generated/__init__.py +186 -5
- mirascope/api/_generated/annotations/client.py +38 -6
- mirascope/api/_generated/annotations/raw_client.py +366 -47
- mirascope/api/_generated/annotations/types/annotations_create_response.py +19 -6
- mirascope/api/_generated/annotations/types/annotations_get_response.py +19 -6
- mirascope/api/_generated/annotations/types/annotations_list_response_annotations_item.py +22 -7
- mirascope/api/_generated/annotations/types/annotations_update_response.py +19 -6
- mirascope/api/_generated/api_keys/__init__.py +12 -2
- mirascope/api/_generated/api_keys/client.py +107 -6
- mirascope/api/_generated/api_keys/raw_client.py +486 -38
- mirascope/api/_generated/api_keys/types/__init__.py +7 -1
- mirascope/api/_generated/api_keys/types/api_keys_list_all_for_org_response_item.py +40 -0
- mirascope/api/_generated/client.py +36 -0
- mirascope/api/_generated/docs/raw_client.py +71 -9
- mirascope/api/_generated/environment.py +3 -3
- mirascope/api/_generated/environments/__init__.py +6 -0
- mirascope/api/_generated/environments/client.py +158 -9
- mirascope/api/_generated/environments/raw_client.py +620 -52
- mirascope/api/_generated/environments/types/__init__.py +10 -0
- mirascope/api/_generated/environments/types/environments_get_analytics_response.py +60 -0
- mirascope/api/_generated/environments/types/environments_get_analytics_response_top_functions_item.py +24 -0
- mirascope/api/_generated/{organizations/types/organizations_credits_response.py → environments/types/environments_get_analytics_response_top_models_item.py} +6 -3
- mirascope/api/_generated/errors/__init__.py +6 -0
- mirascope/api/_generated/errors/bad_request_error.py +5 -2
- mirascope/api/_generated/errors/conflict_error.py +5 -2
- mirascope/api/_generated/errors/payment_required_error.py +15 -0
- mirascope/api/_generated/errors/service_unavailable_error.py +14 -0
- mirascope/api/_generated/errors/too_many_requests_error.py +15 -0
- mirascope/api/_generated/functions/__init__.py +10 -0
- mirascope/api/_generated/functions/client.py +222 -8
- mirascope/api/_generated/functions/raw_client.py +975 -134
- mirascope/api/_generated/functions/types/__init__.py +28 -4
- mirascope/api/_generated/functions/types/functions_get_by_env_response.py +53 -0
- mirascope/api/_generated/functions/types/functions_get_by_env_response_dependencies_value.py +22 -0
- mirascope/api/_generated/functions/types/functions_list_by_env_response.py +25 -0
- mirascope/api/_generated/functions/types/functions_list_by_env_response_functions_item.py +56 -0
- mirascope/api/_generated/functions/types/functions_list_by_env_response_functions_item_dependencies_value.py +22 -0
- mirascope/api/_generated/health/raw_client.py +74 -10
- mirascope/api/_generated/organization_invitations/__init__.py +33 -0
- mirascope/api/_generated/organization_invitations/client.py +546 -0
- mirascope/api/_generated/organization_invitations/raw_client.py +1519 -0
- mirascope/api/_generated/organization_invitations/types/__init__.py +53 -0
- mirascope/api/_generated/organization_invitations/types/organization_invitations_accept_response.py +34 -0
- mirascope/api/_generated/organization_invitations/types/organization_invitations_accept_response_role.py +7 -0
- mirascope/api/_generated/organization_invitations/types/organization_invitations_create_request_role.py +7 -0
- mirascope/api/_generated/organization_invitations/types/organization_invitations_create_response.py +48 -0
- mirascope/api/_generated/organization_invitations/types/organization_invitations_create_response_role.py +7 -0
- mirascope/api/_generated/organization_invitations/types/organization_invitations_create_response_status.py +7 -0
- mirascope/api/_generated/organization_invitations/types/organization_invitations_get_response.py +48 -0
- mirascope/api/_generated/organization_invitations/types/organization_invitations_get_response_role.py +7 -0
- mirascope/api/_generated/organization_invitations/types/organization_invitations_get_response_status.py +7 -0
- mirascope/api/_generated/organization_invitations/types/organization_invitations_list_response_item.py +48 -0
- mirascope/api/_generated/organization_invitations/types/organization_invitations_list_response_item_role.py +7 -0
- mirascope/api/_generated/organization_invitations/types/organization_invitations_list_response_item_status.py +7 -0
- mirascope/api/_generated/organization_memberships/__init__.py +19 -0
- mirascope/api/_generated/organization_memberships/client.py +302 -0
- mirascope/api/_generated/organization_memberships/raw_client.py +736 -0
- mirascope/api/_generated/organization_memberships/types/__init__.py +27 -0
- mirascope/api/_generated/organization_memberships/types/organization_memberships_list_response_item.py +33 -0
- mirascope/api/_generated/organization_memberships/types/organization_memberships_list_response_item_role.py +7 -0
- mirascope/api/_generated/organization_memberships/types/organization_memberships_update_request_role.py +7 -0
- mirascope/api/_generated/organization_memberships/types/organization_memberships_update_response.py +31 -0
- mirascope/api/_generated/organization_memberships/types/organization_memberships_update_response_role.py +7 -0
- mirascope/api/_generated/organizations/__init__.py +26 -2
- mirascope/api/_generated/organizations/client.py +442 -20
- mirascope/api/_generated/organizations/raw_client.py +1763 -164
- mirascope/api/_generated/organizations/types/__init__.py +48 -2
- mirascope/api/_generated/organizations/types/organizations_create_payment_intent_response.py +24 -0
- mirascope/api/_generated/organizations/types/organizations_preview_subscription_change_request_target_plan.py +7 -0
- mirascope/api/_generated/organizations/types/organizations_preview_subscription_change_response.py +47 -0
- mirascope/api/_generated/organizations/types/organizations_preview_subscription_change_response_validation_errors_item.py +33 -0
- mirascope/api/_generated/organizations/types/organizations_preview_subscription_change_response_validation_errors_item_resource.py +7 -0
- mirascope/api/_generated/organizations/types/organizations_router_balance_response.py +24 -0
- mirascope/api/_generated/organizations/types/organizations_subscription_response.py +53 -0
- mirascope/api/_generated/organizations/types/organizations_subscription_response_current_plan.py +7 -0
- mirascope/api/_generated/organizations/types/organizations_subscription_response_payment_method.py +26 -0
- mirascope/api/_generated/organizations/types/organizations_subscription_response_scheduled_change.py +34 -0
- mirascope/api/_generated/organizations/types/organizations_subscription_response_scheduled_change_target_plan.py +7 -0
- mirascope/api/_generated/organizations/types/organizations_update_subscription_request_target_plan.py +7 -0
- mirascope/api/_generated/organizations/types/organizations_update_subscription_response.py +35 -0
- mirascope/api/_generated/project_memberships/__init__.py +25 -0
- mirascope/api/_generated/project_memberships/client.py +437 -0
- mirascope/api/_generated/project_memberships/raw_client.py +1039 -0
- mirascope/api/_generated/project_memberships/types/__init__.py +29 -0
- mirascope/api/_generated/project_memberships/types/project_memberships_create_request_role.py +7 -0
- mirascope/api/_generated/project_memberships/types/project_memberships_create_response.py +35 -0
- mirascope/api/_generated/project_memberships/types/project_memberships_create_response_role.py +7 -0
- mirascope/api/_generated/project_memberships/types/project_memberships_list_response_item.py +33 -0
- mirascope/api/_generated/project_memberships/types/project_memberships_list_response_item_role.py +7 -0
- mirascope/api/_generated/project_memberships/types/project_memberships_update_request_role.py +7 -0
- mirascope/api/_generated/project_memberships/types/project_memberships_update_response.py +35 -0
- mirascope/api/_generated/project_memberships/types/project_memberships_update_response_role.py +7 -0
- mirascope/api/_generated/projects/raw_client.py +415 -58
- mirascope/api/_generated/reference.md +2767 -397
- mirascope/api/_generated/tags/__init__.py +19 -0
- mirascope/api/_generated/tags/client.py +504 -0
- mirascope/api/_generated/tags/raw_client.py +1288 -0
- mirascope/api/_generated/tags/types/__init__.py +17 -0
- mirascope/api/_generated/tags/types/tags_create_response.py +41 -0
- mirascope/api/_generated/tags/types/tags_get_response.py +41 -0
- mirascope/api/_generated/tags/types/tags_list_response.py +23 -0
- mirascope/api/_generated/tags/types/tags_list_response_tags_item.py +41 -0
- mirascope/api/_generated/tags/types/tags_update_response.py +41 -0
- mirascope/api/_generated/token_cost/__init__.py +7 -0
- mirascope/api/_generated/token_cost/client.py +160 -0
- mirascope/api/_generated/token_cost/raw_client.py +264 -0
- mirascope/api/_generated/token_cost/types/__init__.py +8 -0
- mirascope/api/_generated/token_cost/types/token_cost_calculate_request_usage.py +54 -0
- mirascope/api/_generated/token_cost/types/token_cost_calculate_response.py +52 -0
- mirascope/api/_generated/traces/__init__.py +20 -0
- mirascope/api/_generated/traces/client.py +543 -0
- mirascope/api/_generated/traces/raw_client.py +1366 -96
- mirascope/api/_generated/traces/types/__init__.py +28 -0
- mirascope/api/_generated/traces/types/traces_get_analytics_summary_response.py +6 -0
- mirascope/api/_generated/traces/types/traces_get_trace_detail_by_env_response.py +33 -0
- mirascope/api/_generated/traces/types/traces_get_trace_detail_by_env_response_spans_item.py +88 -0
- mirascope/api/_generated/traces/types/traces_get_trace_detail_response_spans_item.py +0 -2
- mirascope/api/_generated/traces/types/traces_list_by_function_hash_response.py +25 -0
- mirascope/api/_generated/traces/types/traces_list_by_function_hash_response_traces_item.py +44 -0
- mirascope/api/_generated/traces/types/traces_search_by_env_request_attribute_filters_item.py +26 -0
- mirascope/api/_generated/traces/types/traces_search_by_env_request_attribute_filters_item_operator.py +7 -0
- mirascope/api/_generated/traces/types/traces_search_by_env_request_sort_by.py +7 -0
- mirascope/api/_generated/traces/types/traces_search_by_env_request_sort_order.py +7 -0
- mirascope/api/_generated/traces/types/traces_search_by_env_response.py +26 -0
- mirascope/api/_generated/traces/types/traces_search_by_env_response_spans_item.py +50 -0
- mirascope/api/_generated/traces/types/traces_search_response_spans_item.py +10 -1
- mirascope/api/_generated/types/__init__.py +32 -2
- mirascope/api/_generated/types/bad_request_error_body.py +50 -0
- mirascope/api/_generated/types/date.py +3 -0
- mirascope/api/_generated/types/immutable_resource_error.py +22 -0
- mirascope/api/_generated/types/internal_server_error_body.py +3 -3
- mirascope/api/_generated/types/plan_limit_exceeded_error.py +32 -0
- mirascope/api/_generated/types/plan_limit_exceeded_error_tag.py +7 -0
- mirascope/api/_generated/types/pricing_unavailable_error.py +23 -0
- mirascope/api/_generated/types/rate_limit_error.py +31 -0
- mirascope/api/_generated/types/rate_limit_error_tag.py +5 -0
- mirascope/api/_generated/types/service_unavailable_error_body.py +24 -0
- mirascope/api/_generated/types/service_unavailable_error_tag.py +7 -0
- mirascope/api/_generated/types/subscription_past_due_error.py +31 -0
- mirascope/api/_generated/types/subscription_past_due_error_tag.py +7 -0
- mirascope/api/settings.py +19 -1
- mirascope/llm/__init__.py +53 -10
- mirascope/llm/calls/__init__.py +2 -1
- mirascope/llm/calls/calls.py +29 -20
- mirascope/llm/calls/decorator.py +21 -7
- mirascope/llm/content/tool_output.py +22 -5
- mirascope/llm/exceptions.py +284 -71
- mirascope/llm/formatting/__init__.py +17 -0
- mirascope/llm/formatting/format.py +112 -35
- mirascope/llm/formatting/output_parser.py +178 -0
- mirascope/llm/formatting/partial.py +80 -7
- mirascope/llm/formatting/primitives.py +192 -0
- mirascope/llm/formatting/types.py +20 -8
- mirascope/llm/messages/__init__.py +3 -0
- mirascope/llm/messages/_utils.py +34 -0
- mirascope/llm/models/__init__.py +5 -0
- mirascope/llm/models/models.py +137 -69
- mirascope/llm/{providers/base → models}/params.py +7 -57
- mirascope/llm/models/thinking_config.py +61 -0
- mirascope/llm/prompts/_utils.py +0 -32
- mirascope/llm/prompts/decorator.py +16 -5
- mirascope/llm/prompts/prompts.py +160 -92
- mirascope/llm/providers/__init__.py +1 -4
- mirascope/llm/providers/anthropic/_utils/__init__.py +2 -0
- mirascope/llm/providers/anthropic/_utils/beta_decode.py +18 -9
- mirascope/llm/providers/anthropic/_utils/beta_encode.py +62 -13
- mirascope/llm/providers/anthropic/_utils/decode.py +18 -9
- mirascope/llm/providers/anthropic/_utils/encode.py +26 -7
- mirascope/llm/providers/anthropic/_utils/errors.py +2 -2
- mirascope/llm/providers/anthropic/beta_provider.py +64 -18
- mirascope/llm/providers/anthropic/provider.py +91 -33
- mirascope/llm/providers/base/__init__.py +0 -4
- mirascope/llm/providers/base/_utils.py +55 -6
- mirascope/llm/providers/base/base_provider.py +116 -37
- mirascope/llm/providers/google/_utils/__init__.py +2 -0
- mirascope/llm/providers/google/_utils/decode.py +20 -7
- mirascope/llm/providers/google/_utils/encode.py +26 -7
- mirascope/llm/providers/google/_utils/errors.py +3 -2
- mirascope/llm/providers/google/provider.py +64 -18
- mirascope/llm/providers/mirascope/_utils.py +13 -17
- mirascope/llm/providers/mirascope/provider.py +49 -18
- mirascope/llm/providers/mlx/_utils.py +7 -2
- mirascope/llm/providers/mlx/encoding/base.py +5 -2
- mirascope/llm/providers/mlx/encoding/transformers.py +5 -2
- mirascope/llm/providers/mlx/mlx.py +23 -6
- mirascope/llm/providers/mlx/provider.py +42 -13
- mirascope/llm/providers/openai/_utils/errors.py +2 -2
- mirascope/llm/providers/openai/completions/_utils/encode.py +20 -16
- mirascope/llm/providers/openai/completions/base_provider.py +40 -11
- mirascope/llm/providers/openai/provider.py +40 -10
- mirascope/llm/providers/openai/responses/_utils/__init__.py +2 -0
- mirascope/llm/providers/openai/responses/_utils/decode.py +19 -6
- mirascope/llm/providers/openai/responses/_utils/encode.py +22 -10
- mirascope/llm/providers/openai/responses/provider.py +56 -18
- mirascope/llm/providers/provider_registry.py +93 -19
- mirascope/llm/responses/__init__.py +6 -1
- mirascope/llm/responses/_utils.py +102 -12
- mirascope/llm/responses/base_response.py +5 -2
- mirascope/llm/responses/base_stream_response.py +115 -25
- mirascope/llm/responses/response.py +2 -1
- mirascope/llm/responses/root_response.py +89 -17
- mirascope/llm/responses/stream_response.py +6 -9
- mirascope/llm/tools/decorator.py +9 -4
- mirascope/llm/tools/tool_schema.py +17 -6
- mirascope/llm/tools/toolkit.py +35 -27
- mirascope/llm/tools/tools.py +45 -20
- mirascope/ops/__init__.py +4 -0
- mirascope/ops/_internal/closure.py +4 -1
- mirascope/ops/_internal/configuration.py +82 -31
- mirascope/ops/_internal/exporters/exporters.py +55 -35
- mirascope/ops/_internal/exporters/utils.py +37 -0
- mirascope/ops/_internal/instrumentation/llm/common.py +530 -0
- mirascope/ops/_internal/instrumentation/llm/cost.py +190 -0
- mirascope/ops/_internal/instrumentation/llm/encode.py +1 -1
- mirascope/ops/_internal/instrumentation/llm/llm.py +116 -1242
- mirascope/ops/_internal/instrumentation/llm/model.py +1798 -0
- mirascope/ops/_internal/instrumentation/llm/response.py +521 -0
- mirascope/ops/_internal/instrumentation/llm/serialize.py +300 -0
- mirascope/ops/_internal/protocols.py +83 -1
- mirascope/ops/_internal/traced_calls.py +18 -0
- mirascope/ops/_internal/traced_functions.py +125 -10
- mirascope/ops/_internal/tracing.py +78 -1
- mirascope/ops/_internal/utils.py +60 -4
- mirascope/ops/_internal/versioned_functions.py +1 -1
- {mirascope-2.0.0a6.dist-info → mirascope-2.0.2.dist-info}/METADATA +12 -11
- mirascope-2.0.2.dist-info/RECORD +424 -0
- {mirascope-2.0.0a6.dist-info → mirascope-2.0.2.dist-info}/licenses/LICENSE +1 -1
- mirascope-2.0.0a6.dist-info/RECORD +0 -316
- {mirascope-2.0.0a6.dist-info → mirascope-2.0.2.dist-info}/WHEEL +0 -0
mirascope/llm/prompts/prompts.py
CHANGED
|
@@ -1,11 +1,15 @@
|
|
|
1
1
|
"""Concrete Prompt classes for generating messages with tools and formatting."""
|
|
2
2
|
|
|
3
|
-
from
|
|
4
|
-
from
|
|
3
|
+
from collections.abc import Callable, Sequence
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from typing import Any, Generic, TypeVar, overload
|
|
5
6
|
|
|
7
|
+
from ..._utils import copy_function_metadata
|
|
6
8
|
from ..context import Context, DepsT
|
|
7
|
-
from ..formatting import Format, FormattableT
|
|
9
|
+
from ..formatting import Format, FormattableT, OutputParser
|
|
10
|
+
from ..messages import Message, promote_to_messages
|
|
8
11
|
from ..models import Model
|
|
12
|
+
from ..providers import ModelId
|
|
9
13
|
from ..responses import (
|
|
10
14
|
AsyncContextResponse,
|
|
11
15
|
AsyncContextStreamResponse,
|
|
@@ -16,14 +20,8 @@ from ..responses import (
|
|
|
16
20
|
Response,
|
|
17
21
|
StreamResponse,
|
|
18
22
|
)
|
|
19
|
-
from ..tools import
|
|
20
|
-
AsyncContextToolkit,
|
|
21
|
-
AsyncToolkit,
|
|
22
|
-
ContextToolkit,
|
|
23
|
-
Toolkit,
|
|
24
|
-
)
|
|
23
|
+
from ..tools import AsyncContextToolkit, AsyncToolkit, ContextToolkit, Toolkit
|
|
25
24
|
from ..types import P
|
|
26
|
-
from . import _utils
|
|
27
25
|
from .protocols import (
|
|
28
26
|
AsyncContextMessageTemplate,
|
|
29
27
|
AsyncMessageTemplate,
|
|
@@ -31,9 +29,26 @@ from .protocols import (
|
|
|
31
29
|
MessageTemplate,
|
|
32
30
|
)
|
|
33
31
|
|
|
32
|
+
FunctionT = TypeVar("FunctionT", bound=Callable[..., Any])
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass(kw_only=True)
|
|
36
|
+
class BasePrompt(Generic[FunctionT]):
|
|
37
|
+
"""Base class for all Prompt types with shared metadata functionality."""
|
|
38
|
+
|
|
39
|
+
fn: FunctionT
|
|
40
|
+
"""The underlying prompt function that generates message content."""
|
|
41
|
+
|
|
42
|
+
__name__: str = field(init=False, repr=False, default="")
|
|
43
|
+
"""The name of the underlying function (preserved for decorator stacking)."""
|
|
44
|
+
|
|
45
|
+
def __post_init__(self) -> None:
|
|
46
|
+
"""Preserve standard function attributes for decorator stacking."""
|
|
47
|
+
copy_function_metadata(self, self.fn)
|
|
48
|
+
|
|
34
49
|
|
|
35
50
|
@dataclass
|
|
36
|
-
class Prompt(Generic[P, FormattableT]):
|
|
51
|
+
class Prompt(BasePrompt[MessageTemplate[P]], Generic[P, FormattableT]):
|
|
37
52
|
"""A prompt that can be called with a model to generate a response.
|
|
38
53
|
|
|
39
54
|
Created by decorating a `MessageTemplate` with `llm.prompt`. The decorated
|
|
@@ -43,70 +58,93 @@ class Prompt(Generic[P, FormattableT]):
|
|
|
43
58
|
It can be invoked with a model: `prompt(model, *args, **kwargs)`.
|
|
44
59
|
"""
|
|
45
60
|
|
|
46
|
-
fn: MessageTemplate[P]
|
|
47
|
-
"""The underlying prompt function that generates message content."""
|
|
48
|
-
|
|
49
61
|
toolkit: Toolkit
|
|
50
62
|
"""The toolkit containing this prompt's tools."""
|
|
51
63
|
|
|
52
|
-
format:
|
|
64
|
+
format: (
|
|
65
|
+
type[FormattableT] | Format[FormattableT] | OutputParser[FormattableT] | None
|
|
66
|
+
)
|
|
53
67
|
"""The response format for the generated response."""
|
|
54
68
|
|
|
69
|
+
def messages(self, *args: P.args, **kwargs: P.kwargs) -> Sequence[Message]:
|
|
70
|
+
"""Return the `Messages` from invoking this prompt."""
|
|
71
|
+
return promote_to_messages(self.fn(*args, **kwargs))
|
|
72
|
+
|
|
55
73
|
@overload
|
|
56
74
|
def __call__(
|
|
57
|
-
self: "Prompt[P, None]",
|
|
75
|
+
self: "Prompt[P, None]",
|
|
76
|
+
model: Model | ModelId,
|
|
77
|
+
*args: P.args,
|
|
78
|
+
**kwargs: P.kwargs,
|
|
58
79
|
) -> Response: ...
|
|
59
80
|
|
|
60
81
|
@overload
|
|
61
82
|
def __call__(
|
|
62
|
-
self: "Prompt[P, FormattableT]",
|
|
83
|
+
self: "Prompt[P, FormattableT]",
|
|
84
|
+
model: Model | ModelId,
|
|
85
|
+
*args: P.args,
|
|
86
|
+
**kwargs: P.kwargs,
|
|
63
87
|
) -> Response[FormattableT]: ...
|
|
64
88
|
|
|
65
89
|
def __call__(
|
|
66
|
-
self, model: Model, *args: P.args, **kwargs: P.kwargs
|
|
90
|
+
self, model: Model | ModelId, *args: P.args, **kwargs: P.kwargs
|
|
67
91
|
) -> Response | Response[FormattableT]:
|
|
68
92
|
"""Generates a response using the provided model."""
|
|
69
93
|
return self.call(model, *args, **kwargs)
|
|
70
94
|
|
|
71
95
|
@overload
|
|
72
96
|
def call(
|
|
73
|
-
self: "Prompt[P, None]",
|
|
97
|
+
self: "Prompt[P, None]",
|
|
98
|
+
model: Model | ModelId,
|
|
99
|
+
*args: P.args,
|
|
100
|
+
**kwargs: P.kwargs,
|
|
74
101
|
) -> Response: ...
|
|
75
102
|
|
|
76
103
|
@overload
|
|
77
104
|
def call(
|
|
78
|
-
self: "Prompt[P, FormattableT]",
|
|
105
|
+
self: "Prompt[P, FormattableT]",
|
|
106
|
+
model: Model | ModelId,
|
|
107
|
+
*args: P.args,
|
|
108
|
+
**kwargs: P.kwargs,
|
|
79
109
|
) -> Response[FormattableT]: ...
|
|
80
110
|
|
|
81
111
|
def call(
|
|
82
|
-
self, model: Model, *args: P.args, **kwargs: P.kwargs
|
|
112
|
+
self, model: Model | ModelId, *args: P.args, **kwargs: P.kwargs
|
|
83
113
|
) -> Response | Response[FormattableT]:
|
|
84
114
|
"""Generates a response using the provided model."""
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
115
|
+
if isinstance(model, str):
|
|
116
|
+
model = Model(model)
|
|
117
|
+
messages = self.messages(*args, **kwargs)
|
|
118
|
+
return model.call(messages, tools=self.toolkit, format=self.format)
|
|
88
119
|
|
|
89
120
|
@overload
|
|
90
121
|
def stream(
|
|
91
|
-
self: "Prompt[P, None]",
|
|
122
|
+
self: "Prompt[P, None]",
|
|
123
|
+
model: Model | ModelId,
|
|
124
|
+
*args: P.args,
|
|
125
|
+
**kwargs: P.kwargs,
|
|
92
126
|
) -> StreamResponse: ...
|
|
93
127
|
|
|
94
128
|
@overload
|
|
95
129
|
def stream(
|
|
96
|
-
self: "Prompt[P, FormattableT]",
|
|
130
|
+
self: "Prompt[P, FormattableT]",
|
|
131
|
+
model: Model | ModelId,
|
|
132
|
+
*args: P.args,
|
|
133
|
+
**kwargs: P.kwargs,
|
|
97
134
|
) -> StreamResponse[FormattableT]: ...
|
|
98
135
|
|
|
99
136
|
def stream(
|
|
100
|
-
self, model: Model, *args: P.args, **kwargs: P.kwargs
|
|
137
|
+
self, model: Model | ModelId, *args: P.args, **kwargs: P.kwargs
|
|
101
138
|
) -> StreamResponse | StreamResponse[FormattableT]:
|
|
102
139
|
"""Generates a streaming response using the provided model."""
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
140
|
+
if isinstance(model, str):
|
|
141
|
+
model = Model(model)
|
|
142
|
+
messages = self.messages(*args, **kwargs)
|
|
143
|
+
return model.stream(messages, tools=self.toolkit, format=self.format)
|
|
106
144
|
|
|
107
145
|
|
|
108
146
|
@dataclass
|
|
109
|
-
class AsyncPrompt(Generic[P, FormattableT]):
|
|
147
|
+
class AsyncPrompt(BasePrompt[AsyncMessageTemplate[P]], Generic[P, FormattableT]):
|
|
110
148
|
"""An async prompt that can be called with a model to generate a response.
|
|
111
149
|
|
|
112
150
|
Created by decorating an async `MessageTemplate` with `llm.prompt`. The decorated
|
|
@@ -116,83 +154,97 @@ class AsyncPrompt(Generic[P, FormattableT]):
|
|
|
116
154
|
It can be invoked with a model: `await prompt(model, *args, **kwargs)`.
|
|
117
155
|
"""
|
|
118
156
|
|
|
119
|
-
fn: AsyncMessageTemplate[P]
|
|
120
|
-
"""The underlying async prompt function that generates message content."""
|
|
121
|
-
|
|
122
157
|
toolkit: AsyncToolkit
|
|
123
158
|
"""The toolkit containing this prompt's async tools."""
|
|
124
159
|
|
|
125
|
-
format:
|
|
160
|
+
format: (
|
|
161
|
+
type[FormattableT] | Format[FormattableT] | OutputParser[FormattableT] | None
|
|
162
|
+
)
|
|
126
163
|
"""The response format for the generated response."""
|
|
127
164
|
|
|
165
|
+
async def messages(self, *args: P.args, **kwargs: P.kwargs) -> Sequence[Message]:
|
|
166
|
+
"""Return the `Messages` from invoking this prompt."""
|
|
167
|
+
return promote_to_messages(await self.fn(*args, **kwargs))
|
|
168
|
+
|
|
128
169
|
@overload
|
|
129
170
|
async def __call__(
|
|
130
|
-
self: "AsyncPrompt[P, None]",
|
|
171
|
+
self: "AsyncPrompt[P, None]",
|
|
172
|
+
model: Model | ModelId,
|
|
173
|
+
*args: P.args,
|
|
174
|
+
**kwargs: P.kwargs,
|
|
131
175
|
) -> AsyncResponse: ...
|
|
132
176
|
|
|
133
177
|
@overload
|
|
134
178
|
async def __call__(
|
|
135
179
|
self: "AsyncPrompt[P, FormattableT]",
|
|
136
|
-
model: Model,
|
|
180
|
+
model: Model | ModelId,
|
|
137
181
|
*args: P.args,
|
|
138
182
|
**kwargs: P.kwargs,
|
|
139
183
|
) -> AsyncResponse[FormattableT]: ...
|
|
140
184
|
|
|
141
185
|
async def __call__(
|
|
142
|
-
self, model: Model, *args: P.args, **kwargs: P.kwargs
|
|
186
|
+
self, model: Model | ModelId, *args: P.args, **kwargs: P.kwargs
|
|
143
187
|
) -> AsyncResponse | AsyncResponse[FormattableT]:
|
|
144
188
|
"""Generates a response using the provided model asynchronously."""
|
|
145
189
|
return await self.call(model, *args, **kwargs)
|
|
146
190
|
|
|
147
191
|
@overload
|
|
148
192
|
async def call(
|
|
149
|
-
self: "AsyncPrompt[P, None]",
|
|
193
|
+
self: "AsyncPrompt[P, None]",
|
|
194
|
+
model: Model | ModelId,
|
|
195
|
+
*args: P.args,
|
|
196
|
+
**kwargs: P.kwargs,
|
|
150
197
|
) -> AsyncResponse: ...
|
|
151
198
|
|
|
152
199
|
@overload
|
|
153
200
|
async def call(
|
|
154
201
|
self: "AsyncPrompt[P, FormattableT]",
|
|
155
|
-
model: Model,
|
|
202
|
+
model: Model | ModelId,
|
|
156
203
|
*args: P.args,
|
|
157
204
|
**kwargs: P.kwargs,
|
|
158
205
|
) -> AsyncResponse[FormattableT]: ...
|
|
159
206
|
|
|
160
207
|
async def call(
|
|
161
|
-
self, model: Model, *args: P.args, **kwargs: P.kwargs
|
|
208
|
+
self, model: Model | ModelId, *args: P.args, **kwargs: P.kwargs
|
|
162
209
|
) -> AsyncResponse | AsyncResponse[FormattableT]:
|
|
163
210
|
"""Generates a response using the provided model asynchronously."""
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
)
|
|
211
|
+
if isinstance(model, str):
|
|
212
|
+
model = Model(model)
|
|
213
|
+
messages = await self.messages(*args, **kwargs)
|
|
214
|
+
return await model.call_async(messages, tools=self.toolkit, format=self.format)
|
|
169
215
|
|
|
170
216
|
@overload
|
|
171
217
|
async def stream(
|
|
172
|
-
self: "AsyncPrompt[P, None]",
|
|
218
|
+
self: "AsyncPrompt[P, None]",
|
|
219
|
+
model: Model | ModelId,
|
|
220
|
+
*args: P.args,
|
|
221
|
+
**kwargs: P.kwargs,
|
|
173
222
|
) -> AsyncStreamResponse: ...
|
|
174
223
|
|
|
175
224
|
@overload
|
|
176
225
|
async def stream(
|
|
177
226
|
self: "AsyncPrompt[P, FormattableT]",
|
|
178
|
-
model: Model,
|
|
227
|
+
model: Model | ModelId,
|
|
179
228
|
*args: P.args,
|
|
180
229
|
**kwargs: P.kwargs,
|
|
181
230
|
) -> AsyncStreamResponse[FormattableT]: ...
|
|
182
231
|
|
|
183
232
|
async def stream(
|
|
184
|
-
self, model: Model, *args: P.args, **kwargs: P.kwargs
|
|
233
|
+
self, model: Model | ModelId, *args: P.args, **kwargs: P.kwargs
|
|
185
234
|
) -> AsyncStreamResponse | AsyncStreamResponse[FormattableT]:
|
|
186
235
|
"""Generates a streaming response using the provided model asynchronously."""
|
|
187
|
-
|
|
188
|
-
|
|
236
|
+
if isinstance(model, str):
|
|
237
|
+
model = Model(model)
|
|
238
|
+
messages = await self.messages(*args, **kwargs)
|
|
189
239
|
return await model.stream_async(
|
|
190
|
-
messages
|
|
240
|
+
messages, tools=self.toolkit, format=self.format
|
|
191
241
|
)
|
|
192
242
|
|
|
193
243
|
|
|
194
244
|
@dataclass
|
|
195
|
-
class ContextPrompt(
|
|
245
|
+
class ContextPrompt(
|
|
246
|
+
BasePrompt[ContextMessageTemplate[P, DepsT]], Generic[P, DepsT, FormattableT]
|
|
247
|
+
):
|
|
196
248
|
"""A context-aware prompt that can be called with a model to generate a response.
|
|
197
249
|
|
|
198
250
|
Created by decorating a `ContextMessageTemplate` with `llm.prompt`. The decorated
|
|
@@ -203,19 +255,24 @@ class ContextPrompt(Generic[P, DepsT, FormattableT]):
|
|
|
203
255
|
It can be invoked with a model: `prompt(model, ctx, *args, **kwargs)`.
|
|
204
256
|
"""
|
|
205
257
|
|
|
206
|
-
fn: ContextMessageTemplate[P, DepsT]
|
|
207
|
-
"""The underlying context-aware prompt function that generates message content."""
|
|
208
|
-
|
|
209
258
|
toolkit: ContextToolkit[DepsT]
|
|
210
259
|
"""The toolkit containing this prompt's context-aware tools."""
|
|
211
260
|
|
|
212
|
-
format:
|
|
261
|
+
format: (
|
|
262
|
+
type[FormattableT] | Format[FormattableT] | OutputParser[FormattableT] | None
|
|
263
|
+
)
|
|
213
264
|
"""The response format for the generated response."""
|
|
214
265
|
|
|
266
|
+
def messages(
|
|
267
|
+
self, ctx: Context[DepsT], *args: P.args, **kwargs: P.kwargs
|
|
268
|
+
) -> Sequence[Message]:
|
|
269
|
+
"""Return the `Messages` from invoking this prompt."""
|
|
270
|
+
return promote_to_messages(self.fn(ctx, *args, **kwargs))
|
|
271
|
+
|
|
215
272
|
@overload
|
|
216
273
|
def __call__(
|
|
217
274
|
self: "ContextPrompt[P, DepsT, None]",
|
|
218
|
-
model: Model,
|
|
275
|
+
model: Model | ModelId,
|
|
219
276
|
ctx: Context[DepsT],
|
|
220
277
|
*args: P.args,
|
|
221
278
|
**kwargs: P.kwargs,
|
|
@@ -224,7 +281,7 @@ class ContextPrompt(Generic[P, DepsT, FormattableT]):
|
|
|
224
281
|
@overload
|
|
225
282
|
def __call__(
|
|
226
283
|
self: "ContextPrompt[P, DepsT, FormattableT]",
|
|
227
|
-
model: Model,
|
|
284
|
+
model: Model | ModelId,
|
|
228
285
|
ctx: Context[DepsT],
|
|
229
286
|
*args: P.args,
|
|
230
287
|
**kwargs: P.kwargs,
|
|
@@ -232,7 +289,7 @@ class ContextPrompt(Generic[P, DepsT, FormattableT]):
|
|
|
232
289
|
|
|
233
290
|
def __call__(
|
|
234
291
|
self,
|
|
235
|
-
model: Model,
|
|
292
|
+
model: Model | ModelId,
|
|
236
293
|
ctx: Context[DepsT],
|
|
237
294
|
*args: P.args,
|
|
238
295
|
**kwargs: P.kwargs,
|
|
@@ -243,7 +300,7 @@ class ContextPrompt(Generic[P, DepsT, FormattableT]):
|
|
|
243
300
|
@overload
|
|
244
301
|
def call(
|
|
245
302
|
self: "ContextPrompt[P, DepsT, None]",
|
|
246
|
-
model: Model,
|
|
303
|
+
model: Model | ModelId,
|
|
247
304
|
ctx: Context[DepsT],
|
|
248
305
|
*args: P.args,
|
|
249
306
|
**kwargs: P.kwargs,
|
|
@@ -252,7 +309,7 @@ class ContextPrompt(Generic[P, DepsT, FormattableT]):
|
|
|
252
309
|
@overload
|
|
253
310
|
def call(
|
|
254
311
|
self: "ContextPrompt[P, DepsT, FormattableT]",
|
|
255
|
-
model: Model,
|
|
312
|
+
model: Model | ModelId,
|
|
256
313
|
ctx: Context[DepsT],
|
|
257
314
|
*args: P.args,
|
|
258
315
|
**kwargs: P.kwargs,
|
|
@@ -260,22 +317,23 @@ class ContextPrompt(Generic[P, DepsT, FormattableT]):
|
|
|
260
317
|
|
|
261
318
|
def call(
|
|
262
319
|
self,
|
|
263
|
-
model: Model,
|
|
320
|
+
model: Model | ModelId,
|
|
264
321
|
ctx: Context[DepsT],
|
|
265
322
|
*args: P.args,
|
|
266
323
|
**kwargs: P.kwargs,
|
|
267
324
|
) -> ContextResponse[DepsT, None] | ContextResponse[DepsT, FormattableT]:
|
|
268
325
|
"""Generates a response using the provided model."""
|
|
269
|
-
|
|
270
|
-
|
|
326
|
+
if isinstance(model, str):
|
|
327
|
+
model = Model(model)
|
|
328
|
+
messages = self.messages(ctx, *args, **kwargs)
|
|
271
329
|
return model.context_call(
|
|
272
|
-
ctx=ctx,
|
|
330
|
+
messages, ctx=ctx, tools=self.toolkit, format=self.format
|
|
273
331
|
)
|
|
274
332
|
|
|
275
333
|
@overload
|
|
276
334
|
def stream(
|
|
277
335
|
self: "ContextPrompt[P, DepsT, None]",
|
|
278
|
-
model: Model,
|
|
336
|
+
model: Model | ModelId,
|
|
279
337
|
ctx: Context[DepsT],
|
|
280
338
|
*args: P.args,
|
|
281
339
|
**kwargs: P.kwargs,
|
|
@@ -284,7 +342,7 @@ class ContextPrompt(Generic[P, DepsT, FormattableT]):
|
|
|
284
342
|
@overload
|
|
285
343
|
def stream(
|
|
286
344
|
self: "ContextPrompt[P, DepsT, FormattableT]",
|
|
287
|
-
model: Model,
|
|
345
|
+
model: Model | ModelId,
|
|
288
346
|
ctx: Context[DepsT],
|
|
289
347
|
*args: P.args,
|
|
290
348
|
**kwargs: P.kwargs,
|
|
@@ -292,7 +350,7 @@ class ContextPrompt(Generic[P, DepsT, FormattableT]):
|
|
|
292
350
|
|
|
293
351
|
def stream(
|
|
294
352
|
self,
|
|
295
|
-
model: Model,
|
|
353
|
+
model: Model | ModelId,
|
|
296
354
|
ctx: Context[DepsT],
|
|
297
355
|
*args: P.args,
|
|
298
356
|
**kwargs: P.kwargs,
|
|
@@ -300,15 +358,18 @@ class ContextPrompt(Generic[P, DepsT, FormattableT]):
|
|
|
300
358
|
ContextStreamResponse[DepsT, None] | ContextStreamResponse[DepsT, FormattableT]
|
|
301
359
|
):
|
|
302
360
|
"""Generates a streaming response using the provided model."""
|
|
303
|
-
|
|
304
|
-
|
|
361
|
+
if isinstance(model, str):
|
|
362
|
+
model = Model(model)
|
|
363
|
+
messages = self.messages(ctx, *args, **kwargs)
|
|
305
364
|
return model.context_stream(
|
|
306
|
-
ctx=ctx,
|
|
365
|
+
messages, ctx=ctx, tools=self.toolkit, format=self.format
|
|
307
366
|
)
|
|
308
367
|
|
|
309
368
|
|
|
310
369
|
@dataclass
|
|
311
|
-
class AsyncContextPrompt(
|
|
370
|
+
class AsyncContextPrompt(
|
|
371
|
+
BasePrompt[AsyncContextMessageTemplate[P, DepsT]], Generic[P, DepsT, FormattableT]
|
|
372
|
+
):
|
|
312
373
|
"""An async context-aware prompt that can be called with a model to generate a response.
|
|
313
374
|
|
|
314
375
|
Created by decorating an async `ContextMessageTemplate` with `llm.prompt`. The decorated
|
|
@@ -319,19 +380,24 @@ class AsyncContextPrompt(Generic[P, DepsT, FormattableT]):
|
|
|
319
380
|
It can be invoked with a model: `await prompt(model, ctx, *args, **kwargs)`.
|
|
320
381
|
"""
|
|
321
382
|
|
|
322
|
-
fn: AsyncContextMessageTemplate[P, DepsT]
|
|
323
|
-
"""The underlying async context-aware prompt function that generates message content."""
|
|
324
|
-
|
|
325
383
|
toolkit: AsyncContextToolkit[DepsT]
|
|
326
384
|
"""The toolkit containing this prompt's async context-aware tools."""
|
|
327
385
|
|
|
328
|
-
format:
|
|
386
|
+
format: (
|
|
387
|
+
type[FormattableT] | Format[FormattableT] | OutputParser[FormattableT] | None
|
|
388
|
+
)
|
|
329
389
|
"""The response format for the generated response."""
|
|
330
390
|
|
|
391
|
+
async def messages(
|
|
392
|
+
self, ctx: Context[DepsT], *args: P.args, **kwargs: P.kwargs
|
|
393
|
+
) -> Sequence[Message]:
|
|
394
|
+
"""Return the `Messages` from invoking this prompt."""
|
|
395
|
+
return promote_to_messages(await self.fn(ctx, *args, **kwargs))
|
|
396
|
+
|
|
331
397
|
@overload
|
|
332
398
|
async def __call__(
|
|
333
399
|
self: "AsyncContextPrompt[P, DepsT, None]",
|
|
334
|
-
model: Model,
|
|
400
|
+
model: Model | ModelId,
|
|
335
401
|
ctx: Context[DepsT],
|
|
336
402
|
*args: P.args,
|
|
337
403
|
**kwargs: P.kwargs,
|
|
@@ -340,7 +406,7 @@ class AsyncContextPrompt(Generic[P, DepsT, FormattableT]):
|
|
|
340
406
|
@overload
|
|
341
407
|
async def __call__(
|
|
342
408
|
self: "AsyncContextPrompt[P, DepsT, FormattableT]",
|
|
343
|
-
model: Model,
|
|
409
|
+
model: Model | ModelId,
|
|
344
410
|
ctx: Context[DepsT],
|
|
345
411
|
*args: P.args,
|
|
346
412
|
**kwargs: P.kwargs,
|
|
@@ -348,7 +414,7 @@ class AsyncContextPrompt(Generic[P, DepsT, FormattableT]):
|
|
|
348
414
|
|
|
349
415
|
async def __call__(
|
|
350
416
|
self,
|
|
351
|
-
model: Model,
|
|
417
|
+
model: Model | ModelId,
|
|
352
418
|
ctx: Context[DepsT],
|
|
353
419
|
*args: P.args,
|
|
354
420
|
**kwargs: P.kwargs,
|
|
@@ -359,7 +425,7 @@ class AsyncContextPrompt(Generic[P, DepsT, FormattableT]):
|
|
|
359
425
|
@overload
|
|
360
426
|
async def call(
|
|
361
427
|
self: "AsyncContextPrompt[P, DepsT, None]",
|
|
362
|
-
model: Model,
|
|
428
|
+
model: Model | ModelId,
|
|
363
429
|
ctx: Context[DepsT],
|
|
364
430
|
*args: P.args,
|
|
365
431
|
**kwargs: P.kwargs,
|
|
@@ -368,7 +434,7 @@ class AsyncContextPrompt(Generic[P, DepsT, FormattableT]):
|
|
|
368
434
|
@overload
|
|
369
435
|
async def call(
|
|
370
436
|
self: "AsyncContextPrompt[P, DepsT, FormattableT]",
|
|
371
|
-
model: Model,
|
|
437
|
+
model: Model | ModelId,
|
|
372
438
|
ctx: Context[DepsT],
|
|
373
439
|
*args: P.args,
|
|
374
440
|
**kwargs: P.kwargs,
|
|
@@ -376,22 +442,23 @@ class AsyncContextPrompt(Generic[P, DepsT, FormattableT]):
|
|
|
376
442
|
|
|
377
443
|
async def call(
|
|
378
444
|
self,
|
|
379
|
-
model: Model,
|
|
445
|
+
model: Model | ModelId,
|
|
380
446
|
ctx: Context[DepsT],
|
|
381
447
|
*args: P.args,
|
|
382
448
|
**kwargs: P.kwargs,
|
|
383
449
|
) -> AsyncContextResponse[DepsT, None] | AsyncContextResponse[DepsT, FormattableT]:
|
|
384
450
|
"""Generates a response using the provided model asynchronously."""
|
|
385
|
-
|
|
386
|
-
|
|
451
|
+
if isinstance(model, str):
|
|
452
|
+
model = Model(model)
|
|
453
|
+
messages = await self.messages(ctx, *args, **kwargs)
|
|
387
454
|
return await model.context_call_async(
|
|
388
|
-
ctx=ctx,
|
|
455
|
+
messages, ctx=ctx, tools=self.toolkit, format=self.format
|
|
389
456
|
)
|
|
390
457
|
|
|
391
458
|
@overload
|
|
392
459
|
async def stream(
|
|
393
460
|
self: "AsyncContextPrompt[P, DepsT, None]",
|
|
394
|
-
model: Model,
|
|
461
|
+
model: Model | ModelId,
|
|
395
462
|
ctx: Context[DepsT],
|
|
396
463
|
*args: P.args,
|
|
397
464
|
**kwargs: P.kwargs,
|
|
@@ -400,7 +467,7 @@ class AsyncContextPrompt(Generic[P, DepsT, FormattableT]):
|
|
|
400
467
|
@overload
|
|
401
468
|
async def stream(
|
|
402
469
|
self: "AsyncContextPrompt[P, DepsT, FormattableT]",
|
|
403
|
-
model: Model,
|
|
470
|
+
model: Model | ModelId,
|
|
404
471
|
ctx: Context[DepsT],
|
|
405
472
|
*args: P.args,
|
|
406
473
|
**kwargs: P.kwargs,
|
|
@@ -408,7 +475,7 @@ class AsyncContextPrompt(Generic[P, DepsT, FormattableT]):
|
|
|
408
475
|
|
|
409
476
|
async def stream(
|
|
410
477
|
self,
|
|
411
|
-
model: Model,
|
|
478
|
+
model: Model | ModelId,
|
|
412
479
|
ctx: Context[DepsT],
|
|
413
480
|
*args: P.args,
|
|
414
481
|
**kwargs: P.kwargs,
|
|
@@ -417,8 +484,9 @@ class AsyncContextPrompt(Generic[P, DepsT, FormattableT]):
|
|
|
417
484
|
| AsyncContextStreamResponse[DepsT, FormattableT]
|
|
418
485
|
):
|
|
419
486
|
"""Generates a streaming response using the provided model asynchronously."""
|
|
420
|
-
|
|
421
|
-
|
|
487
|
+
if isinstance(model, str):
|
|
488
|
+
model = Model(model)
|
|
489
|
+
messages = await self.messages(ctx, *args, **kwargs)
|
|
422
490
|
return await model.context_stream_async(
|
|
423
|
-
ctx=ctx,
|
|
491
|
+
messages, ctx=ctx, tools=self.toolkit, format=self.format
|
|
424
492
|
)
|
|
@@ -19,7 +19,7 @@ from .anthropic import (
|
|
|
19
19
|
AnthropicModelId,
|
|
20
20
|
AnthropicProvider,
|
|
21
21
|
)
|
|
22
|
-
from .base import BaseProvider,
|
|
22
|
+
from .base import BaseProvider, Provider
|
|
23
23
|
from .google import GoogleModelId, GoogleProvider
|
|
24
24
|
from .mirascope import MirascopeProvider
|
|
25
25
|
from .mlx import MLXModelId, MLXProvider
|
|
@@ -53,11 +53,8 @@ __all__ = [
|
|
|
53
53
|
"OllamaProvider",
|
|
54
54
|
"OpenAIModelId",
|
|
55
55
|
"OpenAIProvider",
|
|
56
|
-
"Params",
|
|
57
56
|
"Provider",
|
|
58
57
|
"ProviderId",
|
|
59
|
-
"ThinkingConfig",
|
|
60
|
-
"ThinkingLevel",
|
|
61
58
|
"TogetherProvider",
|
|
62
59
|
"get_provider_for_model",
|
|
63
60
|
"register_provider",
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Shared Anthropic utilities."""
|
|
2
2
|
|
|
3
|
+
from ...base._utils import get_include_thoughts
|
|
3
4
|
from .decode import decode_async_stream, decode_response, decode_stream
|
|
4
5
|
from .encode import (
|
|
5
6
|
DEFAULT_FORMAT_MODE,
|
|
@@ -21,5 +22,6 @@ __all__ = [
|
|
|
21
22
|
"decode_stream",
|
|
22
23
|
"encode_image_mime_type",
|
|
23
24
|
"encode_request",
|
|
25
|
+
"get_include_thoughts",
|
|
24
26
|
"process_params",
|
|
25
27
|
]
|
|
@@ -74,10 +74,15 @@ def _decode_beta_assistant_content(content: BetaContentBlock) -> AssistantConten
|
|
|
74
74
|
def beta_decode_response(
|
|
75
75
|
response: ParsedBetaMessage[Any],
|
|
76
76
|
model_id: str,
|
|
77
|
+
*,
|
|
78
|
+
include_thoughts: bool,
|
|
77
79
|
) -> tuple[AssistantMessage, FinishReason | None, Usage]:
|
|
78
80
|
"""Convert Beta message to mirascope AssistantMessage and usage."""
|
|
81
|
+
content = [_decode_beta_assistant_content(part) for part in response.content]
|
|
82
|
+
if not include_thoughts:
|
|
83
|
+
content = [part for part in content if part.type != "thought"]
|
|
79
84
|
assistant_message = AssistantMessage(
|
|
80
|
-
content=
|
|
85
|
+
content=content,
|
|
81
86
|
provider_id="anthropic",
|
|
82
87
|
model_id=model_id,
|
|
83
88
|
provider_model_name=model_name(model_id),
|
|
@@ -108,10 +113,11 @@ BetaContentBlockParam: TypeAlias = (
|
|
|
108
113
|
class _BetaChunkProcessor:
|
|
109
114
|
"""Processes Beta stream events and maintains state across events."""
|
|
110
115
|
|
|
111
|
-
def __init__(self) -> None:
|
|
116
|
+
def __init__(self, *, include_thoughts: bool) -> None:
|
|
112
117
|
self.current_block_param: BetaContentBlockParam | None = None
|
|
113
118
|
self.accumulated_tool_json: str = ""
|
|
114
119
|
self.accumulated_blocks: list[BetaContentBlockParam] = []
|
|
120
|
+
self.include_thoughts = include_thoughts
|
|
115
121
|
|
|
116
122
|
def process_event(self, event: BetaRawMessageStreamEvent) -> ChunkIterator:
|
|
117
123
|
"""Process a single Beta event and yield the appropriate content chunks."""
|
|
@@ -144,7 +150,8 @@ class _BetaChunkProcessor:
|
|
|
144
150
|
"thinking": "",
|
|
145
151
|
"signature": "",
|
|
146
152
|
}
|
|
147
|
-
|
|
153
|
+
if self.include_thoughts:
|
|
154
|
+
yield ThoughtStartChunk()
|
|
148
155
|
elif content_block.type == "redacted_thinking": # pragma: no cover
|
|
149
156
|
self.current_block_param = {
|
|
150
157
|
"type": "redacted_thinking",
|
|
@@ -183,7 +190,8 @@ class _BetaChunkProcessor:
|
|
|
183
190
|
f"Received thinking_delta for {self.current_block_param['type']} block"
|
|
184
191
|
)
|
|
185
192
|
self.current_block_param["thinking"] += delta.thinking
|
|
186
|
-
|
|
193
|
+
if self.include_thoughts:
|
|
194
|
+
yield ThoughtChunk(delta=delta.thinking)
|
|
187
195
|
elif delta.type == "signature_delta":
|
|
188
196
|
if self.current_block_param["type"] != "thinking": # pragma: no cover
|
|
189
197
|
raise RuntimeError(
|
|
@@ -215,7 +223,8 @@ class _BetaChunkProcessor:
|
|
|
215
223
|
)
|
|
216
224
|
yield ToolCallEndChunk(id=self.current_block_param["id"])
|
|
217
225
|
elif block_type == "thinking":
|
|
218
|
-
|
|
226
|
+
if self.include_thoughts:
|
|
227
|
+
yield ThoughtEndChunk()
|
|
219
228
|
else:
|
|
220
229
|
raise NotImplementedError
|
|
221
230
|
|
|
@@ -251,10 +260,10 @@ class _BetaChunkProcessor:
|
|
|
251
260
|
|
|
252
261
|
|
|
253
262
|
def beta_decode_stream(
|
|
254
|
-
beta_stream_manager: BetaMessageStreamManager[Any],
|
|
263
|
+
beta_stream_manager: BetaMessageStreamManager[Any], *, include_thoughts: bool
|
|
255
264
|
) -> ChunkIterator:
|
|
256
265
|
"""Returns a ChunkIterator converted from a Beta MessageStreamManager."""
|
|
257
|
-
processor = _BetaChunkProcessor()
|
|
266
|
+
processor = _BetaChunkProcessor(include_thoughts=include_thoughts)
|
|
258
267
|
with beta_stream_manager as stream:
|
|
259
268
|
for event in stream._raw_stream: # pyright: ignore[reportPrivateUsage]
|
|
260
269
|
yield from processor.process_event(event)
|
|
@@ -262,10 +271,10 @@ def beta_decode_stream(
|
|
|
262
271
|
|
|
263
272
|
|
|
264
273
|
async def beta_decode_async_stream(
|
|
265
|
-
beta_stream_manager: BetaAsyncMessageStreamManager[Any],
|
|
274
|
+
beta_stream_manager: BetaAsyncMessageStreamManager[Any], *, include_thoughts: bool
|
|
266
275
|
) -> AsyncChunkIterator:
|
|
267
276
|
"""Returns an AsyncChunkIterator converted from a Beta MessageStreamManager."""
|
|
268
|
-
processor = _BetaChunkProcessor()
|
|
277
|
+
processor = _BetaChunkProcessor(include_thoughts=include_thoughts)
|
|
269
278
|
async with beta_stream_manager as stream:
|
|
270
279
|
async for event in stream._raw_stream: # pyright: ignore[reportPrivateUsage]
|
|
271
280
|
for item in processor.process_event(event):
|