mirascope 2.0.0a1__py3-none-any.whl → 2.0.0a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (205) hide show
  1. mirascope/__init__.py +2 -2
  2. mirascope/api/__init__.py +6 -0
  3. mirascope/api/_generated/README.md +207 -0
  4. mirascope/api/_generated/__init__.py +85 -0
  5. mirascope/api/_generated/client.py +155 -0
  6. mirascope/api/_generated/core/__init__.py +52 -0
  7. mirascope/api/_generated/core/api_error.py +23 -0
  8. mirascope/api/_generated/core/client_wrapper.py +58 -0
  9. mirascope/api/_generated/core/datetime_utils.py +30 -0
  10. mirascope/api/_generated/core/file.py +70 -0
  11. mirascope/api/_generated/core/force_multipart.py +16 -0
  12. mirascope/api/_generated/core/http_client.py +619 -0
  13. mirascope/api/_generated/core/http_response.py +55 -0
  14. mirascope/api/_generated/core/jsonable_encoder.py +102 -0
  15. mirascope/api/_generated/core/pydantic_utilities.py +310 -0
  16. mirascope/api/_generated/core/query_encoder.py +60 -0
  17. mirascope/api/_generated/core/remove_none_from_dict.py +11 -0
  18. mirascope/api/_generated/core/request_options.py +35 -0
  19. mirascope/api/_generated/core/serialization.py +282 -0
  20. mirascope/api/_generated/docs/__init__.py +4 -0
  21. mirascope/api/_generated/docs/client.py +95 -0
  22. mirascope/api/_generated/docs/raw_client.py +132 -0
  23. mirascope/api/_generated/environment.py +9 -0
  24. mirascope/api/_generated/errors/__init__.py +7 -0
  25. mirascope/api/_generated/errors/bad_request_error.py +15 -0
  26. mirascope/api/_generated/health/__init__.py +7 -0
  27. mirascope/api/_generated/health/client.py +96 -0
  28. mirascope/api/_generated/health/raw_client.py +129 -0
  29. mirascope/api/_generated/health/types/__init__.py +8 -0
  30. mirascope/api/_generated/health/types/health_check_response.py +24 -0
  31. mirascope/api/_generated/health/types/health_check_response_status.py +5 -0
  32. mirascope/api/_generated/reference.md +167 -0
  33. mirascope/api/_generated/traces/__init__.py +55 -0
  34. mirascope/api/_generated/traces/client.py +162 -0
  35. mirascope/api/_generated/traces/raw_client.py +168 -0
  36. mirascope/api/_generated/traces/types/__init__.py +95 -0
  37. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item.py +36 -0
  38. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_resource.py +31 -0
  39. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_resource_attributes_item.py +25 -0
  40. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_resource_attributes_item_value.py +54 -0
  41. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_resource_attributes_item_value_array_value.py +23 -0
  42. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_resource_attributes_item_value_kvlist_value.py +28 -0
  43. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_resource_attributes_item_value_kvlist_value_values_item.py +24 -0
  44. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item.py +35 -0
  45. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_scope.py +35 -0
  46. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_scope_attributes_item.py +27 -0
  47. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_scope_attributes_item_value.py +54 -0
  48. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_scope_attributes_item_value_array_value.py +23 -0
  49. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_scope_attributes_item_value_kvlist_value.py +28 -0
  50. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_scope_attributes_item_value_kvlist_value_values_item.py +24 -0
  51. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_spans_item.py +60 -0
  52. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_spans_item_attributes_item.py +29 -0
  53. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_spans_item_attributes_item_value.py +54 -0
  54. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_spans_item_attributes_item_value_array_value.py +23 -0
  55. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_spans_item_attributes_item_value_kvlist_value.py +28 -0
  56. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_spans_item_attributes_item_value_kvlist_value_values_item.py +24 -0
  57. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_spans_item_status.py +24 -0
  58. mirascope/api/_generated/traces/types/traces_create_response.py +27 -0
  59. mirascope/api/_generated/traces/types/traces_create_response_partial_success.py +28 -0
  60. mirascope/api/_generated/types/__init__.py +21 -0
  61. mirascope/api/_generated/types/http_api_decode_error.py +31 -0
  62. mirascope/api/_generated/types/http_api_decode_error_tag.py +5 -0
  63. mirascope/api/_generated/types/issue.py +44 -0
  64. mirascope/api/_generated/types/issue_tag.py +17 -0
  65. mirascope/api/_generated/types/property_key.py +7 -0
  66. mirascope/api/_generated/types/property_key_tag.py +29 -0
  67. mirascope/api/_generated/types/property_key_tag_tag.py +5 -0
  68. mirascope/api/client.py +255 -0
  69. mirascope/api/settings.py +81 -0
  70. mirascope/llm/__init__.py +41 -11
  71. mirascope/llm/calls/calls.py +81 -57
  72. mirascope/llm/calls/decorator.py +121 -115
  73. mirascope/llm/content/__init__.py +3 -2
  74. mirascope/llm/context/_utils.py +19 -6
  75. mirascope/llm/exceptions.py +30 -16
  76. mirascope/llm/formatting/_utils.py +9 -5
  77. mirascope/llm/formatting/format.py +2 -2
  78. mirascope/llm/formatting/from_call_args.py +2 -2
  79. mirascope/llm/messages/message.py +13 -5
  80. mirascope/llm/models/__init__.py +2 -2
  81. mirascope/llm/models/models.py +189 -81
  82. mirascope/llm/prompts/__init__.py +13 -12
  83. mirascope/llm/prompts/_utils.py +27 -24
  84. mirascope/llm/prompts/decorator.py +133 -204
  85. mirascope/llm/prompts/prompts.py +424 -0
  86. mirascope/llm/prompts/protocols.py +25 -59
  87. mirascope/llm/providers/__init__.py +38 -0
  88. mirascope/llm/{clients → providers}/_missing_import_stubs.py +8 -6
  89. mirascope/llm/providers/anthropic/__init__.py +24 -0
  90. mirascope/llm/{clients → providers}/anthropic/_utils/decode.py +5 -4
  91. mirascope/llm/{clients → providers}/anthropic/_utils/encode.py +31 -10
  92. mirascope/llm/providers/anthropic/model_id.py +40 -0
  93. mirascope/llm/{clients/anthropic/clients.py → providers/anthropic/provider.py} +33 -418
  94. mirascope/llm/{clients → providers}/base/__init__.py +3 -3
  95. mirascope/llm/{clients → providers}/base/_utils.py +10 -7
  96. mirascope/llm/{clients/base/client.py → providers/base/base_provider.py} +255 -126
  97. mirascope/llm/providers/google/__init__.py +21 -0
  98. mirascope/llm/{clients → providers}/google/_utils/decode.py +6 -4
  99. mirascope/llm/{clients → providers}/google/_utils/encode.py +30 -24
  100. mirascope/llm/providers/google/model_id.py +28 -0
  101. mirascope/llm/providers/google/provider.py +438 -0
  102. mirascope/llm/providers/load_provider.py +48 -0
  103. mirascope/llm/providers/mlx/__init__.py +24 -0
  104. mirascope/llm/providers/mlx/_utils.py +107 -0
  105. mirascope/llm/providers/mlx/encoding/__init__.py +8 -0
  106. mirascope/llm/providers/mlx/encoding/base.py +69 -0
  107. mirascope/llm/providers/mlx/encoding/transformers.py +131 -0
  108. mirascope/llm/providers/mlx/mlx.py +237 -0
  109. mirascope/llm/providers/mlx/model_id.py +17 -0
  110. mirascope/llm/providers/mlx/provider.py +411 -0
  111. mirascope/llm/providers/model_id.py +16 -0
  112. mirascope/llm/providers/openai/__init__.py +6 -0
  113. mirascope/llm/providers/openai/completions/__init__.py +20 -0
  114. mirascope/llm/{clients/openai/responses → providers/openai/completions}/_utils/__init__.py +2 -0
  115. mirascope/llm/{clients → providers}/openai/completions/_utils/decode.py +5 -3
  116. mirascope/llm/{clients → providers}/openai/completions/_utils/encode.py +33 -23
  117. mirascope/llm/providers/openai/completions/provider.py +456 -0
  118. mirascope/llm/providers/openai/model_id.py +31 -0
  119. mirascope/llm/providers/openai/model_info.py +246 -0
  120. mirascope/llm/providers/openai/provider.py +386 -0
  121. mirascope/llm/providers/openai/responses/__init__.py +21 -0
  122. mirascope/llm/{clients → providers}/openai/responses/_utils/decode.py +5 -3
  123. mirascope/llm/{clients → providers}/openai/responses/_utils/encode.py +28 -17
  124. mirascope/llm/providers/openai/responses/provider.py +470 -0
  125. mirascope/llm/{clients → providers}/openai/shared/_utils.py +7 -3
  126. mirascope/llm/providers/provider_id.py +13 -0
  127. mirascope/llm/providers/provider_registry.py +167 -0
  128. mirascope/llm/responses/base_response.py +10 -5
  129. mirascope/llm/responses/base_stream_response.py +10 -5
  130. mirascope/llm/responses/response.py +24 -13
  131. mirascope/llm/responses/root_response.py +7 -12
  132. mirascope/llm/responses/stream_response.py +35 -23
  133. mirascope/llm/tools/__init__.py +9 -2
  134. mirascope/llm/tools/_utils.py +12 -3
  135. mirascope/llm/tools/decorator.py +10 -10
  136. mirascope/llm/tools/protocols.py +4 -4
  137. mirascope/llm/tools/tool_schema.py +44 -9
  138. mirascope/llm/tools/tools.py +12 -11
  139. mirascope/ops/__init__.py +156 -0
  140. mirascope/ops/_internal/__init__.py +5 -0
  141. mirascope/ops/_internal/closure.py +1118 -0
  142. mirascope/ops/_internal/configuration.py +126 -0
  143. mirascope/ops/_internal/context.py +76 -0
  144. mirascope/ops/_internal/exporters/__init__.py +26 -0
  145. mirascope/ops/_internal/exporters/exporters.py +342 -0
  146. mirascope/ops/_internal/exporters/processors.py +104 -0
  147. mirascope/ops/_internal/exporters/types.py +165 -0
  148. mirascope/ops/_internal/exporters/utils.py +29 -0
  149. mirascope/ops/_internal/instrumentation/__init__.py +8 -0
  150. mirascope/ops/_internal/instrumentation/llm/__init__.py +8 -0
  151. mirascope/ops/_internal/instrumentation/llm/encode.py +238 -0
  152. mirascope/ops/_internal/instrumentation/llm/gen_ai_types/__init__.py +38 -0
  153. mirascope/ops/_internal/instrumentation/llm/gen_ai_types/gen_ai_input_messages.py +31 -0
  154. mirascope/ops/_internal/instrumentation/llm/gen_ai_types/gen_ai_output_messages.py +38 -0
  155. mirascope/ops/_internal/instrumentation/llm/gen_ai_types/gen_ai_system_instructions.py +18 -0
  156. mirascope/ops/_internal/instrumentation/llm/gen_ai_types/shared.py +100 -0
  157. mirascope/ops/_internal/instrumentation/llm/llm.py +1288 -0
  158. mirascope/ops/_internal/propagation.py +198 -0
  159. mirascope/ops/_internal/protocols.py +51 -0
  160. mirascope/ops/_internal/session.py +139 -0
  161. mirascope/ops/_internal/spans.py +232 -0
  162. mirascope/ops/_internal/traced_calls.py +371 -0
  163. mirascope/ops/_internal/traced_functions.py +394 -0
  164. mirascope/ops/_internal/tracing.py +276 -0
  165. mirascope/ops/_internal/types.py +13 -0
  166. mirascope/ops/_internal/utils.py +75 -0
  167. mirascope/ops/_internal/versioned_calls.py +512 -0
  168. mirascope/ops/_internal/versioned_functions.py +346 -0
  169. mirascope/ops/_internal/versioning.py +303 -0
  170. mirascope/ops/exceptions.py +21 -0
  171. {mirascope-2.0.0a1.dist-info → mirascope-2.0.0a3.dist-info}/METADATA +77 -1
  172. mirascope-2.0.0a3.dist-info/RECORD +206 -0
  173. {mirascope-2.0.0a1.dist-info → mirascope-2.0.0a3.dist-info}/WHEEL +1 -1
  174. mirascope/graphs/__init__.py +0 -22
  175. mirascope/graphs/finite_state_machine.py +0 -625
  176. mirascope/llm/agents/__init__.py +0 -15
  177. mirascope/llm/agents/agent.py +0 -97
  178. mirascope/llm/agents/agent_template.py +0 -45
  179. mirascope/llm/agents/decorator.py +0 -176
  180. mirascope/llm/calls/base_call.py +0 -33
  181. mirascope/llm/clients/__init__.py +0 -34
  182. mirascope/llm/clients/anthropic/__init__.py +0 -25
  183. mirascope/llm/clients/anthropic/model_ids.py +0 -8
  184. mirascope/llm/clients/google/__init__.py +0 -20
  185. mirascope/llm/clients/google/clients.py +0 -853
  186. mirascope/llm/clients/google/model_ids.py +0 -15
  187. mirascope/llm/clients/openai/__init__.py +0 -25
  188. mirascope/llm/clients/openai/completions/__init__.py +0 -28
  189. mirascope/llm/clients/openai/completions/_utils/model_features.py +0 -81
  190. mirascope/llm/clients/openai/completions/clients.py +0 -833
  191. mirascope/llm/clients/openai/completions/model_ids.py +0 -8
  192. mirascope/llm/clients/openai/responses/__init__.py +0 -26
  193. mirascope/llm/clients/openai/responses/_utils/model_features.py +0 -87
  194. mirascope/llm/clients/openai/responses/clients.py +0 -832
  195. mirascope/llm/clients/openai/responses/model_ids.py +0 -8
  196. mirascope/llm/clients/providers.py +0 -175
  197. mirascope-2.0.0a1.dist-info/RECORD +0 -102
  198. /mirascope/llm/{clients → providers}/anthropic/_utils/__init__.py +0 -0
  199. /mirascope/llm/{clients → providers}/base/kwargs.py +0 -0
  200. /mirascope/llm/{clients → providers}/base/params.py +0 -0
  201. /mirascope/llm/{clients → providers}/google/_utils/__init__.py +0 -0
  202. /mirascope/llm/{clients → providers}/google/message.py +0 -0
  203. /mirascope/llm/{clients/openai/completions → providers/openai/responses}/_utils/__init__.py +0 -0
  204. /mirascope/llm/{clients → providers}/openai/shared/__init__.py +0 -0
  205. {mirascope-2.0.0a1.dist-info → mirascope-2.0.0a3.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,411 @@
1
+ from collections.abc import Sequence
2
+ from functools import cache, lru_cache
3
+ from typing import cast
4
+ from typing_extensions import Unpack
5
+
6
+ import mlx.nn as nn
7
+ from mlx_lm import load as mlx_load
8
+ from transformers import PreTrainedTokenizer
9
+
10
+ from ...context import Context, DepsT
11
+ from ...formatting import Format, FormattableT
12
+ from ...messages import Message
13
+ from ...responses import (
14
+ AsyncContextResponse,
15
+ AsyncContextStreamResponse,
16
+ AsyncResponse,
17
+ AsyncStreamResponse,
18
+ ContextResponse,
19
+ ContextStreamResponse,
20
+ Response,
21
+ StreamResponse,
22
+ )
23
+ from ...tools import (
24
+ AsyncContextTool,
25
+ AsyncContextToolkit,
26
+ AsyncTool,
27
+ AsyncToolkit,
28
+ ContextTool,
29
+ ContextToolkit,
30
+ Tool,
31
+ Toolkit,
32
+ )
33
+ from ..base import BaseProvider, Params
34
+ from . import _utils
35
+ from .encoding import TransformersEncoder
36
+ from .mlx import MLX
37
+ from .model_id import MLXModelId
38
+
39
+
40
+ @cache
41
+ def _mlx_client_singleton() -> "MLXProvider":
42
+ """Get or create the singleton MLX client instance."""
43
+ return MLXProvider()
44
+
45
+
46
+ def client() -> "MLXProvider":
47
+ """Get the MLX client singleton instance."""
48
+ return _mlx_client_singleton()
49
+
50
+
51
+ @lru_cache(maxsize=16)
52
+ def _get_mlx(model_id: MLXModelId) -> MLX:
53
+ model, tokenizer = cast(tuple[nn.Module, PreTrainedTokenizer], mlx_load(model_id))
54
+ encoder = TransformersEncoder(tokenizer)
55
+ return MLX(
56
+ model_id,
57
+ model,
58
+ tokenizer,
59
+ encoder,
60
+ )
61
+
62
+
63
+ class MLXProvider(BaseProvider[None]):
64
+ """Client for interacting with MLX language models.
65
+
66
+ This client provides methods for generating responses from MLX models,
67
+ supporting both synchronous and asynchronous operations, as well as
68
+ streaming responses.
69
+ """
70
+
71
+ id = "mlx"
72
+ default_scope = "mlx-community/"
73
+
74
+ def _call(
75
+ self,
76
+ *,
77
+ model_id: MLXModelId,
78
+ messages: Sequence[Message],
79
+ tools: Sequence[Tool] | Toolkit | None = None,
80
+ format: type[FormattableT] | Format[FormattableT] | None = None,
81
+ **params: Unpack[Params],
82
+ ) -> Response | Response[FormattableT]:
83
+ """Generate an `llm.Response` using MLX model.
84
+
85
+ Args:
86
+ model_id: Model identifier to use.
87
+ messages: Messages to send to the LLM.
88
+ tools: Optional tools that the model may invoke.
89
+ format: Optional response format specifier.
90
+ **params: Additional parameters to configure output (e.g. temperature). See `llm.Params`.
91
+
92
+ Returns:
93
+ An `llm.Response` object containing the LLM-generated content.
94
+ """
95
+ mlx = _get_mlx(model_id)
96
+
97
+ input_messages, format, assistant_message, response = mlx.generate(
98
+ messages, tools, format, params
99
+ )
100
+
101
+ return Response(
102
+ raw=response,
103
+ provider_id="mlx",
104
+ model_id=model_id,
105
+ provider_model_name=model_id,
106
+ params=params,
107
+ tools=tools,
108
+ input_messages=input_messages,
109
+ assistant_message=assistant_message,
110
+ finish_reason=_utils.extract_finish_reason(response),
111
+ format=format,
112
+ )
113
+
114
+ def _context_call(
115
+ self,
116
+ *,
117
+ ctx: Context[DepsT],
118
+ model_id: MLXModelId,
119
+ messages: Sequence[Message],
120
+ tools: Sequence[Tool | ContextTool[DepsT]]
121
+ | ContextToolkit[DepsT]
122
+ | None = None,
123
+ format: type[FormattableT] | Format[FormattableT] | None = None,
124
+ **params: Unpack[Params],
125
+ ) -> ContextResponse[DepsT, None] | ContextResponse[DepsT, FormattableT]:
126
+ """Generate an `llm.ContextResponse` using MLX model.
127
+
128
+ Args:
129
+ ctx: Context object with dependencies for tools.
130
+ model_id: Model identifier to use.
131
+ messages: Messages to send to the LLM.
132
+ tools: Optional tools that the model may invoke.
133
+ format: Optional response format specifier.
134
+ **params: Additional parameters to configure output (e.g. temperature). See `llm.Params`.
135
+
136
+ Returns:
137
+ An `llm.ContextResponse` object containing the LLM-generated content.
138
+ """
139
+ mlx = _get_mlx(model_id)
140
+
141
+ input_messages, format, assistant_message, response = mlx.generate(
142
+ messages, tools, format, params
143
+ )
144
+
145
+ return ContextResponse(
146
+ raw=response,
147
+ provider_id="mlx",
148
+ model_id=model_id,
149
+ provider_model_name=model_id,
150
+ params=params,
151
+ tools=tools,
152
+ input_messages=input_messages,
153
+ assistant_message=assistant_message,
154
+ finish_reason=_utils.extract_finish_reason(response),
155
+ format=format,
156
+ )
157
+
158
+ async def _call_async(
159
+ self,
160
+ *,
161
+ model_id: MLXModelId,
162
+ messages: Sequence[Message],
163
+ tools: Sequence[AsyncTool] | AsyncToolkit | None = None,
164
+ format: type[FormattableT] | Format[FormattableT] | None = None,
165
+ **params: Unpack[Params],
166
+ ) -> AsyncResponse | AsyncResponse[FormattableT]:
167
+ """Generate an `llm.AsyncResponse` using MLX model by asynchronously calloing
168
+ `asycio.to_thread`.
169
+
170
+ Args:
171
+ model_id: Model identifier to use.
172
+ messages: Messages to send to the LLM.
173
+ tools: Optional tools that the model may invoke.
174
+ format: Optional response format specifier.
175
+ **params: Additional parameters to configure output (e.g. temperature). See `llm.Params`.
176
+
177
+ Returns:
178
+ An `llm.AsyncResponse` object containing the LLM-generated content.
179
+ """
180
+ mlx = _get_mlx(model_id)
181
+
182
+ (
183
+ input_messages,
184
+ format,
185
+ assistant_message,
186
+ response,
187
+ ) = await mlx.generate_async(messages, tools, format, params)
188
+
189
+ return AsyncResponse(
190
+ raw=response,
191
+ provider_id="mlx",
192
+ model_id=model_id,
193
+ provider_model_name=model_id,
194
+ params=params,
195
+ tools=tools,
196
+ input_messages=input_messages,
197
+ assistant_message=assistant_message,
198
+ finish_reason=_utils.extract_finish_reason(response),
199
+ format=format,
200
+ )
201
+
202
+ async def _context_call_async(
203
+ self,
204
+ *,
205
+ ctx: Context[DepsT],
206
+ model_id: MLXModelId,
207
+ messages: Sequence[Message],
208
+ tools: Sequence[AsyncTool | AsyncContextTool[DepsT]]
209
+ | AsyncContextToolkit[DepsT]
210
+ | None = None,
211
+ format: type[FormattableT] | Format[FormattableT] | None = None,
212
+ **params: Unpack[Params],
213
+ ) -> AsyncContextResponse[DepsT, None] | AsyncContextResponse[DepsT, FormattableT]:
214
+ """Generate an `llm.AsyncResponse` using MLX model by asynchronously calloing
215
+ `asycio.to_thread`.
216
+
217
+ Args:
218
+ ctx: Context object with dependencies for tools.
219
+ model_id: Model identifier to use.
220
+ messages: Messages to send to the LLM.
221
+ tools: Optional tools that the model may invoke.
222
+ format: Optional response format specifier.
223
+ **params: Additional parameters to configure output (e.g. temperature). See `llm.Params`.
224
+
225
+ Returns:
226
+ An `llm.AsyncContextResponse` object containing the LLM-generated content.
227
+ """
228
+ mlx = _get_mlx(model_id)
229
+
230
+ (
231
+ input_messages,
232
+ format,
233
+ assistant_message,
234
+ response,
235
+ ) = await mlx.generate_async(messages, tools, format, params)
236
+
237
+ return AsyncContextResponse(
238
+ raw=response,
239
+ provider_id="mlx",
240
+ model_id=model_id,
241
+ provider_model_name=model_id,
242
+ params=params,
243
+ tools=tools,
244
+ input_messages=input_messages,
245
+ assistant_message=assistant_message,
246
+ finish_reason=_utils.extract_finish_reason(response),
247
+ format=format,
248
+ )
249
+
250
+ def _stream(
251
+ self,
252
+ *,
253
+ model_id: MLXModelId,
254
+ messages: Sequence[Message],
255
+ tools: Sequence[Tool] | Toolkit | None = None,
256
+ format: type[FormattableT] | Format[FormattableT] | None = None,
257
+ **params: Unpack[Params],
258
+ ) -> StreamResponse | StreamResponse[FormattableT]:
259
+ """Generate an `llm.StreamResponse` by synchronously streaming from MLX model output.
260
+
261
+ Args:
262
+ model_id: Model identifier to use.
263
+ messages: Messages to send to the LLM.
264
+ tools: Optional tools that the model may invoke.
265
+ format: Optional response format specifier.
266
+ **params: Additional parameters to configure output (e.g. temperature). See `llm.Params`.
267
+
268
+ Returns:
269
+ An `llm.StreamResponse` object for iterating over the LLM-generated content.
270
+ """
271
+ mlx = _get_mlx(model_id)
272
+
273
+ input_messages, format, chunk_iterator = mlx.stream(
274
+ messages, tools, format, params
275
+ )
276
+
277
+ return StreamResponse(
278
+ provider_id="mlx",
279
+ model_id=model_id,
280
+ provider_model_name=model_id,
281
+ params=params,
282
+ tools=tools,
283
+ input_messages=input_messages,
284
+ chunk_iterator=chunk_iterator,
285
+ format=format,
286
+ )
287
+
288
+ def _context_stream(
289
+ self,
290
+ *,
291
+ ctx: Context[DepsT],
292
+ model_id: MLXModelId,
293
+ messages: Sequence[Message],
294
+ tools: Sequence[Tool | ContextTool[DepsT]]
295
+ | ContextToolkit[DepsT]
296
+ | None = None,
297
+ format: type[FormattableT] | Format[FormattableT] | None = None,
298
+ **params: Unpack[Params],
299
+ ) -> ContextStreamResponse[DepsT] | ContextStreamResponse[DepsT, FormattableT]:
300
+ """Generate an `llm.ContextStreamResponse` by synchronously streaming from MLX model output.
301
+
302
+ Args:
303
+ ctx: Context object with dependencies for tools.
304
+ model_id: Model identifier to use.
305
+ messages: Messages to send to the LLM.
306
+ tools: Optional tools that the model may invoke.
307
+ format: Optional response format specifier.
308
+ **params: Additional parameters to configure output (e.g. temperature). See `llm.Params`.
309
+
310
+ Returns:
311
+ An `llm.ContextStreamResponse` object for iterating over the LLM-generated content.
312
+ """
313
+ mlx = _get_mlx(model_id)
314
+
315
+ input_messages, format, chunk_iterator = mlx.stream(
316
+ messages, tools, format, params
317
+ )
318
+
319
+ return ContextStreamResponse(
320
+ provider_id="mlx",
321
+ model_id=model_id,
322
+ provider_model_name=model_id,
323
+ params=params,
324
+ tools=tools,
325
+ input_messages=input_messages,
326
+ chunk_iterator=chunk_iterator,
327
+ format=format,
328
+ )
329
+
330
+ async def _stream_async(
331
+ self,
332
+ *,
333
+ model_id: MLXModelId,
334
+ messages: Sequence[Message],
335
+ tools: Sequence[AsyncTool] | AsyncToolkit | None = None,
336
+ format: type[FormattableT] | Format[FormattableT] | None = None,
337
+ **params: Unpack[Params],
338
+ ) -> AsyncStreamResponse | AsyncStreamResponse[FormattableT]:
339
+ """Generate an `llm.AsyncStreamResponse` by asynchronously streaming from MLX model output.
340
+
341
+ Args:
342
+ model_id: Model identifier to use.
343
+ messages: Messages to send to the LLM.
344
+ tools: Optional tools that the model may invoke.
345
+ format: Optional response format specifier.
346
+ **params: Additional parameters to configure output (e.g. temperature). See `llm.Params`.
347
+
348
+ Returns:
349
+ An `llm.AsyncStreamResponse` object for asynchronously iterating over the LLM-generated content.
350
+ """
351
+ mlx = _get_mlx(model_id)
352
+
353
+ input_messages, format, chunk_iterator = await mlx.stream_async(
354
+ messages, tools, format, params
355
+ )
356
+
357
+ return AsyncStreamResponse(
358
+ provider_id="mlx",
359
+ model_id=model_id,
360
+ provider_model_name=model_id,
361
+ params=params,
362
+ tools=tools,
363
+ input_messages=input_messages,
364
+ chunk_iterator=chunk_iterator,
365
+ format=format,
366
+ )
367
+
368
+ async def _context_stream_async(
369
+ self,
370
+ *,
371
+ ctx: Context[DepsT],
372
+ model_id: MLXModelId,
373
+ messages: Sequence[Message],
374
+ tools: Sequence[AsyncTool | AsyncContextTool[DepsT]]
375
+ | AsyncContextToolkit[DepsT]
376
+ | None = None,
377
+ format: type[FormattableT] | Format[FormattableT] | None = None,
378
+ **params: Unpack[Params],
379
+ ) -> (
380
+ AsyncContextStreamResponse[DepsT]
381
+ | AsyncContextStreamResponse[DepsT, FormattableT]
382
+ ):
383
+ """Generate an `llm.AsyncContextStreamResponse` by asynchronously streaming from MLX model output.
384
+
385
+ Args:
386
+ ctx: Context object with dependencies for tools.
387
+ model_id: Model identifier to use.
388
+ messages: Messages to send to the LLM.
389
+ tools: Optional tools that the model may invoke.
390
+ format: Optional response format specifier.
391
+ **params: Additional parameters to configure output (e.g. temperature). See `llm.Params`.
392
+
393
+ Returns:
394
+ An `llm.AsyncContextStreamResponse` object for asynchronously iterating over the LLM-generated content.
395
+ """
396
+ mlx = _get_mlx(model_id)
397
+
398
+ input_messages, format, chunk_iterator = await mlx.stream_async(
399
+ messages, tools, format, params
400
+ )
401
+
402
+ return AsyncContextStreamResponse(
403
+ provider_id="mlx",
404
+ model_id=model_id,
405
+ provider_model_name=model_id,
406
+ params=params,
407
+ tools=tools,
408
+ input_messages=input_messages,
409
+ chunk_iterator=chunk_iterator,
410
+ format=format,
411
+ )
@@ -0,0 +1,16 @@
1
+ from typing import TypeAlias
2
+
3
+ from .anthropic import (
4
+ AnthropicModelId,
5
+ )
6
+ from .google import (
7
+ GoogleModelId,
8
+ )
9
+ from .mlx import (
10
+ MLXModelId,
11
+ )
12
+ from .openai import (
13
+ OpenAIModelId,
14
+ )
15
+
16
+ ModelId: TypeAlias = AnthropicModelId | GoogleModelId | OpenAIModelId | MLXModelId | str
@@ -0,0 +1,6 @@
1
+ """OpenAI client implementation."""
2
+
3
+ from .model_id import OpenAIModelId
4
+ from .provider import OpenAIProvider
5
+
6
+ __all__ = ["OpenAIModelId", "OpenAIProvider"]
@@ -0,0 +1,20 @@
1
+ from typing import TYPE_CHECKING
2
+
3
+ if TYPE_CHECKING:
4
+ from .provider import OpenAICompletionsProvider
5
+ else:
6
+ try:
7
+ from .provider import OpenAICompletionsProvider
8
+ except ImportError: # pragma: no cover
9
+ from ..._missing_import_stubs import (
10
+ create_import_error_stub,
11
+ create_provider_stub,
12
+ )
13
+
14
+ OpenAICompletionsProvider = create_provider_stub(
15
+ "openai", "OpenAICompletionsProvider"
16
+ )
17
+
18
+ __all__ = [
19
+ "OpenAICompletionsProvider",
20
+ ]
@@ -2,6 +2,7 @@ from .decode import (
2
2
  decode_async_stream,
3
3
  decode_response,
4
4
  decode_stream,
5
+ model_name,
5
6
  )
6
7
  from .encode import encode_request
7
8
 
@@ -10,4 +11,5 @@ __all__ = [
10
11
  "decode_response",
11
12
  "decode_stream",
12
13
  "encode_request",
14
+ "model_name",
13
15
  ]
@@ -24,7 +24,7 @@ from .....responses import (
24
24
  FinishReasonChunk,
25
25
  RawStreamEventChunk,
26
26
  )
27
- from ..model_ids import OpenAICompletionsModelId
27
+ from ...model_id import OpenAIModelId, model_name
28
28
 
29
29
  OPENAI_FINISH_REASON_MAP = {
30
30
  "length": FinishReason.MAX_TOKENS,
@@ -34,7 +34,8 @@ OPENAI_FINISH_REASON_MAP = {
34
34
 
35
35
  def decode_response(
36
36
  response: openai_types.ChatCompletion,
37
- model_id: OpenAICompletionsModelId,
37
+ model_id: OpenAIModelId,
38
+ provider_id: Literal["openai", "openai:completions"],
38
39
  ) -> tuple[AssistantMessage, FinishReason | None]:
39
40
  """Convert OpenAI ChatCompletion to mirascope AssistantMessage."""
40
41
  choice = response.choices[0]
@@ -69,8 +70,9 @@ def decode_response(
69
70
 
70
71
  assistant_message = AssistantMessage(
71
72
  content=parts,
72
- provider="openai:completions",
73
+ provider_id=provider_id,
73
74
  model_id=model_id,
75
+ provider_model_name=model_name(model_id, "completions"),
74
76
  raw_message=message.model_dump(exclude_none=True),
75
77
  )
76
78
 
@@ -19,11 +19,11 @@ from .....formatting import (
19
19
  resolve_format,
20
20
  )
21
21
  from .....messages import AssistantMessage, Message, UserMessage
22
- from .....tools import FORMAT_TOOL_NAME, BaseToolkit, ToolSchema
22
+ from .....tools import FORMAT_TOOL_NAME, AnyToolSchema, BaseToolkit
23
23
  from ....base import Params, _utils as _base_utils
24
+ from ...model_id import OpenAIModelId, model_name
25
+ from ...model_info import MODELS_WITHOUT_AUDIO_SUPPORT
24
26
  from ...shared import _utils as _shared_utils
25
- from ..model_ids import OpenAICompletionsModelId
26
- from .model_features import MODEL_FEATURES
27
27
 
28
28
 
29
29
  class ChatCompletionCreateKwargs(TypedDict, total=False):
@@ -49,7 +49,7 @@ class ChatCompletionCreateKwargs(TypedDict, total=False):
49
49
 
50
50
  def _encode_user_message(
51
51
  message: UserMessage,
52
- model_id: OpenAICompletionsModelId,
52
+ model_id: OpenAIModelId,
53
53
  ) -> list[openai_types.ChatCompletionMessageParam]:
54
54
  """Convert Mirascope `UserMessage` to a list of OpenAI `ChatCompletionMessageParam`.
55
55
 
@@ -98,11 +98,11 @@ def _encode_user_message(
98
98
  )
99
99
  current_content.append(content)
100
100
  elif part.type == "audio":
101
- model_status = MODEL_FEATURES.get(model_id)
102
- if model_status == "no_audio_support":
101
+ base_model_name = model_name(model_id, None)
102
+ if base_model_name in MODELS_WITHOUT_AUDIO_SUPPORT:
103
103
  raise FeatureNotSupportedError(
104
104
  feature="Audio inputs",
105
- provider="openai:completions",
105
+ provider_id="openai",
106
106
  message=f"Model '{model_id}' does not support audio inputs.",
107
107
  )
108
108
 
@@ -111,7 +111,7 @@ def _encode_user_message(
111
111
  if audio_format not in ("wav", "mp3"):
112
112
  raise FeatureNotSupportedError(
113
113
  feature=f"Audio format: {audio_format}",
114
- provider="openai:completions",
114
+ provider_id="openai",
115
115
  message="OpenAI only supports 'wav' and 'mp3' audio formats.",
116
116
  ) # pragma: no cover
117
117
  audio_content = openai_types.ChatCompletionContentPartInputAudioParam(
@@ -141,13 +141,14 @@ def _encode_user_message(
141
141
 
142
142
 
143
143
  def _encode_assistant_message(
144
- message: AssistantMessage, model_id: OpenAICompletionsModelId, encode_thoughts: bool
144
+ message: AssistantMessage, model_id: OpenAIModelId, encode_thoughts: bool
145
145
  ) -> openai_types.ChatCompletionAssistantMessageParam:
146
146
  """Convert Mirascope `AssistantMessage` to OpenAI `ChatCompletionAssistantMessageParam`."""
147
147
 
148
148
  if (
149
- message.provider == "openai:completions"
150
- and message.model_id == model_id
149
+ message.provider_id in ("openai", "openai:completions")
150
+ and message.provider_model_name
151
+ == model_name(model_id=model_id, api_mode="completions")
151
152
  and message.raw_message
152
153
  and not encode_thoughts
153
154
  ):
@@ -188,7 +189,7 @@ def _encode_assistant_message(
188
189
  elif text_params:
189
190
  content = text_params
190
191
 
191
- message_params = {
192
+ message_params: openai_types.ChatCompletionAssistantMessageParam = {
192
193
  "role": "assistant",
193
194
  "content": content,
194
195
  }
@@ -199,7 +200,7 @@ def _encode_assistant_message(
199
200
 
200
201
 
201
202
  def _encode_message(
202
- message: Message, model_id: OpenAICompletionsModelId, encode_thoughts: bool
203
+ message: Message, model_id: OpenAIModelId, encode_thoughts: bool
203
204
  ) -> list[openai_types.ChatCompletionMessageParam]:
204
205
  """Convert a Mirascope `Message` to OpenAI `ChatCompletionMessageParam` format.
205
206
 
@@ -227,12 +228,12 @@ def _encode_message(
227
228
 
228
229
  @lru_cache(maxsize=128)
229
230
  def _convert_tool_to_tool_param(
230
- tool: ToolSchema,
231
+ tool: AnyToolSchema,
231
232
  ) -> openai_types.ChatCompletionToolParam:
232
233
  """Convert a single Mirascope `Tool` to OpenAI ChatCompletionToolParam with caching."""
233
234
  schema_dict = tool.parameters.model_dump(by_alias=True, exclude_none=True)
234
235
  schema_dict["type"] = "object"
235
- _shared_utils._ensure_additional_properties_false(schema_dict)
236
+ _shared_utils.ensure_additional_properties_false(schema_dict)
236
237
  return openai_types.ChatCompletionToolParam(
237
238
  type="function",
238
239
  function={
@@ -257,7 +258,7 @@ def _create_strict_response_format(
257
258
  """
258
259
  schema = format.schema.copy()
259
260
 
260
- _shared_utils._ensure_additional_properties_false(schema)
261
+ _shared_utils.ensure_additional_properties_false(schema)
261
262
 
262
263
  json_schema = JSONSchema(
263
264
  name=format.name,
@@ -274,23 +275,32 @@ def _create_strict_response_format(
274
275
 
275
276
  def encode_request(
276
277
  *,
277
- model_id: OpenAICompletionsModelId,
278
+ model_id: OpenAIModelId,
278
279
  messages: Sequence[Message],
279
- tools: Sequence[ToolSchema] | BaseToolkit | None,
280
+ tools: Sequence[AnyToolSchema] | BaseToolkit[AnyToolSchema] | None,
280
281
  format: type[FormattableT] | Format[FormattableT] | None,
281
282
  params: Params,
282
283
  ) -> tuple[Sequence[Message], Format[FormattableT] | None, ChatCompletionCreateKwargs]:
283
284
  """Prepares a request for the `OpenAI.chat.completions.create` method."""
285
+ if model_id.endswith(":responses"):
286
+ raise FeatureNotSupportedError(
287
+ feature="responses API",
288
+ provider_id="openai:completions",
289
+ model_id=model_id,
290
+ message=f"Can't use completions client for responses model: {model_id}",
291
+ )
292
+ base_model_name = model_name(model_id, None)
293
+
284
294
  kwargs: ChatCompletionCreateKwargs = ChatCompletionCreateKwargs(
285
295
  {
286
- "model": model_id,
296
+ "model": base_model_name,
287
297
  }
288
298
  )
289
299
  encode_thoughts = False
290
300
 
291
301
  with _base_utils.ensure_all_params_accessed(
292
302
  params=params,
293
- provider="openai:completions",
303
+ provider_id="openai",
294
304
  unsupported_params=["top_k", "thinking"],
295
305
  ) as param_accessor:
296
306
  if param_accessor.temperature is not None:
@@ -312,7 +322,7 @@ def encode_request(
312
322
  openai_tools = [_convert_tool_to_tool_param(tool) for tool in tools]
313
323
 
314
324
  model_supports_strict = (
315
- model_id not in _shared_utils.MODELS_WITHOUT_JSON_SCHEMA_SUPPORT
325
+ base_model_name not in _shared_utils.MODELS_WITHOUT_JSON_SCHEMA_SUPPORT
316
326
  )
317
327
  default_mode = "strict" if model_supports_strict else "tool"
318
328
  format = resolve_format(format, default_mode=default_mode)
@@ -321,7 +331,7 @@ def encode_request(
321
331
  if not model_supports_strict:
322
332
  raise FormattingModeNotSupportedError(
323
333
  formatting_mode="strict",
324
- provider="openai:completions",
334
+ provider_id="openai",
325
335
  model_id=model_id,
326
336
  )
327
337
  kwargs["response_format"] = _create_strict_response_format(format)
@@ -338,7 +348,7 @@ def encode_request(
338
348
  openai_tools.append(_convert_tool_to_tool_param(format_tool_schema))
339
349
  elif (
340
350
  format.mode == "json"
341
- and model_id not in _shared_utils.MODELS_WITHOUT_JSON_OBJECT_SUPPORT
351
+ and base_model_name not in _shared_utils.MODELS_WITHOUT_JSON_OBJECT_SUPPORT
342
352
  ):
343
353
  kwargs["response_format"] = {"type": "json_object"}
344
354