tigrbl-kernel 0.4.0.dev2__py3-none-any.whl → 0.4.1.dev6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tigrbl_kernel/__init__.py CHANGED
@@ -25,6 +25,7 @@ from .measure import (
25
25
  )
26
26
 
27
27
  _LAZY_EXPORTS = {
28
+ "BatchOpPlan": "models",
28
29
  "Kernel": "core",
29
30
  "OpView": "models",
30
31
  "PackedKernel": "models",
@@ -83,6 +84,7 @@ __all__ = [
83
84
  "Kernel",
84
85
  "RustBackendConfig",
85
86
  "RustPlan",
87
+ "BatchOpPlan",
86
88
  "OpView",
87
89
  "PackedKernel",
88
90
  "SchemaIn",
tigrbl_kernel/_build.py CHANGED
@@ -14,6 +14,7 @@ from tigrbl_atoms import StepFn
14
14
  from tigrbl_atoms.atoms.sys.phase_db import run as _bind_phase_db
15
15
  from tigrbl_atoms.phases import phase_info
16
16
  from tigrbl_atoms.types import EdgeTarget, PhaseTreeEdge, PhaseTreeNode, error_phase_for
17
+ from tigrbl_core.config.resolver import resolve_cfg
17
18
  from tigrbl_typing.phases import normalize_phase
18
19
 
19
20
  from . import events as _ev
@@ -24,7 +25,7 @@ from .atoms import (
24
25
  _is_persistent,
25
26
  _wrap_atom,
26
27
  )
27
- from .models import HotOpPlan, KernelPlan, OpKey, OpView, PackedKernel
28
+ from .models import BatchOpPlan, HotOpPlan, KernelPlan, OpKey, OpView, PackedKernel
28
29
  from .measure import (
29
30
  load_packed_kernel_hot_sections,
30
31
  measure_packed_kernel,
@@ -121,6 +122,11 @@ def _phase_stamp(self: Any, model: type, alias: str) -> tuple[Any, ...]:
121
122
  specs = getattr(getattr(model, "ops", SimpleNamespace()), "by_alias", {})
122
123
  sp_list = specs.get(alias) or ()
123
124
  sp = sp_list[0] if sp_list else None
125
+ if sp is None:
126
+ sp = next(
127
+ (item for item in _opspecs(model) if getattr(item, "alias", None) == alias),
128
+ None,
129
+ )
124
130
  phase_lists = tuple(
125
131
  (
126
132
  phase,
@@ -185,6 +191,75 @@ def _prepend_phase_db_binding(
185
191
  chains[phase] = [_phase_db_step(), *steps]
186
192
 
187
193
 
194
+ def _batch_policy_for_op(sp: Any | None, *, alias: str, target: str, persistent: bool) -> Mapping[str, Any]:
195
+ cfg = resolve_cfg(op=alias, opspec=sp).as_dict()
196
+ batch = cfg.get("batch")
197
+ if not isinstance(batch, Mapping):
198
+ return {"enabled": False}
199
+ policy = dict(batch)
200
+ allow_reads = bool(policy.get("allow_reads", False))
201
+ if target in {"read", "list"} and not allow_reads:
202
+ policy["enabled"] = False
203
+ if not persistent and not allow_reads:
204
+ policy["enabled"] = False
205
+ return policy
206
+
207
+
208
+ def _is_batch_domain_run(run: Any) -> bool:
209
+ mod = getattr(run, "__module__", "") or ""
210
+ parts = mod.split(".")
211
+ try:
212
+ index = parts.index("atoms")
213
+ except ValueError:
214
+ return False
215
+ domain = parts[index + 1] if index + 1 < len(parts) else None
216
+ return domain in {"transport", "intent", "batch", "fanout"}
217
+
218
+
219
+ def _batch_op_plan_from_policy(policy: Mapping[str, Any] | None) -> BatchOpPlan:
220
+ policy = policy if isinstance(policy, Mapping) else {}
221
+
222
+ def _int(name: str, default: int) -> int:
223
+ try:
224
+ return max(0, int(policy.get(name, default)))
225
+ except Exception:
226
+ return default
227
+
228
+ return BatchOpPlan(
229
+ enabled=bool(policy.get("enabled", False)),
230
+ max_size=_int("max_size", 64),
231
+ max_bytes=_int("max_bytes", 1_048_576),
232
+ max_delay_ms=_int("max_delay_ms", 1),
233
+ admission_timeout_ms=_int("admission_timeout_ms", 5),
234
+ conflict_policy=str(policy.get("conflict_policy", "single_fallback")),
235
+ overflow_policy=str(policy.get("overflow_policy", "backpressure")),
236
+ result_fanout=str(policy.get("result_fanout", "by_admission")),
237
+ allow_reads=bool(policy.get("allow_reads", False)),
238
+ max_queue_depth=_int("max_queue_depth", 1024),
239
+ max_in_flight=_int("max_in_flight", 16),
240
+ )
241
+
242
+
243
+ def _batch_policy_for_meta(meta: Any) -> BatchOpPlan:
244
+ model = getattr(meta, "model", None)
245
+ alias = str(getattr(meta, "alias", "") or "")
246
+ target = str(getattr(meta, "target", alias) or alias).lower()
247
+ sp = next(
248
+ (item for item in _opspecs(model) if getattr(item, "alias", None) == alias),
249
+ None,
250
+ )
251
+ persist_policy = getattr(sp, "persist", "default")
252
+ persistent = persist_policy != "skip" and target not in {"read", "list"}
253
+ return _batch_op_plan_from_policy(
254
+ _batch_policy_for_op(
255
+ sp,
256
+ alias=alias,
257
+ target=target,
258
+ persistent=persistent,
259
+ )
260
+ )
261
+
262
+
188
263
  def _build_op(self, model: type, alias: str) -> Dict[str, List[StepFn]]:
189
264
  from .core import DEFAULT_PHASE_ORDER
190
265
 
@@ -201,11 +276,22 @@ def _build_op(self, model: type, alias: str) -> Dict[str, List[StepFn]]:
201
276
  specs = getattr(getattr(model, "ops", SimpleNamespace()), "by_alias", {})
202
277
  sp_list = specs.get(alias) or ()
203
278
  sp = sp_list[0] if sp_list else None
279
+ if sp is None:
280
+ sp = next(
281
+ (item for item in _opspecs(model) if getattr(item, "alias", None) == alias),
282
+ None,
283
+ )
204
284
  target = (getattr(sp, "target", alias) or "").lower()
205
285
  persist_policy = getattr(sp, "persist", "default")
206
286
  persistent = (
207
287
  persist_policy != "skip" and target not in {"read", "list"}
208
288
  ) or _is_persistent(chains)
289
+ batch_policy = _batch_policy_for_op(
290
+ sp,
291
+ alias=alias,
292
+ target=target,
293
+ persistent=persistent,
294
+ )
209
295
 
210
296
  try:
211
297
  _inject_atoms(
@@ -213,6 +299,7 @@ def _build_op(self, model: type, alias: str) -> Dict[str, List[StepFn]]:
213
299
  self._atoms() or (),
214
300
  persistent=persistent,
215
301
  target=target,
302
+ batch_policy=batch_policy,
216
303
  )
217
304
  except Exception:
218
305
  logger.exception(
@@ -240,6 +327,8 @@ def _build_ingress(self, app: Any) -> Dict[str, List[StepFn]]:
240
327
  order = {name: idx for idx, name in enumerate(_ev.all_events_ordered())}
241
328
  ingress_atoms: Dict[str, List[tuple[str, Any]]] = {}
242
329
  for anchor, run in self._atoms() or ():
330
+ if _is_batch_domain_run(run):
331
+ continue
243
332
  if not _ev.is_valid_event(anchor):
244
333
  continue
245
334
  phase = _ev.phase_for_event(anchor)
@@ -262,6 +351,8 @@ def _build_egress(self, app: Any) -> Dict[str, List[StepFn]]:
262
351
  order = {name: idx for idx, name in enumerate(_ev.all_events_ordered())}
263
352
  egress_atoms: Dict[str, List[tuple[str, Any]]] = {}
264
353
  for anchor, run in self._atoms() or ():
354
+ if _is_batch_domain_run(run):
355
+ continue
265
356
  if not _ev.is_valid_event(anchor):
266
357
  continue
267
358
  phase = _ev.phase_for_event(anchor)
@@ -939,8 +1030,18 @@ def _pack_kernel_plan(
939
1030
  error_profile_segment_refs: list[int] = []
940
1031
  program_error_profile_ids: list[int] = []
941
1032
  program_hot_runner_ids: list[int] = []
1033
+ batch_policy_index: dict[BatchOpPlan, int] = {}
1034
+ batch_policy_table: list[BatchOpPlan] = []
1035
+ program_batch_policy_ids: list[int] = []
942
1036
  for program_id, _meta in enumerate(plan.opmeta):
943
1037
  meta = plan.opmeta[program_id]
1038
+ batch_plan = _batch_policy_for_meta(meta)
1039
+ batch_policy_id = batch_policy_index.get(batch_plan)
1040
+ if batch_policy_id is None:
1041
+ batch_policy_id = len(batch_policy_table)
1042
+ batch_policy_index[batch_plan] = batch_policy_id
1043
+ batch_policy_table.append(batch_plan)
1044
+ program_batch_policy_ids.append(batch_policy_id)
944
1045
  seg_offset = op_segment_offsets[program_id]
945
1046
  seg_length = op_segment_lengths[program_id]
946
1047
  by_phase: dict[str, list[int]] = {}
@@ -1106,6 +1207,8 @@ def _pack_kernel_plan(
1106
1207
  program_hot_runner_id=hot_runner_id,
1107
1208
  param_shape_id=param_shape_id,
1108
1209
  transport_kind_id=transport_kind_id,
1210
+ batch_policy_id=batch_policy_id,
1211
+ batch=batch_plan,
1109
1212
  compiled_param_phase_steps=compiled_param_phase_steps,
1110
1213
  websocket_path=websocket_fast_path[0] if websocket_fast_path is not None else "",
1111
1214
  websocket_protocol=websocket_fast_path[1] if websocket_fast_path is not None else "",
@@ -1142,6 +1245,8 @@ def _pack_kernel_plan(
1142
1245
  param_shape_header_hashes=tuple(param_shape_header_hashes),
1143
1246
  program_param_shape_ids=tuple(program_param_shape_ids),
1144
1247
  program_transport_kind_ids=tuple(program_transport_kind_ids),
1248
+ batch_policy_table=tuple(batch_policy_table),
1249
+ program_batch_policy_ids=tuple(program_batch_policy_ids),
1145
1250
  segment_offsets=tuple(segment_offsets),
1146
1251
  segment_lengths=tuple(segment_lengths),
1147
1252
  segment_step_ids=tuple(segment_step_ids),
tigrbl_kernel/atoms.py CHANGED
@@ -4,6 +4,7 @@ import importlib
4
4
  import inspect
5
5
  import logging
6
6
  import pkgutil
7
+ from functools import lru_cache
7
8
  from types import SimpleNamespace
8
9
  from typing import (
9
10
  Any,
@@ -16,6 +17,7 @@ from typing import (
16
17
  Sequence,
17
18
  cast,
18
19
  )
20
+ from collections.abc import Mapping as AbcMapping
19
21
 
20
22
  from tigrbl_typing.phases import HOOK_PHASES as HOOK_PHASES
21
23
 
@@ -45,6 +47,7 @@ _COMPILED_PHASE_DB_REQUIRED_ATOM_NAMES = frozenset(
45
47
  "sys.commit_tx",
46
48
  }
47
49
  )
50
+ _BATCH_DOMAINS = frozenset({"transport", "intent", "batch", "fanout"})
48
51
 
49
52
 
50
53
  def _is_async_callable(run: _AtomRun) -> bool:
@@ -128,6 +131,18 @@ def _policy_atom_name(
128
131
 
129
132
 
130
133
  def _use_two_args_for(run: _AtomRun) -> bool:
134
+ try:
135
+ return _use_two_args_for_cached(run)
136
+ except TypeError:
137
+ return _use_two_args_for_uncached(run)
138
+
139
+
140
+ @lru_cache(maxsize=2048)
141
+ def _use_two_args_for_cached(run: _AtomRun) -> bool:
142
+ return _use_two_args_for_uncached(run)
143
+
144
+
145
+ def _use_two_args_for_uncached(run: _AtomRun) -> bool:
131
146
  try:
132
147
  params = tuple(inspect.signature(run).parameters.values())
133
148
  positional = [
@@ -291,6 +306,7 @@ def _inject_atoms(
291
306
  *,
292
307
  persistent: bool,
293
308
  target: str | None = None,
309
+ batch_policy: Mapping[str, Any] | None = None,
294
310
  ) -> None:
295
311
  order = {name: i for i, name in enumerate(_ev.all_events_ordered())}
296
312
 
@@ -324,7 +340,13 @@ def _inject_atoms(
324
340
  continue
325
341
 
326
342
  domain, _subject = _infer_domain_subject(run)
327
- if not persistent and persist_tied:
343
+ batch_enabled = bool(
344
+ isinstance(batch_policy, AbcMapping)
345
+ and bool(batch_policy.get("enabled", False))
346
+ )
347
+ if domain in _BATCH_DOMAINS and not batch_enabled:
348
+ continue
349
+ if not persistent and persist_tied and not (domain in _BATCH_DOMAINS and batch_enabled):
328
350
  if not (
329
351
  domain == "sys"
330
352
  and isinstance(_subject, str)
tigrbl_kernel/helpers.py CHANGED
@@ -1,10 +1,11 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from dataclasses import fields
3
4
  import inspect
4
5
  import logging
5
6
  from typing import Any, Iterable, Mapping, Optional, Sequence
6
7
 
7
- from tigrbl_atoms.types import AtomFailure, FailedCtx, build_error_ctx
8
+ from tigrbl_atoms.types import AtomFailure, BaseCtx, FailedCtx, build_error_ctx
8
9
  from tigrbl_typing.phases import normalize_phase
9
10
 
10
11
  try:
@@ -15,6 +16,26 @@ except Exception: # pragma: no cover
15
16
  logger = logging.getLogger(__name__)
16
17
 
17
18
 
19
+ def _merge_promoted_ctx(ctx: Any, promoted: BaseCtx[Any, Any]) -> None:
20
+ for field_info in fields(type(promoted)):
21
+ name = field_info.name
22
+ value = getattr(promoted, name)
23
+ if name == "bag":
24
+ if isinstance(value, Mapping):
25
+ for bag_key, bag_value in value.items():
26
+ if bag_key == "result" and hasattr(promoted, "result"):
27
+ continue
28
+ ctx[bag_key] = bag_value
29
+ continue
30
+ if name == "temp":
31
+ ctx.temp = dict(value or {})
32
+ continue
33
+ if name in {"env", "error", "phase", "current_phase", "error_phase"}:
34
+ setattr(ctx, name, value)
35
+ continue
36
+ ctx[name] = value
37
+
38
+
18
39
  def _normalize_payload(payload: Any) -> Any:
19
40
  if isinstance(payload, (str, int, float, bool)) or payload is None:
20
41
  return payload
@@ -109,7 +130,9 @@ async def _run_chain(ctx: Any, chain: Optional[Iterable[Any]], *, phase: str) ->
109
130
  if isinstance(err, BaseException):
110
131
  raise err
111
132
  raise AtomFailure(err)
112
- if rv is not None and rv is not ctx:
133
+ if isinstance(rv, BaseCtx):
134
+ _merge_promoted_ctx(ctx, rv)
135
+ elif rv is not None and rv is not ctx:
113
136
  ctx.result = rv
114
137
  if trace_active:
115
138
  _trace.end(ctx, seq, status=_trace.OK)
@@ -0,0 +1,34 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Iterable
4
+
5
+
6
+ _RECEIVE_CAPABILITIES = {"recv", "receive"}
7
+
8
+
9
+ def select_loop_mode(
10
+ *,
11
+ binding: str,
12
+ subevent_handlers: Iterable[str] = (),
13
+ explicit_mode: str | None = None,
14
+ capabilities: Iterable[str] = (),
15
+ ) -> str:
16
+ handlers = tuple(subevent_handlers or ())
17
+ caps = {str(capability) for capability in capabilities or ()}
18
+ if explicit_mode not in {None, "owner", "dispatch"}:
19
+ raise ValueError(f"unsupported loop mode {explicit_mode!r}")
20
+
21
+ if explicit_mode == "owner" and handlers:
22
+ raise ValueError("owner loop mode cannot be selected with dispatch handlers")
23
+ if explicit_mode == "dispatch":
24
+ if not handlers:
25
+ raise ValueError("dispatch loop mode requires subevent handlers")
26
+ if caps and caps.isdisjoint(_RECEIVE_CAPABILITIES):
27
+ raise ValueError("dispatch loop mode requires receive capability")
28
+ return "dispatch"
29
+ if explicit_mode == "owner":
30
+ return "owner"
31
+ return "dispatch" if handlers else "owner"
32
+
33
+
34
+ __all__ = ["select_loop_mode"]
tigrbl_kernel/models.py CHANGED
@@ -74,6 +74,21 @@ class OpMeta:
74
74
  target: str
75
75
 
76
76
 
77
+ @dataclass(frozen=True, slots=True)
78
+ class BatchOpPlan:
79
+ enabled: bool = False
80
+ max_size: int = 64
81
+ max_bytes: int = 1_048_576
82
+ max_delay_ms: int = 1
83
+ admission_timeout_ms: int = 5
84
+ conflict_policy: str = "single_fallback"
85
+ overflow_policy: str = "backpressure"
86
+ result_fanout: str = "by_admission"
87
+ allow_reads: bool = False
88
+ max_queue_depth: int = 1024
89
+ max_in_flight: int = 16
90
+
91
+
77
92
  @dataclass(frozen=True, slots=True)
78
93
  class CompiledPhase:
79
94
  name: str
@@ -112,6 +127,8 @@ class HotOpPlan:
112
127
  program_hot_runner_id: int = 0
113
128
  param_shape_id: int = -1
114
129
  transport_kind_id: int = 0
130
+ batch_policy_id: int = -1
131
+ batch: BatchOpPlan = field(default_factory=BatchOpPlan)
115
132
  compiled_param_phase_steps: tuple[tuple[str, tuple[Any, ...]], ...] = ()
116
133
  websocket_path: str = ""
117
134
  websocket_protocol: str = ""
@@ -283,6 +300,8 @@ class PackedKernel:
283
300
  param_shape_header_hashes: tuple[int, ...] = ()
284
301
  program_param_shape_ids: tuple[int, ...] = ()
285
302
  program_transport_kind_ids: tuple[int, ...] = ()
303
+ batch_policy_table: tuple[BatchOpPlan, ...] = ()
304
+ program_batch_policy_ids: tuple[int, ...] = ()
286
305
 
287
306
  segment_offsets: tuple[int, ...] = ()
288
307
  segment_lengths: tuple[int, ...] = ()
tigrbl_kernel/ordering.py CHANGED
@@ -18,21 +18,51 @@ from .labels import Label
18
18
  _PREF: Dict[str, Tuple[str, ...]] = {
19
19
  _ev.INGRESS_CTX_INIT: ("ingress:ctx_init",),
20
20
  _ev.INGRESS_TRANSPORT_EXTRACT: ("ingress:transport_extract",),
21
+ _ev.BATCH_TRANSPORT_UNIT_CAPTURE: ("transport:unit_capture",),
22
+ _ev.BATCH_TRANSPORT_SINK_BIND: ("transport:sink_bind",),
21
23
  _ev.INGRESS_INPUT_PREPARE: ("ingress:input_prepare",),
22
24
  _ev.DISPATCH_BINDING_MATCH: ("dispatch:binding_match",),
23
25
  _ev.DISPATCH_BINDING_PARSE: ("dispatch:binding_parse",),
24
26
  _ev.DISPATCH_INPUT_NORMALIZE: ("dispatch:input_normalize",),
25
27
  _ev.DISPATCH_OP_RESOLVE: ("dispatch:op_resolve",),
28
+ _ev.BATCH_INTENT_BUILD: ("intent:build",),
29
+ _ev.BATCH_PREKEY: ("intent:prekey",),
26
30
  _ev.DEP_SECURITY: (_ev.DEP_SECURITY,),
27
31
  _ev.DEP_EXTRA: (_ev.DEP_EXTRA,),
32
+ _ev.BATCH_GROUP_KEY: ("intent:final_group_key",),
33
+ _ev.BATCH_ADMIT: ("batch:admit",),
34
+ _ev.BATCH_DEDUPE: ("batch:dedupe",),
35
+ _ev.BATCH_SEAL_CHECK: ("batch:seal_check",),
36
+ _ev.BATCH_AWAIT_SEAL: ("batch:await_seal",),
28
37
  _ev.SCHEMA_COLLECT_IN: ("schema:collect_in",),
38
+ _ev.BATCH_PREPARE_EXECUTE: ("batch:prepare_execute",),
29
39
  _ev.IN_VALIDATE: ("wire:build_in", "wire:validate_in"),
30
- _ev.SYS_TX_BEGIN: ("sys:txn:begin",),
40
+ _ev.SYS_TX_BEGIN: ("sys:start_tx",),
41
+ _ev.BATCH_TX_BEGIN: ("batch:tx_begin",),
31
42
  _ev.RESOLVE_VALUES: ("resolve:assemble", "resolve:paired_gen"),
32
43
  _ev.PRE_FLUSH: ("storage:to_stored",),
33
44
  _ev.EMIT_ALIASES_PRE: ("emit:paired_pre",),
34
- _ev.SYS_HANDLER_PERSISTENCE: ("sys:handler:crud",),
35
- _ev.SYS_TX_COMMIT: ("sys:txn:commit",),
45
+ _ev.BATCH_EXECUTE: ("batch:execute",),
46
+ _ev.SYS_HANDLER_PERSISTENCE: (
47
+ "sys:handler_create",
48
+ "sys:handler_read",
49
+ "sys:handler_update",
50
+ "sys:handler_replace",
51
+ "sys:handler_merge",
52
+ "sys:handler_delete",
53
+ "sys:handler_list",
54
+ "sys:handler_clear",
55
+ "sys:handler_bulk_create",
56
+ "sys:handler_bulk_update",
57
+ "sys:handler_bulk_replace",
58
+ "sys:handler_bulk_merge",
59
+ "sys:handler_bulk_delete",
60
+ "sys:handler_persistence",
61
+ ),
62
+ _ev.BATCH_RESULT_SLOTS: ("batch:result_slots",),
63
+ _ev.BATCH_PRECOMMIT_VALIDATE: ("batch:precommit_validate",),
64
+ _ev.SYS_TX_COMMIT: ("sys:commit_tx",),
65
+ _ev.BATCH_COMMIT: ("batch:commit",),
36
66
  _ev.POST_FLUSH: ("refresh:demand",),
37
67
  _ev.EMIT_ALIASES_POST: ("emit:paired_post",),
38
68
  _ev.SCHEMA_COLLECT_OUT: ("schema:collect_out",),
@@ -43,7 +73,10 @@ _PREF: Dict[str, Tuple[str, ...]] = {
43
73
  _ev.EGRESS_HEADERS_APPLY: ("egress:headers_apply",),
44
74
  _ev.EGRESS_HTTP_FINALIZE: ("egress:http_finalize",),
45
75
  _ev.EGRESS_TO_TRANSPORT_RESPONSE: ("egress:to_transport_response",),
76
+ _ev.BATCH_EGRESS_SHAPE: ("fanout:shape",),
46
77
  _ev.EGRESS_ASGI_SEND: ("egress:asgi_send",),
78
+ _ev.BATCH_FANOUT_EMIT: ("fanout:emit_many",),
79
+ _ev.BATCH_CLEANUP: ("batch:cleanup",),
47
80
  _ev.EMIT_ALIASES_READ: ("emit:readtime_alias",),
48
81
  _ev.OUT_DUMP: (
49
82
  "wire:dump",
@@ -0,0 +1,31 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+
6
+ def compile_completion_fence(binding: dict[str, Any]) -> dict[str, Any]:
7
+ transport = binding.get("transport")
8
+ needs_fence = transport in {"stream", "websocket", "webtransport", "datagram"} or binding.get("phase") == "EMIT"
9
+ if not needs_fence and binding.get("send_completion") == "synchronous":
10
+ return {
11
+ "completion_fence": None,
12
+ "runtime_owned": True,
13
+ "public_hook_phase": False,
14
+ "after_phase": binding.get("phase", "EMIT"),
15
+ "explicit_ack_required": False,
16
+ }
17
+ return {
18
+ "completion_fence": "POST_EMIT",
19
+ "runtime_owned": True,
20
+ "public_hook_phase": False,
21
+ "after_phase": binding.get("phase", "EMIT"),
22
+ "explicit_ack_required": True,
23
+ }
24
+
25
+
26
+ def validate_completion_hook_phase(phase: str) -> None:
27
+ if phase == "POST_EMIT":
28
+ raise ValueError("POST_EMIT is a runtime-owned completion fence hook phase")
29
+
30
+
31
+ __all__ = ["compile_completion_fence", "validate_completion_hook_phase"]
@@ -0,0 +1,73 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Iterable, Mapping
4
+ from typing import Any
5
+
6
+
7
+ def _tuple_values(segment: Mapping[str, Any], key: str) -> tuple[Any, ...]:
8
+ value = segment.get(key, ())
9
+ if value is None:
10
+ return ()
11
+ return tuple(value)
12
+
13
+
14
+ def _barrier(segment: Mapping[str, Any]) -> bool:
15
+ return bool(segment.get("barrier")) or str(segment.get("barrier_kind", "")) in {
16
+ "transaction",
17
+ "transport",
18
+ "error",
19
+ "completion",
20
+ }
21
+
22
+
23
+ def fuse_protocol_segments(
24
+ segments: Iterable[Mapping[str, Any]],
25
+ *,
26
+ force: bool = False,
27
+ ) -> list[dict[str, object]]:
28
+ items = [dict(segment) for segment in segments]
29
+ if force and sum(1 for segment in items if _barrier(segment)) > 1:
30
+ raise ValueError("protocol segment fusion cannot cross transaction, transport, or error barriers")
31
+
32
+ fused: list[dict[str, object]] = []
33
+ bucket: list[dict[str, Any]] = []
34
+
35
+ def flush() -> None:
36
+ nonlocal bucket
37
+ if not bucket:
38
+ return
39
+ if len(bucket) == 1:
40
+ item = dict(bucket[0])
41
+ if "err_target" in item:
42
+ item["err_targets"] = (item["err_target"],)
43
+ fused.append(item)
44
+ else:
45
+ merged: dict[str, object] = {
46
+ "segment_id": "+".join(str(item["segment_id"]) for item in bucket),
47
+ "source_segments": tuple(str(item["segment_id"]) for item in bucket),
48
+ "anchors": tuple(anchor for item in bucket for anchor in _tuple_values(item, "anchors")),
49
+ }
50
+ atoms = tuple(atom for item in bucket for atom in _tuple_values(item, "atoms"))
51
+ if atoms:
52
+ merged["atoms"] = atoms
53
+ err_targets = tuple(item["err_target"] for item in bucket if "err_target" in item)
54
+ if err_targets:
55
+ merged["err_targets"] = err_targets
56
+ fused.append(merged)
57
+ bucket = []
58
+
59
+ for segment in items:
60
+ if _barrier(segment):
61
+ flush()
62
+ fused.append(segment)
63
+ else:
64
+ bucket.append(segment)
65
+ flush()
66
+ return fused
67
+
68
+
69
+ def linearize_segment_anchors(segments: Iterable[Mapping[str, Any]]) -> tuple[Any, ...]:
70
+ return tuple(anchor for segment in segments for anchor in _tuple_values(segment, "anchors"))
71
+
72
+
73
+ __all__ = ["fuse_protocol_segments", "linearize_segment_anchors"]
@@ -0,0 +1,117 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Iterable, Mapping
4
+ from typing import Any
5
+
6
+
7
+ REQUIRED_COLUMNS = {
8
+ "binding",
9
+ "subevent",
10
+ "phase",
11
+ "segment",
12
+ "atom",
13
+ "legality",
14
+ "transaction_unit",
15
+ "hookable",
16
+ "emits_bytes",
17
+ "requires_channel",
18
+ "ok_target",
19
+ "err_target",
20
+ "err_ctx",
21
+ "err_kind",
22
+ "rollback_required",
23
+ "terminal_policy",
24
+ }
25
+
26
+
27
+ def _row(binding: str, subevent: str, atom: str, **overrides: Any) -> dict[str, Any]:
28
+ row = {
29
+ "binding": binding,
30
+ "subevent": subevent,
31
+ "phase": "EMIT" if "emit" in atom else "HANDLER",
32
+ "segment": "transport" if atom.startswith("transport.") else "handler",
33
+ "atom": atom,
34
+ "legality": "required",
35
+ "transaction_unit": "request",
36
+ "hookable": False,
37
+ "emits_bytes": atom == "transport.emit",
38
+ "requires_channel": binding not in {"http.rest", "http.jsonrpc"},
39
+ "ok_target": "NEXT",
40
+ "err_target": "ON_PROTOCOL_ERROR",
41
+ "err_ctx": "ErrorCtx",
42
+ "err_kind": "protocol_error",
43
+ "rollback_required": False,
44
+ "terminal_policy": "continue",
45
+ }
46
+ row.update(overrides)
47
+ return row
48
+
49
+
50
+ def generate_legality_matrix() -> list[dict[str, Any]]:
51
+ rows = [
52
+ _row("http.rest", "request.received", "CALL_HANDLER", requires_channel=False),
53
+ _row("http.jsonrpc", "rpc.request", "CALL_HANDLER", requires_channel=False),
54
+ _row("http.stream", "stream.chunk.emit", "transport.emit"),
55
+ _row("http.sse", "message.emit", "transport.emit"),
56
+ _row("http.sse", "message.emit_complete", "transport.emit_complete"),
57
+ _row("websocket", "session.open", "transport.accept", phase="PRE_HANDLER"),
58
+ _row("websocket", "session.open", "framing.decode", phase="PRE_HANDLER"),
59
+ _row("webtransport.stream", "stream.chunk.received", "framing.decode"),
60
+ _row("webtransport.datagram", "datagram.received", "framing.decode"),
61
+ _row("webtransport.app_frame", "message.received", "framing.decode"),
62
+ ]
63
+ return rows
64
+
65
+
66
+ def validate_legality_matrix(rows: Iterable[Mapping[str, Any]]) -> dict[str, bool]:
67
+ for row in rows:
68
+ missing = REQUIRED_COLUMNS - set(row)
69
+ if missing:
70
+ raise ValueError(f"legality matrix row missing required columns: {sorted(missing)}")
71
+ if row.get("legality") == "forbidden" and row.get("atom") == "transport.emit":
72
+ raise ValueError("legality matrix forbids illegal transport atom")
73
+ return {"passed": True}
74
+
75
+
76
+ def validate_protocol_plan(*, binding: str, subevent: str, phase: str, atoms: tuple[str, ...]) -> None:
77
+ rows = [
78
+ row
79
+ for row in generate_legality_matrix()
80
+ if row["binding"] == binding and row["subevent"] == subevent and row["phase"] == phase
81
+ ]
82
+ allowed = {row["atom"] for row in rows}
83
+ if "transport.emit" in atoms and "transport.emit" not in allowed:
84
+ raise ValueError("forbidden atom violates legality matrix")
85
+ required = {row["atom"] for row in rows if row["legality"] == "required"}
86
+ missing = required - set(atoms)
87
+ if missing:
88
+ raise ValueError(f"required atom missing from legality plan: {sorted(missing)}")
89
+
90
+
91
+ def _key(row: Mapping[str, Any]) -> tuple[Any, Any, Any, Any]:
92
+ return (row.get("binding"), row.get("subevent"), row.get("phase"), row.get("atom"))
93
+
94
+
95
+ def diff_legality_matrix(
96
+ old: Iterable[Mapping[str, Any]],
97
+ new: Iterable[Mapping[str, Any]],
98
+ ) -> dict[str, list[dict[str, Any]]]:
99
+ old_map = {_key(row): dict(row) for row in old}
100
+ new_map = {_key(row): dict(row) for row in new}
101
+ return {
102
+ "added": [new_map[key] for key in new_map.keys() - old_map.keys()],
103
+ "removed": [old_map[key] for key in old_map.keys() - new_map.keys()],
104
+ "changed": [
105
+ {"old": old_map[key], "new": new_map[key]}
106
+ for key in old_map.keys() & new_map.keys()
107
+ if old_map[key] != new_map[key]
108
+ ],
109
+ }
110
+
111
+
112
+ __all__ = [
113
+ "diff_legality_matrix",
114
+ "generate_legality_matrix",
115
+ "validate_legality_matrix",
116
+ "validate_protocol_plan",
117
+ ]