stratum-py 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- stratum/__init__.py +120 -0
- stratum/_config.py +48 -0
- stratum/budget.py +61 -0
- stratum/compiler.py +160 -0
- stratum/concurrency.py +215 -0
- stratum/contracts.py +232 -0
- stratum/decorators.py +377 -0
- stratum/exceptions.py +126 -0
- stratum/executor.py +564 -0
- stratum/exporters/__init__.py +5 -0
- stratum/exporters/otlp.py +149 -0
- stratum/flow_scope.py +31 -0
- stratum/hitl.py +170 -0
- stratum/trace.py +51 -0
- stratum/types.py +108 -0
- stratum_py-0.1.0.dist-info/METADATA +12 -0
- stratum_py-0.1.0.dist-info/RECORD +18 -0
- stratum_py-0.1.0.dist-info/WHEEL +4 -0
stratum/__init__.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Stratum — LLM calls that behave like the rest of your code.
|
|
3
|
+
|
|
4
|
+
Public API surface (v1):
|
|
5
|
+
|
|
6
|
+
Decorators: @contract, @infer, @compute, @flow, @refine
|
|
7
|
+
Types: Budget, opaque, Probabilistic, HumanDecision, HumanReviewContext
|
|
8
|
+
HITL: await_human
|
|
9
|
+
Concurrency: parallel, debate, race
|
|
10
|
+
Utilities: configure, run
|
|
11
|
+
Errors: StratumError and all subclasses
|
|
12
|
+
Trace: TraceRecord, all_records, clear_traces
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import asyncio
|
|
18
|
+
from typing import Any
|
|
19
|
+
|
|
20
|
+
from .contracts import contract, opaque, is_registered
|
|
21
|
+
from .budget import Budget
|
|
22
|
+
from .exceptions import (
|
|
23
|
+
StratumError,
|
|
24
|
+
StratumCompileError,
|
|
25
|
+
PreconditionFailed,
|
|
26
|
+
PostconditionFailed,
|
|
27
|
+
ParseFailure,
|
|
28
|
+
BudgetExceeded,
|
|
29
|
+
ConvergenceFailure,
|
|
30
|
+
ConsensusFailure,
|
|
31
|
+
ParallelValidationFailed,
|
|
32
|
+
HITLTimeoutError,
|
|
33
|
+
StabilityAssertionError,
|
|
34
|
+
)
|
|
35
|
+
from .decorators import infer, compute, flow, refine
|
|
36
|
+
from .trace import TraceRecord, all_records, clear as clear_traces
|
|
37
|
+
from ._config import configure
|
|
38
|
+
from .types import Probabilistic, HumanDecision, HumanReviewContext, Success, Failure
|
|
39
|
+
from .hitl import await_human, ReviewSink, ConsoleReviewSink, PendingReview
|
|
40
|
+
from .concurrency import parallel, debate, race
|
|
41
|
+
from .flow_scope import FlowScope
|
|
42
|
+
from . import exporters
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def run(coro: Any) -> Any:
|
|
46
|
+
"""
|
|
47
|
+
Synchronous shim for non-async contexts (scripts, notebooks).
|
|
48
|
+
|
|
49
|
+
Manages an event loop internally. MUST NOT be called from inside an
|
|
50
|
+
already-running event loop.
|
|
51
|
+
"""
|
|
52
|
+
try:
|
|
53
|
+
loop = asyncio.get_event_loop()
|
|
54
|
+
if loop.is_running():
|
|
55
|
+
raise RuntimeError(
|
|
56
|
+
"stratum.run() must not be called from inside a running event loop. "
|
|
57
|
+
"Use 'await' directly instead."
|
|
58
|
+
)
|
|
59
|
+
return loop.run_until_complete(coro)
|
|
60
|
+
except RuntimeError as exc:
|
|
61
|
+
if "no current event loop" in str(exc).lower():
|
|
62
|
+
loop = asyncio.new_event_loop()
|
|
63
|
+
asyncio.set_event_loop(loop)
|
|
64
|
+
try:
|
|
65
|
+
return loop.run_until_complete(coro)
|
|
66
|
+
finally:
|
|
67
|
+
loop.close()
|
|
68
|
+
raise
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
__all__ = [
|
|
72
|
+
# Decorators
|
|
73
|
+
"contract",
|
|
74
|
+
"infer",
|
|
75
|
+
"compute",
|
|
76
|
+
"flow",
|
|
77
|
+
"refine",
|
|
78
|
+
# Types
|
|
79
|
+
"Budget",
|
|
80
|
+
"opaque",
|
|
81
|
+
"Probabilistic",
|
|
82
|
+
"Success",
|
|
83
|
+
"Failure",
|
|
84
|
+
"HumanDecision",
|
|
85
|
+
"HumanReviewContext",
|
|
86
|
+
# HITL
|
|
87
|
+
"await_human",
|
|
88
|
+
"ReviewSink",
|
|
89
|
+
"ConsoleReviewSink",
|
|
90
|
+
"PendingReview",
|
|
91
|
+
# Concurrency
|
|
92
|
+
"parallel",
|
|
93
|
+
"debate",
|
|
94
|
+
"race",
|
|
95
|
+
# Flow context
|
|
96
|
+
"FlowScope",
|
|
97
|
+
# Configuration
|
|
98
|
+
"configure",
|
|
99
|
+
"run",
|
|
100
|
+
# Trace
|
|
101
|
+
"TraceRecord",
|
|
102
|
+
"all_records",
|
|
103
|
+
"clear_traces",
|
|
104
|
+
# Errors
|
|
105
|
+
"StratumError",
|
|
106
|
+
"StratumCompileError",
|
|
107
|
+
"PreconditionFailed",
|
|
108
|
+
"PostconditionFailed",
|
|
109
|
+
"ParseFailure",
|
|
110
|
+
"BudgetExceeded",
|
|
111
|
+
"ConvergenceFailure",
|
|
112
|
+
"ConsensusFailure",
|
|
113
|
+
"ParallelValidationFailed",
|
|
114
|
+
"HITLTimeoutError",
|
|
115
|
+
"StabilityAssertionError",
|
|
116
|
+
# Registry
|
|
117
|
+
"is_registered",
|
|
118
|
+
# Exporters
|
|
119
|
+
"exporters",
|
|
120
|
+
]
|
stratum/_config.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
"""Global Stratum configuration."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
_config: dict[str, Any] = {
|
|
9
|
+
"client": None, # uses litellm directly if None
|
|
10
|
+
"review_sink": None, # ConsoleReviewSink if None
|
|
11
|
+
"tracer": None, # None = no OTel export
|
|
12
|
+
"default_model": "claude-sonnet-4-6",
|
|
13
|
+
"test_mode": False, # True → sample sample_n times for Probabilistic[T]
|
|
14
|
+
"sample_n": 5, # samples per @infer call in test_mode
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def configure(
|
|
19
|
+
client: Any = None,
|
|
20
|
+
review_sink: Any = None,
|
|
21
|
+
tracer: Any = None,
|
|
22
|
+
default_model: str | None = None,
|
|
23
|
+
test_mode: bool | None = None,
|
|
24
|
+
sample_n: int | None = None,
|
|
25
|
+
) -> None:
|
|
26
|
+
"""
|
|
27
|
+
Set global Stratum configuration.
|
|
28
|
+
|
|
29
|
+
Configuration is global and set once at startup. Per-function decorator
|
|
30
|
+
annotations take precedence over global defaults.
|
|
31
|
+
"""
|
|
32
|
+
if client is not None:
|
|
33
|
+
_config["client"] = client
|
|
34
|
+
if review_sink is not None:
|
|
35
|
+
_config["review_sink"] = review_sink
|
|
36
|
+
if tracer is not None:
|
|
37
|
+
_config["tracer"] = tracer
|
|
38
|
+
if default_model is not None:
|
|
39
|
+
_config["default_model"] = default_model
|
|
40
|
+
if test_mode is not None:
|
|
41
|
+
_config["test_mode"] = test_mode
|
|
42
|
+
if sample_n is not None:
|
|
43
|
+
_config["sample_n"] = sample_n
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def get_config() -> dict[str, Any]:
|
|
47
|
+
"""Return the current configuration dict (mutable reference)."""
|
|
48
|
+
return _config
|
stratum/budget.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
"""Budget dataclass — time and cost envelope for @infer and @flow calls."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import time
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class Budget:
|
|
11
|
+
"""
|
|
12
|
+
Declares a time and/or cost budget for an @infer or @flow invocation.
|
|
13
|
+
|
|
14
|
+
Either or both of `ms` and `usd` may be specified. Unspecified axes are
|
|
15
|
+
unbounded.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
ms: int | None = None # wall-clock milliseconds
|
|
19
|
+
usd: float | None = None # cost ceiling in USD
|
|
20
|
+
|
|
21
|
+
# Runtime tracking — not part of the public API, not shown in repr
|
|
22
|
+
_start_ms: float = field(
|
|
23
|
+
default_factory=lambda: time.monotonic() * 1000,
|
|
24
|
+
init=False,
|
|
25
|
+
repr=False,
|
|
26
|
+
compare=False,
|
|
27
|
+
)
|
|
28
|
+
_spent_usd: float = field(
|
|
29
|
+
default=0.0,
|
|
30
|
+
init=False,
|
|
31
|
+
repr=False,
|
|
32
|
+
compare=False,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
def remaining_seconds(self) -> float | None:
|
|
36
|
+
"""
|
|
37
|
+
Return remaining wall-clock time in seconds, or None if no ms limit.
|
|
38
|
+
Returns 0.0 if the budget is already exhausted.
|
|
39
|
+
"""
|
|
40
|
+
if self.ms is None:
|
|
41
|
+
return None
|
|
42
|
+
elapsed_ms = (time.monotonic() * 1000) - self._start_ms
|
|
43
|
+
remaining_ms = self.ms - elapsed_ms
|
|
44
|
+
return max(0.0, remaining_ms / 1000.0)
|
|
45
|
+
|
|
46
|
+
def record_cost(self, usd: float) -> None:
|
|
47
|
+
"""Accumulate a cost charge against this budget."""
|
|
48
|
+
self._spent_usd += usd
|
|
49
|
+
|
|
50
|
+
def is_cost_exceeded(self) -> bool:
|
|
51
|
+
"""Return True if cumulative cost has reached or exceeded the usd ceiling."""
|
|
52
|
+
if self.usd is None:
|
|
53
|
+
return False
|
|
54
|
+
return self._spent_usd >= self.usd
|
|
55
|
+
|
|
56
|
+
def clone(self) -> Budget:
|
|
57
|
+
"""
|
|
58
|
+
Create a fresh budget with the same limits, resetting the elapsed clock
|
|
59
|
+
and spent cost to zero. Used by @flow to create a per-execution envelope.
|
|
60
|
+
"""
|
|
61
|
+
return Budget(ms=self.ms, usd=self.usd)
|
stratum/compiler.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
"""Prompt compiler — deterministic assembly and SHA-256 hash."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import hashlib
|
|
6
|
+
import json
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _format_value(value: Any) -> str:
|
|
11
|
+
"""
|
|
12
|
+
Format a value for inline prompt display.
|
|
13
|
+
|
|
14
|
+
Contract instances are rendered as a dict of their public attributes.
|
|
15
|
+
"""
|
|
16
|
+
if isinstance(value, str):
|
|
17
|
+
return value
|
|
18
|
+
if hasattr(value, "__dict__"):
|
|
19
|
+
public = {k: v for k, v in value.__dict__.items() if not k.startswith("_")}
|
|
20
|
+
return str(public)
|
|
21
|
+
return str(value)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def compile_prompt(
|
|
25
|
+
intent: str,
|
|
26
|
+
context: list[str],
|
|
27
|
+
inputs: dict[str, Any],
|
|
28
|
+
opaque_fields: set[str],
|
|
29
|
+
retry_reasons: list[str],
|
|
30
|
+
) -> str:
|
|
31
|
+
"""
|
|
32
|
+
Assemble the prompt string sent to the LLM.
|
|
33
|
+
|
|
34
|
+
Assembly order (per spec §4.1):
|
|
35
|
+
1. intent
|
|
36
|
+
2. context annotations (in declaration order)
|
|
37
|
+
3. non-opaque input bindings
|
|
38
|
+
4. retry context (only when retry_reasons is non-empty)
|
|
39
|
+
5. opaque data reference (only when opaque_fields is non-empty)
|
|
40
|
+
|
|
41
|
+
The output schema is NOT included — it is enforced via the structured
|
|
42
|
+
outputs API (tool_choice).
|
|
43
|
+
|
|
44
|
+
Spec §4.2: raises StratumCompileError if an opaque field name appears as
|
|
45
|
+
an inline {field} reference in intent or context strings.
|
|
46
|
+
"""
|
|
47
|
+
# Spec §4.2: raise StratumCompileError if an opaque field is referenced
|
|
48
|
+
# inline in intent or context strings.
|
|
49
|
+
if opaque_fields:
|
|
50
|
+
from .exceptions import StratumCompileError
|
|
51
|
+
for text in [intent, *context]:
|
|
52
|
+
for field_name in opaque_fields:
|
|
53
|
+
if f"{{{field_name}}}" in text:
|
|
54
|
+
raise StratumCompileError(
|
|
55
|
+
f"opaque field '{field_name}' must not appear in inline "
|
|
56
|
+
"string interpolation (intent or context). "
|
|
57
|
+
"Opaque fields are passed as structured attachments only."
|
|
58
|
+
)
|
|
59
|
+
parts: list[str] = []
|
|
60
|
+
|
|
61
|
+
# 1. Intent
|
|
62
|
+
parts.append(intent)
|
|
63
|
+
|
|
64
|
+
# 2. Context annotations
|
|
65
|
+
for ctx in context:
|
|
66
|
+
if ctx:
|
|
67
|
+
parts.append(ctx)
|
|
68
|
+
|
|
69
|
+
# 3. Non-opaque input bindings
|
|
70
|
+
non_opaque = {k: v for k, v in inputs.items() if k not in opaque_fields}
|
|
71
|
+
if non_opaque:
|
|
72
|
+
parts.append("Inputs:")
|
|
73
|
+
for key, value in non_opaque.items():
|
|
74
|
+
parts.append(f" {key}: {_format_value(value)}")
|
|
75
|
+
|
|
76
|
+
# 4. Retry context — only on retries (attempt > 0)
|
|
77
|
+
if retry_reasons:
|
|
78
|
+
parts.append("Previous attempt failed:")
|
|
79
|
+
for reason in retry_reasons:
|
|
80
|
+
parts.append(f" - {reason}")
|
|
81
|
+
parts.append("Fix these issues specifically.")
|
|
82
|
+
|
|
83
|
+
# 5. Opaque field reference
|
|
84
|
+
if opaque_fields:
|
|
85
|
+
names = ", ".join(sorted(opaque_fields))
|
|
86
|
+
parts.append(f"See attached data for: {names}")
|
|
87
|
+
|
|
88
|
+
return "\n".join(parts)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def compile_prompt_stable(
|
|
92
|
+
intent: str,
|
|
93
|
+
context: list[str],
|
|
94
|
+
opaque_fields: set[str],
|
|
95
|
+
) -> str:
|
|
96
|
+
"""
|
|
97
|
+
Return the stable, cacheable prefix of the compiled prompt: intent + context.
|
|
98
|
+
|
|
99
|
+
Runs the opaque-field inline-reference check (spec §4.2).
|
|
100
|
+
Used by the executor to build Anthropic prompt-cache content blocks.
|
|
101
|
+
"""
|
|
102
|
+
if opaque_fields:
|
|
103
|
+
from .exceptions import StratumCompileError
|
|
104
|
+
for text in [intent, *context]:
|
|
105
|
+
for field_name in opaque_fields:
|
|
106
|
+
if f"{{{field_name}}}" in text:
|
|
107
|
+
raise StratumCompileError(
|
|
108
|
+
f"opaque field '{field_name}' must not appear in inline "
|
|
109
|
+
"string interpolation (intent or context). "
|
|
110
|
+
"Opaque fields are passed as structured attachments only."
|
|
111
|
+
)
|
|
112
|
+
parts = [intent]
|
|
113
|
+
for ctx in context:
|
|
114
|
+
if ctx:
|
|
115
|
+
parts.append(ctx)
|
|
116
|
+
return "\n".join(parts)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def compile_prompt_variable(
|
|
120
|
+
inputs: dict[str, Any],
|
|
121
|
+
opaque_fields: set[str],
|
|
122
|
+
retry_reasons: list[str],
|
|
123
|
+
) -> str:
|
|
124
|
+
"""
|
|
125
|
+
Return the variable suffix of the compiled prompt: inputs + retry + opaque reference.
|
|
126
|
+
|
|
127
|
+
Used by the executor to build the uncached portion of Anthropic content blocks.
|
|
128
|
+
"""
|
|
129
|
+
parts: list[str] = []
|
|
130
|
+
non_opaque = {k: v for k, v in inputs.items() if k not in opaque_fields}
|
|
131
|
+
if non_opaque:
|
|
132
|
+
parts.append("Inputs:")
|
|
133
|
+
for key, value in non_opaque.items():
|
|
134
|
+
parts.append(f" {key}: {_format_value(value)}")
|
|
135
|
+
if retry_reasons:
|
|
136
|
+
parts.append("Previous attempt failed:")
|
|
137
|
+
for reason in retry_reasons:
|
|
138
|
+
parts.append(f" - {reason}")
|
|
139
|
+
parts.append("Fix these issues specifically.")
|
|
140
|
+
if opaque_fields:
|
|
141
|
+
names = ", ".join(sorted(opaque_fields))
|
|
142
|
+
parts.append(f"See attached data for: {names}")
|
|
143
|
+
return "\n".join(parts)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def prompt_hash(prompt: str) -> str:
|
|
147
|
+
"""Return the first 12 hex characters of the SHA-256 of the prompt string."""
|
|
148
|
+
return hashlib.sha256(prompt.encode()).hexdigest()[:12]
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def build_opaque_attachment(
|
|
152
|
+
inputs: dict[str, Any],
|
|
153
|
+
opaque_fields: set[str],
|
|
154
|
+
) -> dict[str, Any] | None:
|
|
155
|
+
"""
|
|
156
|
+
Return a dict of {field_name: value} for all opaque fields, or None if
|
|
157
|
+
there are no opaque fields.
|
|
158
|
+
"""
|
|
159
|
+
result = {k: v for k, v in inputs.items() if k in opaque_fields}
|
|
160
|
+
return result if result else None
|
stratum/concurrency.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
"""Concurrency primitives: parallel, race, debate."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
from typing import Any, Callable
|
|
7
|
+
|
|
8
|
+
from .exceptions import ConsensusFailure, ParallelValidationFailed
|
|
9
|
+
from .types import Failure, Success
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
async def parallel(
|
|
13
|
+
*coros: Any,
|
|
14
|
+
require: str | int = "all",
|
|
15
|
+
validate: Callable[[list], bool] | None = None,
|
|
16
|
+
) -> Any:
|
|
17
|
+
"""
|
|
18
|
+
Run coroutines concurrently with configurable success semantics.
|
|
19
|
+
|
|
20
|
+
require="all" → all must succeed; any failure cancels rest and re-raises.
|
|
21
|
+
Returns a tuple matching input order.
|
|
22
|
+
require="any" → first success wins, rest cancelled. Returns single result.
|
|
23
|
+
require=N: int → at least N must succeed. Returns list of N results.
|
|
24
|
+
require=0 → collect all regardless of failure. Returns list[Success|Failure].
|
|
25
|
+
|
|
26
|
+
validate → optional callable on collected results; False → ParallelValidationFailed.
|
|
27
|
+
"""
|
|
28
|
+
if require == "all":
|
|
29
|
+
async with asyncio.TaskGroup() as tg:
|
|
30
|
+
tasks = [tg.create_task(c) for c in coros]
|
|
31
|
+
results = [t.result() for t in tasks]
|
|
32
|
+
|
|
33
|
+
if validate is not None and not validate(results):
|
|
34
|
+
raise ParallelValidationFailed()
|
|
35
|
+
|
|
36
|
+
return tuple(results)
|
|
37
|
+
|
|
38
|
+
if require == "any":
|
|
39
|
+
if not coros:
|
|
40
|
+
raise ValueError("parallel(require='any'): requires at least one coroutine")
|
|
41
|
+
tasks = [asyncio.create_task(c) for c in coros]
|
|
42
|
+
pending: set = set(tasks)
|
|
43
|
+
last_exc: Exception | None = None
|
|
44
|
+
|
|
45
|
+
while pending:
|
|
46
|
+
done, pending = await asyncio.wait(
|
|
47
|
+
pending, return_when=asyncio.FIRST_COMPLETED
|
|
48
|
+
)
|
|
49
|
+
winner = None
|
|
50
|
+
for d in done:
|
|
51
|
+
if d.exception() is None:
|
|
52
|
+
winner = d
|
|
53
|
+
break
|
|
54
|
+
last_exc = d.exception()
|
|
55
|
+
|
|
56
|
+
if winner is not None:
|
|
57
|
+
# Cancel remaining pending tasks
|
|
58
|
+
for p in pending:
|
|
59
|
+
p.cancel()
|
|
60
|
+
try:
|
|
61
|
+
await p
|
|
62
|
+
except (asyncio.CancelledError, Exception):
|
|
63
|
+
pass
|
|
64
|
+
# Drain other done tasks so their exceptions are retrieved
|
|
65
|
+
for d in done:
|
|
66
|
+
if d is not winner:
|
|
67
|
+
try:
|
|
68
|
+
d.exception()
|
|
69
|
+
except (asyncio.CancelledError, Exception):
|
|
70
|
+
pass
|
|
71
|
+
result = winner.result()
|
|
72
|
+
if validate is not None and not validate([result]):
|
|
73
|
+
raise ParallelValidationFailed()
|
|
74
|
+
return result
|
|
75
|
+
|
|
76
|
+
if last_exc is not None:
|
|
77
|
+
raise last_exc
|
|
78
|
+
raise RuntimeError("parallel: all coroutines failed with no exception recorded")
|
|
79
|
+
|
|
80
|
+
if isinstance(require, int) and require == 0:
|
|
81
|
+
# Collect all regardless of failure — wrap in Success/Failure
|
|
82
|
+
raw = await asyncio.gather(*coros, return_exceptions=True)
|
|
83
|
+
results = [
|
|
84
|
+
Success(r) if not isinstance(r, Exception) else Failure(r)
|
|
85
|
+
for r in raw
|
|
86
|
+
]
|
|
87
|
+
if validate is not None and not validate(results):
|
|
88
|
+
raise ParallelValidationFailed()
|
|
89
|
+
return results
|
|
90
|
+
|
|
91
|
+
if isinstance(require, int) and require > 0:
|
|
92
|
+
# At least require many must succeed
|
|
93
|
+
all_results = await asyncio.gather(*coros, return_exceptions=True)
|
|
94
|
+
successes = [r for r in all_results if not isinstance(r, Exception)]
|
|
95
|
+
failures = [r for r in all_results if isinstance(r, Exception)]
|
|
96
|
+
|
|
97
|
+
if len(successes) < require:
|
|
98
|
+
if failures:
|
|
99
|
+
raise failures[0]
|
|
100
|
+
raise RuntimeError(
|
|
101
|
+
f"parallel: needed {require} successes, got {len(successes)}"
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
results = successes[:require]
|
|
105
|
+
if validate is not None and not validate(results):
|
|
106
|
+
raise ParallelValidationFailed()
|
|
107
|
+
return results
|
|
108
|
+
|
|
109
|
+
raise ValueError(f"parallel: invalid require value: {require!r}")
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
async def race(*coros: Any) -> Any:
|
|
113
|
+
"""
|
|
114
|
+
Submit all coroutines concurrently. First to complete without raising wins.
|
|
115
|
+
Remaining coroutines are cancelled. If all raise, re-raises the last error.
|
|
116
|
+
"""
|
|
117
|
+
if not coros:
|
|
118
|
+
raise ValueError("race: requires at least one coroutine")
|
|
119
|
+
tasks = [asyncio.create_task(c) for c in coros]
|
|
120
|
+
pending: set = set(tasks)
|
|
121
|
+
last_exc: Exception | None = None
|
|
122
|
+
|
|
123
|
+
while pending:
|
|
124
|
+
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
|
|
125
|
+
winner = None
|
|
126
|
+
for d in done:
|
|
127
|
+
exc = d.exception()
|
|
128
|
+
if exc is None:
|
|
129
|
+
winner = d
|
|
130
|
+
break
|
|
131
|
+
last_exc = exc
|
|
132
|
+
|
|
133
|
+
if winner is not None:
|
|
134
|
+
# Cancel the rest
|
|
135
|
+
for p in pending:
|
|
136
|
+
p.cancel()
|
|
137
|
+
try:
|
|
138
|
+
await p
|
|
139
|
+
except (asyncio.CancelledError, Exception):
|
|
140
|
+
pass
|
|
141
|
+
# Drain other done tasks so their exceptions are retrieved
|
|
142
|
+
for d in done:
|
|
143
|
+
if d is not winner:
|
|
144
|
+
try:
|
|
145
|
+
d.exception()
|
|
146
|
+
except (asyncio.CancelledError, Exception):
|
|
147
|
+
pass
|
|
148
|
+
return winner.result()
|
|
149
|
+
|
|
150
|
+
if last_exc is not None:
|
|
151
|
+
raise last_exc
|
|
152
|
+
raise RuntimeError("race: all coroutines failed")
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
async def debate(
|
|
156
|
+
agents: list[Callable],
|
|
157
|
+
topic: Any,
|
|
158
|
+
rounds: int = 2,
|
|
159
|
+
*,
|
|
160
|
+
synthesize: Callable,
|
|
161
|
+
) -> Any:
|
|
162
|
+
"""
|
|
163
|
+
Multi-agent debate protocol.
|
|
164
|
+
|
|
165
|
+
Round 1: all agents invoked concurrently with topic.
|
|
166
|
+
Rounds 2..N: each agent invoked concurrently with topic + other agents' previous arguments.
|
|
167
|
+
After all rounds, convergence is computed and synthesize is called.
|
|
168
|
+
|
|
169
|
+
synthesize is required.
|
|
170
|
+
"""
|
|
171
|
+
if not agents:
|
|
172
|
+
raise ValueError("debate: agents list must not be empty")
|
|
173
|
+
|
|
174
|
+
# Round 1 — initial arguments (concurrent)
|
|
175
|
+
initial_results = await asyncio.gather(*[agent(topic=topic) for agent in agents])
|
|
176
|
+
arguments: list[Any] = list(initial_results)
|
|
177
|
+
history: list[list[Any]] = [arguments]
|
|
178
|
+
|
|
179
|
+
# Rebuttal rounds — each round all agents run concurrently
|
|
180
|
+
for _round in range(1, rounds):
|
|
181
|
+
rebuttal_coros = [
|
|
182
|
+
agents[i](
|
|
183
|
+
topic=topic,
|
|
184
|
+
previous_arguments=[arguments[j] for j in range(len(arguments)) if j != i],
|
|
185
|
+
)
|
|
186
|
+
for i in range(len(agents))
|
|
187
|
+
]
|
|
188
|
+
new_args = list(await asyncio.gather(*rebuttal_coros))
|
|
189
|
+
arguments = new_args
|
|
190
|
+
history.append(list(arguments))
|
|
191
|
+
|
|
192
|
+
# Compute convergence — use agree_on field from agents if declared
|
|
193
|
+
last_round = history[-1]
|
|
194
|
+
agree_on_field: str | None = None
|
|
195
|
+
for agent in agents:
|
|
196
|
+
spec = getattr(agent, "_stratum_spec", None)
|
|
197
|
+
if spec is not None and getattr(spec, "agree_on", None):
|
|
198
|
+
agree_on_field = spec.agree_on
|
|
199
|
+
break
|
|
200
|
+
|
|
201
|
+
def _get_field(obj: Any, name: str) -> Any:
|
|
202
|
+
if hasattr(obj, name):
|
|
203
|
+
return getattr(obj, name)
|
|
204
|
+
if isinstance(obj, dict):
|
|
205
|
+
return obj.get(name)
|
|
206
|
+
return obj
|
|
207
|
+
|
|
208
|
+
if agree_on_field is not None:
|
|
209
|
+
comparison_values = {str(_get_field(a, agree_on_field)) for a in last_round}
|
|
210
|
+
else:
|
|
211
|
+
comparison_values = {str(a) for a in last_round}
|
|
212
|
+
|
|
213
|
+
converged = len(comparison_values) == 1
|
|
214
|
+
|
|
215
|
+
return await synthesize(topic=topic, arguments=history, converged=converged)
|