mirascope 2.0.0a2__py3-none-any.whl → 2.0.0a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mirascope/__init__.py +2 -2
- mirascope/api/__init__.py +6 -0
- mirascope/api/_generated/README.md +207 -0
- mirascope/api/_generated/__init__.py +141 -0
- mirascope/api/_generated/client.py +163 -0
- mirascope/api/_generated/core/__init__.py +52 -0
- mirascope/api/_generated/core/api_error.py +23 -0
- mirascope/api/_generated/core/client_wrapper.py +58 -0
- mirascope/api/_generated/core/datetime_utils.py +30 -0
- mirascope/api/_generated/core/file.py +70 -0
- mirascope/api/_generated/core/force_multipart.py +16 -0
- mirascope/api/_generated/core/http_client.py +619 -0
- mirascope/api/_generated/core/http_response.py +55 -0
- mirascope/api/_generated/core/jsonable_encoder.py +102 -0
- mirascope/api/_generated/core/pydantic_utilities.py +310 -0
- mirascope/api/_generated/core/query_encoder.py +60 -0
- mirascope/api/_generated/core/remove_none_from_dict.py +11 -0
- mirascope/api/_generated/core/request_options.py +35 -0
- mirascope/api/_generated/core/serialization.py +282 -0
- mirascope/api/_generated/docs/__init__.py +4 -0
- mirascope/api/_generated/docs/client.py +95 -0
- mirascope/api/_generated/docs/raw_client.py +132 -0
- mirascope/api/_generated/environment.py +9 -0
- mirascope/api/_generated/errors/__init__.py +17 -0
- mirascope/api/_generated/errors/bad_request_error.py +15 -0
- mirascope/api/_generated/errors/conflict_error.py +15 -0
- mirascope/api/_generated/errors/forbidden_error.py +15 -0
- mirascope/api/_generated/errors/internal_server_error.py +15 -0
- mirascope/api/_generated/errors/not_found_error.py +15 -0
- mirascope/api/_generated/health/__init__.py +7 -0
- mirascope/api/_generated/health/client.py +96 -0
- mirascope/api/_generated/health/raw_client.py +129 -0
- mirascope/api/_generated/health/types/__init__.py +8 -0
- mirascope/api/_generated/health/types/health_check_response.py +24 -0
- mirascope/api/_generated/health/types/health_check_response_status.py +5 -0
- mirascope/api/_generated/organizations/__init__.py +25 -0
- mirascope/api/_generated/organizations/client.py +380 -0
- mirascope/api/_generated/organizations/raw_client.py +876 -0
- mirascope/api/_generated/organizations/types/__init__.py +23 -0
- mirascope/api/_generated/organizations/types/organizations_create_response.py +24 -0
- mirascope/api/_generated/organizations/types/organizations_create_response_role.py +7 -0
- mirascope/api/_generated/organizations/types/organizations_get_response.py +24 -0
- mirascope/api/_generated/organizations/types/organizations_get_response_role.py +7 -0
- mirascope/api/_generated/organizations/types/organizations_list_response_item.py +24 -0
- mirascope/api/_generated/organizations/types/organizations_list_response_item_role.py +7 -0
- mirascope/api/_generated/organizations/types/organizations_update_response.py +24 -0
- mirascope/api/_generated/organizations/types/organizations_update_response_role.py +7 -0
- mirascope/api/_generated/projects/__init__.py +17 -0
- mirascope/api/_generated/projects/client.py +458 -0
- mirascope/api/_generated/projects/raw_client.py +1016 -0
- mirascope/api/_generated/projects/types/__init__.py +15 -0
- mirascope/api/_generated/projects/types/projects_create_response.py +30 -0
- mirascope/api/_generated/projects/types/projects_get_response.py +30 -0
- mirascope/api/_generated/projects/types/projects_list_response_item.py +30 -0
- mirascope/api/_generated/projects/types/projects_update_response.py +30 -0
- mirascope/api/_generated/reference.md +753 -0
- mirascope/api/_generated/traces/__init__.py +55 -0
- mirascope/api/_generated/traces/client.py +162 -0
- mirascope/api/_generated/traces/raw_client.py +168 -0
- mirascope/api/_generated/traces/types/__init__.py +95 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item.py +36 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_resource.py +31 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_resource_attributes_item.py +25 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_resource_attributes_item_value.py +54 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_resource_attributes_item_value_array_value.py +23 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_resource_attributes_item_value_kvlist_value.py +28 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_resource_attributes_item_value_kvlist_value_values_item.py +24 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item.py +35 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_scope.py +35 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_scope_attributes_item.py +27 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_scope_attributes_item_value.py +54 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_scope_attributes_item_value_array_value.py +23 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_scope_attributes_item_value_kvlist_value.py +28 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_scope_attributes_item_value_kvlist_value_values_item.py +24 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_spans_item.py +60 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_spans_item_attributes_item.py +29 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_spans_item_attributes_item_value.py +54 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_spans_item_attributes_item_value_array_value.py +23 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_spans_item_attributes_item_value_kvlist_value.py +28 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_spans_item_attributes_item_value_kvlist_value_values_item.py +24 -0
- mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_spans_item_status.py +24 -0
- mirascope/api/_generated/traces/types/traces_create_response.py +27 -0
- mirascope/api/_generated/traces/types/traces_create_response_partial_success.py +28 -0
- mirascope/api/_generated/types/__init__.py +37 -0
- mirascope/api/_generated/types/already_exists_error.py +24 -0
- mirascope/api/_generated/types/already_exists_error_tag.py +5 -0
- mirascope/api/_generated/types/database_error.py +24 -0
- mirascope/api/_generated/types/database_error_tag.py +5 -0
- mirascope/api/_generated/types/http_api_decode_error.py +29 -0
- mirascope/api/_generated/types/http_api_decode_error_tag.py +5 -0
- mirascope/api/_generated/types/issue.py +40 -0
- mirascope/api/_generated/types/issue_tag.py +17 -0
- mirascope/api/_generated/types/not_found_error_body.py +24 -0
- mirascope/api/_generated/types/not_found_error_tag.py +5 -0
- mirascope/api/_generated/types/permission_denied_error.py +24 -0
- mirascope/api/_generated/types/permission_denied_error_tag.py +7 -0
- mirascope/api/_generated/types/property_key.py +7 -0
- mirascope/api/_generated/types/property_key_key.py +27 -0
- mirascope/api/_generated/types/property_key_key_tag.py +5 -0
- mirascope/api/client.py +255 -0
- mirascope/api/settings.py +81 -0
- mirascope/llm/__init__.py +45 -11
- mirascope/llm/calls/calls.py +81 -57
- mirascope/llm/calls/decorator.py +121 -115
- mirascope/llm/content/__init__.py +3 -2
- mirascope/llm/context/_utils.py +19 -6
- mirascope/llm/exceptions.py +30 -16
- mirascope/llm/formatting/_utils.py +9 -5
- mirascope/llm/formatting/format.py +2 -2
- mirascope/llm/formatting/from_call_args.py +2 -2
- mirascope/llm/messages/message.py +13 -5
- mirascope/llm/models/__init__.py +2 -2
- mirascope/llm/models/models.py +189 -81
- mirascope/llm/prompts/__init__.py +13 -12
- mirascope/llm/prompts/_utils.py +27 -24
- mirascope/llm/prompts/decorator.py +133 -204
- mirascope/llm/prompts/prompts.py +424 -0
- mirascope/llm/prompts/protocols.py +25 -59
- mirascope/llm/providers/__init__.py +44 -0
- mirascope/llm/{clients → providers}/_missing_import_stubs.py +8 -6
- mirascope/llm/providers/anthropic/__init__.py +29 -0
- mirascope/llm/providers/anthropic/_utils/__init__.py +23 -0
- mirascope/llm/providers/anthropic/_utils/beta_decode.py +271 -0
- mirascope/llm/providers/anthropic/_utils/beta_encode.py +216 -0
- mirascope/llm/{clients → providers}/anthropic/_utils/decode.py +44 -11
- mirascope/llm/providers/anthropic/_utils/encode.py +356 -0
- mirascope/llm/providers/anthropic/beta_provider.py +322 -0
- mirascope/llm/providers/anthropic/model_id.py +23 -0
- mirascope/llm/providers/anthropic/model_info.py +87 -0
- mirascope/llm/providers/anthropic/provider.py +416 -0
- mirascope/llm/{clients → providers}/base/__init__.py +3 -3
- mirascope/llm/{clients → providers}/base/_utils.py +25 -8
- mirascope/llm/{clients/base/client.py → providers/base/base_provider.py} +255 -126
- mirascope/llm/providers/google/__init__.py +21 -0
- mirascope/llm/{clients → providers}/google/_utils/decode.py +61 -7
- mirascope/llm/{clients → providers}/google/_utils/encode.py +44 -30
- mirascope/llm/providers/google/model_id.py +22 -0
- mirascope/llm/providers/google/model_info.py +62 -0
- mirascope/llm/providers/google/provider.py +442 -0
- mirascope/llm/providers/load_provider.py +54 -0
- mirascope/llm/providers/mlx/__init__.py +24 -0
- mirascope/llm/providers/mlx/_utils.py +129 -0
- mirascope/llm/providers/mlx/encoding/__init__.py +8 -0
- mirascope/llm/providers/mlx/encoding/base.py +69 -0
- mirascope/llm/providers/mlx/encoding/transformers.py +147 -0
- mirascope/llm/providers/mlx/mlx.py +237 -0
- mirascope/llm/providers/mlx/model_id.py +17 -0
- mirascope/llm/providers/mlx/provider.py +415 -0
- mirascope/llm/providers/model_id.py +16 -0
- mirascope/llm/providers/ollama/__init__.py +19 -0
- mirascope/llm/providers/ollama/provider.py +71 -0
- mirascope/llm/providers/openai/__init__.py +6 -0
- mirascope/llm/providers/openai/completions/__init__.py +25 -0
- mirascope/llm/{clients → providers}/openai/completions/_utils/__init__.py +2 -0
- mirascope/llm/{clients → providers}/openai/completions/_utils/decode.py +60 -6
- mirascope/llm/{clients → providers}/openai/completions/_utils/encode.py +37 -26
- mirascope/llm/providers/openai/completions/base_provider.py +513 -0
- mirascope/llm/providers/openai/completions/provider.py +22 -0
- mirascope/llm/providers/openai/model_id.py +31 -0
- mirascope/llm/providers/openai/model_info.py +303 -0
- mirascope/llm/providers/openai/provider.py +398 -0
- mirascope/llm/providers/openai/responses/__init__.py +21 -0
- mirascope/llm/{clients → providers}/openai/responses/_utils/decode.py +59 -6
- mirascope/llm/{clients → providers}/openai/responses/_utils/encode.py +34 -23
- mirascope/llm/providers/openai/responses/provider.py +469 -0
- mirascope/llm/providers/provider_id.py +23 -0
- mirascope/llm/providers/provider_registry.py +169 -0
- mirascope/llm/providers/together/__init__.py +19 -0
- mirascope/llm/providers/together/provider.py +40 -0
- mirascope/llm/responses/__init__.py +3 -0
- mirascope/llm/responses/base_response.py +14 -5
- mirascope/llm/responses/base_stream_response.py +35 -6
- mirascope/llm/responses/finish_reason.py +1 -0
- mirascope/llm/responses/response.py +33 -13
- mirascope/llm/responses/root_response.py +12 -13
- mirascope/llm/responses/stream_response.py +35 -23
- mirascope/llm/responses/usage.py +95 -0
- mirascope/llm/tools/__init__.py +9 -2
- mirascope/llm/tools/_utils.py +12 -3
- mirascope/llm/tools/protocols.py +4 -4
- mirascope/llm/tools/tool_schema.py +44 -9
- mirascope/llm/tools/tools.py +10 -9
- mirascope/ops/__init__.py +156 -0
- mirascope/ops/_internal/__init__.py +5 -0
- mirascope/ops/_internal/closure.py +1118 -0
- mirascope/ops/_internal/configuration.py +126 -0
- mirascope/ops/_internal/context.py +76 -0
- mirascope/ops/_internal/exporters/__init__.py +26 -0
- mirascope/ops/_internal/exporters/exporters.py +342 -0
- mirascope/ops/_internal/exporters/processors.py +104 -0
- mirascope/ops/_internal/exporters/types.py +165 -0
- mirascope/ops/_internal/exporters/utils.py +29 -0
- mirascope/ops/_internal/instrumentation/__init__.py +8 -0
- mirascope/ops/_internal/instrumentation/llm/__init__.py +8 -0
- mirascope/ops/_internal/instrumentation/llm/encode.py +238 -0
- mirascope/ops/_internal/instrumentation/llm/gen_ai_types/__init__.py +38 -0
- mirascope/ops/_internal/instrumentation/llm/gen_ai_types/gen_ai_input_messages.py +31 -0
- mirascope/ops/_internal/instrumentation/llm/gen_ai_types/gen_ai_output_messages.py +38 -0
- mirascope/ops/_internal/instrumentation/llm/gen_ai_types/gen_ai_system_instructions.py +18 -0
- mirascope/ops/_internal/instrumentation/llm/gen_ai_types/shared.py +100 -0
- mirascope/ops/_internal/instrumentation/llm/llm.py +1288 -0
- mirascope/ops/_internal/propagation.py +198 -0
- mirascope/ops/_internal/protocols.py +51 -0
- mirascope/ops/_internal/session.py +139 -0
- mirascope/ops/_internal/spans.py +232 -0
- mirascope/ops/_internal/traced_calls.py +371 -0
- mirascope/ops/_internal/traced_functions.py +394 -0
- mirascope/ops/_internal/tracing.py +276 -0
- mirascope/ops/_internal/types.py +13 -0
- mirascope/ops/_internal/utils.py +75 -0
- mirascope/ops/_internal/versioned_calls.py +512 -0
- mirascope/ops/_internal/versioned_functions.py +346 -0
- mirascope/ops/_internal/versioning.py +303 -0
- mirascope/ops/exceptions.py +21 -0
- {mirascope-2.0.0a2.dist-info → mirascope-2.0.0a4.dist-info}/METADATA +78 -3
- mirascope-2.0.0a4.dist-info/RECORD +247 -0
- {mirascope-2.0.0a2.dist-info → mirascope-2.0.0a4.dist-info}/WHEEL +1 -1
- mirascope/graphs/__init__.py +0 -22
- mirascope/graphs/finite_state_machine.py +0 -625
- mirascope/llm/agents/__init__.py +0 -15
- mirascope/llm/agents/agent.py +0 -97
- mirascope/llm/agents/agent_template.py +0 -45
- mirascope/llm/agents/decorator.py +0 -176
- mirascope/llm/calls/base_call.py +0 -33
- mirascope/llm/clients/__init__.py +0 -34
- mirascope/llm/clients/anthropic/__init__.py +0 -25
- mirascope/llm/clients/anthropic/_utils/encode.py +0 -243
- mirascope/llm/clients/anthropic/clients.py +0 -819
- mirascope/llm/clients/anthropic/model_ids.py +0 -8
- mirascope/llm/clients/google/__init__.py +0 -20
- mirascope/llm/clients/google/clients.py +0 -853
- mirascope/llm/clients/google/model_ids.py +0 -15
- mirascope/llm/clients/openai/__init__.py +0 -25
- mirascope/llm/clients/openai/completions/__init__.py +0 -28
- mirascope/llm/clients/openai/completions/_utils/model_features.py +0 -81
- mirascope/llm/clients/openai/completions/clients.py +0 -833
- mirascope/llm/clients/openai/completions/model_ids.py +0 -8
- mirascope/llm/clients/openai/responses/__init__.py +0 -26
- mirascope/llm/clients/openai/responses/_utils/__init__.py +0 -13
- mirascope/llm/clients/openai/responses/_utils/model_features.py +0 -87
- mirascope/llm/clients/openai/responses/clients.py +0 -832
- mirascope/llm/clients/openai/responses/model_ids.py +0 -8
- mirascope/llm/clients/openai/shared/__init__.py +0 -7
- mirascope/llm/clients/openai/shared/_utils.py +0 -55
- mirascope/llm/clients/providers.py +0 -175
- mirascope-2.0.0a2.dist-info/RECORD +0 -102
- /mirascope/llm/{clients → providers}/base/kwargs.py +0 -0
- /mirascope/llm/{clients → providers}/base/params.py +0 -0
- /mirascope/llm/{clients/anthropic → providers/google}/_utils/__init__.py +0 -0
- /mirascope/llm/{clients → providers}/google/message.py +0 -0
- /mirascope/llm/{clients/google → providers/openai/responses}/_utils/__init__.py +0 -0
- {mirascope-2.0.0a2.dist-info → mirascope-2.0.0a4.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,415 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
from functools import cache, lru_cache
|
|
3
|
+
from typing import cast
|
|
4
|
+
from typing_extensions import Unpack
|
|
5
|
+
|
|
6
|
+
import mlx.nn as nn
|
|
7
|
+
from mlx_lm import load as mlx_load
|
|
8
|
+
from transformers import PreTrainedTokenizer
|
|
9
|
+
|
|
10
|
+
from ...context import Context, DepsT
|
|
11
|
+
from ...formatting import Format, FormattableT
|
|
12
|
+
from ...messages import Message
|
|
13
|
+
from ...responses import (
|
|
14
|
+
AsyncContextResponse,
|
|
15
|
+
AsyncContextStreamResponse,
|
|
16
|
+
AsyncResponse,
|
|
17
|
+
AsyncStreamResponse,
|
|
18
|
+
ContextResponse,
|
|
19
|
+
ContextStreamResponse,
|
|
20
|
+
Response,
|
|
21
|
+
StreamResponse,
|
|
22
|
+
)
|
|
23
|
+
from ...tools import (
|
|
24
|
+
AsyncContextTool,
|
|
25
|
+
AsyncContextToolkit,
|
|
26
|
+
AsyncTool,
|
|
27
|
+
AsyncToolkit,
|
|
28
|
+
ContextTool,
|
|
29
|
+
ContextToolkit,
|
|
30
|
+
Tool,
|
|
31
|
+
Toolkit,
|
|
32
|
+
)
|
|
33
|
+
from ..base import BaseProvider, Params
|
|
34
|
+
from . import _utils
|
|
35
|
+
from .encoding import TransformersEncoder
|
|
36
|
+
from .mlx import MLX
|
|
37
|
+
from .model_id import MLXModelId
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@cache
|
|
41
|
+
def _mlx_client_singleton() -> "MLXProvider":
|
|
42
|
+
"""Get or create the singleton MLX client instance."""
|
|
43
|
+
return MLXProvider()
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def client() -> "MLXProvider":
|
|
47
|
+
"""Get the MLX client singleton instance."""
|
|
48
|
+
return _mlx_client_singleton()
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@lru_cache(maxsize=16)
|
|
52
|
+
def _get_mlx(model_id: MLXModelId) -> MLX:
|
|
53
|
+
model, tokenizer = cast(tuple[nn.Module, PreTrainedTokenizer], mlx_load(model_id))
|
|
54
|
+
encoder = TransformersEncoder(tokenizer)
|
|
55
|
+
return MLX(
|
|
56
|
+
model_id,
|
|
57
|
+
model,
|
|
58
|
+
tokenizer,
|
|
59
|
+
encoder,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class MLXProvider(BaseProvider[None]):
|
|
64
|
+
"""Client for interacting with MLX language models.
|
|
65
|
+
|
|
66
|
+
This client provides methods for generating responses from MLX models,
|
|
67
|
+
supporting both synchronous and asynchronous operations, as well as
|
|
68
|
+
streaming responses.
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
id = "mlx"
|
|
72
|
+
default_scope = "mlx-community/"
|
|
73
|
+
|
|
74
|
+
def _call(
|
|
75
|
+
self,
|
|
76
|
+
*,
|
|
77
|
+
model_id: MLXModelId,
|
|
78
|
+
messages: Sequence[Message],
|
|
79
|
+
tools: Sequence[Tool] | Toolkit | None = None,
|
|
80
|
+
format: type[FormattableT] | Format[FormattableT] | None = None,
|
|
81
|
+
**params: Unpack[Params],
|
|
82
|
+
) -> Response | Response[FormattableT]:
|
|
83
|
+
"""Generate an `llm.Response` using MLX model.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
model_id: Model identifier to use.
|
|
87
|
+
messages: Messages to send to the LLM.
|
|
88
|
+
tools: Optional tools that the model may invoke.
|
|
89
|
+
format: Optional response format specifier.
|
|
90
|
+
**params: Additional parameters to configure output (e.g. temperature). See `llm.Params`.
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
An `llm.Response` object containing the LLM-generated content.
|
|
94
|
+
"""
|
|
95
|
+
mlx = _get_mlx(model_id)
|
|
96
|
+
|
|
97
|
+
input_messages, format, assistant_message, response = mlx.generate(
|
|
98
|
+
messages, tools, format, params
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
return Response(
|
|
102
|
+
raw=response,
|
|
103
|
+
provider_id="mlx",
|
|
104
|
+
model_id=model_id,
|
|
105
|
+
provider_model_name=model_id,
|
|
106
|
+
params=params,
|
|
107
|
+
tools=tools,
|
|
108
|
+
input_messages=input_messages,
|
|
109
|
+
assistant_message=assistant_message,
|
|
110
|
+
finish_reason=_utils.extract_finish_reason(response),
|
|
111
|
+
usage=_utils.extract_usage(response),
|
|
112
|
+
format=format,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
def _context_call(
|
|
116
|
+
self,
|
|
117
|
+
*,
|
|
118
|
+
ctx: Context[DepsT],
|
|
119
|
+
model_id: MLXModelId,
|
|
120
|
+
messages: Sequence[Message],
|
|
121
|
+
tools: Sequence[Tool | ContextTool[DepsT]]
|
|
122
|
+
| ContextToolkit[DepsT]
|
|
123
|
+
| None = None,
|
|
124
|
+
format: type[FormattableT] | Format[FormattableT] | None = None,
|
|
125
|
+
**params: Unpack[Params],
|
|
126
|
+
) -> ContextResponse[DepsT, None] | ContextResponse[DepsT, FormattableT]:
|
|
127
|
+
"""Generate an `llm.ContextResponse` using MLX model.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
ctx: Context object with dependencies for tools.
|
|
131
|
+
model_id: Model identifier to use.
|
|
132
|
+
messages: Messages to send to the LLM.
|
|
133
|
+
tools: Optional tools that the model may invoke.
|
|
134
|
+
format: Optional response format specifier.
|
|
135
|
+
**params: Additional parameters to configure output (e.g. temperature). See `llm.Params`.
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
An `llm.ContextResponse` object containing the LLM-generated content.
|
|
139
|
+
"""
|
|
140
|
+
mlx = _get_mlx(model_id)
|
|
141
|
+
|
|
142
|
+
input_messages, format, assistant_message, response = mlx.generate(
|
|
143
|
+
messages, tools, format, params
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
return ContextResponse(
|
|
147
|
+
raw=response,
|
|
148
|
+
provider_id="mlx",
|
|
149
|
+
model_id=model_id,
|
|
150
|
+
provider_model_name=model_id,
|
|
151
|
+
params=params,
|
|
152
|
+
tools=tools,
|
|
153
|
+
input_messages=input_messages,
|
|
154
|
+
assistant_message=assistant_message,
|
|
155
|
+
finish_reason=_utils.extract_finish_reason(response),
|
|
156
|
+
usage=_utils.extract_usage(response),
|
|
157
|
+
format=format,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
async def _call_async(
|
|
161
|
+
self,
|
|
162
|
+
*,
|
|
163
|
+
model_id: MLXModelId,
|
|
164
|
+
messages: Sequence[Message],
|
|
165
|
+
tools: Sequence[AsyncTool] | AsyncToolkit | None = None,
|
|
166
|
+
format: type[FormattableT] | Format[FormattableT] | None = None,
|
|
167
|
+
**params: Unpack[Params],
|
|
168
|
+
) -> AsyncResponse | AsyncResponse[FormattableT]:
|
|
169
|
+
"""Generate an `llm.AsyncResponse` using MLX model by asynchronously calloing
|
|
170
|
+
`asycio.to_thread`.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
model_id: Model identifier to use.
|
|
174
|
+
messages: Messages to send to the LLM.
|
|
175
|
+
tools: Optional tools that the model may invoke.
|
|
176
|
+
format: Optional response format specifier.
|
|
177
|
+
**params: Additional parameters to configure output (e.g. temperature). See `llm.Params`.
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
An `llm.AsyncResponse` object containing the LLM-generated content.
|
|
181
|
+
"""
|
|
182
|
+
mlx = _get_mlx(model_id)
|
|
183
|
+
|
|
184
|
+
(
|
|
185
|
+
input_messages,
|
|
186
|
+
format,
|
|
187
|
+
assistant_message,
|
|
188
|
+
response,
|
|
189
|
+
) = await mlx.generate_async(messages, tools, format, params)
|
|
190
|
+
|
|
191
|
+
return AsyncResponse(
|
|
192
|
+
raw=response,
|
|
193
|
+
provider_id="mlx",
|
|
194
|
+
model_id=model_id,
|
|
195
|
+
provider_model_name=model_id,
|
|
196
|
+
params=params,
|
|
197
|
+
tools=tools,
|
|
198
|
+
input_messages=input_messages,
|
|
199
|
+
assistant_message=assistant_message,
|
|
200
|
+
finish_reason=_utils.extract_finish_reason(response),
|
|
201
|
+
usage=_utils.extract_usage(response),
|
|
202
|
+
format=format,
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
async def _context_call_async(
|
|
206
|
+
self,
|
|
207
|
+
*,
|
|
208
|
+
ctx: Context[DepsT],
|
|
209
|
+
model_id: MLXModelId,
|
|
210
|
+
messages: Sequence[Message],
|
|
211
|
+
tools: Sequence[AsyncTool | AsyncContextTool[DepsT]]
|
|
212
|
+
| AsyncContextToolkit[DepsT]
|
|
213
|
+
| None = None,
|
|
214
|
+
format: type[FormattableT] | Format[FormattableT] | None = None,
|
|
215
|
+
**params: Unpack[Params],
|
|
216
|
+
) -> AsyncContextResponse[DepsT, None] | AsyncContextResponse[DepsT, FormattableT]:
|
|
217
|
+
"""Generate an `llm.AsyncResponse` using MLX model by asynchronously calloing
|
|
218
|
+
`asycio.to_thread`.
|
|
219
|
+
|
|
220
|
+
Args:
|
|
221
|
+
ctx: Context object with dependencies for tools.
|
|
222
|
+
model_id: Model identifier to use.
|
|
223
|
+
messages: Messages to send to the LLM.
|
|
224
|
+
tools: Optional tools that the model may invoke.
|
|
225
|
+
format: Optional response format specifier.
|
|
226
|
+
**params: Additional parameters to configure output (e.g. temperature). See `llm.Params`.
|
|
227
|
+
|
|
228
|
+
Returns:
|
|
229
|
+
An `llm.AsyncContextResponse` object containing the LLM-generated content.
|
|
230
|
+
"""
|
|
231
|
+
mlx = _get_mlx(model_id)
|
|
232
|
+
|
|
233
|
+
(
|
|
234
|
+
input_messages,
|
|
235
|
+
format,
|
|
236
|
+
assistant_message,
|
|
237
|
+
response,
|
|
238
|
+
) = await mlx.generate_async(messages, tools, format, params)
|
|
239
|
+
|
|
240
|
+
return AsyncContextResponse(
|
|
241
|
+
raw=response,
|
|
242
|
+
provider_id="mlx",
|
|
243
|
+
model_id=model_id,
|
|
244
|
+
provider_model_name=model_id,
|
|
245
|
+
params=params,
|
|
246
|
+
tools=tools,
|
|
247
|
+
input_messages=input_messages,
|
|
248
|
+
assistant_message=assistant_message,
|
|
249
|
+
finish_reason=_utils.extract_finish_reason(response),
|
|
250
|
+
usage=_utils.extract_usage(response),
|
|
251
|
+
format=format,
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
def _stream(
|
|
255
|
+
self,
|
|
256
|
+
*,
|
|
257
|
+
model_id: MLXModelId,
|
|
258
|
+
messages: Sequence[Message],
|
|
259
|
+
tools: Sequence[Tool] | Toolkit | None = None,
|
|
260
|
+
format: type[FormattableT] | Format[FormattableT] | None = None,
|
|
261
|
+
**params: Unpack[Params],
|
|
262
|
+
) -> StreamResponse | StreamResponse[FormattableT]:
|
|
263
|
+
"""Generate an `llm.StreamResponse` by synchronously streaming from MLX model output.
|
|
264
|
+
|
|
265
|
+
Args:
|
|
266
|
+
model_id: Model identifier to use.
|
|
267
|
+
messages: Messages to send to the LLM.
|
|
268
|
+
tools: Optional tools that the model may invoke.
|
|
269
|
+
format: Optional response format specifier.
|
|
270
|
+
**params: Additional parameters to configure output (e.g. temperature). See `llm.Params`.
|
|
271
|
+
|
|
272
|
+
Returns:
|
|
273
|
+
An `llm.StreamResponse` object for iterating over the LLM-generated content.
|
|
274
|
+
"""
|
|
275
|
+
mlx = _get_mlx(model_id)
|
|
276
|
+
|
|
277
|
+
input_messages, format, chunk_iterator = mlx.stream(
|
|
278
|
+
messages, tools, format, params
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
return StreamResponse(
|
|
282
|
+
provider_id="mlx",
|
|
283
|
+
model_id=model_id,
|
|
284
|
+
provider_model_name=model_id,
|
|
285
|
+
params=params,
|
|
286
|
+
tools=tools,
|
|
287
|
+
input_messages=input_messages,
|
|
288
|
+
chunk_iterator=chunk_iterator,
|
|
289
|
+
format=format,
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
def _context_stream(
|
|
293
|
+
self,
|
|
294
|
+
*,
|
|
295
|
+
ctx: Context[DepsT],
|
|
296
|
+
model_id: MLXModelId,
|
|
297
|
+
messages: Sequence[Message],
|
|
298
|
+
tools: Sequence[Tool | ContextTool[DepsT]]
|
|
299
|
+
| ContextToolkit[DepsT]
|
|
300
|
+
| None = None,
|
|
301
|
+
format: type[FormattableT] | Format[FormattableT] | None = None,
|
|
302
|
+
**params: Unpack[Params],
|
|
303
|
+
) -> ContextStreamResponse[DepsT] | ContextStreamResponse[DepsT, FormattableT]:
|
|
304
|
+
"""Generate an `llm.ContextStreamResponse` by synchronously streaming from MLX model output.
|
|
305
|
+
|
|
306
|
+
Args:
|
|
307
|
+
ctx: Context object with dependencies for tools.
|
|
308
|
+
model_id: Model identifier to use.
|
|
309
|
+
messages: Messages to send to the LLM.
|
|
310
|
+
tools: Optional tools that the model may invoke.
|
|
311
|
+
format: Optional response format specifier.
|
|
312
|
+
**params: Additional parameters to configure output (e.g. temperature). See `llm.Params`.
|
|
313
|
+
|
|
314
|
+
Returns:
|
|
315
|
+
An `llm.ContextStreamResponse` object for iterating over the LLM-generated content.
|
|
316
|
+
"""
|
|
317
|
+
mlx = _get_mlx(model_id)
|
|
318
|
+
|
|
319
|
+
input_messages, format, chunk_iterator = mlx.stream(
|
|
320
|
+
messages, tools, format, params
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
return ContextStreamResponse(
|
|
324
|
+
provider_id="mlx",
|
|
325
|
+
model_id=model_id,
|
|
326
|
+
provider_model_name=model_id,
|
|
327
|
+
params=params,
|
|
328
|
+
tools=tools,
|
|
329
|
+
input_messages=input_messages,
|
|
330
|
+
chunk_iterator=chunk_iterator,
|
|
331
|
+
format=format,
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
async def _stream_async(
|
|
335
|
+
self,
|
|
336
|
+
*,
|
|
337
|
+
model_id: MLXModelId,
|
|
338
|
+
messages: Sequence[Message],
|
|
339
|
+
tools: Sequence[AsyncTool] | AsyncToolkit | None = None,
|
|
340
|
+
format: type[FormattableT] | Format[FormattableT] | None = None,
|
|
341
|
+
**params: Unpack[Params],
|
|
342
|
+
) -> AsyncStreamResponse | AsyncStreamResponse[FormattableT]:
|
|
343
|
+
"""Generate an `llm.AsyncStreamResponse` by asynchronously streaming from MLX model output.
|
|
344
|
+
|
|
345
|
+
Args:
|
|
346
|
+
model_id: Model identifier to use.
|
|
347
|
+
messages: Messages to send to the LLM.
|
|
348
|
+
tools: Optional tools that the model may invoke.
|
|
349
|
+
format: Optional response format specifier.
|
|
350
|
+
**params: Additional parameters to configure output (e.g. temperature). See `llm.Params`.
|
|
351
|
+
|
|
352
|
+
Returns:
|
|
353
|
+
An `llm.AsyncStreamResponse` object for asynchronously iterating over the LLM-generated content.
|
|
354
|
+
"""
|
|
355
|
+
mlx = _get_mlx(model_id)
|
|
356
|
+
|
|
357
|
+
input_messages, format, chunk_iterator = await mlx.stream_async(
|
|
358
|
+
messages, tools, format, params
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
return AsyncStreamResponse(
|
|
362
|
+
provider_id="mlx",
|
|
363
|
+
model_id=model_id,
|
|
364
|
+
provider_model_name=model_id,
|
|
365
|
+
params=params,
|
|
366
|
+
tools=tools,
|
|
367
|
+
input_messages=input_messages,
|
|
368
|
+
chunk_iterator=chunk_iterator,
|
|
369
|
+
format=format,
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
async def _context_stream_async(
|
|
373
|
+
self,
|
|
374
|
+
*,
|
|
375
|
+
ctx: Context[DepsT],
|
|
376
|
+
model_id: MLXModelId,
|
|
377
|
+
messages: Sequence[Message],
|
|
378
|
+
tools: Sequence[AsyncTool | AsyncContextTool[DepsT]]
|
|
379
|
+
| AsyncContextToolkit[DepsT]
|
|
380
|
+
| None = None,
|
|
381
|
+
format: type[FormattableT] | Format[FormattableT] | None = None,
|
|
382
|
+
**params: Unpack[Params],
|
|
383
|
+
) -> (
|
|
384
|
+
AsyncContextStreamResponse[DepsT]
|
|
385
|
+
| AsyncContextStreamResponse[DepsT, FormattableT]
|
|
386
|
+
):
|
|
387
|
+
"""Generate an `llm.AsyncContextStreamResponse` by asynchronously streaming from MLX model output.
|
|
388
|
+
|
|
389
|
+
Args:
|
|
390
|
+
ctx: Context object with dependencies for tools.
|
|
391
|
+
model_id: Model identifier to use.
|
|
392
|
+
messages: Messages to send to the LLM.
|
|
393
|
+
tools: Optional tools that the model may invoke.
|
|
394
|
+
format: Optional response format specifier.
|
|
395
|
+
**params: Additional parameters to configure output (e.g. temperature). See `llm.Params`.
|
|
396
|
+
|
|
397
|
+
Returns:
|
|
398
|
+
An `llm.AsyncContextStreamResponse` object for asynchronously iterating over the LLM-generated content.
|
|
399
|
+
"""
|
|
400
|
+
mlx = _get_mlx(model_id)
|
|
401
|
+
|
|
402
|
+
input_messages, format, chunk_iterator = await mlx.stream_async(
|
|
403
|
+
messages, tools, format, params
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
return AsyncContextStreamResponse(
|
|
407
|
+
provider_id="mlx",
|
|
408
|
+
model_id=model_id,
|
|
409
|
+
provider_model_name=model_id,
|
|
410
|
+
params=params,
|
|
411
|
+
tools=tools,
|
|
412
|
+
input_messages=input_messages,
|
|
413
|
+
chunk_iterator=chunk_iterator,
|
|
414
|
+
format=format,
|
|
415
|
+
)
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from typing import TypeAlias
|
|
2
|
+
|
|
3
|
+
from .anthropic import (
|
|
4
|
+
AnthropicModelId,
|
|
5
|
+
)
|
|
6
|
+
from .google import (
|
|
7
|
+
GoogleModelId,
|
|
8
|
+
)
|
|
9
|
+
from .mlx import (
|
|
10
|
+
MLXModelId,
|
|
11
|
+
)
|
|
12
|
+
from .openai import (
|
|
13
|
+
OpenAIModelId,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
ModelId: TypeAlias = AnthropicModelId | GoogleModelId | OpenAIModelId | MLXModelId | str
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""Ollama provider implementation."""
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
if TYPE_CHECKING:
|
|
6
|
+
from .provider import OllamaProvider
|
|
7
|
+
else:
|
|
8
|
+
try:
|
|
9
|
+
from .provider import OllamaProvider
|
|
10
|
+
except ImportError: # pragma: no cover
|
|
11
|
+
from .._missing_import_stubs import (
|
|
12
|
+
create_provider_stub,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
OllamaProvider = create_provider_stub("openai", "OllamaProvider")
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
"OllamaProvider",
|
|
19
|
+
]
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
"""Ollama provider implementation."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import ClassVar
|
|
5
|
+
|
|
6
|
+
from openai import AsyncOpenAI, OpenAI
|
|
7
|
+
|
|
8
|
+
from ..openai.completions.base_provider import BaseOpenAICompletionsProvider
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class OllamaProvider(BaseOpenAICompletionsProvider):
|
|
12
|
+
"""Provider for Ollama's OpenAI-compatible API.
|
|
13
|
+
|
|
14
|
+
Inherits from BaseOpenAICompletionsProvider with Ollama-specific configuration:
|
|
15
|
+
- Uses Ollama's local API endpoint (default: http://localhost:11434/v1/)
|
|
16
|
+
- API key is not required (Ollama ignores API keys)
|
|
17
|
+
- Supports OLLAMA_BASE_URL environment variable
|
|
18
|
+
|
|
19
|
+
Usage:
|
|
20
|
+
Register the provider with model ID prefixes you want to use:
|
|
21
|
+
|
|
22
|
+
```python
|
|
23
|
+
import llm
|
|
24
|
+
|
|
25
|
+
# Register for ollama models
|
|
26
|
+
llm.register_provider("ollama", "ollama/")
|
|
27
|
+
|
|
28
|
+
# Now you can use ollama models directly
|
|
29
|
+
@llm.call("ollama/llama2")
|
|
30
|
+
def my_prompt():
|
|
31
|
+
return [llm.messages.user("Hello!")]
|
|
32
|
+
```
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
id: ClassVar[str] = "ollama"
|
|
36
|
+
default_scope: ClassVar[str | list[str]] = "ollama/"
|
|
37
|
+
default_base_url: ClassVar[str | None] = "http://localhost:11434/v1/"
|
|
38
|
+
api_key_env_var: ClassVar[str] = "OLLAMA_API_KEY"
|
|
39
|
+
api_key_required: ClassVar[bool] = False
|
|
40
|
+
provider_name: ClassVar[str | None] = "Ollama"
|
|
41
|
+
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
*,
|
|
45
|
+
api_key: str | None = None,
|
|
46
|
+
base_url: str | None = None,
|
|
47
|
+
) -> None:
|
|
48
|
+
"""Initialize the Ollama provider.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
api_key: API key (optional). Defaults to OLLAMA_API_KEY env var or 'ollama'.
|
|
52
|
+
base_url: Custom base URL. Defaults to OLLAMA_BASE_URL env var
|
|
53
|
+
or http://localhost:11434/v1/.
|
|
54
|
+
"""
|
|
55
|
+
resolved_api_key = api_key or os.environ.get(self.api_key_env_var) or "ollama"
|
|
56
|
+
resolved_base_url = (
|
|
57
|
+
base_url or os.environ.get("OLLAMA_BASE_URL") or self.default_base_url
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
self.client = OpenAI(
|
|
61
|
+
api_key=resolved_api_key,
|
|
62
|
+
base_url=resolved_base_url,
|
|
63
|
+
)
|
|
64
|
+
self.async_client = AsyncOpenAI(
|
|
65
|
+
api_key=resolved_api_key,
|
|
66
|
+
base_url=resolved_base_url,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
def _model_name(self, model_id: str) -> str:
|
|
70
|
+
"""Strip 'ollama/' prefix from model ID for Ollama API."""
|
|
71
|
+
return model_id.removeprefix("ollama/")
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING
|
|
2
|
+
|
|
3
|
+
if TYPE_CHECKING:
|
|
4
|
+
from .base_provider import BaseOpenAICompletionsProvider
|
|
5
|
+
from .provider import OpenAICompletionsProvider
|
|
6
|
+
else:
|
|
7
|
+
try:
|
|
8
|
+
from .base_provider import BaseOpenAICompletionsProvider
|
|
9
|
+
from .provider import OpenAICompletionsProvider
|
|
10
|
+
except ImportError: # pragma: no cover
|
|
11
|
+
from ..._missing_import_stubs import (
|
|
12
|
+
create_provider_stub,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
BaseOpenAICompletionsProvider = create_provider_stub(
|
|
16
|
+
"openai", "BaseOpenAICompletionsProvider"
|
|
17
|
+
)
|
|
18
|
+
OpenAICompletionsProvider = create_provider_stub(
|
|
19
|
+
"openai", "OpenAICompletionsProvider"
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
__all__ = [
|
|
23
|
+
"BaseOpenAICompletionsProvider",
|
|
24
|
+
"OpenAICompletionsProvider",
|
|
25
|
+
]
|
|
@@ -4,6 +4,7 @@ from typing import Literal
|
|
|
4
4
|
|
|
5
5
|
from openai import AsyncStream, Stream
|
|
6
6
|
from openai.types import chat as openai_types
|
|
7
|
+
from openai.types.completion_usage import CompletionUsage
|
|
7
8
|
|
|
8
9
|
from .....content import (
|
|
9
10
|
AssistantContentPart,
|
|
@@ -23,8 +24,10 @@ from .....responses import (
|
|
|
23
24
|
FinishReason,
|
|
24
25
|
FinishReasonChunk,
|
|
25
26
|
RawStreamEventChunk,
|
|
27
|
+
Usage,
|
|
28
|
+
UsageDeltaChunk,
|
|
26
29
|
)
|
|
27
|
-
from
|
|
30
|
+
from ...model_id import OpenAIModelId, model_name
|
|
28
31
|
|
|
29
32
|
OPENAI_FINISH_REASON_MAP = {
|
|
30
33
|
"length": FinishReason.MAX_TOKENS,
|
|
@@ -32,11 +35,40 @@ OPENAI_FINISH_REASON_MAP = {
|
|
|
32
35
|
}
|
|
33
36
|
|
|
34
37
|
|
|
38
|
+
def _decode_usage(
|
|
39
|
+
usage: CompletionUsage | None,
|
|
40
|
+
) -> Usage | None:
|
|
41
|
+
"""Convert OpenAI CompletionUsage to Mirascope Usage."""
|
|
42
|
+
if usage is None: # pragma: no cover
|
|
43
|
+
return None
|
|
44
|
+
|
|
45
|
+
return Usage(
|
|
46
|
+
input_tokens=usage.prompt_tokens,
|
|
47
|
+
output_tokens=usage.completion_tokens,
|
|
48
|
+
cache_read_tokens=(
|
|
49
|
+
usage.prompt_tokens_details.cached_tokens
|
|
50
|
+
if usage.prompt_tokens_details
|
|
51
|
+
else None
|
|
52
|
+
)
|
|
53
|
+
or 0,
|
|
54
|
+
cache_write_tokens=0,
|
|
55
|
+
reasoning_tokens=(
|
|
56
|
+
usage.completion_tokens_details.reasoning_tokens
|
|
57
|
+
if usage.completion_tokens_details
|
|
58
|
+
else None
|
|
59
|
+
)
|
|
60
|
+
or 0,
|
|
61
|
+
raw=usage,
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
|
|
35
65
|
def decode_response(
|
|
36
66
|
response: openai_types.ChatCompletion,
|
|
37
|
-
model_id:
|
|
38
|
-
|
|
39
|
-
|
|
67
|
+
model_id: OpenAIModelId,
|
|
68
|
+
provider_id: str,
|
|
69
|
+
provider_model_name: str | None = None,
|
|
70
|
+
) -> tuple[AssistantMessage, FinishReason | None, Usage | None]:
|
|
71
|
+
"""Convert OpenAI ChatCompletion to mirascope AssistantMessage and usage."""
|
|
40
72
|
choice = response.choices[0]
|
|
41
73
|
message = choice.message
|
|
42
74
|
refused = False
|
|
@@ -69,12 +101,14 @@ def decode_response(
|
|
|
69
101
|
|
|
70
102
|
assistant_message = AssistantMessage(
|
|
71
103
|
content=parts,
|
|
72
|
-
|
|
104
|
+
provider_id=provider_id,
|
|
73
105
|
model_id=model_id,
|
|
106
|
+
provider_model_name=provider_model_name or model_name(model_id, "completions"),
|
|
74
107
|
raw_message=message.model_dump(exclude_none=True),
|
|
75
108
|
)
|
|
76
109
|
|
|
77
|
-
|
|
110
|
+
usage = _decode_usage(response.usage)
|
|
111
|
+
return assistant_message, finish_reason, usage
|
|
78
112
|
|
|
79
113
|
|
|
80
114
|
class _OpenAIChunkProcessor:
|
|
@@ -89,6 +123,26 @@ class _OpenAIChunkProcessor:
|
|
|
89
123
|
"""Process a single OpenAI chunk and yield the appropriate content chunks."""
|
|
90
124
|
yield RawStreamEventChunk(raw_stream_event=chunk)
|
|
91
125
|
|
|
126
|
+
if chunk.usage:
|
|
127
|
+
usage = chunk.usage
|
|
128
|
+
yield UsageDeltaChunk(
|
|
129
|
+
input_tokens=usage.prompt_tokens,
|
|
130
|
+
output_tokens=usage.completion_tokens,
|
|
131
|
+
cache_read_tokens=(
|
|
132
|
+
usage.prompt_tokens_details.cached_tokens
|
|
133
|
+
if usage.prompt_tokens_details
|
|
134
|
+
else None
|
|
135
|
+
)
|
|
136
|
+
or 0,
|
|
137
|
+
cache_write_tokens=0,
|
|
138
|
+
reasoning_tokens=(
|
|
139
|
+
usage.completion_tokens_details.reasoning_tokens
|
|
140
|
+
if usage.completion_tokens_details
|
|
141
|
+
else None
|
|
142
|
+
)
|
|
143
|
+
or 0,
|
|
144
|
+
)
|
|
145
|
+
|
|
92
146
|
choice = chunk.choices[0] if chunk.choices else None
|
|
93
147
|
if not choice:
|
|
94
148
|
return # pragma: no cover
|