mirascope 2.0.0a2__py3-none-any.whl → 2.0.0a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (204) hide show
  1. mirascope/__init__.py +2 -2
  2. mirascope/api/__init__.py +6 -0
  3. mirascope/api/_generated/README.md +207 -0
  4. mirascope/api/_generated/__init__.py +85 -0
  5. mirascope/api/_generated/client.py +155 -0
  6. mirascope/api/_generated/core/__init__.py +52 -0
  7. mirascope/api/_generated/core/api_error.py +23 -0
  8. mirascope/api/_generated/core/client_wrapper.py +58 -0
  9. mirascope/api/_generated/core/datetime_utils.py +30 -0
  10. mirascope/api/_generated/core/file.py +70 -0
  11. mirascope/api/_generated/core/force_multipart.py +16 -0
  12. mirascope/api/_generated/core/http_client.py +619 -0
  13. mirascope/api/_generated/core/http_response.py +55 -0
  14. mirascope/api/_generated/core/jsonable_encoder.py +102 -0
  15. mirascope/api/_generated/core/pydantic_utilities.py +310 -0
  16. mirascope/api/_generated/core/query_encoder.py +60 -0
  17. mirascope/api/_generated/core/remove_none_from_dict.py +11 -0
  18. mirascope/api/_generated/core/request_options.py +35 -0
  19. mirascope/api/_generated/core/serialization.py +282 -0
  20. mirascope/api/_generated/docs/__init__.py +4 -0
  21. mirascope/api/_generated/docs/client.py +95 -0
  22. mirascope/api/_generated/docs/raw_client.py +132 -0
  23. mirascope/api/_generated/environment.py +9 -0
  24. mirascope/api/_generated/errors/__init__.py +7 -0
  25. mirascope/api/_generated/errors/bad_request_error.py +15 -0
  26. mirascope/api/_generated/health/__init__.py +7 -0
  27. mirascope/api/_generated/health/client.py +96 -0
  28. mirascope/api/_generated/health/raw_client.py +129 -0
  29. mirascope/api/_generated/health/types/__init__.py +8 -0
  30. mirascope/api/_generated/health/types/health_check_response.py +24 -0
  31. mirascope/api/_generated/health/types/health_check_response_status.py +5 -0
  32. mirascope/api/_generated/reference.md +167 -0
  33. mirascope/api/_generated/traces/__init__.py +55 -0
  34. mirascope/api/_generated/traces/client.py +162 -0
  35. mirascope/api/_generated/traces/raw_client.py +168 -0
  36. mirascope/api/_generated/traces/types/__init__.py +95 -0
  37. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item.py +36 -0
  38. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_resource.py +31 -0
  39. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_resource_attributes_item.py +25 -0
  40. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_resource_attributes_item_value.py +54 -0
  41. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_resource_attributes_item_value_array_value.py +23 -0
  42. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_resource_attributes_item_value_kvlist_value.py +28 -0
  43. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_resource_attributes_item_value_kvlist_value_values_item.py +24 -0
  44. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item.py +35 -0
  45. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_scope.py +35 -0
  46. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_scope_attributes_item.py +27 -0
  47. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_scope_attributes_item_value.py +54 -0
  48. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_scope_attributes_item_value_array_value.py +23 -0
  49. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_scope_attributes_item_value_kvlist_value.py +28 -0
  50. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_scope_attributes_item_value_kvlist_value_values_item.py +24 -0
  51. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_spans_item.py +60 -0
  52. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_spans_item_attributes_item.py +29 -0
  53. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_spans_item_attributes_item_value.py +54 -0
  54. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_spans_item_attributes_item_value_array_value.py +23 -0
  55. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_spans_item_attributes_item_value_kvlist_value.py +28 -0
  56. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_spans_item_attributes_item_value_kvlist_value_values_item.py +24 -0
  57. mirascope/api/_generated/traces/types/traces_create_request_resource_spans_item_scope_spans_item_spans_item_status.py +24 -0
  58. mirascope/api/_generated/traces/types/traces_create_response.py +27 -0
  59. mirascope/api/_generated/traces/types/traces_create_response_partial_success.py +28 -0
  60. mirascope/api/_generated/types/__init__.py +21 -0
  61. mirascope/api/_generated/types/http_api_decode_error.py +31 -0
  62. mirascope/api/_generated/types/http_api_decode_error_tag.py +5 -0
  63. mirascope/api/_generated/types/issue.py +44 -0
  64. mirascope/api/_generated/types/issue_tag.py +17 -0
  65. mirascope/api/_generated/types/property_key.py +7 -0
  66. mirascope/api/_generated/types/property_key_tag.py +29 -0
  67. mirascope/api/_generated/types/property_key_tag_tag.py +5 -0
  68. mirascope/api/client.py +255 -0
  69. mirascope/api/settings.py +81 -0
  70. mirascope/llm/__init__.py +41 -11
  71. mirascope/llm/calls/calls.py +81 -57
  72. mirascope/llm/calls/decorator.py +121 -115
  73. mirascope/llm/content/__init__.py +3 -2
  74. mirascope/llm/context/_utils.py +19 -6
  75. mirascope/llm/exceptions.py +30 -16
  76. mirascope/llm/formatting/_utils.py +9 -5
  77. mirascope/llm/formatting/format.py +2 -2
  78. mirascope/llm/formatting/from_call_args.py +2 -2
  79. mirascope/llm/messages/message.py +13 -5
  80. mirascope/llm/models/__init__.py +2 -2
  81. mirascope/llm/models/models.py +189 -81
  82. mirascope/llm/prompts/__init__.py +13 -12
  83. mirascope/llm/prompts/_utils.py +27 -24
  84. mirascope/llm/prompts/decorator.py +133 -204
  85. mirascope/llm/prompts/prompts.py +424 -0
  86. mirascope/llm/prompts/protocols.py +25 -59
  87. mirascope/llm/providers/__init__.py +38 -0
  88. mirascope/llm/{clients → providers}/_missing_import_stubs.py +8 -6
  89. mirascope/llm/providers/anthropic/__init__.py +24 -0
  90. mirascope/llm/{clients → providers}/anthropic/_utils/decode.py +5 -4
  91. mirascope/llm/{clients → providers}/anthropic/_utils/encode.py +31 -10
  92. mirascope/llm/providers/anthropic/model_id.py +40 -0
  93. mirascope/llm/{clients/anthropic/clients.py → providers/anthropic/provider.py} +33 -418
  94. mirascope/llm/{clients → providers}/base/__init__.py +3 -3
  95. mirascope/llm/{clients → providers}/base/_utils.py +10 -7
  96. mirascope/llm/{clients/base/client.py → providers/base/base_provider.py} +255 -126
  97. mirascope/llm/providers/google/__init__.py +21 -0
  98. mirascope/llm/{clients → providers}/google/_utils/decode.py +6 -4
  99. mirascope/llm/{clients → providers}/google/_utils/encode.py +30 -24
  100. mirascope/llm/providers/google/model_id.py +28 -0
  101. mirascope/llm/providers/google/provider.py +438 -0
  102. mirascope/llm/providers/load_provider.py +48 -0
  103. mirascope/llm/providers/mlx/__init__.py +24 -0
  104. mirascope/llm/providers/mlx/_utils.py +107 -0
  105. mirascope/llm/providers/mlx/encoding/__init__.py +8 -0
  106. mirascope/llm/providers/mlx/encoding/base.py +69 -0
  107. mirascope/llm/providers/mlx/encoding/transformers.py +131 -0
  108. mirascope/llm/providers/mlx/mlx.py +237 -0
  109. mirascope/llm/providers/mlx/model_id.py +17 -0
  110. mirascope/llm/providers/mlx/provider.py +411 -0
  111. mirascope/llm/providers/model_id.py +16 -0
  112. mirascope/llm/providers/openai/__init__.py +6 -0
  113. mirascope/llm/providers/openai/completions/__init__.py +20 -0
  114. mirascope/llm/{clients/openai/responses → providers/openai/completions}/_utils/__init__.py +2 -0
  115. mirascope/llm/{clients → providers}/openai/completions/_utils/decode.py +5 -3
  116. mirascope/llm/{clients → providers}/openai/completions/_utils/encode.py +33 -23
  117. mirascope/llm/providers/openai/completions/provider.py +456 -0
  118. mirascope/llm/providers/openai/model_id.py +31 -0
  119. mirascope/llm/providers/openai/model_info.py +246 -0
  120. mirascope/llm/providers/openai/provider.py +386 -0
  121. mirascope/llm/providers/openai/responses/__init__.py +21 -0
  122. mirascope/llm/{clients → providers}/openai/responses/_utils/decode.py +5 -3
  123. mirascope/llm/{clients → providers}/openai/responses/_utils/encode.py +28 -17
  124. mirascope/llm/providers/openai/responses/provider.py +470 -0
  125. mirascope/llm/{clients → providers}/openai/shared/_utils.py +7 -3
  126. mirascope/llm/providers/provider_id.py +13 -0
  127. mirascope/llm/providers/provider_registry.py +167 -0
  128. mirascope/llm/responses/base_response.py +10 -5
  129. mirascope/llm/responses/base_stream_response.py +10 -5
  130. mirascope/llm/responses/response.py +24 -13
  131. mirascope/llm/responses/root_response.py +7 -12
  132. mirascope/llm/responses/stream_response.py +35 -23
  133. mirascope/llm/tools/__init__.py +9 -2
  134. mirascope/llm/tools/_utils.py +12 -3
  135. mirascope/llm/tools/protocols.py +4 -4
  136. mirascope/llm/tools/tool_schema.py +44 -9
  137. mirascope/llm/tools/tools.py +10 -9
  138. mirascope/ops/__init__.py +156 -0
  139. mirascope/ops/_internal/__init__.py +5 -0
  140. mirascope/ops/_internal/closure.py +1118 -0
  141. mirascope/ops/_internal/configuration.py +126 -0
  142. mirascope/ops/_internal/context.py +76 -0
  143. mirascope/ops/_internal/exporters/__init__.py +26 -0
  144. mirascope/ops/_internal/exporters/exporters.py +342 -0
  145. mirascope/ops/_internal/exporters/processors.py +104 -0
  146. mirascope/ops/_internal/exporters/types.py +165 -0
  147. mirascope/ops/_internal/exporters/utils.py +29 -0
  148. mirascope/ops/_internal/instrumentation/__init__.py +8 -0
  149. mirascope/ops/_internal/instrumentation/llm/__init__.py +8 -0
  150. mirascope/ops/_internal/instrumentation/llm/encode.py +238 -0
  151. mirascope/ops/_internal/instrumentation/llm/gen_ai_types/__init__.py +38 -0
  152. mirascope/ops/_internal/instrumentation/llm/gen_ai_types/gen_ai_input_messages.py +31 -0
  153. mirascope/ops/_internal/instrumentation/llm/gen_ai_types/gen_ai_output_messages.py +38 -0
  154. mirascope/ops/_internal/instrumentation/llm/gen_ai_types/gen_ai_system_instructions.py +18 -0
  155. mirascope/ops/_internal/instrumentation/llm/gen_ai_types/shared.py +100 -0
  156. mirascope/ops/_internal/instrumentation/llm/llm.py +1288 -0
  157. mirascope/ops/_internal/propagation.py +198 -0
  158. mirascope/ops/_internal/protocols.py +51 -0
  159. mirascope/ops/_internal/session.py +139 -0
  160. mirascope/ops/_internal/spans.py +232 -0
  161. mirascope/ops/_internal/traced_calls.py +371 -0
  162. mirascope/ops/_internal/traced_functions.py +394 -0
  163. mirascope/ops/_internal/tracing.py +276 -0
  164. mirascope/ops/_internal/types.py +13 -0
  165. mirascope/ops/_internal/utils.py +75 -0
  166. mirascope/ops/_internal/versioned_calls.py +512 -0
  167. mirascope/ops/_internal/versioned_functions.py +346 -0
  168. mirascope/ops/_internal/versioning.py +303 -0
  169. mirascope/ops/exceptions.py +21 -0
  170. {mirascope-2.0.0a2.dist-info → mirascope-2.0.0a3.dist-info}/METADATA +76 -1
  171. mirascope-2.0.0a3.dist-info/RECORD +206 -0
  172. {mirascope-2.0.0a2.dist-info → mirascope-2.0.0a3.dist-info}/WHEEL +1 -1
  173. mirascope/graphs/__init__.py +0 -22
  174. mirascope/graphs/finite_state_machine.py +0 -625
  175. mirascope/llm/agents/__init__.py +0 -15
  176. mirascope/llm/agents/agent.py +0 -97
  177. mirascope/llm/agents/agent_template.py +0 -45
  178. mirascope/llm/agents/decorator.py +0 -176
  179. mirascope/llm/calls/base_call.py +0 -33
  180. mirascope/llm/clients/__init__.py +0 -34
  181. mirascope/llm/clients/anthropic/__init__.py +0 -25
  182. mirascope/llm/clients/anthropic/model_ids.py +0 -8
  183. mirascope/llm/clients/google/__init__.py +0 -20
  184. mirascope/llm/clients/google/clients.py +0 -853
  185. mirascope/llm/clients/google/model_ids.py +0 -15
  186. mirascope/llm/clients/openai/__init__.py +0 -25
  187. mirascope/llm/clients/openai/completions/__init__.py +0 -28
  188. mirascope/llm/clients/openai/completions/_utils/model_features.py +0 -81
  189. mirascope/llm/clients/openai/completions/clients.py +0 -833
  190. mirascope/llm/clients/openai/completions/model_ids.py +0 -8
  191. mirascope/llm/clients/openai/responses/__init__.py +0 -26
  192. mirascope/llm/clients/openai/responses/_utils/model_features.py +0 -87
  193. mirascope/llm/clients/openai/responses/clients.py +0 -832
  194. mirascope/llm/clients/openai/responses/model_ids.py +0 -8
  195. mirascope/llm/clients/providers.py +0 -175
  196. mirascope-2.0.0a2.dist-info/RECORD +0 -102
  197. /mirascope/llm/{clients → providers}/anthropic/_utils/__init__.py +0 -0
  198. /mirascope/llm/{clients → providers}/base/kwargs.py +0 -0
  199. /mirascope/llm/{clients → providers}/base/params.py +0 -0
  200. /mirascope/llm/{clients → providers}/google/_utils/__init__.py +0 -0
  201. /mirascope/llm/{clients → providers}/google/message.py +0 -0
  202. /mirascope/llm/{clients/openai/completions → providers/openai/responses}/_utils/__init__.py +0 -0
  203. /mirascope/llm/{clients → providers}/openai/shared/__init__.py +0 -0
  204. {mirascope-2.0.0a2.dist-info → mirascope-2.0.0a3.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,24 @@
1
+ """MLX client implementation."""
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ if TYPE_CHECKING:
6
+ from .model_id import MLXModelId
7
+ from .provider import MLXProvider
8
+ else:
9
+ try:
10
+ from .model_id import MLXModelId
11
+ from .provider import MLXProvider
12
+ except ImportError: # pragma: no cover
13
+ from .._missing_import_stubs import (
14
+ create_import_error_stub,
15
+ create_provider_stub,
16
+ )
17
+
18
+ MLXProvider = create_provider_stub("mlx", "MLXProvider")
19
+ MLXModelId = str
20
+
21
+ __all__ = [
22
+ "MLXModelId",
23
+ "MLXProvider",
24
+ ]
@@ -0,0 +1,107 @@
1
+ from collections.abc import Callable
2
+ from typing import TypeAlias, TypedDict
3
+
4
+ import mlx.core as mx
5
+ from mlx_lm.generate import GenerationResponse
6
+ from mlx_lm.sample_utils import make_sampler
7
+
8
+ from ...responses import FinishReason
9
+ from ..base import Params, _utils as _base_utils
10
+
11
+ Sampler: TypeAlias = Callable[[mx.array], mx.array]
12
+
13
+
14
+ class MakeSamplerKwargs(TypedDict, total=False):
15
+ """Keyword arguments to be used for `mlx_lm`-s `make_sampler` function.
16
+
17
+ Some of these settings are directly match the generic client parameters
18
+ as defined in the `Params` class. See mirascope.llm.providers.Params for
19
+ more details.
20
+ """
21
+
22
+ temp: float
23
+ "The temperature for sampling, if 0 the argmax is used."
24
+
25
+ top_p: float
26
+ "Nulceus sampling, higher means model considers more less likely words."
27
+
28
+ min_p: float
29
+ """The minimum value (scaled by the top token's probability) that a token
30
+ probability must have to be considered."""
31
+
32
+ min_tokens_to_keep: int
33
+ "Minimum number of tokens that cannot be filtered by min_p sampling."
34
+
35
+ top_k: int
36
+ "The top k tokens ranked by probability to constrain the sampling to."
37
+
38
+ xtc_probability: float
39
+ "The probability of applying XTC sampling."
40
+
41
+ xtc_threshold: float
42
+ "The threshold the probs need to reach for being sampled."
43
+
44
+ xtc_special_tokens: list[int]
45
+ "List of special tokens IDs to be excluded from XTC sampling."
46
+
47
+
48
+ class StreamGenerateKwargs(TypedDict, total=False):
49
+ """Keyword arguments for the `mlx-lm.stream_generate` function."""
50
+
51
+ max_tokens: int
52
+ "The maximum number of tokens to generate."
53
+
54
+ sampler: Sampler
55
+ "A sampler for sampling token from a vector of logits."
56
+
57
+
58
+ def encode_params(params: Params) -> tuple[int | None, StreamGenerateKwargs]:
59
+ """Convert generic params to mlx-lm stream_generate kwargs.
60
+
61
+ Args:
62
+ params: The generic parameters.
63
+
64
+ Returns:
65
+ The mlx-lm specific stream_generate keyword arguments.
66
+ """
67
+ kwargs: StreamGenerateKwargs = {}
68
+
69
+ with _base_utils.ensure_all_params_accessed(
70
+ params=params,
71
+ provider_id="mlx",
72
+ unsupported_params=["stop_sequences", "thinking", "encode_thoughts_as_text"],
73
+ ) as param_accessor:
74
+ if param_accessor.max_tokens is not None:
75
+ kwargs["max_tokens"] = param_accessor.max_tokens
76
+ else:
77
+ kwargs["max_tokens"] = -1
78
+
79
+ sampler_kwargs = MakeSamplerKwargs({})
80
+ if param_accessor.temperature is not None:
81
+ sampler_kwargs["temp"] = param_accessor.temperature
82
+ if param_accessor.top_k is not None:
83
+ sampler_kwargs["top_k"] = param_accessor.top_k
84
+ if param_accessor.top_p is not None:
85
+ sampler_kwargs["top_p"] = param_accessor.top_p
86
+
87
+ kwargs["sampler"] = make_sampler(**sampler_kwargs)
88
+
89
+ return param_accessor.seed, kwargs
90
+
91
+
92
+ def extract_finish_reason(response: GenerationResponse | None) -> FinishReason | None:
93
+ """Extract the finish reason from an MLX generation response.
94
+
95
+ Args:
96
+ response: The MLX generation response to extract from.
97
+
98
+ Returns:
99
+ The normalized finish reason, or None if not applicable.
100
+ """
101
+ if response is None:
102
+ return None
103
+
104
+ if response.finish_reason == "length":
105
+ return FinishReason.MAX_TOKENS
106
+
107
+ return None
@@ -0,0 +1,8 @@
1
+ from .base import BaseEncoder, TokenIds
2
+ from .transformers import TransformersEncoder
3
+
4
+ __all__ = [
5
+ "BaseEncoder",
6
+ "TokenIds",
7
+ "TransformersEncoder",
8
+ ]
@@ -0,0 +1,69 @@
1
+ from __future__ import annotations
2
+
3
+ import abc
4
+ from collections.abc import Iterable, Sequence
5
+ from typing import TypeAlias
6
+
7
+ from mlx_lm.generate import GenerationResponse
8
+
9
+ from ....formatting import Format, FormattableT
10
+ from ....messages import AssistantContent, Message
11
+ from ....responses import ChunkIterator
12
+ from ....tools import AnyToolSchema, BaseToolkit
13
+
14
+ TokenIds: TypeAlias = list[int]
15
+
16
+
17
+ class BaseEncoder(abc.ABC):
18
+ """Abstract base class for Mirascope <> MLX encoding and decoding."""
19
+
20
+ @abc.abstractmethod
21
+ def encode_request(
22
+ self,
23
+ messages: Sequence[Message],
24
+ tools: Sequence[AnyToolSchema] | BaseToolkit[AnyToolSchema] | None,
25
+ format: type[FormattableT] | Format[FormattableT] | None,
26
+ ) -> tuple[Sequence[Message], Format[FormattableT] | None, TokenIds]:
27
+ """Encode the request messages into a format suitable for the model.
28
+
29
+ Args:
30
+ messages: The sequence of messages to encode.
31
+ tools: Optional sequence of tool schemas or toolkit for the model.
32
+ format: Optional format specification for structured outputs.
33
+
34
+ Returns:
35
+ A tuple containing:
36
+ - The processed messages
37
+ - The format specification (if provided)
38
+ - The encoded prompt as token IDs
39
+ """
40
+
41
+ ...
42
+
43
+ @abc.abstractmethod
44
+ def decode_response(
45
+ self, stream: Iterable[GenerationResponse]
46
+ ) -> tuple[AssistantContent, GenerationResponse | None]:
47
+ """Decode a stream of MLX generation responses into assistant content.
48
+
49
+ Args:
50
+ stream: An iterable of MLX generation responses.
51
+
52
+ Returns:
53
+ A tuple containing:
54
+ - The decoded assistant content
55
+ - The final generation response (if available)
56
+ """
57
+ ...
58
+
59
+ @abc.abstractmethod
60
+ def decode_stream(self, stream: Iterable[GenerationResponse]) -> ChunkIterator:
61
+ """Decode a stream of MLX generation responses into an iterable of chunks.
62
+
63
+ Args:
64
+ stream: An iterable of MLX generation responses.
65
+
66
+ Returns:
67
+ A ChunkIterator yielding content chunks for streaming responses.
68
+ """
69
+ ...
@@ -0,0 +1,131 @@
1
+ import io
2
+ from collections.abc import Iterable, Sequence
3
+ from dataclasses import dataclass
4
+ from typing import Literal, cast
5
+ from typing_extensions import TypedDict
6
+
7
+ from mlx_lm.generate import GenerationResponse
8
+ from transformers import PreTrainedTokenizer
9
+
10
+ from ....content import ContentPart, TextChunk, TextEndChunk, TextStartChunk
11
+ from ....formatting import Format, FormattableT
12
+ from ....messages import AssistantContent, Message
13
+ from ....responses import ChunkIterator, FinishReasonChunk, RawStreamEventChunk
14
+ from ....tools import AnyToolSchema, BaseToolkit
15
+ from .. import _utils
16
+ from .base import BaseEncoder, TokenIds
17
+
18
+ HFRole = Literal["system", "user", "assistant"] | str
19
+
20
+
21
+ class TransformersMessage(TypedDict):
22
+ """Message in Transformers format."""
23
+
24
+ role: HFRole
25
+ content: str
26
+
27
+
28
+ def _encode_content(content: Sequence[ContentPart]) -> str:
29
+ """Encode content parts into a string.
30
+
31
+ Args:
32
+ content: The sequence of content parts to encode.
33
+
34
+ Returns:
35
+ The encoded content as a string.
36
+
37
+ Raises:
38
+ NotImplementedError: If content contains non-text parts.
39
+ """
40
+ if len(content) == 1 and content[0].type == "text":
41
+ return content[0].text
42
+
43
+ raise NotImplementedError("Only text content is supported in this example.")
44
+
45
+
46
+ def _encode_message(message: Message) -> TransformersMessage:
47
+ """Encode a Mirascope message into Transformers format.
48
+
49
+ Args:
50
+ message: The message to encode.
51
+
52
+ Returns:
53
+ The encoded message in Transformers format.
54
+
55
+ Raises:
56
+ ValueError: If the message role is not supported.
57
+ """
58
+ if message.role == "system":
59
+ return TransformersMessage(role="system", content=message.content.text)
60
+ elif message.role == "assistant" or message.role == "user":
61
+ return TransformersMessage(
62
+ role=message.role, content=_encode_content(message.content)
63
+ )
64
+ else:
65
+ raise ValueError(f"Unsupported message type: {type(message)}")
66
+
67
+
68
+ @dataclass(frozen=True)
69
+ class TransformersEncoder(BaseEncoder):
70
+ """Encoder for Transformers models."""
71
+
72
+ tokenizer: PreTrainedTokenizer
73
+ """The tokenizer to use for encoding."""
74
+
75
+ def encode_request(
76
+ self,
77
+ messages: Sequence[Message],
78
+ tools: Sequence[AnyToolSchema] | BaseToolkit[AnyToolSchema] | None,
79
+ format: type[FormattableT] | Format[FormattableT] | None,
80
+ ) -> tuple[Sequence[Message], Format[FormattableT] | None, TokenIds]:
81
+ """Encode a request into a format suitable for the model."""
82
+ tool_schemas = tools.tools if isinstance(tools, BaseToolkit) else tools or []
83
+ if len(tool_schemas) > 0:
84
+ raise NotImplementedError("Tool usage is not supported.")
85
+ if format is not None:
86
+ raise NotImplementedError("Formatting is not supported.")
87
+
88
+ hf_messages: list[TransformersMessage] = [
89
+ _encode_message(msg) for msg in messages
90
+ ]
91
+ prompt_text = cast(
92
+ str,
93
+ self.tokenizer.apply_chat_template( # pyright: ignore[reportUnknownMemberType]
94
+ cast(list[dict[str, str]], hf_messages),
95
+ tokenize=False,
96
+ add_generation_prompt=True,
97
+ ),
98
+ )
99
+ return (
100
+ messages,
101
+ format,
102
+ self.tokenizer.encode(prompt_text, add_special_tokens=False), # pyright: ignore[reportUnknownMemberType]
103
+ )
104
+
105
+ def decode_response(
106
+ self, stream: Iterable[GenerationResponse]
107
+ ) -> tuple[AssistantContent, GenerationResponse | None]:
108
+ """Decode a response into a format suitable for the model."""
109
+ with io.StringIO() as buffer:
110
+ last_response: GenerationResponse | None = None
111
+ for response in stream:
112
+ buffer.write(response.text)
113
+ last_response = response
114
+
115
+ return buffer.getvalue(), last_response
116
+
117
+ def decode_stream(self, stream: Iterable[GenerationResponse]) -> ChunkIterator:
118
+ """Decode a stream of responses into a format suitable for the model."""
119
+ yield TextStartChunk()
120
+
121
+ response: GenerationResponse | None = None
122
+ for response in stream:
123
+ yield RawStreamEventChunk(raw_stream_event=response)
124
+ yield TextChunk(delta=response.text)
125
+
126
+ assert response is not None
127
+ finish_reason = _utils.extract_finish_reason(response)
128
+ if finish_reason is not None:
129
+ yield FinishReasonChunk(finish_reason=finish_reason)
130
+ else:
131
+ yield TextEndChunk()
@@ -0,0 +1,237 @@
1
+ import asyncio
2
+ import threading
3
+ from collections.abc import Iterable, Sequence
4
+ from dataclasses import dataclass, field
5
+ from typing_extensions import Unpack
6
+
7
+ import mlx.core as mx
8
+ import mlx.nn as nn
9
+ from mlx_lm import stream_generate # type: ignore[reportPrivateImportUsage]
10
+ from mlx_lm.generate import GenerationResponse
11
+ from transformers import PreTrainedTokenizer
12
+
13
+ from ...formatting import Format, FormattableT
14
+ from ...messages import AssistantMessage, Message, assistant
15
+ from ...responses import AsyncChunkIterator, ChunkIterator, StreamResponseChunk
16
+ from ...tools import AnyToolSchema, BaseToolkit
17
+ from ..base import Params
18
+ from . import _utils
19
+ from .encoding import BaseEncoder, TokenIds
20
+ from .model_id import MLXModelId
21
+
22
+
23
+ def _consume_sync_stream_into_queue(
24
+ generation_stream: ChunkIterator,
25
+ loop: asyncio.AbstractEventLoop,
26
+ queue: asyncio.Queue[StreamResponseChunk | Exception | None],
27
+ ) -> None:
28
+ """Consume a synchronous stream and put chunks into an async queue.
29
+
30
+ Args:
31
+ sync_stream: The synchronous chunk iterator to consume.
32
+ loop: The event loop for scheduling queue operations.
33
+ queue: The async queue to put chunks into.
34
+ """
35
+ try:
36
+ for response in generation_stream:
37
+ asyncio.run_coroutine_threadsafe(queue.put(response), loop)
38
+ except Exception as e:
39
+ asyncio.run_coroutine_threadsafe(queue.put(e), loop)
40
+
41
+ asyncio.run_coroutine_threadsafe(queue.put(None), loop)
42
+
43
+
44
+ @dataclass(frozen=True)
45
+ class MLX:
46
+ """MLX model wrapper for synchronous and asynchronous generation.
47
+
48
+ Args:
49
+ model_id: The MLX model identifier.
50
+ model: The underlying MLX model.
51
+ tokenizer: The tokenizer for the model.
52
+ encoder: The encoder for prompts and responses.
53
+ """
54
+
55
+ model_id: MLXModelId
56
+ """The MLX model identifier."""
57
+
58
+ model: nn.Module
59
+ """The underlying MLX model."""
60
+
61
+ tokenizer: PreTrainedTokenizer
62
+ """The tokenizer for the model."""
63
+
64
+ encoder: BaseEncoder
65
+ """The encoder for prompts and responses."""
66
+
67
+ _lock: threading.Lock = field(default_factory=threading.Lock)
68
+ """The lock for thread-safety."""
69
+
70
+ def _stream_generate(
71
+ self,
72
+ prompt: TokenIds,
73
+ seed: int | None,
74
+ **kwargs: Unpack[_utils.StreamGenerateKwargs],
75
+ ) -> Iterable[GenerationResponse]:
76
+ """Generator that streams generation responses.
77
+
78
+ Using this generator instead of calling stream_generate directly ensures
79
+ thread-safety when using the model in a multi-threaded context.
80
+ """
81
+ with self._lock:
82
+ if seed is not None:
83
+ mx.random.seed(seed)
84
+
85
+ return stream_generate(
86
+ self.model,
87
+ self.tokenizer,
88
+ prompt,
89
+ **kwargs,
90
+ )
91
+
92
+ async def _stream_generate_async(
93
+ self,
94
+ prompt: TokenIds,
95
+ seed: int | None,
96
+ **kwargs: Unpack[_utils.StreamGenerateKwargs],
97
+ ) -> AsyncChunkIterator:
98
+ """Async generator that streams generation responses.
99
+
100
+ Note that, while stream_generate returns an iterable of GenerationResponse,
101
+ here we return an `AsyncChunkIterator`, in order to avoid having to implement
102
+ both synchronous and asynchronous versions of BaseEncoder.decode_stream.
103
+ This makes sense as in this case, there is nothing to gain from consuming the
104
+ generation asyncnronously.
105
+ """
106
+ loop = asyncio.get_running_loop()
107
+ generation_queue: asyncio.Queue[StreamResponseChunk | Exception | None] = (
108
+ asyncio.Queue()
109
+ )
110
+
111
+ sync_stream = self.encoder.decode_stream(
112
+ self._stream_generate(
113
+ prompt,
114
+ seed,
115
+ **kwargs,
116
+ )
117
+ )
118
+
119
+ consume_task = asyncio.create_task(
120
+ asyncio.to_thread(
121
+ _consume_sync_stream_into_queue, sync_stream, loop, generation_queue
122
+ ),
123
+ )
124
+ while item := await generation_queue.get():
125
+ if isinstance(item, Exception):
126
+ raise item
127
+
128
+ yield item
129
+
130
+ await consume_task
131
+
132
+ def stream(
133
+ self,
134
+ messages: Sequence[Message],
135
+ tools: Sequence[AnyToolSchema] | BaseToolkit[AnyToolSchema] | None,
136
+ format: type[FormattableT] | Format[FormattableT] | None,
137
+ params: Params,
138
+ ) -> tuple[Sequence[Message], Format[FormattableT] | None, ChunkIterator]:
139
+ """Stream response chunks synchronously.
140
+
141
+ Args:
142
+ messages: The input messages.
143
+ tools: Optional tools for the model.
144
+ format: Optional response format.
145
+
146
+ Returns:
147
+ Tuple of messages, format, and chunk iterator.
148
+ """
149
+ messages, format, prompt = self.encoder.encode_request(messages, tools, format)
150
+ seed, kwargs = _utils.encode_params(params)
151
+
152
+ stream = self._stream_generate(prompt, seed, **kwargs)
153
+ return messages, format, self.encoder.decode_stream(stream)
154
+
155
+ async def stream_async(
156
+ self,
157
+ messages: Sequence[Message],
158
+ tools: Sequence[AnyToolSchema] | BaseToolkit[AnyToolSchema] | None,
159
+ format: type[FormattableT] | Format[FormattableT] | None,
160
+ params: Params,
161
+ ) -> tuple[Sequence[Message], Format[FormattableT] | None, AsyncChunkIterator]:
162
+ """Stream response chunks asynchronously.
163
+
164
+ Args:
165
+ messages: The input messages.
166
+ tools: Optional tools for the model.
167
+ format: Optional response format.
168
+ Returns:
169
+ Tuple of messages, format, and async chunk iterator.
170
+ """
171
+ messages, format, prompt = await asyncio.to_thread(
172
+ self.encoder.encode_request, messages, tools, format
173
+ )
174
+ seed, kwargs = _utils.encode_params(params)
175
+
176
+ chunk_iterator = self._stream_generate_async(prompt, seed, **kwargs)
177
+ return messages, format, chunk_iterator
178
+
179
+ def generate(
180
+ self,
181
+ messages: Sequence[Message],
182
+ tools: Sequence[AnyToolSchema] | BaseToolkit[AnyToolSchema] | None,
183
+ format: type[FormattableT] | Format[FormattableT] | None,
184
+ params: Params,
185
+ ) -> tuple[
186
+ Sequence[Message],
187
+ Format[FormattableT] | None,
188
+ AssistantMessage,
189
+ GenerationResponse | None,
190
+ ]:
191
+ """Generate a response synchronously.
192
+
193
+ Args:
194
+ messages: The input messages.
195
+ tools: Optional tools for the model.
196
+ format: Optional response format.
197
+ params: Generation parameters.
198
+ Returns:
199
+ Tuple of messages, format, assistant message, and last generation response.
200
+ """
201
+ messages, format, prompt = self.encoder.encode_request(messages, tools, format)
202
+ seed, kwargs = _utils.encode_params(params)
203
+
204
+ stream = self._stream_generate(prompt, seed, **kwargs)
205
+ assistant_content, last_response = self.encoder.decode_response(stream)
206
+ assistant_message = assistant(
207
+ content=assistant_content,
208
+ model_id=self.model_id,
209
+ provider_id="mlx",
210
+ raw_message=None,
211
+ name=None,
212
+ )
213
+ return messages, format, assistant_message, last_response
214
+
215
+ async def generate_async(
216
+ self,
217
+ messages: Sequence[Message],
218
+ tools: Sequence[AnyToolSchema] | BaseToolkit[AnyToolSchema] | None,
219
+ format: type[FormattableT] | Format[FormattableT] | None,
220
+ params: Params,
221
+ ) -> tuple[
222
+ Sequence[Message],
223
+ Format[FormattableT] | None,
224
+ AssistantMessage,
225
+ GenerationResponse | None,
226
+ ]:
227
+ """Generate a response asynchronously.
228
+
229
+ Args:
230
+ messages: The input messages.
231
+ tools: Optional tools for the model.
232
+ format: Optional response format.
233
+ params: Generation parameters.
234
+ Returns:
235
+ Tuple of messages, format, assistant message, and last generation response.
236
+ """
237
+ return await asyncio.to_thread(self.generate, messages, tools, format, params)
@@ -0,0 +1,17 @@
1
+ from typing import TypeAlias
2
+
3
+ # TODO: Add more explicit literals
4
+ # TODO: Ensure automatic model downloads are supported.
5
+ # TODO: Ensure instructions are clear for examples that run as copied
6
+ MLXModelId: TypeAlias = str
7
+ """The identifier of the MLX model to be loaded by the MLX client.
8
+
9
+ An MLX model identifier might be a local path to a model's file, or a huggingface
10
+ repository such as:
11
+ - "mlx-community/Qwen3-8B-4bit-DWQ-053125"
12
+ - "mlx-community/gpt-oss-20b-MXFP4-Q8"
13
+
14
+ For more details, see:
15
+ - https://github.com/ml-explore/mlx-lm/?tab=readme-ov-file#supported-models
16
+ - https://huggingface.co/mlx-community
17
+ """