stratum-py 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
stratum/__init__.py ADDED
@@ -0,0 +1,120 @@
1
+ """
2
+ Stratum — LLM calls that behave like the rest of your code.
3
+
4
+ Public API surface (v1):
5
+
6
+ Decorators: @contract, @infer, @compute, @flow, @refine
7
+ Types: Budget, opaque, Probabilistic, HumanDecision, HumanReviewContext
8
+ HITL: await_human
9
+ Concurrency: parallel, debate, race
10
+ Utilities: configure, run
11
+ Errors: StratumError and all subclasses
12
+ Trace: TraceRecord, all_records, clear_traces
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import asyncio
18
+ from typing import Any
19
+
20
+ from .contracts import contract, opaque, is_registered
21
+ from .budget import Budget
22
+ from .exceptions import (
23
+ StratumError,
24
+ StratumCompileError,
25
+ PreconditionFailed,
26
+ PostconditionFailed,
27
+ ParseFailure,
28
+ BudgetExceeded,
29
+ ConvergenceFailure,
30
+ ConsensusFailure,
31
+ ParallelValidationFailed,
32
+ HITLTimeoutError,
33
+ StabilityAssertionError,
34
+ )
35
+ from .decorators import infer, compute, flow, refine
36
+ from .trace import TraceRecord, all_records, clear as clear_traces
37
+ from ._config import configure
38
+ from .types import Probabilistic, HumanDecision, HumanReviewContext, Success, Failure
39
+ from .hitl import await_human, ReviewSink, ConsoleReviewSink, PendingReview
40
+ from .concurrency import parallel, debate, race
41
+ from .flow_scope import FlowScope
42
+ from . import exporters
43
+
44
+
45
+ def run(coro: Any) -> Any:
46
+ """
47
+ Synchronous shim for non-async contexts (scripts, notebooks).
48
+
49
+ Manages an event loop internally. MUST NOT be called from inside an
50
+ already-running event loop.
51
+ """
52
+ try:
53
+ loop = asyncio.get_event_loop()
54
+ if loop.is_running():
55
+ raise RuntimeError(
56
+ "stratum.run() must not be called from inside a running event loop. "
57
+ "Use 'await' directly instead."
58
+ )
59
+ return loop.run_until_complete(coro)
60
+ except RuntimeError as exc:
61
+ if "no current event loop" in str(exc).lower():
62
+ loop = asyncio.new_event_loop()
63
+ asyncio.set_event_loop(loop)
64
+ try:
65
+ return loop.run_until_complete(coro)
66
+ finally:
67
+ loop.close()
68
+ raise
69
+
70
+
71
+ __all__ = [
72
+ # Decorators
73
+ "contract",
74
+ "infer",
75
+ "compute",
76
+ "flow",
77
+ "refine",
78
+ # Types
79
+ "Budget",
80
+ "opaque",
81
+ "Probabilistic",
82
+ "Success",
83
+ "Failure",
84
+ "HumanDecision",
85
+ "HumanReviewContext",
86
+ # HITL
87
+ "await_human",
88
+ "ReviewSink",
89
+ "ConsoleReviewSink",
90
+ "PendingReview",
91
+ # Concurrency
92
+ "parallel",
93
+ "debate",
94
+ "race",
95
+ # Flow context
96
+ "FlowScope",
97
+ # Configuration
98
+ "configure",
99
+ "run",
100
+ # Trace
101
+ "TraceRecord",
102
+ "all_records",
103
+ "clear_traces",
104
+ # Errors
105
+ "StratumError",
106
+ "StratumCompileError",
107
+ "PreconditionFailed",
108
+ "PostconditionFailed",
109
+ "ParseFailure",
110
+ "BudgetExceeded",
111
+ "ConvergenceFailure",
112
+ "ConsensusFailure",
113
+ "ParallelValidationFailed",
114
+ "HITLTimeoutError",
115
+ "StabilityAssertionError",
116
+ # Registry
117
+ "is_registered",
118
+ # Exporters
119
+ "exporters",
120
+ ]
stratum/_config.py ADDED
@@ -0,0 +1,48 @@
1
+ """Global Stratum configuration."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
7
+
8
+ _config: dict[str, Any] = {
9
+ "client": None, # uses litellm directly if None
10
+ "review_sink": None, # ConsoleReviewSink if None
11
+ "tracer": None, # None = no OTel export
12
+ "default_model": "claude-sonnet-4-6",
13
+ "test_mode": False, # True → sample sample_n times for Probabilistic[T]
14
+ "sample_n": 5, # samples per @infer call in test_mode
15
+ }
16
+
17
+
18
+ def configure(
19
+ client: Any = None,
20
+ review_sink: Any = None,
21
+ tracer: Any = None,
22
+ default_model: str | None = None,
23
+ test_mode: bool | None = None,
24
+ sample_n: int | None = None,
25
+ ) -> None:
26
+ """
27
+ Set global Stratum configuration.
28
+
29
+ Configuration is global and set once at startup. Per-function decorator
30
+ annotations take precedence over global defaults.
31
+ """
32
+ if client is not None:
33
+ _config["client"] = client
34
+ if review_sink is not None:
35
+ _config["review_sink"] = review_sink
36
+ if tracer is not None:
37
+ _config["tracer"] = tracer
38
+ if default_model is not None:
39
+ _config["default_model"] = default_model
40
+ if test_mode is not None:
41
+ _config["test_mode"] = test_mode
42
+ if sample_n is not None:
43
+ _config["sample_n"] = sample_n
44
+
45
+
46
+ def get_config() -> dict[str, Any]:
47
+ """Return the current configuration dict (mutable reference)."""
48
+ return _config
stratum/budget.py ADDED
@@ -0,0 +1,61 @@
1
+ """Budget dataclass — time and cost envelope for @infer and @flow calls."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import time
6
+ from dataclasses import dataclass, field
7
+
8
+
9
+ @dataclass
10
+ class Budget:
11
+ """
12
+ Declares a time and/or cost budget for an @infer or @flow invocation.
13
+
14
+ Either or both of `ms` and `usd` may be specified. Unspecified axes are
15
+ unbounded.
16
+ """
17
+
18
+ ms: int | None = None # wall-clock milliseconds
19
+ usd: float | None = None # cost ceiling in USD
20
+
21
+ # Runtime tracking — not part of the public API, not shown in repr
22
+ _start_ms: float = field(
23
+ default_factory=lambda: time.monotonic() * 1000,
24
+ init=False,
25
+ repr=False,
26
+ compare=False,
27
+ )
28
+ _spent_usd: float = field(
29
+ default=0.0,
30
+ init=False,
31
+ repr=False,
32
+ compare=False,
33
+ )
34
+
35
+ def remaining_seconds(self) -> float | None:
36
+ """
37
+ Return remaining wall-clock time in seconds, or None if no ms limit.
38
+ Returns 0.0 if the budget is already exhausted.
39
+ """
40
+ if self.ms is None:
41
+ return None
42
+ elapsed_ms = (time.monotonic() * 1000) - self._start_ms
43
+ remaining_ms = self.ms - elapsed_ms
44
+ return max(0.0, remaining_ms / 1000.0)
45
+
46
+ def record_cost(self, usd: float) -> None:
47
+ """Accumulate a cost charge against this budget."""
48
+ self._spent_usd += usd
49
+
50
+ def is_cost_exceeded(self) -> bool:
51
+ """Return True if cumulative cost has reached or exceeded the usd ceiling."""
52
+ if self.usd is None:
53
+ return False
54
+ return self._spent_usd >= self.usd
55
+
56
+ def clone(self) -> Budget:
57
+ """
58
+ Create a fresh budget with the same limits, resetting the elapsed clock
59
+ and spent cost to zero. Used by @flow to create a per-execution envelope.
60
+ """
61
+ return Budget(ms=self.ms, usd=self.usd)
stratum/compiler.py ADDED
@@ -0,0 +1,160 @@
1
+ """Prompt compiler — deterministic assembly and SHA-256 hash."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import hashlib
6
+ import json
7
+ from typing import Any
8
+
9
+
10
+ def _format_value(value: Any) -> str:
11
+ """
12
+ Format a value for inline prompt display.
13
+
14
+ Contract instances are rendered as a dict of their public attributes.
15
+ """
16
+ if isinstance(value, str):
17
+ return value
18
+ if hasattr(value, "__dict__"):
19
+ public = {k: v for k, v in value.__dict__.items() if not k.startswith("_")}
20
+ return str(public)
21
+ return str(value)
22
+
23
+
24
+ def compile_prompt(
25
+ intent: str,
26
+ context: list[str],
27
+ inputs: dict[str, Any],
28
+ opaque_fields: set[str],
29
+ retry_reasons: list[str],
30
+ ) -> str:
31
+ """
32
+ Assemble the prompt string sent to the LLM.
33
+
34
+ Assembly order (per spec §4.1):
35
+ 1. intent
36
+ 2. context annotations (in declaration order)
37
+ 3. non-opaque input bindings
38
+ 4. retry context (only when retry_reasons is non-empty)
39
+ 5. opaque data reference (only when opaque_fields is non-empty)
40
+
41
+ The output schema is NOT included — it is enforced via the structured
42
+ outputs API (tool_choice).
43
+
44
+ Spec §4.2: raises StratumCompileError if an opaque field name appears as
45
+ an inline {field} reference in intent or context strings.
46
+ """
47
+ # Spec §4.2: raise StratumCompileError if an opaque field is referenced
48
+ # inline in intent or context strings.
49
+ if opaque_fields:
50
+ from .exceptions import StratumCompileError
51
+ for text in [intent, *context]:
52
+ for field_name in opaque_fields:
53
+ if f"{{{field_name}}}" in text:
54
+ raise StratumCompileError(
55
+ f"opaque field '{field_name}' must not appear in inline "
56
+ "string interpolation (intent or context). "
57
+ "Opaque fields are passed as structured attachments only."
58
+ )
59
+ parts: list[str] = []
60
+
61
+ # 1. Intent
62
+ parts.append(intent)
63
+
64
+ # 2. Context annotations
65
+ for ctx in context:
66
+ if ctx:
67
+ parts.append(ctx)
68
+
69
+ # 3. Non-opaque input bindings
70
+ non_opaque = {k: v for k, v in inputs.items() if k not in opaque_fields}
71
+ if non_opaque:
72
+ parts.append("Inputs:")
73
+ for key, value in non_opaque.items():
74
+ parts.append(f" {key}: {_format_value(value)}")
75
+
76
+ # 4. Retry context — only on retries (attempt > 0)
77
+ if retry_reasons:
78
+ parts.append("Previous attempt failed:")
79
+ for reason in retry_reasons:
80
+ parts.append(f" - {reason}")
81
+ parts.append("Fix these issues specifically.")
82
+
83
+ # 5. Opaque field reference
84
+ if opaque_fields:
85
+ names = ", ".join(sorted(opaque_fields))
86
+ parts.append(f"See attached data for: {names}")
87
+
88
+ return "\n".join(parts)
89
+
90
+
91
+ def compile_prompt_stable(
92
+ intent: str,
93
+ context: list[str],
94
+ opaque_fields: set[str],
95
+ ) -> str:
96
+ """
97
+ Return the stable, cacheable prefix of the compiled prompt: intent + context.
98
+
99
+ Runs the opaque-field inline-reference check (spec §4.2).
100
+ Used by the executor to build Anthropic prompt-cache content blocks.
101
+ """
102
+ if opaque_fields:
103
+ from .exceptions import StratumCompileError
104
+ for text in [intent, *context]:
105
+ for field_name in opaque_fields:
106
+ if f"{{{field_name}}}" in text:
107
+ raise StratumCompileError(
108
+ f"opaque field '{field_name}' must not appear in inline "
109
+ "string interpolation (intent or context). "
110
+ "Opaque fields are passed as structured attachments only."
111
+ )
112
+ parts = [intent]
113
+ for ctx in context:
114
+ if ctx:
115
+ parts.append(ctx)
116
+ return "\n".join(parts)
117
+
118
+
119
+ def compile_prompt_variable(
120
+ inputs: dict[str, Any],
121
+ opaque_fields: set[str],
122
+ retry_reasons: list[str],
123
+ ) -> str:
124
+ """
125
+ Return the variable suffix of the compiled prompt: inputs + retry + opaque reference.
126
+
127
+ Used by the executor to build the uncached portion of Anthropic content blocks.
128
+ """
129
+ parts: list[str] = []
130
+ non_opaque = {k: v for k, v in inputs.items() if k not in opaque_fields}
131
+ if non_opaque:
132
+ parts.append("Inputs:")
133
+ for key, value in non_opaque.items():
134
+ parts.append(f" {key}: {_format_value(value)}")
135
+ if retry_reasons:
136
+ parts.append("Previous attempt failed:")
137
+ for reason in retry_reasons:
138
+ parts.append(f" - {reason}")
139
+ parts.append("Fix these issues specifically.")
140
+ if opaque_fields:
141
+ names = ", ".join(sorted(opaque_fields))
142
+ parts.append(f"See attached data for: {names}")
143
+ return "\n".join(parts)
144
+
145
+
146
+ def prompt_hash(prompt: str) -> str:
147
+ """Return the first 12 hex characters of the SHA-256 of the prompt string."""
148
+ return hashlib.sha256(prompt.encode()).hexdigest()[:12]
149
+
150
+
151
+ def build_opaque_attachment(
152
+ inputs: dict[str, Any],
153
+ opaque_fields: set[str],
154
+ ) -> dict[str, Any] | None:
155
+ """
156
+ Return a dict of {field_name: value} for all opaque fields, or None if
157
+ there are no opaque fields.
158
+ """
159
+ result = {k: v for k, v in inputs.items() if k in opaque_fields}
160
+ return result if result else None
stratum/concurrency.py ADDED
@@ -0,0 +1,215 @@
1
+ """Concurrency primitives: parallel, race, debate."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ from typing import Any, Callable
7
+
8
+ from .exceptions import ConsensusFailure, ParallelValidationFailed
9
+ from .types import Failure, Success
10
+
11
+
12
+ async def parallel(
13
+ *coros: Any,
14
+ require: str | int = "all",
15
+ validate: Callable[[list], bool] | None = None,
16
+ ) -> Any:
17
+ """
18
+ Run coroutines concurrently with configurable success semantics.
19
+
20
+ require="all" → all must succeed; any failure cancels rest and re-raises.
21
+ Returns a tuple matching input order.
22
+ require="any" → first success wins, rest cancelled. Returns single result.
23
+ require=N: int → at least N must succeed. Returns list of N results.
24
+ require=0 → collect all regardless of failure. Returns list[Success|Failure].
25
+
26
+ validate → optional callable on collected results; False → ParallelValidationFailed.
27
+ """
28
+ if require == "all":
29
+ async with asyncio.TaskGroup() as tg:
30
+ tasks = [tg.create_task(c) for c in coros]
31
+ results = [t.result() for t in tasks]
32
+
33
+ if validate is not None and not validate(results):
34
+ raise ParallelValidationFailed()
35
+
36
+ return tuple(results)
37
+
38
+ if require == "any":
39
+ if not coros:
40
+ raise ValueError("parallel(require='any'): requires at least one coroutine")
41
+ tasks = [asyncio.create_task(c) for c in coros]
42
+ pending: set = set(tasks)
43
+ last_exc: Exception | None = None
44
+
45
+ while pending:
46
+ done, pending = await asyncio.wait(
47
+ pending, return_when=asyncio.FIRST_COMPLETED
48
+ )
49
+ winner = None
50
+ for d in done:
51
+ if d.exception() is None:
52
+ winner = d
53
+ break
54
+ last_exc = d.exception()
55
+
56
+ if winner is not None:
57
+ # Cancel remaining pending tasks
58
+ for p in pending:
59
+ p.cancel()
60
+ try:
61
+ await p
62
+ except (asyncio.CancelledError, Exception):
63
+ pass
64
+ # Drain other done tasks so their exceptions are retrieved
65
+ for d in done:
66
+ if d is not winner:
67
+ try:
68
+ d.exception()
69
+ except (asyncio.CancelledError, Exception):
70
+ pass
71
+ result = winner.result()
72
+ if validate is not None and not validate([result]):
73
+ raise ParallelValidationFailed()
74
+ return result
75
+
76
+ if last_exc is not None:
77
+ raise last_exc
78
+ raise RuntimeError("parallel: all coroutines failed with no exception recorded")
79
+
80
+ if isinstance(require, int) and require == 0:
81
+ # Collect all regardless of failure — wrap in Success/Failure
82
+ raw = await asyncio.gather(*coros, return_exceptions=True)
83
+ results = [
84
+ Success(r) if not isinstance(r, Exception) else Failure(r)
85
+ for r in raw
86
+ ]
87
+ if validate is not None and not validate(results):
88
+ raise ParallelValidationFailed()
89
+ return results
90
+
91
+ if isinstance(require, int) and require > 0:
92
+ # At least require many must succeed
93
+ all_results = await asyncio.gather(*coros, return_exceptions=True)
94
+ successes = [r for r in all_results if not isinstance(r, Exception)]
95
+ failures = [r for r in all_results if isinstance(r, Exception)]
96
+
97
+ if len(successes) < require:
98
+ if failures:
99
+ raise failures[0]
100
+ raise RuntimeError(
101
+ f"parallel: needed {require} successes, got {len(successes)}"
102
+ )
103
+
104
+ results = successes[:require]
105
+ if validate is not None and not validate(results):
106
+ raise ParallelValidationFailed()
107
+ return results
108
+
109
+ raise ValueError(f"parallel: invalid require value: {require!r}")
110
+
111
+
112
+ async def race(*coros: Any) -> Any:
113
+ """
114
+ Submit all coroutines concurrently. First to complete without raising wins.
115
+ Remaining coroutines are cancelled. If all raise, re-raises the last error.
116
+ """
117
+ if not coros:
118
+ raise ValueError("race: requires at least one coroutine")
119
+ tasks = [asyncio.create_task(c) for c in coros]
120
+ pending: set = set(tasks)
121
+ last_exc: Exception | None = None
122
+
123
+ while pending:
124
+ done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
125
+ winner = None
126
+ for d in done:
127
+ exc = d.exception()
128
+ if exc is None:
129
+ winner = d
130
+ break
131
+ last_exc = exc
132
+
133
+ if winner is not None:
134
+ # Cancel the rest
135
+ for p in pending:
136
+ p.cancel()
137
+ try:
138
+ await p
139
+ except (asyncio.CancelledError, Exception):
140
+ pass
141
+ # Drain other done tasks so their exceptions are retrieved
142
+ for d in done:
143
+ if d is not winner:
144
+ try:
145
+ d.exception()
146
+ except (asyncio.CancelledError, Exception):
147
+ pass
148
+ return winner.result()
149
+
150
+ if last_exc is not None:
151
+ raise last_exc
152
+ raise RuntimeError("race: all coroutines failed")
153
+
154
+
155
+ async def debate(
156
+ agents: list[Callable],
157
+ topic: Any,
158
+ rounds: int = 2,
159
+ *,
160
+ synthesize: Callable,
161
+ ) -> Any:
162
+ """
163
+ Multi-agent debate protocol.
164
+
165
+ Round 1: all agents invoked concurrently with topic.
166
+ Rounds 2..N: each agent invoked concurrently with topic + other agents' previous arguments.
167
+ After all rounds, convergence is computed and synthesize is called.
168
+
169
+ synthesize is required.
170
+ """
171
+ if not agents:
172
+ raise ValueError("debate: agents list must not be empty")
173
+
174
+ # Round 1 — initial arguments (concurrent)
175
+ initial_results = await asyncio.gather(*[agent(topic=topic) for agent in agents])
176
+ arguments: list[Any] = list(initial_results)
177
+ history: list[list[Any]] = [arguments]
178
+
179
+ # Rebuttal rounds — each round all agents run concurrently
180
+ for _round in range(1, rounds):
181
+ rebuttal_coros = [
182
+ agents[i](
183
+ topic=topic,
184
+ previous_arguments=[arguments[j] for j in range(len(arguments)) if j != i],
185
+ )
186
+ for i in range(len(agents))
187
+ ]
188
+ new_args = list(await asyncio.gather(*rebuttal_coros))
189
+ arguments = new_args
190
+ history.append(list(arguments))
191
+
192
+ # Compute convergence — use agree_on field from agents if declared
193
+ last_round = history[-1]
194
+ agree_on_field: str | None = None
195
+ for agent in agents:
196
+ spec = getattr(agent, "_stratum_spec", None)
197
+ if spec is not None and getattr(spec, "agree_on", None):
198
+ agree_on_field = spec.agree_on
199
+ break
200
+
201
+ def _get_field(obj: Any, name: str) -> Any:
202
+ if hasattr(obj, name):
203
+ return getattr(obj, name)
204
+ if isinstance(obj, dict):
205
+ return obj.get(name)
206
+ return obj
207
+
208
+ if agree_on_field is not None:
209
+ comparison_values = {str(_get_field(a, agree_on_field)) for a in last_round}
210
+ else:
211
+ comparison_values = {str(a) for a in last_round}
212
+
213
+ converged = len(comparison_values) == 1
214
+
215
+ return await synthesize(topic=topic, arguments=history, converged=converged)