mirascope 2.0.0a2__py3-none-any.whl → 2.0.0a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mirascope/__init__.py +2 -2
- mirascope/api/__init__.py +6 -0
- mirascope/api/_generated/README.md +207 -0
- mirascope/api/_generated/__init__.py +141 -0
- mirascope/api/_generated/client.py +163 -0
- mirascope/api/_generated/core/__init__.py +52 -0
- mirascope/api/_generated/core/api_error.py +23 -0
- mirascope/api/_generated/core/client_wrapper.py +58 -0
- mirascope/api/_generated/core/datetime_utils.py +30 -0
- mirascope/api/_generated/core/file.py +70 -0
- mirascope/api/_generated/core/force_multipart.py +16 -0
- mirascope/api/_generated/core/http_client.py +619 -0
- mirascope/api/_generated/core/http_response.py +55 -0
- mirascope/api/_generated/core/jsonable_encoder.py +102 -0
- mirascope/api/_generated/core/pydantic_utilities.py +310 -0
- mirascope/api/_generated/core/query_encoder.py +60 -0
- mirascope/api/_generated/core/remove_none_from_dict.py +11 -0
- mirascope/api/_generated/core/request_options.py +35 -0
- mirascope/api/_generated/core/serialization.py +282 -0
- mirascope/api/_generated/docs/__init__.py +4 -0
- mirascope/api/_generated/docs/client.py +95 -0
- mirascope/api/_generated/docs/raw_client.py +132 -0
- mirascope/api/_generated/environment.py +9 -0
- mirascope/api/_generated/errors/__init__.py +17 -0
- mirascope/api/_generated/errors/bad_request_error.py +15 -0
- mirascope/api/_generated/errors/conflict_error.py +15 -0
- mirascope/api/_generated/errors/forbidden_error.py +15 -0
- mirascope/api/_generated/errors/internal_server_error.py +15 -0
- mirascope/api/_generated/errors/not_found_error.py +15 -0
- mirascope/api/_generated/health/__init__.py +7 -0
- mirascope/api/_generated/health/client.py +96 -0
- mirascope/api/_generated/health/raw_client.py +129 -0
- mirascope/api/_generated/health/types/__init__.py +8 -0
- mirascope/api/_generated/health/types/health_check_response.py +24 -0
- mirascope/api/_generated/health/types/health_check_response_status.py +5 -0
- mirascope/api/_generated/organizations/__init__.py +25 -0
- mirascope/api/_generated/organizations/client.py +380 -0
- mirascope/api/_generated/organizations/raw_client.py +876 -0
- mirascope/api/_generated/organizations/types/__init__.py +23 -0
- mirascope/api/_generated/organizations/types/organizations_create_response.py +24 -0
- mirascope/api/_generated/organizations/types/organizations_create_response_role.py +7 -0
- mirascope/api/_generated/organizations/types/organizations_get_response.py +24 -0
- mirascope/api/_generated/organizations/types/organizations_get_response_role.py +7 -0
- mirascope/api/_generated/organizations/types/organizations_list_response_item.py +24 -0
- mirascope/api/_generated/organizations/types/organizations_list_response_item_role.py +7 -0
- mirascope/api/_generated/organizations/types/organizations_update_response.py +24 -0
- mirascope/api/_generated/organizations/types/organizations_update_response_role.py +7 -0
- mirascope/api/_generated/projects/__init__.py +17 -0
- mirascope/api/_generated/projects/client.py +458 -0
- mirascope/api/_generated/projects/raw_client.py +1016 -0
- mirascope/api/_generated/projects/types/__init__.py +15 -0
- mirascope/api/_generated/projects/types/projects_create_response.py +30 -0
- mirascope/api/_generated/projects/types/projects_get_response.py +30 -0
- mirascope/api/_generated/projects/types/projects_list_response_item.py +30 -0
- mirascope/api/_generated/projects/types/projects_update_response.py +30 -0
- mirascope/api/_generated/reference.md +753 -0
- mirascope/api/_generated/traces/__init__.py +55 -0
- mirascope/api/_generated/traces/client.py +162 -0
- mirascope/api/_generated/traces/raw_client.py +168 -0
- mirascope/api/_generated/traces/types/__init__.py +95 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item.py +36 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_resource.py +31 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_resource_attributes_item.py +25 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_resource_attributes_item_value.py +54 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_resource_attributes_item_value_array_value.py +23 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_resource_attributes_item_value_kvlist_value.py +28 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_resource_attributes_item_value_kvlist_value_values_item.py +24 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item.py +35 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_scope.py +35 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_scope_attributes_item.py +27 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_scope_attributes_item_value.py +54 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_scope_attributes_item_value_array_value.py +23 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_scope_attributes_item_value_kvlist_value.py +28 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_scope_attributes_item_value_kvlist_value_values_item.py +24 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_spans_item.py +60 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_spans_item_attributes_item.py +29 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_spans_item_attributes_item_value.py +54 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_spans_item_attributes_item_value_array_value.py +23 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_spans_item_attributes_item_value_kvlist_value.py +28 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_spans_item_attributes_item_value_kvlist_value_values_item.py +24 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_spans_item_status.py +24 -0
- mirascope/api/_generated/traces/types/traces_create_response.py +27 -0
- mirascope/api/_generated/traces/types/traces_create_response_partial_success.py +28 -0
- mirascope/api/_generated/types/__init__.py +37 -0
- mirascope/api/_generated/types/already_exists_error.py +24 -0
- mirascope/api/_generated/types/already_exists_error_tag.py +5 -0
- mirascope/api/_generated/types/database_error.py +24 -0
- mirascope/api/_generated/types/database_error_tag.py +5 -0
- mirascope/api/_generated/types/http_api_decode_error.py +29 -0
- mirascope/api/_generated/types/http_api_decode_error_tag.py +5 -0
- mirascope/api/_generated/types/issue.py +40 -0
- mirascope/api/_generated/types/issue_tag.py +17 -0
- mirascope/api/_generated/types/not_found_error_body.py +24 -0
- mirascope/api/_generated/types/not_found_error_tag.py +5 -0
- mirascope/api/_generated/types/permission_denied_error.py +24 -0
- mirascope/api/_generated/types/permission_denied_error_tag.py +7 -0
- mirascope/api/_generated/types/property_key.py +7 -0
- mirascope/api/_generated/types/property_key_key.py +27 -0
- mirascope/api/_generated/types/property_key_key_tag.py +5 -0
- mirascope/api/client.py +255 -0
- mirascope/api/settings.py +81 -0
- mirascope/llm/__init__.py +45 -11
- mirascope/llm/calls/calls.py +81 -57
- mirascope/llm/calls/decorator.py +121 -115
- mirascope/llm/content/__init__.py +3 -2
- mirascope/llm/context/_utils.py +19 -6
- mirascope/llm/exceptions.py +30 -16
- mirascope/llm/formatting/_utils.py +9 -5
- mirascope/llm/formatting/format.py +2 -2
- mirascope/llm/formatting/from_call_args.py +2 -2
- mirascope/llm/messages/message.py +13 -5
- mirascope/llm/models/__init__.py +2 -2
- mirascope/llm/models/models.py +189 -81
- mirascope/llm/prompts/__init__.py +13 -12
- mirascope/llm/prompts/_utils.py +27 -24
- mirascope/llm/prompts/decorator.py +133 -204
- mirascope/llm/prompts/prompts.py +424 -0
- mirascope/llm/prompts/protocols.py +25 -59
- mirascope/llm/providers/__init__.py +44 -0
- mirascope/llm/{clients → providers}/_missing_import_stubs.py +8 -6
- mirascope/llm/providers/anthropic/__init__.py +29 -0
- mirascope/llm/providers/anthropic/_utils/__init__.py +23 -0
- mirascope/llm/providers/anthropic/_utils/beta_decode.py +271 -0
- mirascope/llm/providers/anthropic/_utils/beta_encode.py +216 -0
- mirascope/llm/{clients → providers}/anthropic/_utils/decode.py +44 -11
- mirascope/llm/providers/anthropic/_utils/encode.py +356 -0
- mirascope/llm/providers/anthropic/beta_provider.py +322 -0
- mirascope/llm/providers/anthropic/model_id.py +23 -0
- mirascope/llm/providers/anthropic/model_info.py +87 -0
- mirascope/llm/providers/anthropic/provider.py +416 -0
- mirascope/llm/{clients → providers}/base/__init__.py +3 -3
- mirascope/llm/{clients → providers}/base/_utils.py +25 -8
- mirascope/llm/{clients/base/client.py → providers/base/base_provider.py} +255 -126
- mirascope/llm/providers/google/__init__.py +21 -0
- mirascope/llm/{clients → providers}/google/_utils/decode.py +61 -7
- mirascope/llm/{clients → providers}/google/_utils/encode.py +44 -30
- mirascope/llm/providers/google/model_id.py +22 -0
- mirascope/llm/providers/google/model_info.py +62 -0
- mirascope/llm/providers/google/provider.py +442 -0
- mirascope/llm/providers/load_provider.py +54 -0
- mirascope/llm/providers/mlx/__init__.py +24 -0
- mirascope/llm/providers/mlx/_utils.py +129 -0
- mirascope/llm/providers/mlx/encoding/__init__.py +8 -0
- mirascope/llm/providers/mlx/encoding/base.py +69 -0
- mirascope/llm/providers/mlx/encoding/transformers.py +147 -0
- mirascope/llm/providers/mlx/mlx.py +237 -0
- mirascope/llm/providers/mlx/model_id.py +17 -0
- mirascope/llm/providers/mlx/provider.py +415 -0
- mirascope/llm/providers/model_id.py +16 -0
- mirascope/llm/providers/ollama/__init__.py +19 -0
- mirascope/llm/providers/ollama/provider.py +71 -0
- mirascope/llm/providers/openai/__init__.py +6 -0
- mirascope/llm/providers/openai/completions/__init__.py +25 -0
- mirascope/llm/{clients → providers}/openai/completions/_utils/__init__.py +2 -0
- mirascope/llm/{clients → providers}/openai/completions/_utils/decode.py +60 -6
- mirascope/llm/{clients → providers}/openai/completions/_utils/encode.py +37 -26
- mirascope/llm/providers/openai/completions/base_provider.py +513 -0
- mirascope/llm/providers/openai/completions/provider.py +22 -0
- mirascope/llm/providers/openai/model_id.py +31 -0
- mirascope/llm/providers/openai/model_info.py +303 -0
- mirascope/llm/providers/openai/provider.py +398 -0
- mirascope/llm/providers/openai/responses/__init__.py +21 -0
- mirascope/llm/{clients → providers}/openai/responses/_utils/decode.py +59 -6
- mirascope/llm/{clients → providers}/openai/responses/_utils/encode.py +34 -23
- mirascope/llm/providers/openai/responses/provider.py +469 -0
- mirascope/llm/providers/provider_id.py +23 -0
- mirascope/llm/providers/provider_registry.py +169 -0
- mirascope/llm/providers/together/__init__.py +19 -0
- mirascope/llm/providers/together/provider.py +40 -0
- mirascope/llm/responses/__init__.py +3 -0
- mirascope/llm/responses/base_response.py +14 -5
- mirascope/llm/responses/base_stream_response.py +35 -6
- mirascope/llm/responses/finish_reason.py +1 -0
- mirascope/llm/responses/response.py +33 -13
- mirascope/llm/responses/root_response.py +12 -13
- mirascope/llm/responses/stream_response.py +35 -23
- mirascope/llm/responses/usage.py +95 -0
- mirascope/llm/tools/__init__.py +9 -2
- mirascope/llm/tools/_utils.py +12 -3
- mirascope/llm/tools/protocols.py +4 -4
- mirascope/llm/tools/tool_schema.py +44 -9
- mirascope/llm/tools/tools.py +10 -9
- mirascope/ops/__init__.py +156 -0
- mirascope/ops/_internal/__init__.py +5 -0
- mirascope/ops/_internal/closure.py +1118 -0
- mirascope/ops/_internal/configuration.py +126 -0
- mirascope/ops/_internal/context.py +76 -0
- mirascope/ops/_internal/exporters/__init__.py +26 -0
- mirascope/ops/_internal/exporters/exporters.py +342 -0
- mirascope/ops/_internal/exporters/processors.py +104 -0
- mirascope/ops/_internal/exporters/types.py +165 -0
- mirascope/ops/_internal/exporters/utils.py +29 -0
- mirascope/ops/_internal/instrumentation/__init__.py +8 -0
- mirascope/ops/_internal/instrumentation/llm/__init__.py +8 -0
- mirascope/ops/_internal/instrumentation/llm/encode.py +238 -0
- mirascope/ops/_internal/instrumentation/llm/gen_ai_types/__init__.py +38 -0
- mirascope/ops/_internal/instrumentation/llm/gen_ai_types/gen_ai_input_messages.py +31 -0
- mirascope/ops/_internal/instrumentation/llm/gen_ai_types/gen_ai_output_messages.py +38 -0
- mirascope/ops/_internal/instrumentation/llm/gen_ai_types/gen_ai_system_instructions.py +18 -0
- mirascope/ops/_internal/instrumentation/llm/gen_ai_types/shared.py +100 -0
- mirascope/ops/_internal/instrumentation/llm/llm.py +1288 -0
- mirascope/ops/_internal/propagation.py +198 -0
- mirascope/ops/_internal/protocols.py +51 -0
- mirascope/ops/_internal/session.py +139 -0
- mirascope/ops/_internal/spans.py +232 -0
- mirascope/ops/_internal/traced_calls.py +371 -0
- mirascope/ops/_internal/traced_functions.py +394 -0
- mirascope/ops/_internal/tracing.py +276 -0
- mirascope/ops/_internal/types.py +13 -0
- mirascope/ops/_internal/utils.py +75 -0
- mirascope/ops/_internal/versioned_calls.py +512 -0
- mirascope/ops/_internal/versioned_functions.py +346 -0
- mirascope/ops/_internal/versioning.py +303 -0
- mirascope/ops/exceptions.py +21 -0
- {mirascope-2.0.0a2.dist-info → mirascope-2.0.0a4.dist-info}/METADATA +78 -3
- mirascope-2.0.0a4.dist-info/RECORD +247 -0
- {mirascope-2.0.0a2.dist-info → mirascope-2.0.0a4.dist-info}/WHEEL +1 -1
- mirascope/graphs/__init__.py +0 -22
- mirascope/graphs/finite_state_machine.py +0 -625
- mirascope/llm/agents/__init__.py +0 -15
- mirascope/llm/agents/agent.py +0 -97
- mirascope/llm/agents/agent_template.py +0 -45
- mirascope/llm/agents/decorator.py +0 -176
- mirascope/llm/calls/base_call.py +0 -33
- mirascope/llm/clients/__init__.py +0 -34
- mirascope/llm/clients/anthropic/__init__.py +0 -25
- mirascope/llm/clients/anthropic/_utils/encode.py +0 -243
- mirascope/llm/clients/anthropic/clients.py +0 -819
- mirascope/llm/clients/anthropic/model_ids.py +0 -8
- mirascope/llm/clients/google/__init__.py +0 -20
- mirascope/llm/clients/google/clients.py +0 -853
- mirascope/llm/clients/google/model_ids.py +0 -15
- mirascope/llm/clients/openai/__init__.py +0 -25
- mirascope/llm/clients/openai/completions/__init__.py +0 -28
- mirascope/llm/clients/openai/completions/_utils/model_features.py +0 -81
- mirascope/llm/clients/openai/completions/clients.py +0 -833
- mirascope/llm/clients/openai/completions/model_ids.py +0 -8
- mirascope/llm/clients/openai/responses/__init__.py +0 -26
- mirascope/llm/clients/openai/responses/_utils/__init__.py +0 -13
- mirascope/llm/clients/openai/responses/_utils/model_features.py +0 -87
- mirascope/llm/clients/openai/responses/clients.py +0 -832
- mirascope/llm/clients/openai/responses/model_ids.py +0 -8
- mirascope/llm/clients/openai/shared/__init__.py +0 -7
- mirascope/llm/clients/openai/shared/_utils.py +0 -55
- mirascope/llm/clients/providers.py +0 -175
- mirascope-2.0.0a2.dist-info/RECORD +0 -102
- /mirascope/llm/{clients → providers}/base/kwargs.py +0 -0
- /mirascope/llm/{clients → providers}/base/params.py +0 -0
- /mirascope/llm/{clients/anthropic → providers/google}/_utils/__init__.py +0 -0
- /mirascope/llm/{clients → providers}/google/message.py +0 -0
- /mirascope/llm/{clients/google → providers/openai/responses}/_utils/__init__.py +0 -0
- {mirascope-2.0.0a2.dist-info → mirascope-2.0.0a4.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,442 @@
|
|
|
1
|
+
"""Google provider implementation."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
from typing_extensions import Unpack
|
|
5
|
+
|
|
6
|
+
from google.genai import Client
|
|
7
|
+
from google.genai.types import HttpOptions
|
|
8
|
+
|
|
9
|
+
from ...context import Context, DepsT
|
|
10
|
+
from ...formatting import Format, FormattableT
|
|
11
|
+
from ...messages import Message
|
|
12
|
+
from ...responses import (
|
|
13
|
+
AsyncContextResponse,
|
|
14
|
+
AsyncContextStreamResponse,
|
|
15
|
+
AsyncResponse,
|
|
16
|
+
AsyncStreamResponse,
|
|
17
|
+
ContextResponse,
|
|
18
|
+
ContextStreamResponse,
|
|
19
|
+
Response,
|
|
20
|
+
StreamResponse,
|
|
21
|
+
)
|
|
22
|
+
from ...tools import (
|
|
23
|
+
AsyncContextTool,
|
|
24
|
+
AsyncContextToolkit,
|
|
25
|
+
AsyncTool,
|
|
26
|
+
AsyncToolkit,
|
|
27
|
+
ContextTool,
|
|
28
|
+
ContextToolkit,
|
|
29
|
+
Tool,
|
|
30
|
+
Toolkit,
|
|
31
|
+
)
|
|
32
|
+
from ..base import BaseProvider, Params
|
|
33
|
+
from . import _utils
|
|
34
|
+
from .model_id import GoogleModelId, model_name
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class GoogleProvider(BaseProvider[Client]):
|
|
38
|
+
"""The client for the Google LLM model."""
|
|
39
|
+
|
|
40
|
+
id = "google"
|
|
41
|
+
default_scope = "google/"
|
|
42
|
+
|
|
43
|
+
def __init__(
|
|
44
|
+
self, *, api_key: str | None = None, base_url: str | None = None
|
|
45
|
+
) -> None:
|
|
46
|
+
"""Initialize the Google client."""
|
|
47
|
+
http_options = None
|
|
48
|
+
if base_url:
|
|
49
|
+
http_options = HttpOptions(base_url=base_url)
|
|
50
|
+
|
|
51
|
+
self.client = Client(api_key=api_key, http_options=http_options)
|
|
52
|
+
|
|
53
|
+
def _call(
|
|
54
|
+
self,
|
|
55
|
+
*,
|
|
56
|
+
model_id: GoogleModelId,
|
|
57
|
+
messages: Sequence[Message],
|
|
58
|
+
tools: Sequence[Tool] | Toolkit | None = None,
|
|
59
|
+
format: type[FormattableT] | Format[FormattableT] | None = None,
|
|
60
|
+
**params: Unpack[Params],
|
|
61
|
+
) -> Response | Response[FormattableT]:
|
|
62
|
+
"""Generate an `llm.Response` by synchronously calling the Google GenAI API.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
model_id: Model identifier to use.
|
|
66
|
+
messages: Messages to send to the LLM.
|
|
67
|
+
tools: Optional tools that the model may invoke.
|
|
68
|
+
format: Optional response format specifier.
|
|
69
|
+
**params: Additional parameters to configure output (e.g. temperature). See `llm.Params`.
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
An `llm.Response` object containing the LLM-generated content.
|
|
73
|
+
"""
|
|
74
|
+
input_messages, format, kwargs = _utils.encode_request(
|
|
75
|
+
model_id=model_id,
|
|
76
|
+
messages=messages,
|
|
77
|
+
tools=tools,
|
|
78
|
+
format=format,
|
|
79
|
+
params=params,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
google_response = self.client.models.generate_content(**kwargs)
|
|
83
|
+
|
|
84
|
+
assistant_message, finish_reason, usage = _utils.decode_response(
|
|
85
|
+
google_response, model_id
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
return Response(
|
|
89
|
+
raw=google_response,
|
|
90
|
+
provider_id="google",
|
|
91
|
+
model_id=model_id,
|
|
92
|
+
provider_model_name=model_name(model_id),
|
|
93
|
+
params=params,
|
|
94
|
+
tools=tools,
|
|
95
|
+
input_messages=input_messages,
|
|
96
|
+
assistant_message=assistant_message,
|
|
97
|
+
finish_reason=finish_reason,
|
|
98
|
+
usage=usage,
|
|
99
|
+
format=format,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
def _context_call(
|
|
103
|
+
self,
|
|
104
|
+
*,
|
|
105
|
+
ctx: Context[DepsT],
|
|
106
|
+
model_id: GoogleModelId,
|
|
107
|
+
messages: Sequence[Message],
|
|
108
|
+
tools: Sequence[Tool | ContextTool[DepsT]]
|
|
109
|
+
| ContextToolkit[DepsT]
|
|
110
|
+
| None = None,
|
|
111
|
+
format: type[FormattableT] | Format[FormattableT] | None = None,
|
|
112
|
+
**params: Unpack[Params],
|
|
113
|
+
) -> ContextResponse[DepsT, None] | ContextResponse[DepsT, FormattableT]:
|
|
114
|
+
"""Generate an `llm.ContextResponse` by synchronously calling the Google GenAI API.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
ctx: Context object with dependencies for tools.
|
|
118
|
+
model_id: Model identifier to use.
|
|
119
|
+
messages: Messages to send to the LLM.
|
|
120
|
+
tools: Optional tools that the model may invoke.
|
|
121
|
+
format: Optional response format specifier.
|
|
122
|
+
**params: Additional parameters to configure output (e.g. temperature). See `llm.Params`.
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
An `llm.ContextResponse` object containing the LLM-generated content.
|
|
126
|
+
"""
|
|
127
|
+
input_messages, format, kwargs = _utils.encode_request(
|
|
128
|
+
model_id=model_id,
|
|
129
|
+
messages=messages,
|
|
130
|
+
tools=tools,
|
|
131
|
+
format=format,
|
|
132
|
+
params=params,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
google_response = self.client.models.generate_content(**kwargs)
|
|
136
|
+
|
|
137
|
+
assistant_message, finish_reason, usage = _utils.decode_response(
|
|
138
|
+
google_response, model_id
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
return ContextResponse(
|
|
142
|
+
raw=google_response,
|
|
143
|
+
provider_id="google",
|
|
144
|
+
model_id=model_id,
|
|
145
|
+
provider_model_name=model_name(model_id),
|
|
146
|
+
params=params,
|
|
147
|
+
tools=tools,
|
|
148
|
+
input_messages=input_messages,
|
|
149
|
+
assistant_message=assistant_message,
|
|
150
|
+
finish_reason=finish_reason,
|
|
151
|
+
usage=usage,
|
|
152
|
+
format=format,
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
async def _call_async(
|
|
156
|
+
self,
|
|
157
|
+
*,
|
|
158
|
+
model_id: GoogleModelId,
|
|
159
|
+
messages: Sequence[Message],
|
|
160
|
+
tools: Sequence[AsyncTool] | AsyncToolkit | None = None,
|
|
161
|
+
format: type[FormattableT] | Format[FormattableT] | None = None,
|
|
162
|
+
**params: Unpack[Params],
|
|
163
|
+
) -> AsyncResponse | AsyncResponse[FormattableT]:
|
|
164
|
+
"""Generate an `llm.AsyncResponse` by asynchronously calling the Google GenAI API.
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
model_id: Model identifier to use.
|
|
168
|
+
messages: Messages to send to the LLM.
|
|
169
|
+
tools: Optional tools that the model may invoke.
|
|
170
|
+
format: Optional response format specifier.
|
|
171
|
+
**params: Additional parameters to configure output (e.g. temperature). See `llm.Params`.
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
An `llm.AsyncResponse` object containing the LLM-generated content.
|
|
175
|
+
"""
|
|
176
|
+
input_messages, format, kwargs = _utils.encode_request(
|
|
177
|
+
model_id=model_id,
|
|
178
|
+
messages=messages,
|
|
179
|
+
tools=tools,
|
|
180
|
+
format=format,
|
|
181
|
+
params=params,
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
google_response = await self.client.aio.models.generate_content(**kwargs)
|
|
185
|
+
|
|
186
|
+
assistant_message, finish_reason, usage = _utils.decode_response(
|
|
187
|
+
google_response, model_id
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
return AsyncResponse(
|
|
191
|
+
raw=google_response,
|
|
192
|
+
provider_id="google",
|
|
193
|
+
model_id=model_id,
|
|
194
|
+
provider_model_name=model_name(model_id),
|
|
195
|
+
params=params,
|
|
196
|
+
tools=tools,
|
|
197
|
+
input_messages=input_messages,
|
|
198
|
+
assistant_message=assistant_message,
|
|
199
|
+
finish_reason=finish_reason,
|
|
200
|
+
usage=usage,
|
|
201
|
+
format=format,
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
async def _context_call_async(
|
|
205
|
+
self,
|
|
206
|
+
*,
|
|
207
|
+
ctx: Context[DepsT],
|
|
208
|
+
model_id: GoogleModelId,
|
|
209
|
+
messages: Sequence[Message],
|
|
210
|
+
tools: Sequence[AsyncTool | AsyncContextTool[DepsT]]
|
|
211
|
+
| AsyncContextToolkit[DepsT]
|
|
212
|
+
| None = None,
|
|
213
|
+
format: type[FormattableT] | Format[FormattableT] | None = None,
|
|
214
|
+
**params: Unpack[Params],
|
|
215
|
+
) -> AsyncContextResponse[DepsT, None] | AsyncContextResponse[DepsT, FormattableT]:
|
|
216
|
+
"""Generate an `llm.AsyncContextResponse` by asynchronously calling the Google GenAI API.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
ctx: Context object with dependencies for tools.
|
|
220
|
+
model_id: Model identifier to use.
|
|
221
|
+
messages: Messages to send to the LLM.
|
|
222
|
+
tools: Optional tools that the model may invoke.
|
|
223
|
+
format: Optional response format specifier.
|
|
224
|
+
**params: Additional parameters to configure output (e.g. temperature). See `llm.Params`.
|
|
225
|
+
|
|
226
|
+
Returns:
|
|
227
|
+
An `llm.AsyncContextResponse` object containing the LLM-generated content.
|
|
228
|
+
"""
|
|
229
|
+
input_messages, format, kwargs = _utils.encode_request(
|
|
230
|
+
model_id=model_id,
|
|
231
|
+
messages=messages,
|
|
232
|
+
tools=tools,
|
|
233
|
+
format=format,
|
|
234
|
+
params=params,
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
google_response = await self.client.aio.models.generate_content(**kwargs)
|
|
238
|
+
|
|
239
|
+
assistant_message, finish_reason, usage = _utils.decode_response(
|
|
240
|
+
google_response, model_id
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
return AsyncContextResponse(
|
|
244
|
+
raw=google_response,
|
|
245
|
+
provider_id="google",
|
|
246
|
+
model_id=model_id,
|
|
247
|
+
provider_model_name=model_name(model_id),
|
|
248
|
+
params=params,
|
|
249
|
+
tools=tools,
|
|
250
|
+
input_messages=input_messages,
|
|
251
|
+
assistant_message=assistant_message,
|
|
252
|
+
finish_reason=finish_reason,
|
|
253
|
+
usage=usage,
|
|
254
|
+
format=format,
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
def _stream(
|
|
258
|
+
self,
|
|
259
|
+
*,
|
|
260
|
+
model_id: GoogleModelId,
|
|
261
|
+
messages: Sequence[Message],
|
|
262
|
+
tools: Sequence[Tool] | Toolkit | None = None,
|
|
263
|
+
format: type[FormattableT] | Format[FormattableT] | None = None,
|
|
264
|
+
**params: Unpack[Params],
|
|
265
|
+
) -> StreamResponse | StreamResponse[FormattableT]:
|
|
266
|
+
"""Generate an `llm.StreamResponse` by synchronously streaming from the Google GenAI API.
|
|
267
|
+
|
|
268
|
+
Args:
|
|
269
|
+
model_id: Model identifier to use.
|
|
270
|
+
messages: Messages to send to the LLM.
|
|
271
|
+
tools: Optional tools that the model may invoke.
|
|
272
|
+
format: Optional response format specifier.
|
|
273
|
+
**params: Additional parameters to configure output (e.g. temperature). See `llm.Params`.
|
|
274
|
+
|
|
275
|
+
Returns:
|
|
276
|
+
An `llm.StreamResponse` object for iterating over the LLM-generated content.
|
|
277
|
+
"""
|
|
278
|
+
input_messages, format, kwargs = _utils.encode_request(
|
|
279
|
+
model_id=model_id,
|
|
280
|
+
messages=messages,
|
|
281
|
+
tools=tools,
|
|
282
|
+
format=format,
|
|
283
|
+
params=params,
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
google_stream = self.client.models.generate_content_stream(**kwargs)
|
|
287
|
+
|
|
288
|
+
chunk_iterator = _utils.decode_stream(google_stream)
|
|
289
|
+
|
|
290
|
+
return StreamResponse(
|
|
291
|
+
provider_id="google",
|
|
292
|
+
model_id=model_id,
|
|
293
|
+
provider_model_name=model_name(model_id),
|
|
294
|
+
params=params,
|
|
295
|
+
tools=tools,
|
|
296
|
+
input_messages=input_messages,
|
|
297
|
+
chunk_iterator=chunk_iterator,
|
|
298
|
+
format=format,
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
def _context_stream(
|
|
302
|
+
self,
|
|
303
|
+
*,
|
|
304
|
+
ctx: Context[DepsT],
|
|
305
|
+
model_id: GoogleModelId,
|
|
306
|
+
messages: Sequence[Message],
|
|
307
|
+
tools: Sequence[Tool | ContextTool[DepsT]]
|
|
308
|
+
| ContextToolkit[DepsT]
|
|
309
|
+
| None = None,
|
|
310
|
+
format: type[FormattableT] | Format[FormattableT] | None = None,
|
|
311
|
+
**params: Unpack[Params],
|
|
312
|
+
) -> ContextStreamResponse[DepsT] | ContextStreamResponse[DepsT, FormattableT]:
|
|
313
|
+
"""Generate an `llm.ContextStreamResponse` by synchronously streaming from the Google GenAI API.
|
|
314
|
+
|
|
315
|
+
Args:
|
|
316
|
+
ctx: Context object with dependencies for tools.
|
|
317
|
+
model_id: Model identifier to use.
|
|
318
|
+
messages: Messages to send to the LLM.
|
|
319
|
+
tools: Optional tools that the model may invoke.
|
|
320
|
+
format: Optional response format specifier.
|
|
321
|
+
**params: Additional parameters to configure output (e.g. temperature). See `llm.Params`.
|
|
322
|
+
|
|
323
|
+
Returns:
|
|
324
|
+
An `llm.ContextStreamResponse` object for iterating over the LLM-generated content.
|
|
325
|
+
"""
|
|
326
|
+
input_messages, format, kwargs = _utils.encode_request(
|
|
327
|
+
model_id=model_id,
|
|
328
|
+
messages=messages,
|
|
329
|
+
tools=tools,
|
|
330
|
+
format=format,
|
|
331
|
+
params=params,
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
google_stream = self.client.models.generate_content_stream(**kwargs)
|
|
335
|
+
|
|
336
|
+
chunk_iterator = _utils.decode_stream(google_stream)
|
|
337
|
+
|
|
338
|
+
return ContextStreamResponse(
|
|
339
|
+
provider_id="google",
|
|
340
|
+
model_id=model_id,
|
|
341
|
+
provider_model_name=model_name(model_id),
|
|
342
|
+
params=params,
|
|
343
|
+
tools=tools,
|
|
344
|
+
input_messages=input_messages,
|
|
345
|
+
chunk_iterator=chunk_iterator,
|
|
346
|
+
format=format,
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
async def _stream_async(
|
|
350
|
+
self,
|
|
351
|
+
*,
|
|
352
|
+
model_id: GoogleModelId,
|
|
353
|
+
messages: Sequence[Message],
|
|
354
|
+
tools: Sequence[AsyncTool] | AsyncToolkit | None = None,
|
|
355
|
+
format: type[FormattableT] | Format[FormattableT] | None = None,
|
|
356
|
+
**params: Unpack[Params],
|
|
357
|
+
) -> AsyncStreamResponse | AsyncStreamResponse[FormattableT]:
|
|
358
|
+
"""Generate an `llm.AsyncStreamResponse` by asynchronously streaming from the Google GenAI API.
|
|
359
|
+
|
|
360
|
+
Args:
|
|
361
|
+
model_id: Model identifier to use.
|
|
362
|
+
messages: Messages to send to the LLM.
|
|
363
|
+
tools: Optional tools that the model may invoke.
|
|
364
|
+
format: Optional response format specifier.
|
|
365
|
+
**params: Additional parameters to configure output (e.g. temperature). See `llm.Params`.
|
|
366
|
+
|
|
367
|
+
Returns:
|
|
368
|
+
An `llm.AsyncStreamResponse` object for asynchronously iterating over the LLM-generated content.
|
|
369
|
+
"""
|
|
370
|
+
input_messages, format, kwargs = _utils.encode_request(
|
|
371
|
+
model_id=model_id,
|
|
372
|
+
messages=messages,
|
|
373
|
+
tools=tools,
|
|
374
|
+
format=format,
|
|
375
|
+
params=params,
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
google_stream = await self.client.aio.models.generate_content_stream(**kwargs)
|
|
379
|
+
|
|
380
|
+
chunk_iterator = _utils.decode_async_stream(google_stream)
|
|
381
|
+
|
|
382
|
+
return AsyncStreamResponse(
|
|
383
|
+
provider_id="google",
|
|
384
|
+
model_id=model_id,
|
|
385
|
+
provider_model_name=model_name(model_id),
|
|
386
|
+
params=params,
|
|
387
|
+
tools=tools,
|
|
388
|
+
input_messages=input_messages,
|
|
389
|
+
chunk_iterator=chunk_iterator,
|
|
390
|
+
format=format,
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
async def _context_stream_async(
|
|
394
|
+
self,
|
|
395
|
+
*,
|
|
396
|
+
ctx: Context[DepsT],
|
|
397
|
+
model_id: GoogleModelId,
|
|
398
|
+
messages: Sequence[Message],
|
|
399
|
+
tools: Sequence[AsyncTool | AsyncContextTool[DepsT]]
|
|
400
|
+
| AsyncContextToolkit[DepsT]
|
|
401
|
+
| None = None,
|
|
402
|
+
format: type[FormattableT] | Format[FormattableT] | None = None,
|
|
403
|
+
**params: Unpack[Params],
|
|
404
|
+
) -> (
|
|
405
|
+
AsyncContextStreamResponse[DepsT]
|
|
406
|
+
| AsyncContextStreamResponse[DepsT, FormattableT]
|
|
407
|
+
):
|
|
408
|
+
"""Generate an `llm.AsyncContextStreamResponse` by asynchronously streaming from the Google GenAI API.
|
|
409
|
+
|
|
410
|
+
Args:
|
|
411
|
+
ctx: Context object with dependencies for tools.
|
|
412
|
+
model_id: Model identifier to use.
|
|
413
|
+
messages: Messages to send to the LLM.
|
|
414
|
+
tools: Optional tools that the model may invoke.
|
|
415
|
+
format: Optional response format specifier.
|
|
416
|
+
**params: Additional parameters to configure output (e.g. temperature). See `llm.Params`.
|
|
417
|
+
|
|
418
|
+
Returns:
|
|
419
|
+
An `llm.AsyncContextStreamResponse` object for asynchronously iterating over the LLM-generated content.
|
|
420
|
+
"""
|
|
421
|
+
input_messages, format, kwargs = _utils.encode_request(
|
|
422
|
+
model_id=model_id,
|
|
423
|
+
messages=messages,
|
|
424
|
+
tools=tools,
|
|
425
|
+
format=format,
|
|
426
|
+
params=params,
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
google_stream = await self.client.aio.models.generate_content_stream(**kwargs)
|
|
430
|
+
|
|
431
|
+
chunk_iterator = _utils.decode_async_stream(google_stream)
|
|
432
|
+
|
|
433
|
+
return AsyncContextStreamResponse(
|
|
434
|
+
provider_id="google",
|
|
435
|
+
model_id=model_id,
|
|
436
|
+
provider_model_name=model_name(model_id),
|
|
437
|
+
params=params,
|
|
438
|
+
tools=tools,
|
|
439
|
+
input_messages=input_messages,
|
|
440
|
+
chunk_iterator=chunk_iterator,
|
|
441
|
+
format=format,
|
|
442
|
+
)
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
from functools import lru_cache
|
|
2
|
+
|
|
3
|
+
from .anthropic import AnthropicProvider
|
|
4
|
+
from .base import Provider
|
|
5
|
+
from .google import GoogleProvider
|
|
6
|
+
from .mlx import MLXProvider
|
|
7
|
+
from .ollama import OllamaProvider
|
|
8
|
+
from .openai import OpenAIProvider
|
|
9
|
+
from .openai.completions.provider import OpenAICompletionsProvider
|
|
10
|
+
from .openai.responses.provider import OpenAIResponsesProvider
|
|
11
|
+
from .provider_id import ProviderId
|
|
12
|
+
from .together import TogetherProvider
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@lru_cache(maxsize=256)
|
|
16
|
+
def load_provider(
|
|
17
|
+
provider_id: ProviderId, *, api_key: str | None = None, base_url: str | None = None
|
|
18
|
+
) -> Provider:
|
|
19
|
+
"""Create a cached provider instance for the specified provider id.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
provider_id: The provider name ("openai", "anthropic", or "google").
|
|
23
|
+
api_key: API key for authentication. If None, uses provider-specific env var.
|
|
24
|
+
base_url: Base URL for the API. If None, uses provider-specific env var.
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
A cached provider instance for the specified provider with the given parameters.
|
|
28
|
+
|
|
29
|
+
Raises:
|
|
30
|
+
ValueError: If the provider_id is not supported.
|
|
31
|
+
"""
|
|
32
|
+
match provider_id:
|
|
33
|
+
case "anthropic":
|
|
34
|
+
return AnthropicProvider(api_key=api_key, base_url=base_url)
|
|
35
|
+
case "google":
|
|
36
|
+
return GoogleProvider(api_key=api_key, base_url=base_url)
|
|
37
|
+
case "mlx": # pragma: no cover (MLX is only available on macOS)
|
|
38
|
+
return MLXProvider()
|
|
39
|
+
case "ollama":
|
|
40
|
+
return OllamaProvider(api_key=api_key, base_url=base_url)
|
|
41
|
+
case "openai":
|
|
42
|
+
return OpenAIProvider(api_key=api_key, base_url=base_url)
|
|
43
|
+
case "openai:completions":
|
|
44
|
+
return OpenAICompletionsProvider(api_key=api_key, base_url=base_url)
|
|
45
|
+
case "openai:responses":
|
|
46
|
+
return OpenAIResponsesProvider(api_key=api_key, base_url=base_url)
|
|
47
|
+
case "together":
|
|
48
|
+
return TogetherProvider(api_key=api_key, base_url=base_url)
|
|
49
|
+
case _: # pragma: no cover
|
|
50
|
+
raise ValueError(f"Unknown provider: '{provider_id}'")
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
load = load_provider
|
|
54
|
+
"""Convenient alias as `llm.providers.load`"""
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
"""MLX client implementation."""
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
if TYPE_CHECKING:
|
|
6
|
+
from .model_id import MLXModelId
|
|
7
|
+
from .provider import MLXProvider
|
|
8
|
+
else:
|
|
9
|
+
try:
|
|
10
|
+
from .model_id import MLXModelId
|
|
11
|
+
from .provider import MLXProvider
|
|
12
|
+
except ImportError: # pragma: no cover
|
|
13
|
+
from .._missing_import_stubs import (
|
|
14
|
+
create_import_error_stub,
|
|
15
|
+
create_provider_stub,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
MLXProvider = create_provider_stub("mlx", "MLXProvider")
|
|
19
|
+
MLXModelId = str
|
|
20
|
+
|
|
21
|
+
__all__ = [
|
|
22
|
+
"MLXModelId",
|
|
23
|
+
"MLXProvider",
|
|
24
|
+
]
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from typing import TypeAlias, TypedDict
|
|
3
|
+
|
|
4
|
+
import mlx.core as mx
|
|
5
|
+
from mlx_lm.generate import GenerationResponse
|
|
6
|
+
from mlx_lm.sample_utils import make_sampler
|
|
7
|
+
|
|
8
|
+
from ...responses import FinishReason, Usage
|
|
9
|
+
from ..base import Params, _utils as _base_utils
|
|
10
|
+
|
|
11
|
+
Sampler: TypeAlias = Callable[[mx.array], mx.array]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class MakeSamplerKwargs(TypedDict, total=False):
|
|
15
|
+
"""Keyword arguments to be used for `mlx_lm`-s `make_sampler` function.
|
|
16
|
+
|
|
17
|
+
Some of these settings are directly match the generic client parameters
|
|
18
|
+
as defined in the `Params` class. See mirascope.llm.providers.Params for
|
|
19
|
+
more details.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
temp: float
|
|
23
|
+
"The temperature for sampling, if 0 the argmax is used."
|
|
24
|
+
|
|
25
|
+
top_p: float
|
|
26
|
+
"Nulceus sampling, higher means model considers more less likely words."
|
|
27
|
+
|
|
28
|
+
min_p: float
|
|
29
|
+
"""The minimum value (scaled by the top token's probability) that a token
|
|
30
|
+
probability must have to be considered."""
|
|
31
|
+
|
|
32
|
+
min_tokens_to_keep: int
|
|
33
|
+
"Minimum number of tokens that cannot be filtered by min_p sampling."
|
|
34
|
+
|
|
35
|
+
top_k: int
|
|
36
|
+
"The top k tokens ranked by probability to constrain the sampling to."
|
|
37
|
+
|
|
38
|
+
xtc_probability: float
|
|
39
|
+
"The probability of applying XTC sampling."
|
|
40
|
+
|
|
41
|
+
xtc_threshold: float
|
|
42
|
+
"The threshold the probs need to reach for being sampled."
|
|
43
|
+
|
|
44
|
+
xtc_special_tokens: list[int]
|
|
45
|
+
"List of special tokens IDs to be excluded from XTC sampling."
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class StreamGenerateKwargs(TypedDict, total=False):
|
|
49
|
+
"""Keyword arguments for the `mlx-lm.stream_generate` function."""
|
|
50
|
+
|
|
51
|
+
max_tokens: int
|
|
52
|
+
"The maximum number of tokens to generate."
|
|
53
|
+
|
|
54
|
+
sampler: Sampler
|
|
55
|
+
"A sampler for sampling token from a vector of logits."
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def encode_params(params: Params) -> tuple[int | None, StreamGenerateKwargs]:
|
|
59
|
+
"""Convert generic params to mlx-lm stream_generate kwargs.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
params: The generic parameters.
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
The mlx-lm specific stream_generate keyword arguments.
|
|
66
|
+
"""
|
|
67
|
+
kwargs: StreamGenerateKwargs = {}
|
|
68
|
+
|
|
69
|
+
with _base_utils.ensure_all_params_accessed(
|
|
70
|
+
params=params,
|
|
71
|
+
provider_id="mlx",
|
|
72
|
+
unsupported_params=["stop_sequences", "thinking", "encode_thoughts_as_text"],
|
|
73
|
+
) as param_accessor:
|
|
74
|
+
if param_accessor.max_tokens is not None:
|
|
75
|
+
kwargs["max_tokens"] = param_accessor.max_tokens
|
|
76
|
+
else:
|
|
77
|
+
kwargs["max_tokens"] = -1
|
|
78
|
+
|
|
79
|
+
sampler_kwargs = MakeSamplerKwargs({})
|
|
80
|
+
if param_accessor.temperature is not None:
|
|
81
|
+
sampler_kwargs["temp"] = param_accessor.temperature
|
|
82
|
+
if param_accessor.top_k is not None:
|
|
83
|
+
sampler_kwargs["top_k"] = param_accessor.top_k
|
|
84
|
+
if param_accessor.top_p is not None:
|
|
85
|
+
sampler_kwargs["top_p"] = param_accessor.top_p
|
|
86
|
+
|
|
87
|
+
kwargs["sampler"] = make_sampler(**sampler_kwargs)
|
|
88
|
+
|
|
89
|
+
return param_accessor.seed, kwargs
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def extract_finish_reason(response: GenerationResponse | None) -> FinishReason | None:
|
|
93
|
+
"""Extract the finish reason from an MLX generation response.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
response: The MLX generation response to extract from.
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
The normalized finish reason, or None if not applicable.
|
|
100
|
+
"""
|
|
101
|
+
if response is None:
|
|
102
|
+
return None
|
|
103
|
+
|
|
104
|
+
if response.finish_reason == "length":
|
|
105
|
+
return FinishReason.MAX_TOKENS
|
|
106
|
+
|
|
107
|
+
return None
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def extract_usage(response: GenerationResponse | None) -> Usage | None:
|
|
111
|
+
"""Extract usage information from an MLX generation response.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
response: The MLX generation response to extract from.
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
The Usage object with token counts, or None if not applicable.
|
|
118
|
+
"""
|
|
119
|
+
if response is None:
|
|
120
|
+
return None
|
|
121
|
+
|
|
122
|
+
return Usage(
|
|
123
|
+
input_tokens=response.prompt_tokens,
|
|
124
|
+
output_tokens=response.generation_tokens,
|
|
125
|
+
cache_read_tokens=0,
|
|
126
|
+
cache_write_tokens=0,
|
|
127
|
+
reasoning_tokens=0,
|
|
128
|
+
raw=response,
|
|
129
|
+
)
|