prompture 0.0.40.dev1__py3-none-any.whl → 0.0.43.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prompture/_version.py +2 -2
- prompture/agent.py +11 -11
- prompture/async_agent.py +11 -11
- prompture/async_groups.py +63 -0
- prompture/cost_mixin.py +25 -0
- prompture/discovery.py +20 -0
- prompture/drivers/__init__.py +39 -0
- prompture/drivers/async_azure_driver.py +3 -2
- prompture/drivers/async_modelscope_driver.py +286 -0
- prompture/drivers/async_moonshot_driver.py +312 -0
- prompture/drivers/async_openai_driver.py +3 -2
- prompture/drivers/async_openrouter_driver.py +192 -3
- prompture/drivers/async_registry.py +30 -0
- prompture/drivers/async_zai_driver.py +303 -0
- prompture/drivers/azure_driver.py +3 -2
- prompture/drivers/modelscope_driver.py +303 -0
- prompture/drivers/moonshot_driver.py +342 -0
- prompture/drivers/openai_driver.py +3 -2
- prompture/drivers/openrouter_driver.py +244 -40
- prompture/drivers/zai_driver.py +318 -0
- prompture/groups.py +42 -0
- prompture/model_rates.py +2 -0
- prompture/settings.py +16 -1
- {prompture-0.0.40.dev1.dist-info → prompture-0.0.43.dev1.dist-info}/METADATA +1 -1
- {prompture-0.0.40.dev1.dist-info → prompture-0.0.43.dev1.dist-info}/RECORD +29 -23
- {prompture-0.0.40.dev1.dist-info → prompture-0.0.43.dev1.dist-info}/WHEEL +0 -0
- {prompture-0.0.40.dev1.dist-info → prompture-0.0.43.dev1.dist-info}/entry_points.txt +0 -0
- {prompture-0.0.40.dev1.dist-info → prompture-0.0.43.dev1.dist-info}/licenses/LICENSE +0 -0
- {prompture-0.0.40.dev1.dist-info → prompture-0.0.43.dev1.dist-info}/top_level.txt +0 -0
prompture/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.0.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 0,
|
|
31
|
+
__version__ = version = '0.0.43.dev1'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 0, 43, 'dev1')
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
prompture/agent.py
CHANGED
|
@@ -188,7 +188,7 @@ class Agent(Generic[DepsType]):
|
|
|
188
188
|
for fn in tools:
|
|
189
189
|
self._tools.register(fn)
|
|
190
190
|
|
|
191
|
-
self.
|
|
191
|
+
self._lifecycle = AgentState.idle
|
|
192
192
|
self._stop_requested = False
|
|
193
193
|
|
|
194
194
|
# ------------------------------------------------------------------
|
|
@@ -206,7 +206,7 @@ class Agent(Generic[DepsType]):
|
|
|
206
206
|
@property
|
|
207
207
|
def state(self) -> AgentState:
|
|
208
208
|
"""Current lifecycle state of the agent."""
|
|
209
|
-
return self.
|
|
209
|
+
return self._lifecycle
|
|
210
210
|
|
|
211
211
|
def stop(self) -> None:
|
|
212
212
|
"""Request graceful shutdown after the current iteration."""
|
|
@@ -265,16 +265,16 @@ class Agent(Generic[DepsType]):
|
|
|
265
265
|
prompt: The user prompt to send.
|
|
266
266
|
deps: Optional dependencies injected into :class:`RunContext`.
|
|
267
267
|
"""
|
|
268
|
-
self.
|
|
268
|
+
self._lifecycle = AgentState.running
|
|
269
269
|
self._stop_requested = False
|
|
270
270
|
steps: list[AgentStep] = []
|
|
271
271
|
|
|
272
272
|
try:
|
|
273
273
|
result = self._execute(prompt, steps, deps)
|
|
274
|
-
self.
|
|
274
|
+
self._lifecycle = AgentState.idle
|
|
275
275
|
return result
|
|
276
276
|
except Exception:
|
|
277
|
-
self.
|
|
277
|
+
self._lifecycle = AgentState.errored
|
|
278
278
|
raise
|
|
279
279
|
|
|
280
280
|
# ------------------------------------------------------------------
|
|
@@ -722,7 +722,7 @@ class Agent(Generic[DepsType]):
|
|
|
722
722
|
|
|
723
723
|
def _execute_iter(self, prompt: str, deps: Any) -> Generator[AgentStep, None, AgentResult]:
|
|
724
724
|
"""Generator that executes the agent loop and yields each step."""
|
|
725
|
-
self.
|
|
725
|
+
self._lifecycle = AgentState.running
|
|
726
726
|
self._stop_requested = False
|
|
727
727
|
steps: list[AgentStep] = []
|
|
728
728
|
|
|
@@ -730,10 +730,10 @@ class Agent(Generic[DepsType]):
|
|
|
730
730
|
result = self._execute(prompt, steps, deps)
|
|
731
731
|
# Yield each step one at a time
|
|
732
732
|
yield from result.steps
|
|
733
|
-
self.
|
|
733
|
+
self._lifecycle = AgentState.idle
|
|
734
734
|
return result
|
|
735
735
|
except Exception:
|
|
736
|
-
self.
|
|
736
|
+
self._lifecycle = AgentState.errored
|
|
737
737
|
raise
|
|
738
738
|
|
|
739
739
|
# ------------------------------------------------------------------
|
|
@@ -757,7 +757,7 @@ class Agent(Generic[DepsType]):
|
|
|
757
757
|
|
|
758
758
|
def _execute_stream(self, prompt: str, deps: Any) -> Generator[StreamEvent, None, AgentResult]:
|
|
759
759
|
"""Generator that executes the agent loop and yields stream events."""
|
|
760
|
-
self.
|
|
760
|
+
self._lifecycle = AgentState.running
|
|
761
761
|
self._stop_requested = False
|
|
762
762
|
steps: list[AgentStep] = []
|
|
763
763
|
|
|
@@ -853,10 +853,10 @@ class Agent(Generic[DepsType]):
|
|
|
853
853
|
data=result,
|
|
854
854
|
)
|
|
855
855
|
|
|
856
|
-
self.
|
|
856
|
+
self._lifecycle = AgentState.idle
|
|
857
857
|
return result
|
|
858
858
|
except Exception:
|
|
859
|
-
self.
|
|
859
|
+
self._lifecycle = AgentState.errored
|
|
860
860
|
raise
|
|
861
861
|
|
|
862
862
|
|
prompture/async_agent.py
CHANGED
|
@@ -182,7 +182,7 @@ class AsyncAgent(Generic[DepsType]):
|
|
|
182
182
|
for fn in tools:
|
|
183
183
|
self._tools.register(fn)
|
|
184
184
|
|
|
185
|
-
self.
|
|
185
|
+
self._lifecycle = AgentState.idle
|
|
186
186
|
self._stop_requested = False
|
|
187
187
|
|
|
188
188
|
# ------------------------------------------------------------------
|
|
@@ -197,7 +197,7 @@ class AsyncAgent(Generic[DepsType]):
|
|
|
197
197
|
@property
|
|
198
198
|
def state(self) -> AgentState:
|
|
199
199
|
"""Current lifecycle state of the agent."""
|
|
200
|
-
return self.
|
|
200
|
+
return self._lifecycle
|
|
201
201
|
|
|
202
202
|
def stop(self) -> None:
|
|
203
203
|
"""Request graceful shutdown after the current iteration."""
|
|
@@ -264,16 +264,16 @@ class AsyncAgent(Generic[DepsType]):
|
|
|
264
264
|
Creates a fresh conversation, sends the prompt, handles tool calls,
|
|
265
265
|
and optionally parses the final response into ``output_type``.
|
|
266
266
|
"""
|
|
267
|
-
self.
|
|
267
|
+
self._lifecycle = AgentState.running
|
|
268
268
|
self._stop_requested = False
|
|
269
269
|
steps: list[AgentStep] = []
|
|
270
270
|
|
|
271
271
|
try:
|
|
272
272
|
result = await self._execute(prompt, steps, deps)
|
|
273
|
-
self.
|
|
273
|
+
self._lifecycle = AgentState.idle
|
|
274
274
|
return result
|
|
275
275
|
except Exception:
|
|
276
|
-
self.
|
|
276
|
+
self._lifecycle = AgentState.errored
|
|
277
277
|
raise
|
|
278
278
|
|
|
279
279
|
async def iter(self, prompt: str, *, deps: Any = None) -> AsyncAgentIterator:
|
|
@@ -714,7 +714,7 @@ class AsyncAgent(Generic[DepsType]):
|
|
|
714
714
|
|
|
715
715
|
async def _execute_iter(self, prompt: str, deps: Any) -> AsyncGenerator[AgentStep, None]:
|
|
716
716
|
"""Async generator that executes the agent loop and yields each step."""
|
|
717
|
-
self.
|
|
717
|
+
self._lifecycle = AgentState.running
|
|
718
718
|
self._stop_requested = False
|
|
719
719
|
steps: list[AgentStep] = []
|
|
720
720
|
|
|
@@ -722,11 +722,11 @@ class AsyncAgent(Generic[DepsType]):
|
|
|
722
722
|
result = await self._execute(prompt, steps, deps)
|
|
723
723
|
for step in result.steps:
|
|
724
724
|
yield step
|
|
725
|
-
self.
|
|
725
|
+
self._lifecycle = AgentState.idle
|
|
726
726
|
# Store result on the generator for retrieval
|
|
727
727
|
self._last_iter_result = result
|
|
728
728
|
except Exception:
|
|
729
|
-
self.
|
|
729
|
+
self._lifecycle = AgentState.errored
|
|
730
730
|
raise
|
|
731
731
|
|
|
732
732
|
# ------------------------------------------------------------------
|
|
@@ -735,7 +735,7 @@ class AsyncAgent(Generic[DepsType]):
|
|
|
735
735
|
|
|
736
736
|
async def _execute_stream(self, prompt: str, deps: Any) -> AsyncGenerator[StreamEvent, None]:
|
|
737
737
|
"""Async generator that executes the agent loop and yields stream events."""
|
|
738
|
-
self.
|
|
738
|
+
self._lifecycle = AgentState.running
|
|
739
739
|
self._stop_requested = False
|
|
740
740
|
steps: list[AgentStep] = []
|
|
741
741
|
|
|
@@ -803,10 +803,10 @@ class AsyncAgent(Generic[DepsType]):
|
|
|
803
803
|
|
|
804
804
|
yield StreamEvent(event_type=StreamEventType.output, data=result)
|
|
805
805
|
|
|
806
|
-
self.
|
|
806
|
+
self._lifecycle = AgentState.idle
|
|
807
807
|
self._last_stream_result = result
|
|
808
808
|
except Exception:
|
|
809
|
-
self.
|
|
809
|
+
self._lifecycle = AgentState.errored
|
|
810
810
|
raise
|
|
811
811
|
|
|
812
812
|
|
prompture/async_groups.py
CHANGED
|
@@ -70,6 +70,27 @@ class ParallelGroup:
|
|
|
70
70
|
"""Request graceful shutdown."""
|
|
71
71
|
self._stop_requested = True
|
|
72
72
|
|
|
73
|
+
@property
|
|
74
|
+
def shared_state(self) -> dict[str, Any]:
|
|
75
|
+
"""Return a copy of the current shared execution state."""
|
|
76
|
+
return dict(self._state)
|
|
77
|
+
|
|
78
|
+
def inject_state(self, state: dict[str, Any], *, recursive: bool = False) -> None:
|
|
79
|
+
"""Merge external key-value pairs into this group's shared state.
|
|
80
|
+
|
|
81
|
+
Existing keys are NOT overwritten (uses setdefault semantics).
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
state: Key-value pairs to inject.
|
|
85
|
+
recursive: If True, also inject into nested sub-groups.
|
|
86
|
+
"""
|
|
87
|
+
for k, v in state.items():
|
|
88
|
+
self._state.setdefault(k, v)
|
|
89
|
+
if recursive:
|
|
90
|
+
for agent, _ in self._agents:
|
|
91
|
+
if hasattr(agent, "inject_state"):
|
|
92
|
+
agent.inject_state(state, recursive=True)
|
|
93
|
+
|
|
73
94
|
async def run_async(self, prompt: str = "") -> GroupResult:
|
|
74
95
|
"""Execute all agents concurrently."""
|
|
75
96
|
self._stop_requested = False
|
|
@@ -213,6 +234,27 @@ class AsyncSequentialGroup:
|
|
|
213
234
|
def stop(self) -> None:
|
|
214
235
|
self._stop_requested = True
|
|
215
236
|
|
|
237
|
+
@property
|
|
238
|
+
def shared_state(self) -> dict[str, Any]:
|
|
239
|
+
"""Return a copy of the current shared execution state."""
|
|
240
|
+
return dict(self._state)
|
|
241
|
+
|
|
242
|
+
def inject_state(self, state: dict[str, Any], *, recursive: bool = False) -> None:
|
|
243
|
+
"""Merge external key-value pairs into this group's shared state.
|
|
244
|
+
|
|
245
|
+
Existing keys are NOT overwritten (uses setdefault semantics).
|
|
246
|
+
|
|
247
|
+
Args:
|
|
248
|
+
state: Key-value pairs to inject.
|
|
249
|
+
recursive: If True, also inject into nested sub-groups.
|
|
250
|
+
"""
|
|
251
|
+
for k, v in state.items():
|
|
252
|
+
self._state.setdefault(k, v)
|
|
253
|
+
if recursive:
|
|
254
|
+
for agent, _ in self._agents:
|
|
255
|
+
if hasattr(agent, "inject_state"):
|
|
256
|
+
agent.inject_state(state, recursive=True)
|
|
257
|
+
|
|
216
258
|
async def run(self, prompt: str = "") -> GroupResult:
|
|
217
259
|
"""Execute all agents in sequence (async)."""
|
|
218
260
|
self._stop_requested = False
|
|
@@ -351,6 +393,27 @@ class AsyncLoopGroup:
|
|
|
351
393
|
def stop(self) -> None:
|
|
352
394
|
self._stop_requested = True
|
|
353
395
|
|
|
396
|
+
@property
|
|
397
|
+
def shared_state(self) -> dict[str, Any]:
|
|
398
|
+
"""Return a copy of the current shared execution state."""
|
|
399
|
+
return dict(self._state)
|
|
400
|
+
|
|
401
|
+
def inject_state(self, state: dict[str, Any], *, recursive: bool = False) -> None:
|
|
402
|
+
"""Merge external key-value pairs into this group's shared state.
|
|
403
|
+
|
|
404
|
+
Existing keys are NOT overwritten (uses setdefault semantics).
|
|
405
|
+
|
|
406
|
+
Args:
|
|
407
|
+
state: Key-value pairs to inject.
|
|
408
|
+
recursive: If True, also inject into nested sub-groups.
|
|
409
|
+
"""
|
|
410
|
+
for k, v in state.items():
|
|
411
|
+
self._state.setdefault(k, v)
|
|
412
|
+
if recursive:
|
|
413
|
+
for agent, _ in self._agents:
|
|
414
|
+
if hasattr(agent, "inject_state"):
|
|
415
|
+
agent.inject_state(state, recursive=True)
|
|
416
|
+
|
|
354
417
|
async def run(self, prompt: str = "") -> GroupResult:
|
|
355
418
|
"""Execute the loop (async)."""
|
|
356
419
|
self._stop_requested = False
|
prompture/cost_mixin.py
CHANGED
|
@@ -2,9 +2,34 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
+
import copy
|
|
5
6
|
from typing import Any
|
|
6
7
|
|
|
7
8
|
|
|
9
|
+
def prepare_strict_schema(schema: dict[str, Any]) -> dict[str, Any]:
|
|
10
|
+
"""Prepare a JSON schema for OpenAI strict structured-output mode.
|
|
11
|
+
|
|
12
|
+
OpenAI's ``strict: true`` requires every object to have
|
|
13
|
+
``"additionalProperties": false`` and a ``"required"`` array listing
|
|
14
|
+
all property keys. This function recursively patches a schema copy
|
|
15
|
+
so callers don't need to worry about these constraints.
|
|
16
|
+
"""
|
|
17
|
+
schema = copy.deepcopy(schema)
|
|
18
|
+
_patch_strict(schema)
|
|
19
|
+
return schema
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _patch_strict(node: dict[str, Any]) -> None:
|
|
23
|
+
"""Recursively add strict-mode constraints to an object schema node."""
|
|
24
|
+
if node.get("type") == "object" and "properties" in node:
|
|
25
|
+
node.setdefault("additionalProperties", False)
|
|
26
|
+
node.setdefault("required", list(node["properties"].keys()))
|
|
27
|
+
for prop in node["properties"].values():
|
|
28
|
+
_patch_strict(prop)
|
|
29
|
+
elif node.get("type") == "array" and isinstance(node.get("items"), dict):
|
|
30
|
+
_patch_strict(node["items"])
|
|
31
|
+
|
|
32
|
+
|
|
8
33
|
class CostMixin:
|
|
9
34
|
"""Mixin that provides ``_calculate_cost`` to sync and async drivers.
|
|
10
35
|
|
prompture/discovery.py
CHANGED
|
@@ -10,6 +10,7 @@ from typing import Any, overload
|
|
|
10
10
|
import requests
|
|
11
11
|
|
|
12
12
|
from .drivers import (
|
|
13
|
+
AirLLMDriver,
|
|
13
14
|
AzureDriver,
|
|
14
15
|
ClaudeDriver,
|
|
15
16
|
GoogleDriver,
|
|
@@ -17,9 +18,12 @@ from .drivers import (
|
|
|
17
18
|
GroqDriver,
|
|
18
19
|
LMStudioDriver,
|
|
19
20
|
LocalHTTPDriver,
|
|
21
|
+
ModelScopeDriver,
|
|
22
|
+
MoonshotDriver,
|
|
20
23
|
OllamaDriver,
|
|
21
24
|
OpenAIDriver,
|
|
22
25
|
OpenRouterDriver,
|
|
26
|
+
ZaiDriver,
|
|
23
27
|
)
|
|
24
28
|
from .settings import settings
|
|
25
29
|
|
|
@@ -71,6 +75,10 @@ def get_available_models(
|
|
|
71
75
|
"ollama": OllamaDriver,
|
|
72
76
|
"lmstudio": LMStudioDriver,
|
|
73
77
|
"local_http": LocalHTTPDriver,
|
|
78
|
+
"moonshot": MoonshotDriver,
|
|
79
|
+
"zai": ZaiDriver,
|
|
80
|
+
"modelscope": ModelScopeDriver,
|
|
81
|
+
"airllm": AirLLMDriver,
|
|
74
82
|
}
|
|
75
83
|
|
|
76
84
|
for provider, driver_cls in provider_classes.items():
|
|
@@ -102,6 +110,18 @@ def get_available_models(
|
|
|
102
110
|
elif provider == "grok":
|
|
103
111
|
if settings.grok_api_key or os.getenv("GROK_API_KEY"):
|
|
104
112
|
is_configured = True
|
|
113
|
+
elif provider == "moonshot":
|
|
114
|
+
if settings.moonshot_api_key or os.getenv("MOONSHOT_API_KEY"):
|
|
115
|
+
is_configured = True
|
|
116
|
+
elif provider == "zai":
|
|
117
|
+
if settings.zhipu_api_key or os.getenv("ZHIPU_API_KEY"):
|
|
118
|
+
is_configured = True
|
|
119
|
+
elif provider == "modelscope":
|
|
120
|
+
if settings.modelscope_api_key or os.getenv("MODELSCOPE_API_KEY"):
|
|
121
|
+
is_configured = True
|
|
122
|
+
elif provider == "airllm":
|
|
123
|
+
# AirLLM runs locally, always considered configured
|
|
124
|
+
is_configured = True
|
|
105
125
|
elif (
|
|
106
126
|
provider == "ollama"
|
|
107
127
|
or provider == "lmstudio"
|
prompture/drivers/__init__.py
CHANGED
|
@@ -37,10 +37,13 @@ from .async_groq_driver import AsyncGroqDriver
|
|
|
37
37
|
from .async_hugging_driver import AsyncHuggingFaceDriver
|
|
38
38
|
from .async_lmstudio_driver import AsyncLMStudioDriver
|
|
39
39
|
from .async_local_http_driver import AsyncLocalHTTPDriver
|
|
40
|
+
from .async_modelscope_driver import AsyncModelScopeDriver
|
|
41
|
+
from .async_moonshot_driver import AsyncMoonshotDriver
|
|
40
42
|
from .async_ollama_driver import AsyncOllamaDriver
|
|
41
43
|
from .async_openai_driver import AsyncOpenAIDriver
|
|
42
44
|
from .async_openrouter_driver import AsyncOpenRouterDriver
|
|
43
45
|
from .async_registry import ASYNC_DRIVER_REGISTRY, get_async_driver, get_async_driver_for_model
|
|
46
|
+
from .async_zai_driver import AsyncZaiDriver
|
|
44
47
|
from .azure_driver import AzureDriver
|
|
45
48
|
from .claude_driver import ClaudeDriver
|
|
46
49
|
from .google_driver import GoogleDriver
|
|
@@ -48,6 +51,8 @@ from .grok_driver import GrokDriver
|
|
|
48
51
|
from .groq_driver import GroqDriver
|
|
49
52
|
from .lmstudio_driver import LMStudioDriver
|
|
50
53
|
from .local_http_driver import LocalHTTPDriver
|
|
54
|
+
from .modelscope_driver import ModelScopeDriver
|
|
55
|
+
from .moonshot_driver import MoonshotDriver
|
|
51
56
|
from .ollama_driver import OllamaDriver
|
|
52
57
|
from .openai_driver import OpenAIDriver
|
|
53
58
|
from .openrouter_driver import OpenRouterDriver
|
|
@@ -65,6 +70,7 @@ from .registry import (
|
|
|
65
70
|
unregister_async_driver,
|
|
66
71
|
unregister_driver,
|
|
67
72
|
)
|
|
73
|
+
from .zai_driver import ZaiDriver
|
|
68
74
|
|
|
69
75
|
# Register built-in sync drivers
|
|
70
76
|
register_driver(
|
|
@@ -123,6 +129,33 @@ register_driver(
|
|
|
123
129
|
lambda model=None: GrokDriver(api_key=settings.grok_api_key, model=model or settings.grok_model),
|
|
124
130
|
overwrite=True,
|
|
125
131
|
)
|
|
132
|
+
register_driver(
|
|
133
|
+
"moonshot",
|
|
134
|
+
lambda model=None: MoonshotDriver(
|
|
135
|
+
api_key=settings.moonshot_api_key,
|
|
136
|
+
model=model or settings.moonshot_model,
|
|
137
|
+
endpoint=settings.moonshot_endpoint,
|
|
138
|
+
),
|
|
139
|
+
overwrite=True,
|
|
140
|
+
)
|
|
141
|
+
register_driver(
|
|
142
|
+
"modelscope",
|
|
143
|
+
lambda model=None: ModelScopeDriver(
|
|
144
|
+
api_key=settings.modelscope_api_key,
|
|
145
|
+
model=model or settings.modelscope_model,
|
|
146
|
+
endpoint=settings.modelscope_endpoint,
|
|
147
|
+
),
|
|
148
|
+
overwrite=True,
|
|
149
|
+
)
|
|
150
|
+
register_driver(
|
|
151
|
+
"zai",
|
|
152
|
+
lambda model=None: ZaiDriver(
|
|
153
|
+
api_key=settings.zhipu_api_key,
|
|
154
|
+
model=model or settings.zhipu_model,
|
|
155
|
+
endpoint=settings.zhipu_endpoint,
|
|
156
|
+
),
|
|
157
|
+
overwrite=True,
|
|
158
|
+
)
|
|
126
159
|
register_driver(
|
|
127
160
|
"airllm",
|
|
128
161
|
lambda model=None: AirLLMDriver(
|
|
@@ -197,9 +230,12 @@ __all__ = [
|
|
|
197
230
|
"AsyncHuggingFaceDriver",
|
|
198
231
|
"AsyncLMStudioDriver",
|
|
199
232
|
"AsyncLocalHTTPDriver",
|
|
233
|
+
"AsyncModelScopeDriver",
|
|
234
|
+
"AsyncMoonshotDriver",
|
|
200
235
|
"AsyncOllamaDriver",
|
|
201
236
|
"AsyncOpenAIDriver",
|
|
202
237
|
"AsyncOpenRouterDriver",
|
|
238
|
+
"AsyncZaiDriver",
|
|
203
239
|
"AzureDriver",
|
|
204
240
|
"ClaudeDriver",
|
|
205
241
|
"GoogleDriver",
|
|
@@ -207,9 +243,12 @@ __all__ = [
|
|
|
207
243
|
"GroqDriver",
|
|
208
244
|
"LMStudioDriver",
|
|
209
245
|
"LocalHTTPDriver",
|
|
246
|
+
"ModelScopeDriver",
|
|
247
|
+
"MoonshotDriver",
|
|
210
248
|
"OllamaDriver",
|
|
211
249
|
"OpenAIDriver",
|
|
212
250
|
"OpenRouterDriver",
|
|
251
|
+
"ZaiDriver",
|
|
213
252
|
"get_async_driver",
|
|
214
253
|
"get_async_driver_for_model",
|
|
215
254
|
# Factory functions
|
|
@@ -11,7 +11,7 @@ except Exception:
|
|
|
11
11
|
AsyncAzureOpenAI = None
|
|
12
12
|
|
|
13
13
|
from ..async_driver import AsyncDriver
|
|
14
|
-
from ..cost_mixin import CostMixin
|
|
14
|
+
from ..cost_mixin import CostMixin, prepare_strict_schema
|
|
15
15
|
from .azure_driver import AzureDriver
|
|
16
16
|
|
|
17
17
|
|
|
@@ -89,12 +89,13 @@ class AsyncAzureDriver(CostMixin, AsyncDriver):
|
|
|
89
89
|
if options.get("json_mode"):
|
|
90
90
|
json_schema = options.get("json_schema")
|
|
91
91
|
if json_schema:
|
|
92
|
+
schema_copy = prepare_strict_schema(json_schema)
|
|
92
93
|
kwargs["response_format"] = {
|
|
93
94
|
"type": "json_schema",
|
|
94
95
|
"json_schema": {
|
|
95
96
|
"name": "extraction",
|
|
96
97
|
"strict": True,
|
|
97
|
-
"schema":
|
|
98
|
+
"schema": schema_copy,
|
|
98
99
|
},
|
|
99
100
|
}
|
|
100
101
|
else:
|