prompture 0.0.40.dev1__py3-none-any.whl → 0.0.43.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
prompture/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.0.40.dev1'
32
- __version_tuple__ = version_tuple = (0, 0, 40, 'dev1')
31
+ __version__ = version = '0.0.43.dev1'
32
+ __version_tuple__ = version_tuple = (0, 0, 43, 'dev1')
33
33
 
34
34
  __commit_id__ = commit_id = None
prompture/agent.py CHANGED
@@ -188,7 +188,7 @@ class Agent(Generic[DepsType]):
188
188
  for fn in tools:
189
189
  self._tools.register(fn)
190
190
 
191
- self._state = AgentState.idle
191
+ self._lifecycle = AgentState.idle
192
192
  self._stop_requested = False
193
193
 
194
194
  # ------------------------------------------------------------------
@@ -206,7 +206,7 @@ class Agent(Generic[DepsType]):
206
206
  @property
207
207
  def state(self) -> AgentState:
208
208
  """Current lifecycle state of the agent."""
209
- return self._state
209
+ return self._lifecycle
210
210
 
211
211
  def stop(self) -> None:
212
212
  """Request graceful shutdown after the current iteration."""
@@ -265,16 +265,16 @@ class Agent(Generic[DepsType]):
265
265
  prompt: The user prompt to send.
266
266
  deps: Optional dependencies injected into :class:`RunContext`.
267
267
  """
268
- self._state = AgentState.running
268
+ self._lifecycle = AgentState.running
269
269
  self._stop_requested = False
270
270
  steps: list[AgentStep] = []
271
271
 
272
272
  try:
273
273
  result = self._execute(prompt, steps, deps)
274
- self._state = AgentState.idle
274
+ self._lifecycle = AgentState.idle
275
275
  return result
276
276
  except Exception:
277
- self._state = AgentState.errored
277
+ self._lifecycle = AgentState.errored
278
278
  raise
279
279
 
280
280
  # ------------------------------------------------------------------
@@ -722,7 +722,7 @@ class Agent(Generic[DepsType]):
722
722
 
723
723
  def _execute_iter(self, prompt: str, deps: Any) -> Generator[AgentStep, None, AgentResult]:
724
724
  """Generator that executes the agent loop and yields each step."""
725
- self._state = AgentState.running
725
+ self._lifecycle = AgentState.running
726
726
  self._stop_requested = False
727
727
  steps: list[AgentStep] = []
728
728
 
@@ -730,10 +730,10 @@ class Agent(Generic[DepsType]):
730
730
  result = self._execute(prompt, steps, deps)
731
731
  # Yield each step one at a time
732
732
  yield from result.steps
733
- self._state = AgentState.idle
733
+ self._lifecycle = AgentState.idle
734
734
  return result
735
735
  except Exception:
736
- self._state = AgentState.errored
736
+ self._lifecycle = AgentState.errored
737
737
  raise
738
738
 
739
739
  # ------------------------------------------------------------------
@@ -757,7 +757,7 @@ class Agent(Generic[DepsType]):
757
757
 
758
758
  def _execute_stream(self, prompt: str, deps: Any) -> Generator[StreamEvent, None, AgentResult]:
759
759
  """Generator that executes the agent loop and yields stream events."""
760
- self._state = AgentState.running
760
+ self._lifecycle = AgentState.running
761
761
  self._stop_requested = False
762
762
  steps: list[AgentStep] = []
763
763
 
@@ -853,10 +853,10 @@ class Agent(Generic[DepsType]):
853
853
  data=result,
854
854
  )
855
855
 
856
- self._state = AgentState.idle
856
+ self._lifecycle = AgentState.idle
857
857
  return result
858
858
  except Exception:
859
- self._state = AgentState.errored
859
+ self._lifecycle = AgentState.errored
860
860
  raise
861
861
 
862
862
 
prompture/async_agent.py CHANGED
@@ -182,7 +182,7 @@ class AsyncAgent(Generic[DepsType]):
182
182
  for fn in tools:
183
183
  self._tools.register(fn)
184
184
 
185
- self._state = AgentState.idle
185
+ self._lifecycle = AgentState.idle
186
186
  self._stop_requested = False
187
187
 
188
188
  # ------------------------------------------------------------------
@@ -197,7 +197,7 @@ class AsyncAgent(Generic[DepsType]):
197
197
  @property
198
198
  def state(self) -> AgentState:
199
199
  """Current lifecycle state of the agent."""
200
- return self._state
200
+ return self._lifecycle
201
201
 
202
202
  def stop(self) -> None:
203
203
  """Request graceful shutdown after the current iteration."""
@@ -264,16 +264,16 @@ class AsyncAgent(Generic[DepsType]):
264
264
  Creates a fresh conversation, sends the prompt, handles tool calls,
265
265
  and optionally parses the final response into ``output_type``.
266
266
  """
267
- self._state = AgentState.running
267
+ self._lifecycle = AgentState.running
268
268
  self._stop_requested = False
269
269
  steps: list[AgentStep] = []
270
270
 
271
271
  try:
272
272
  result = await self._execute(prompt, steps, deps)
273
- self._state = AgentState.idle
273
+ self._lifecycle = AgentState.idle
274
274
  return result
275
275
  except Exception:
276
- self._state = AgentState.errored
276
+ self._lifecycle = AgentState.errored
277
277
  raise
278
278
 
279
279
  async def iter(self, prompt: str, *, deps: Any = None) -> AsyncAgentIterator:
@@ -714,7 +714,7 @@ class AsyncAgent(Generic[DepsType]):
714
714
 
715
715
  async def _execute_iter(self, prompt: str, deps: Any) -> AsyncGenerator[AgentStep, None]:
716
716
  """Async generator that executes the agent loop and yields each step."""
717
- self._state = AgentState.running
717
+ self._lifecycle = AgentState.running
718
718
  self._stop_requested = False
719
719
  steps: list[AgentStep] = []
720
720
 
@@ -722,11 +722,11 @@ class AsyncAgent(Generic[DepsType]):
722
722
  result = await self._execute(prompt, steps, deps)
723
723
  for step in result.steps:
724
724
  yield step
725
- self._state = AgentState.idle
725
+ self._lifecycle = AgentState.idle
726
726
  # Store result on the generator for retrieval
727
727
  self._last_iter_result = result
728
728
  except Exception:
729
- self._state = AgentState.errored
729
+ self._lifecycle = AgentState.errored
730
730
  raise
731
731
 
732
732
  # ------------------------------------------------------------------
@@ -735,7 +735,7 @@ class AsyncAgent(Generic[DepsType]):
735
735
 
736
736
  async def _execute_stream(self, prompt: str, deps: Any) -> AsyncGenerator[StreamEvent, None]:
737
737
  """Async generator that executes the agent loop and yields stream events."""
738
- self._state = AgentState.running
738
+ self._lifecycle = AgentState.running
739
739
  self._stop_requested = False
740
740
  steps: list[AgentStep] = []
741
741
 
@@ -803,10 +803,10 @@ class AsyncAgent(Generic[DepsType]):
803
803
 
804
804
  yield StreamEvent(event_type=StreamEventType.output, data=result)
805
805
 
806
- self._state = AgentState.idle
806
+ self._lifecycle = AgentState.idle
807
807
  self._last_stream_result = result
808
808
  except Exception:
809
- self._state = AgentState.errored
809
+ self._lifecycle = AgentState.errored
810
810
  raise
811
811
 
812
812
 
prompture/async_groups.py CHANGED
@@ -70,6 +70,27 @@ class ParallelGroup:
70
70
  """Request graceful shutdown."""
71
71
  self._stop_requested = True
72
72
 
73
+ @property
74
+ def shared_state(self) -> dict[str, Any]:
75
+ """Return a copy of the current shared execution state."""
76
+ return dict(self._state)
77
+
78
+ def inject_state(self, state: dict[str, Any], *, recursive: bool = False) -> None:
79
+ """Merge external key-value pairs into this group's shared state.
80
+
81
+ Existing keys are NOT overwritten (uses setdefault semantics).
82
+
83
+ Args:
84
+ state: Key-value pairs to inject.
85
+ recursive: If True, also inject into nested sub-groups.
86
+ """
87
+ for k, v in state.items():
88
+ self._state.setdefault(k, v)
89
+ if recursive:
90
+ for agent, _ in self._agents:
91
+ if hasattr(agent, "inject_state"):
92
+ agent.inject_state(state, recursive=True)
93
+
73
94
  async def run_async(self, prompt: str = "") -> GroupResult:
74
95
  """Execute all agents concurrently."""
75
96
  self._stop_requested = False
@@ -213,6 +234,27 @@ class AsyncSequentialGroup:
213
234
  def stop(self) -> None:
214
235
  self._stop_requested = True
215
236
 
237
+ @property
238
+ def shared_state(self) -> dict[str, Any]:
239
+ """Return a copy of the current shared execution state."""
240
+ return dict(self._state)
241
+
242
+ def inject_state(self, state: dict[str, Any], *, recursive: bool = False) -> None:
243
+ """Merge external key-value pairs into this group's shared state.
244
+
245
+ Existing keys are NOT overwritten (uses setdefault semantics).
246
+
247
+ Args:
248
+ state: Key-value pairs to inject.
249
+ recursive: If True, also inject into nested sub-groups.
250
+ """
251
+ for k, v in state.items():
252
+ self._state.setdefault(k, v)
253
+ if recursive:
254
+ for agent, _ in self._agents:
255
+ if hasattr(agent, "inject_state"):
256
+ agent.inject_state(state, recursive=True)
257
+
216
258
  async def run(self, prompt: str = "") -> GroupResult:
217
259
  """Execute all agents in sequence (async)."""
218
260
  self._stop_requested = False
@@ -351,6 +393,27 @@ class AsyncLoopGroup:
351
393
  def stop(self) -> None:
352
394
  self._stop_requested = True
353
395
 
396
+ @property
397
+ def shared_state(self) -> dict[str, Any]:
398
+ """Return a copy of the current shared execution state."""
399
+ return dict(self._state)
400
+
401
+ def inject_state(self, state: dict[str, Any], *, recursive: bool = False) -> None:
402
+ """Merge external key-value pairs into this group's shared state.
403
+
404
+ Existing keys are NOT overwritten (uses setdefault semantics).
405
+
406
+ Args:
407
+ state: Key-value pairs to inject.
408
+ recursive: If True, also inject into nested sub-groups.
409
+ """
410
+ for k, v in state.items():
411
+ self._state.setdefault(k, v)
412
+ if recursive:
413
+ for agent, _ in self._agents:
414
+ if hasattr(agent, "inject_state"):
415
+ agent.inject_state(state, recursive=True)
416
+
354
417
  async def run(self, prompt: str = "") -> GroupResult:
355
418
  """Execute the loop (async)."""
356
419
  self._stop_requested = False
prompture/cost_mixin.py CHANGED
@@ -2,9 +2,34 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ import copy
5
6
  from typing import Any
6
7
 
7
8
 
9
+ def prepare_strict_schema(schema: dict[str, Any]) -> dict[str, Any]:
10
+ """Prepare a JSON schema for OpenAI strict structured-output mode.
11
+
12
+ OpenAI's ``strict: true`` requires every object to have
13
+ ``"additionalProperties": false`` and a ``"required"`` array listing
14
+ all property keys. This function recursively patches a schema copy
15
+ so callers don't need to worry about these constraints.
16
+ """
17
+ schema = copy.deepcopy(schema)
18
+ _patch_strict(schema)
19
+ return schema
20
+
21
+
22
+ def _patch_strict(node: dict[str, Any]) -> None:
23
+ """Recursively add strict-mode constraints to an object schema node."""
24
+ if node.get("type") == "object" and "properties" in node:
25
+ node.setdefault("additionalProperties", False)
26
+ node.setdefault("required", list(node["properties"].keys()))
27
+ for prop in node["properties"].values():
28
+ _patch_strict(prop)
29
+ elif node.get("type") == "array" and isinstance(node.get("items"), dict):
30
+ _patch_strict(node["items"])
31
+
32
+
8
33
  class CostMixin:
9
34
  """Mixin that provides ``_calculate_cost`` to sync and async drivers.
10
35
 
prompture/discovery.py CHANGED
@@ -10,6 +10,7 @@ from typing import Any, overload
10
10
  import requests
11
11
 
12
12
  from .drivers import (
13
+ AirLLMDriver,
13
14
  AzureDriver,
14
15
  ClaudeDriver,
15
16
  GoogleDriver,
@@ -17,9 +18,12 @@ from .drivers import (
17
18
  GroqDriver,
18
19
  LMStudioDriver,
19
20
  LocalHTTPDriver,
21
+ ModelScopeDriver,
22
+ MoonshotDriver,
20
23
  OllamaDriver,
21
24
  OpenAIDriver,
22
25
  OpenRouterDriver,
26
+ ZaiDriver,
23
27
  )
24
28
  from .settings import settings
25
29
 
@@ -71,6 +75,10 @@ def get_available_models(
71
75
  "ollama": OllamaDriver,
72
76
  "lmstudio": LMStudioDriver,
73
77
  "local_http": LocalHTTPDriver,
78
+ "moonshot": MoonshotDriver,
79
+ "zai": ZaiDriver,
80
+ "modelscope": ModelScopeDriver,
81
+ "airllm": AirLLMDriver,
74
82
  }
75
83
 
76
84
  for provider, driver_cls in provider_classes.items():
@@ -102,6 +110,18 @@ def get_available_models(
102
110
  elif provider == "grok":
103
111
  if settings.grok_api_key or os.getenv("GROK_API_KEY"):
104
112
  is_configured = True
113
+ elif provider == "moonshot":
114
+ if settings.moonshot_api_key or os.getenv("MOONSHOT_API_KEY"):
115
+ is_configured = True
116
+ elif provider == "zai":
117
+ if settings.zhipu_api_key or os.getenv("ZHIPU_API_KEY"):
118
+ is_configured = True
119
+ elif provider == "modelscope":
120
+ if settings.modelscope_api_key or os.getenv("MODELSCOPE_API_KEY"):
121
+ is_configured = True
122
+ elif provider == "airllm":
123
+ # AirLLM runs locally, always considered configured
124
+ is_configured = True
105
125
  elif (
106
126
  provider == "ollama"
107
127
  or provider == "lmstudio"
@@ -37,10 +37,13 @@ from .async_groq_driver import AsyncGroqDriver
37
37
  from .async_hugging_driver import AsyncHuggingFaceDriver
38
38
  from .async_lmstudio_driver import AsyncLMStudioDriver
39
39
  from .async_local_http_driver import AsyncLocalHTTPDriver
40
+ from .async_modelscope_driver import AsyncModelScopeDriver
41
+ from .async_moonshot_driver import AsyncMoonshotDriver
40
42
  from .async_ollama_driver import AsyncOllamaDriver
41
43
  from .async_openai_driver import AsyncOpenAIDriver
42
44
  from .async_openrouter_driver import AsyncOpenRouterDriver
43
45
  from .async_registry import ASYNC_DRIVER_REGISTRY, get_async_driver, get_async_driver_for_model
46
+ from .async_zai_driver import AsyncZaiDriver
44
47
  from .azure_driver import AzureDriver
45
48
  from .claude_driver import ClaudeDriver
46
49
  from .google_driver import GoogleDriver
@@ -48,6 +51,8 @@ from .grok_driver import GrokDriver
48
51
  from .groq_driver import GroqDriver
49
52
  from .lmstudio_driver import LMStudioDriver
50
53
  from .local_http_driver import LocalHTTPDriver
54
+ from .modelscope_driver import ModelScopeDriver
55
+ from .moonshot_driver import MoonshotDriver
51
56
  from .ollama_driver import OllamaDriver
52
57
  from .openai_driver import OpenAIDriver
53
58
  from .openrouter_driver import OpenRouterDriver
@@ -65,6 +70,7 @@ from .registry import (
65
70
  unregister_async_driver,
66
71
  unregister_driver,
67
72
  )
73
+ from .zai_driver import ZaiDriver
68
74
 
69
75
  # Register built-in sync drivers
70
76
  register_driver(
@@ -123,6 +129,33 @@ register_driver(
123
129
  lambda model=None: GrokDriver(api_key=settings.grok_api_key, model=model or settings.grok_model),
124
130
  overwrite=True,
125
131
  )
132
+ register_driver(
133
+ "moonshot",
134
+ lambda model=None: MoonshotDriver(
135
+ api_key=settings.moonshot_api_key,
136
+ model=model or settings.moonshot_model,
137
+ endpoint=settings.moonshot_endpoint,
138
+ ),
139
+ overwrite=True,
140
+ )
141
+ register_driver(
142
+ "modelscope",
143
+ lambda model=None: ModelScopeDriver(
144
+ api_key=settings.modelscope_api_key,
145
+ model=model or settings.modelscope_model,
146
+ endpoint=settings.modelscope_endpoint,
147
+ ),
148
+ overwrite=True,
149
+ )
150
+ register_driver(
151
+ "zai",
152
+ lambda model=None: ZaiDriver(
153
+ api_key=settings.zhipu_api_key,
154
+ model=model or settings.zhipu_model,
155
+ endpoint=settings.zhipu_endpoint,
156
+ ),
157
+ overwrite=True,
158
+ )
126
159
  register_driver(
127
160
  "airllm",
128
161
  lambda model=None: AirLLMDriver(
@@ -197,9 +230,12 @@ __all__ = [
197
230
  "AsyncHuggingFaceDriver",
198
231
  "AsyncLMStudioDriver",
199
232
  "AsyncLocalHTTPDriver",
233
+ "AsyncModelScopeDriver",
234
+ "AsyncMoonshotDriver",
200
235
  "AsyncOllamaDriver",
201
236
  "AsyncOpenAIDriver",
202
237
  "AsyncOpenRouterDriver",
238
+ "AsyncZaiDriver",
203
239
  "AzureDriver",
204
240
  "ClaudeDriver",
205
241
  "GoogleDriver",
@@ -207,9 +243,12 @@ __all__ = [
207
243
  "GroqDriver",
208
244
  "LMStudioDriver",
209
245
  "LocalHTTPDriver",
246
+ "ModelScopeDriver",
247
+ "MoonshotDriver",
210
248
  "OllamaDriver",
211
249
  "OpenAIDriver",
212
250
  "OpenRouterDriver",
251
+ "ZaiDriver",
213
252
  "get_async_driver",
214
253
  "get_async_driver_for_model",
215
254
  # Factory functions
@@ -11,7 +11,7 @@ except Exception:
11
11
  AsyncAzureOpenAI = None
12
12
 
13
13
  from ..async_driver import AsyncDriver
14
- from ..cost_mixin import CostMixin
14
+ from ..cost_mixin import CostMixin, prepare_strict_schema
15
15
  from .azure_driver import AzureDriver
16
16
 
17
17
 
@@ -89,12 +89,13 @@ class AsyncAzureDriver(CostMixin, AsyncDriver):
89
89
  if options.get("json_mode"):
90
90
  json_schema = options.get("json_schema")
91
91
  if json_schema:
92
+ schema_copy = prepare_strict_schema(json_schema)
92
93
  kwargs["response_format"] = {
93
94
  "type": "json_schema",
94
95
  "json_schema": {
95
96
  "name": "extraction",
96
97
  "strict": True,
97
- "schema": json_schema,
98
+ "schema": schema_copy,
98
99
  },
99
100
  }
100
101
  else: