mirascope 2.0.0a3__py3-none-any.whl → 2.0.0a5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. mirascope/api/_generated/__init__.py +78 -6
  2. mirascope/api/_generated/api_keys/__init__.py +7 -0
  3. mirascope/api/_generated/api_keys/client.py +453 -0
  4. mirascope/api/_generated/api_keys/raw_client.py +853 -0
  5. mirascope/api/_generated/api_keys/types/__init__.py +9 -0
  6. mirascope/api/_generated/api_keys/types/api_keys_create_response.py +36 -0
  7. mirascope/api/_generated/api_keys/types/api_keys_get_response.py +35 -0
  8. mirascope/api/_generated/api_keys/types/api_keys_list_response_item.py +35 -0
  9. mirascope/api/_generated/client.py +14 -0
  10. mirascope/api/_generated/environments/__init__.py +17 -0
  11. mirascope/api/_generated/environments/client.py +532 -0
  12. mirascope/api/_generated/environments/raw_client.py +1088 -0
  13. mirascope/api/_generated/environments/types/__init__.py +15 -0
  14. mirascope/api/_generated/environments/types/environments_create_response.py +26 -0
  15. mirascope/api/_generated/environments/types/environments_get_response.py +26 -0
  16. mirascope/api/_generated/environments/types/environments_list_response_item.py +26 -0
  17. mirascope/api/_generated/environments/types/environments_update_response.py +26 -0
  18. mirascope/api/_generated/errors/__init__.py +11 -1
  19. mirascope/api/_generated/errors/conflict_error.py +15 -0
  20. mirascope/api/_generated/errors/forbidden_error.py +15 -0
  21. mirascope/api/_generated/errors/internal_server_error.py +15 -0
  22. mirascope/api/_generated/errors/not_found_error.py +15 -0
  23. mirascope/api/_generated/organizations/__init__.py +25 -0
  24. mirascope/api/_generated/organizations/client.py +404 -0
  25. mirascope/api/_generated/organizations/raw_client.py +902 -0
  26. mirascope/api/_generated/organizations/types/__init__.py +23 -0
  27. mirascope/api/_generated/organizations/types/organizations_create_response.py +25 -0
  28. mirascope/api/_generated/organizations/types/organizations_create_response_role.py +7 -0
  29. mirascope/api/_generated/organizations/types/organizations_get_response.py +25 -0
  30. mirascope/api/_generated/organizations/types/organizations_get_response_role.py +7 -0
  31. mirascope/api/_generated/organizations/types/organizations_list_response_item.py +25 -0
  32. mirascope/api/_generated/organizations/types/organizations_list_response_item_role.py +7 -0
  33. mirascope/api/_generated/organizations/types/organizations_update_response.py +25 -0
  34. mirascope/api/_generated/organizations/types/organizations_update_response_role.py +7 -0
  35. mirascope/api/_generated/projects/__init__.py +17 -0
  36. mirascope/api/_generated/projects/client.py +482 -0
  37. mirascope/api/_generated/projects/raw_client.py +1058 -0
  38. mirascope/api/_generated/projects/types/__init__.py +15 -0
  39. mirascope/api/_generated/projects/types/projects_create_response.py +31 -0
  40. mirascope/api/_generated/projects/types/projects_get_response.py +31 -0
  41. mirascope/api/_generated/projects/types/projects_list_response_item.py +31 -0
  42. mirascope/api/_generated/projects/types/projects_update_response.py +31 -0
  43. mirascope/api/_generated/reference.md +1311 -0
  44. mirascope/api/_generated/types/__init__.py +20 -4
  45. mirascope/api/_generated/types/already_exists_error.py +24 -0
  46. mirascope/api/_generated/types/already_exists_error_tag.py +5 -0
  47. mirascope/api/_generated/types/database_error.py +24 -0
  48. mirascope/api/_generated/types/database_error_tag.py +5 -0
  49. mirascope/api/_generated/types/http_api_decode_error.py +1 -3
  50. mirascope/api/_generated/types/issue.py +1 -5
  51. mirascope/api/_generated/types/not_found_error_body.py +24 -0
  52. mirascope/api/_generated/types/not_found_error_tag.py +5 -0
  53. mirascope/api/_generated/types/permission_denied_error.py +24 -0
  54. mirascope/api/_generated/types/permission_denied_error_tag.py +7 -0
  55. mirascope/api/_generated/types/property_key.py +2 -2
  56. mirascope/api/_generated/types/{property_key_tag.py → property_key_key.py} +3 -5
  57. mirascope/api/_generated/types/{property_key_tag_tag.py → property_key_key_tag.py} +1 -1
  58. mirascope/llm/__init__.py +6 -2
  59. mirascope/llm/exceptions.py +28 -0
  60. mirascope/llm/providers/__init__.py +12 -4
  61. mirascope/llm/providers/anthropic/__init__.py +6 -1
  62. mirascope/llm/providers/anthropic/_utils/__init__.py +17 -5
  63. mirascope/llm/providers/anthropic/_utils/beta_decode.py +271 -0
  64. mirascope/llm/providers/anthropic/_utils/beta_encode.py +216 -0
  65. mirascope/llm/providers/anthropic/_utils/decode.py +39 -7
  66. mirascope/llm/providers/anthropic/_utils/encode.py +156 -64
  67. mirascope/llm/providers/anthropic/_utils/errors.py +46 -0
  68. mirascope/llm/providers/anthropic/beta_provider.py +328 -0
  69. mirascope/llm/providers/anthropic/model_id.py +10 -27
  70. mirascope/llm/providers/anthropic/model_info.py +87 -0
  71. mirascope/llm/providers/anthropic/provider.py +132 -145
  72. mirascope/llm/providers/base/__init__.py +2 -1
  73. mirascope/llm/providers/base/_utils.py +15 -1
  74. mirascope/llm/providers/base/base_provider.py +173 -58
  75. mirascope/llm/providers/google/_utils/__init__.py +2 -0
  76. mirascope/llm/providers/google/_utils/decode.py +55 -3
  77. mirascope/llm/providers/google/_utils/encode.py +14 -6
  78. mirascope/llm/providers/google/_utils/errors.py +49 -0
  79. mirascope/llm/providers/google/model_id.py +7 -13
  80. mirascope/llm/providers/google/model_info.py +62 -0
  81. mirascope/llm/providers/google/provider.py +13 -8
  82. mirascope/llm/providers/mlx/_utils.py +31 -2
  83. mirascope/llm/providers/mlx/encoding/transformers.py +17 -1
  84. mirascope/llm/providers/mlx/provider.py +12 -0
  85. mirascope/llm/providers/ollama/__init__.py +19 -0
  86. mirascope/llm/providers/ollama/provider.py +71 -0
  87. mirascope/llm/providers/openai/__init__.py +10 -1
  88. mirascope/llm/providers/openai/_utils/__init__.py +5 -0
  89. mirascope/llm/providers/openai/_utils/errors.py +46 -0
  90. mirascope/llm/providers/openai/completions/__init__.py +6 -1
  91. mirascope/llm/providers/openai/completions/_utils/decode.py +57 -5
  92. mirascope/llm/providers/openai/completions/_utils/encode.py +9 -8
  93. mirascope/llm/providers/openai/completions/base_provider.py +513 -0
  94. mirascope/llm/providers/openai/completions/provider.py +13 -447
  95. mirascope/llm/providers/openai/model_info.py +57 -0
  96. mirascope/llm/providers/openai/provider.py +30 -5
  97. mirascope/llm/providers/openai/responses/_utils/decode.py +55 -4
  98. mirascope/llm/providers/openai/responses/_utils/encode.py +9 -9
  99. mirascope/llm/providers/openai/responses/provider.py +33 -28
  100. mirascope/llm/providers/provider_id.py +11 -1
  101. mirascope/llm/providers/provider_registry.py +59 -4
  102. mirascope/llm/providers/together/__init__.py +19 -0
  103. mirascope/llm/providers/together/provider.py +40 -0
  104. mirascope/llm/responses/__init__.py +3 -0
  105. mirascope/llm/responses/base_response.py +4 -0
  106. mirascope/llm/responses/base_stream_response.py +25 -1
  107. mirascope/llm/responses/finish_reason.py +1 -0
  108. mirascope/llm/responses/response.py +9 -0
  109. mirascope/llm/responses/root_response.py +5 -1
  110. mirascope/llm/responses/usage.py +95 -0
  111. mirascope/ops/_internal/closure.py +62 -11
  112. {mirascope-2.0.0a3.dist-info → mirascope-2.0.0a5.dist-info}/METADATA +3 -3
  113. {mirascope-2.0.0a3.dist-info → mirascope-2.0.0a5.dist-info}/RECORD +115 -56
  114. mirascope/llm/providers/load_provider.py +0 -48
  115. mirascope/llm/providers/openai/shared/__init__.py +0 -7
  116. mirascope/llm/providers/openai/shared/_utils.py +0 -59
  117. {mirascope-2.0.0a3.dist-info → mirascope-2.0.0a5.dist-info}/WHEEL +0 -0
  118. {mirascope-2.0.0a3.dist-info → mirascope-2.0.0a5.dist-info}/licenses/LICENSE +0 -0
@@ -39,6 +39,7 @@ class GoogleProvider(BaseProvider[Client]):
39
39
 
40
40
  id = "google"
41
41
  default_scope = "google/"
42
+ error_map = _utils.GOOGLE_ERROR_MAP
42
43
 
43
44
  def __init__(
44
45
  self, *, api_key: str | None = None, base_url: str | None = None
@@ -50,6 +51,10 @@ class GoogleProvider(BaseProvider[Client]):
50
51
 
51
52
  self.client = Client(api_key=api_key, http_options=http_options)
52
53
 
54
+ def get_error_status(self, e: Exception) -> int | None:
55
+ """Extract HTTP status code from Google exception."""
56
+ return getattr(e, "code", None)
57
+
53
58
  def _call(
54
59
  self,
55
60
  *,
@@ -78,10 +83,9 @@ class GoogleProvider(BaseProvider[Client]):
78
83
  format=format,
79
84
  params=params,
80
85
  )
81
-
82
86
  google_response = self.client.models.generate_content(**kwargs)
83
87
 
84
- assistant_message, finish_reason = _utils.decode_response(
88
+ assistant_message, finish_reason, usage = _utils.decode_response(
85
89
  google_response, model_id
86
90
  )
87
91
 
@@ -95,6 +99,7 @@ class GoogleProvider(BaseProvider[Client]):
95
99
  input_messages=input_messages,
96
100
  assistant_message=assistant_message,
97
101
  finish_reason=finish_reason,
102
+ usage=usage,
98
103
  format=format,
99
104
  )
100
105
 
@@ -130,10 +135,9 @@ class GoogleProvider(BaseProvider[Client]):
130
135
  format=format,
131
136
  params=params,
132
137
  )
133
-
134
138
  google_response = self.client.models.generate_content(**kwargs)
135
139
 
136
- assistant_message, finish_reason = _utils.decode_response(
140
+ assistant_message, finish_reason, usage = _utils.decode_response(
137
141
  google_response, model_id
138
142
  )
139
143
 
@@ -147,6 +151,7 @@ class GoogleProvider(BaseProvider[Client]):
147
151
  input_messages=input_messages,
148
152
  assistant_message=assistant_message,
149
153
  finish_reason=finish_reason,
154
+ usage=usage,
150
155
  format=format,
151
156
  )
152
157
 
@@ -178,10 +183,9 @@ class GoogleProvider(BaseProvider[Client]):
178
183
  format=format,
179
184
  params=params,
180
185
  )
181
-
182
186
  google_response = await self.client.aio.models.generate_content(**kwargs)
183
187
 
184
- assistant_message, finish_reason = _utils.decode_response(
188
+ assistant_message, finish_reason, usage = _utils.decode_response(
185
189
  google_response, model_id
186
190
  )
187
191
 
@@ -195,6 +199,7 @@ class GoogleProvider(BaseProvider[Client]):
195
199
  input_messages=input_messages,
196
200
  assistant_message=assistant_message,
197
201
  finish_reason=finish_reason,
202
+ usage=usage,
198
203
  format=format,
199
204
  )
200
205
 
@@ -230,10 +235,9 @@ class GoogleProvider(BaseProvider[Client]):
230
235
  format=format,
231
236
  params=params,
232
237
  )
233
-
234
238
  google_response = await self.client.aio.models.generate_content(**kwargs)
235
239
 
236
- assistant_message, finish_reason = _utils.decode_response(
240
+ assistant_message, finish_reason, usage = _utils.decode_response(
237
241
  google_response, model_id
238
242
  )
239
243
 
@@ -247,6 +251,7 @@ class GoogleProvider(BaseProvider[Client]):
247
251
  input_messages=input_messages,
248
252
  assistant_message=assistant_message,
249
253
  finish_reason=finish_reason,
254
+ usage=usage,
250
255
  format=format,
251
256
  )
252
257
 
@@ -2,14 +2,21 @@ from collections.abc import Callable
2
2
  from typing import TypeAlias, TypedDict
3
3
 
4
4
  import mlx.core as mx
5
+ from huggingface_hub.errors import LocalEntryNotFoundError
5
6
  from mlx_lm.generate import GenerationResponse
6
7
  from mlx_lm.sample_utils import make_sampler
7
8
 
8
- from ...responses import FinishReason
9
- from ..base import Params, _utils as _base_utils
9
+ from ...exceptions import NotFoundError
10
+ from ...responses import FinishReason, Usage
11
+ from ..base import Params, ProviderErrorMap, _utils as _base_utils
10
12
 
11
13
  Sampler: TypeAlias = Callable[[mx.array], mx.array]
12
14
 
15
+ # Error mapping for MLX provider
16
+ MLX_ERROR_MAP: ProviderErrorMap = {
17
+ LocalEntryNotFoundError: NotFoundError,
18
+ }
19
+
13
20
 
14
21
  class MakeSamplerKwargs(TypedDict, total=False):
15
22
  """Keyword arguments to be used for `mlx_lm`-s `make_sampler` function.
@@ -105,3 +112,25 @@ def extract_finish_reason(response: GenerationResponse | None) -> FinishReason |
105
112
  return FinishReason.MAX_TOKENS
106
113
 
107
114
  return None
115
+
116
+
117
+ def extract_usage(response: GenerationResponse | None) -> Usage | None:
118
+ """Extract usage information from an MLX generation response.
119
+
120
+ Args:
121
+ response: The MLX generation response to extract from.
122
+
123
+ Returns:
124
+ The Usage object with token counts, or None if not applicable.
125
+ """
126
+ if response is None:
127
+ return None
128
+
129
+ return Usage(
130
+ input_tokens=response.prompt_tokens,
131
+ output_tokens=response.generation_tokens,
132
+ cache_read_tokens=0,
133
+ cache_write_tokens=0,
134
+ reasoning_tokens=0,
135
+ raw=response,
136
+ )
@@ -10,7 +10,12 @@ from transformers import PreTrainedTokenizer
10
10
  from ....content import ContentPart, TextChunk, TextEndChunk, TextStartChunk
11
11
  from ....formatting import Format, FormattableT
12
12
  from ....messages import AssistantContent, Message
13
- from ....responses import ChunkIterator, FinishReasonChunk, RawStreamEventChunk
13
+ from ....responses import (
14
+ ChunkIterator,
15
+ FinishReasonChunk,
16
+ RawStreamEventChunk,
17
+ UsageDeltaChunk,
18
+ )
14
19
  from ....tools import AnyToolSchema, BaseToolkit
15
20
  from .. import _utils
16
21
  from .base import BaseEncoder, TokenIds
@@ -129,3 +134,14 @@ class TransformersEncoder(BaseEncoder):
129
134
  yield FinishReasonChunk(finish_reason=finish_reason)
130
135
  else:
131
136
  yield TextEndChunk()
137
+
138
+ # Emit usage delta if available
139
+ usage = _utils.extract_usage(response)
140
+ if usage:
141
+ yield UsageDeltaChunk(
142
+ input_tokens=usage.input_tokens,
143
+ output_tokens=usage.output_tokens,
144
+ cache_read_tokens=usage.cache_read_tokens,
145
+ cache_write_tokens=usage.cache_write_tokens,
146
+ reasoning_tokens=usage.reasoning_tokens,
147
+ )
@@ -70,6 +70,14 @@ class MLXProvider(BaseProvider[None]):
70
70
 
71
71
  id = "mlx"
72
72
  default_scope = "mlx-community/"
73
+ error_map = _utils.MLX_ERROR_MAP
74
+
75
+ def get_error_status(self, e: Exception) -> int | None:
76
+ """Extract HTTP status code from MLX exception.
77
+
78
+ MLX/HuggingFace Hub exceptions don't have status codes.
79
+ """
80
+ return None
73
81
 
74
82
  def _call(
75
83
  self,
@@ -108,6 +116,7 @@ class MLXProvider(BaseProvider[None]):
108
116
  input_messages=input_messages,
109
117
  assistant_message=assistant_message,
110
118
  finish_reason=_utils.extract_finish_reason(response),
119
+ usage=_utils.extract_usage(response),
111
120
  format=format,
112
121
  )
113
122
 
@@ -152,6 +161,7 @@ class MLXProvider(BaseProvider[None]):
152
161
  input_messages=input_messages,
153
162
  assistant_message=assistant_message,
154
163
  finish_reason=_utils.extract_finish_reason(response),
164
+ usage=_utils.extract_usage(response),
155
165
  format=format,
156
166
  )
157
167
 
@@ -196,6 +206,7 @@ class MLXProvider(BaseProvider[None]):
196
206
  input_messages=input_messages,
197
207
  assistant_message=assistant_message,
198
208
  finish_reason=_utils.extract_finish_reason(response),
209
+ usage=_utils.extract_usage(response),
199
210
  format=format,
200
211
  )
201
212
 
@@ -244,6 +255,7 @@ class MLXProvider(BaseProvider[None]):
244
255
  input_messages=input_messages,
245
256
  assistant_message=assistant_message,
246
257
  finish_reason=_utils.extract_finish_reason(response),
258
+ usage=_utils.extract_usage(response),
247
259
  format=format,
248
260
  )
249
261
 
@@ -0,0 +1,19 @@
1
+ """Ollama provider implementation."""
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ if TYPE_CHECKING:
6
+ from .provider import OllamaProvider
7
+ else:
8
+ try:
9
+ from .provider import OllamaProvider
10
+ except ImportError: # pragma: no cover
11
+ from .._missing_import_stubs import (
12
+ create_provider_stub,
13
+ )
14
+
15
+ OllamaProvider = create_provider_stub("openai", "OllamaProvider")
16
+
17
+ __all__ = [
18
+ "OllamaProvider",
19
+ ]
@@ -0,0 +1,71 @@
1
+ """Ollama provider implementation."""
2
+
3
+ import os
4
+ from typing import ClassVar
5
+
6
+ from openai import AsyncOpenAI, OpenAI
7
+
8
+ from ..openai.completions.base_provider import BaseOpenAICompletionsProvider
9
+
10
+
11
+ class OllamaProvider(BaseOpenAICompletionsProvider):
12
+ """Provider for Ollama's OpenAI-compatible API.
13
+
14
+ Inherits from BaseOpenAICompletionsProvider with Ollama-specific configuration:
15
+ - Uses Ollama's local API endpoint (default: http://localhost:11434/v1/)
16
+ - API key is not required (Ollama ignores API keys)
17
+ - Supports OLLAMA_BASE_URL environment variable
18
+
19
+ Usage:
20
+ Register the provider with model ID prefixes you want to use:
21
+
22
+ ```python
23
+ import llm
24
+
25
+ # Register for ollama models
26
+ llm.register_provider("ollama", "ollama/")
27
+
28
+ # Now you can use ollama models directly
29
+ @llm.call("ollama/llama2")
30
+ def my_prompt():
31
+ return [llm.messages.user("Hello!")]
32
+ ```
33
+ """
34
+
35
+ id: ClassVar[str] = "ollama"
36
+ default_scope: ClassVar[str | list[str]] = "ollama/"
37
+ default_base_url: ClassVar[str | None] = "http://localhost:11434/v1/"
38
+ api_key_env_var: ClassVar[str] = "OLLAMA_API_KEY"
39
+ api_key_required: ClassVar[bool] = False
40
+ provider_name: ClassVar[str | None] = "Ollama"
41
+
42
+ def __init__(
43
+ self,
44
+ *,
45
+ api_key: str | None = None,
46
+ base_url: str | None = None,
47
+ ) -> None:
48
+ """Initialize the Ollama provider.
49
+
50
+ Args:
51
+ api_key: API key (optional). Defaults to OLLAMA_API_KEY env var or 'ollama'.
52
+ base_url: Custom base URL. Defaults to OLLAMA_BASE_URL env var
53
+ or http://localhost:11434/v1/.
54
+ """
55
+ resolved_api_key = api_key or os.environ.get(self.api_key_env_var) or "ollama"
56
+ resolved_base_url = (
57
+ base_url or os.environ.get("OLLAMA_BASE_URL") or self.default_base_url
58
+ )
59
+
60
+ self.client = OpenAI(
61
+ api_key=resolved_api_key,
62
+ base_url=resolved_base_url,
63
+ )
64
+ self.async_client = AsyncOpenAI(
65
+ api_key=resolved_api_key,
66
+ base_url=resolved_base_url,
67
+ )
68
+
69
+ def _model_name(self, model_id: str) -> str:
70
+ """Strip 'ollama/' prefix from model ID for Ollama API."""
71
+ return model_id.removeprefix("ollama/")
@@ -1,6 +1,15 @@
1
1
  """OpenAI client implementation."""
2
2
 
3
+ from .completions.base_provider import BaseOpenAICompletionsProvider
4
+ from .completions.provider import OpenAICompletionsProvider
3
5
  from .model_id import OpenAIModelId
4
6
  from .provider import OpenAIProvider
7
+ from .responses.provider import OpenAIResponsesProvider
5
8
 
6
- __all__ = ["OpenAIModelId", "OpenAIProvider"]
9
+ __all__ = [
10
+ "BaseOpenAICompletionsProvider",
11
+ "OpenAICompletionsProvider",
12
+ "OpenAIModelId",
13
+ "OpenAIProvider",
14
+ "OpenAIResponsesProvider",
15
+ ]
@@ -0,0 +1,5 @@
1
+ """Shared OpenAI utilities."""
2
+
3
+ from .errors import OPENAI_ERROR_MAP
4
+
5
+ __all__ = ["OPENAI_ERROR_MAP"]
@@ -0,0 +1,46 @@
1
+ """OpenAI error handling utilities."""
2
+
3
+ from openai import (
4
+ APIConnectionError as OpenAIAPIConnectionError,
5
+ APIResponseValidationError as OpenAIAPIResponseValidationError,
6
+ APITimeoutError as OpenAIAPITimeoutError,
7
+ AuthenticationError as OpenAIAuthenticationError,
8
+ BadRequestError as OpenAIBadRequestError,
9
+ ConflictError as OpenAIConflictError,
10
+ InternalServerError as OpenAIInternalServerError,
11
+ NotFoundError as OpenAINotFoundError,
12
+ OpenAIError,
13
+ PermissionDeniedError as OpenAIPermissionDeniedError,
14
+ RateLimitError as OpenAIRateLimitError,
15
+ UnprocessableEntityError as OpenAIUnprocessableEntityError,
16
+ )
17
+
18
+ from ....exceptions import (
19
+ APIError,
20
+ AuthenticationError,
21
+ BadRequestError,
22
+ ConnectionError,
23
+ NotFoundError,
24
+ PermissionError,
25
+ RateLimitError,
26
+ ResponseValidationError,
27
+ ServerError,
28
+ TimeoutError,
29
+ )
30
+ from ...base import ProviderErrorMap
31
+
32
+ # Shared error mapping used by OpenAI Responses and Completions providers
33
+ OPENAI_ERROR_MAP: ProviderErrorMap = {
34
+ OpenAIAuthenticationError: AuthenticationError,
35
+ OpenAIPermissionDeniedError: PermissionError,
36
+ OpenAINotFoundError: NotFoundError,
37
+ OpenAIBadRequestError: BadRequestError,
38
+ OpenAIUnprocessableEntityError: BadRequestError,
39
+ OpenAIConflictError: BadRequestError,
40
+ OpenAIRateLimitError: RateLimitError,
41
+ OpenAIInternalServerError: ServerError,
42
+ OpenAIAPITimeoutError: TimeoutError,
43
+ OpenAIAPIConnectionError: ConnectionError,
44
+ OpenAIAPIResponseValidationError: ResponseValidationError,
45
+ OpenAIError: APIError, # Catch-all for unknown OpenAI errors
46
+ }
@@ -1,20 +1,25 @@
1
1
  from typing import TYPE_CHECKING
2
2
 
3
3
  if TYPE_CHECKING:
4
+ from .base_provider import BaseOpenAICompletionsProvider
4
5
  from .provider import OpenAICompletionsProvider
5
6
  else:
6
7
  try:
8
+ from .base_provider import BaseOpenAICompletionsProvider
7
9
  from .provider import OpenAICompletionsProvider
8
10
  except ImportError: # pragma: no cover
9
11
  from ..._missing_import_stubs import (
10
- create_import_error_stub,
11
12
  create_provider_stub,
12
13
  )
13
14
 
15
+ BaseOpenAICompletionsProvider = create_provider_stub(
16
+ "openai", "BaseOpenAICompletionsProvider"
17
+ )
14
18
  OpenAICompletionsProvider = create_provider_stub(
15
19
  "openai", "OpenAICompletionsProvider"
16
20
  )
17
21
 
18
22
  __all__ = [
23
+ "BaseOpenAICompletionsProvider",
19
24
  "OpenAICompletionsProvider",
20
25
  ]
@@ -4,6 +4,7 @@ from typing import Literal
4
4
 
5
5
  from openai import AsyncStream, Stream
6
6
  from openai.types import chat as openai_types
7
+ from openai.types.completion_usage import CompletionUsage
7
8
 
8
9
  from .....content import (
9
10
  AssistantContentPart,
@@ -23,6 +24,8 @@ from .....responses import (
23
24
  FinishReason,
24
25
  FinishReasonChunk,
25
26
  RawStreamEventChunk,
27
+ Usage,
28
+ UsageDeltaChunk,
26
29
  )
27
30
  from ...model_id import OpenAIModelId, model_name
28
31
 
@@ -32,12 +35,40 @@ OPENAI_FINISH_REASON_MAP = {
32
35
  }
33
36
 
34
37
 
38
+ def _decode_usage(
39
+ usage: CompletionUsage | None,
40
+ ) -> Usage | None:
41
+ """Convert OpenAI CompletionUsage to Mirascope Usage."""
42
+ if usage is None: # pragma: no cover
43
+ return None
44
+
45
+ return Usage(
46
+ input_tokens=usage.prompt_tokens,
47
+ output_tokens=usage.completion_tokens,
48
+ cache_read_tokens=(
49
+ usage.prompt_tokens_details.cached_tokens
50
+ if usage.prompt_tokens_details
51
+ else None
52
+ )
53
+ or 0,
54
+ cache_write_tokens=0,
55
+ reasoning_tokens=(
56
+ usage.completion_tokens_details.reasoning_tokens
57
+ if usage.completion_tokens_details
58
+ else None
59
+ )
60
+ or 0,
61
+ raw=usage,
62
+ )
63
+
64
+
35
65
  def decode_response(
36
66
  response: openai_types.ChatCompletion,
37
67
  model_id: OpenAIModelId,
38
- provider_id: Literal["openai", "openai:completions"],
39
- ) -> tuple[AssistantMessage, FinishReason | None]:
40
- """Convert OpenAI ChatCompletion to mirascope AssistantMessage."""
68
+ provider_id: str,
69
+ provider_model_name: str | None = None,
70
+ ) -> tuple[AssistantMessage, FinishReason | None, Usage | None]:
71
+ """Convert OpenAI ChatCompletion to mirascope AssistantMessage and usage."""
41
72
  choice = response.choices[0]
42
73
  message = choice.message
43
74
  refused = False
@@ -72,11 +103,12 @@ def decode_response(
72
103
  content=parts,
73
104
  provider_id=provider_id,
74
105
  model_id=model_id,
75
- provider_model_name=model_name(model_id, "completions"),
106
+ provider_model_name=provider_model_name or model_name(model_id, "completions"),
76
107
  raw_message=message.model_dump(exclude_none=True),
77
108
  )
78
109
 
79
- return assistant_message, finish_reason
110
+ usage = _decode_usage(response.usage)
111
+ return assistant_message, finish_reason, usage
80
112
 
81
113
 
82
114
  class _OpenAIChunkProcessor:
@@ -91,6 +123,26 @@ class _OpenAIChunkProcessor:
91
123
  """Process a single OpenAI chunk and yield the appropriate content chunks."""
92
124
  yield RawStreamEventChunk(raw_stream_event=chunk)
93
125
 
126
+ if chunk.usage:
127
+ usage = chunk.usage
128
+ yield UsageDeltaChunk(
129
+ input_tokens=usage.prompt_tokens,
130
+ output_tokens=usage.completion_tokens,
131
+ cache_read_tokens=(
132
+ usage.prompt_tokens_details.cached_tokens
133
+ if usage.prompt_tokens_details
134
+ else None
135
+ )
136
+ or 0,
137
+ cache_write_tokens=0,
138
+ reasoning_tokens=(
139
+ usage.completion_tokens_details.reasoning_tokens
140
+ if usage.completion_tokens_details
141
+ else None
142
+ )
143
+ or 0,
144
+ )
145
+
94
146
  choice = chunk.choices[0] if chunk.choices else None
95
147
  if not choice:
96
148
  return # pragma: no cover
@@ -22,8 +22,11 @@ from .....messages import AssistantMessage, Message, UserMessage
22
22
  from .....tools import FORMAT_TOOL_NAME, AnyToolSchema, BaseToolkit
23
23
  from ....base import Params, _utils as _base_utils
24
24
  from ...model_id import OpenAIModelId, model_name
25
- from ...model_info import MODELS_WITHOUT_AUDIO_SUPPORT
26
- from ...shared import _utils as _shared_utils
25
+ from ...model_info import (
26
+ MODELS_WITHOUT_AUDIO_SUPPORT,
27
+ MODELS_WITHOUT_JSON_OBJECT_SUPPORT,
28
+ MODELS_WITHOUT_JSON_SCHEMA_SUPPORT,
29
+ )
27
30
 
28
31
 
29
32
  class ChatCompletionCreateKwargs(TypedDict, total=False):
@@ -233,7 +236,7 @@ def _convert_tool_to_tool_param(
233
236
  """Convert a single Mirascope `Tool` to OpenAI ChatCompletionToolParam with caching."""
234
237
  schema_dict = tool.parameters.model_dump(by_alias=True, exclude_none=True)
235
238
  schema_dict["type"] = "object"
236
- _shared_utils.ensure_additional_properties_false(schema_dict)
239
+ _base_utils.ensure_additional_properties_false(schema_dict)
237
240
  return openai_types.ChatCompletionToolParam(
238
241
  type="function",
239
242
  function={
@@ -258,7 +261,7 @@ def _create_strict_response_format(
258
261
  """
259
262
  schema = format.schema.copy()
260
263
 
261
- _shared_utils.ensure_additional_properties_false(schema)
264
+ _base_utils.ensure_additional_properties_false(schema)
262
265
 
263
266
  json_schema = JSONSchema(
264
267
  name=format.name,
@@ -321,9 +324,7 @@ def encode_request(
321
324
 
322
325
  openai_tools = [_convert_tool_to_tool_param(tool) for tool in tools]
323
326
 
324
- model_supports_strict = (
325
- base_model_name not in _shared_utils.MODELS_WITHOUT_JSON_SCHEMA_SUPPORT
326
- )
327
+ model_supports_strict = base_model_name not in MODELS_WITHOUT_JSON_SCHEMA_SUPPORT
327
328
  default_mode = "strict" if model_supports_strict else "tool"
328
329
  format = resolve_format(format, default_mode=default_mode)
329
330
  if format is not None:
@@ -348,7 +349,7 @@ def encode_request(
348
349
  openai_tools.append(_convert_tool_to_tool_param(format_tool_schema))
349
350
  elif (
350
351
  format.mode == "json"
351
- and base_model_name not in _shared_utils.MODELS_WITHOUT_JSON_OBJECT_SUPPORT
352
+ and base_model_name not in MODELS_WITHOUT_JSON_OBJECT_SUPPORT
352
353
  ):
353
354
  kwargs["response_format"] = {"type": "json_object"}
354
355