@tt-a1i/mco 0.1.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,241 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import threading
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+ from typing import Callable, Dict, List, Optional, Set, Tuple
8
+
9
+ from .retry import RetryPolicy
10
+ from .types import AttemptResult, ErrorKind, RunResult, TaskState, WarningKind
11
+
12
+
13
+ RETRYABLE_ERRORS = {
14
+ ErrorKind.RETRYABLE_TIMEOUT,
15
+ ErrorKind.RETRYABLE_RATE_LIMIT,
16
+ ErrorKind.RETRYABLE_TRANSIENT_NETWORK,
17
+ }
18
+
19
+
20
+ VALID_TRANSITIONS: Dict[TaskState, Set[TaskState]] = {
21
+ TaskState.DRAFT: {TaskState.QUEUED},
22
+ TaskState.QUEUED: {TaskState.DISPATCHED, TaskState.CANCELLED, TaskState.EXPIRED},
23
+ TaskState.DISPATCHED: {TaskState.RUNNING, TaskState.CANCELLED, TaskState.EXPIRED},
24
+ TaskState.RUNNING: {
25
+ TaskState.RETRYING,
26
+ TaskState.AGGREGATING,
27
+ TaskState.FAILED,
28
+ TaskState.CANCELLED,
29
+ TaskState.EXPIRED,
30
+ TaskState.PARTIAL_SUCCESS,
31
+ },
32
+ TaskState.RETRYING: {TaskState.RUNNING, TaskState.FAILED, TaskState.EXPIRED},
33
+ TaskState.AGGREGATING: {TaskState.COMPLETED, TaskState.PARTIAL_SUCCESS, TaskState.FAILED},
34
+ TaskState.COMPLETED: set(),
35
+ TaskState.PARTIAL_SUCCESS: set(),
36
+ TaskState.FAILED: set(),
37
+ TaskState.CANCELLED: set(),
38
+ TaskState.EXPIRED: set(),
39
+ }
40
+
41
+
42
+ @dataclass
43
+ class TaskStateMachine:
44
+ state: TaskState = TaskState.DRAFT
45
+
46
+ def transition(self, next_state: TaskState) -> None:
47
+ if next_state not in VALID_TRANSITIONS[self.state]:
48
+ raise ValueError(f"illegal transition {self.state} -> {next_state}")
49
+ self.state = next_state
50
+
51
+
52
+ class OrchestratorRuntime:
53
+ def __init__(self, retry_policy: Optional[RetryPolicy] = None, state_file: Optional[str] = None) -> None:
54
+ self.retry_policy = retry_policy or RetryPolicy()
55
+ self.dispatch_cache: Dict[str, RunResult] = {}
56
+ self.idempotency_index: Dict[str, str] = {}
57
+ self.sent_notifications: Set[Tuple[str, str, str]] = set()
58
+ self.state_file = Path(state_file) if state_file else None
59
+ self._lock = threading.RLock()
60
+ if self.state_file:
61
+ self._load_state()
62
+
63
+ def _load_state(self) -> None:
64
+ with self._lock:
65
+ if not self.state_file:
66
+ return
67
+ if not self.state_file.exists():
68
+ return
69
+ data = json.loads(self.state_file.read_text(encoding="utf-8"))
70
+ self.idempotency_index = dict(data.get("idempotency_index", {}))
71
+ self.sent_notifications = {
72
+ (item["task_id"], item["state"], item["channel"]) for item in data.get("sent_notifications", [])
73
+ }
74
+
75
+ self.dispatch_cache = {}
76
+ for key, value in data.get("dispatch_cache", {}).items():
77
+ warnings = [WarningKind(w) for w in value.get("warnings", [])]
78
+ final_error = value.get("final_error")
79
+ self.dispatch_cache[key] = RunResult(
80
+ task_id=value["task_id"],
81
+ provider=value["provider"],
82
+ dispatch_key=value["dispatch_key"],
83
+ success=value["success"],
84
+ attempts=value["attempts"],
85
+ delays_seconds=value.get("delays_seconds", []),
86
+ output=value.get("output"),
87
+ final_error=ErrorKind(final_error) if final_error else None,
88
+ warnings=warnings,
89
+ deduped_dispatch=False,
90
+ )
91
+
92
+ def _persist_state(self) -> None:
93
+ with self._lock:
94
+ if not self.state_file:
95
+ return
96
+ if not self.state_file.parent.exists():
97
+ self.state_file.parent.mkdir(parents=True, exist_ok=True)
98
+
99
+ dispatch_cache = {}
100
+ for key, value in self.dispatch_cache.items():
101
+ dispatch_cache[key] = {
102
+ "task_id": value.task_id,
103
+ "provider": value.provider,
104
+ "dispatch_key": value.dispatch_key,
105
+ "success": value.success,
106
+ "attempts": value.attempts,
107
+ "delays_seconds": value.delays_seconds,
108
+ "output": value.output,
109
+ "final_error": value.final_error.value if value.final_error else None,
110
+ "warnings": [w.value for w in value.warnings],
111
+ }
112
+
113
+ payload = {
114
+ "idempotency_index": self.idempotency_index,
115
+ "dispatch_cache": dispatch_cache,
116
+ "sent_notifications": [
117
+ {"task_id": task_id, "state": state, "channel": channel}
118
+ for task_id, state, channel in sorted(self.sent_notifications)
119
+ ],
120
+ }
121
+
122
+ tmp = self.state_file.with_suffix(self.state_file.suffix + ".tmp")
123
+ tmp.write_text(json.dumps(payload, ensure_ascii=True, indent=2), encoding="utf-8")
124
+ tmp.replace(self.state_file)
125
+
126
+ def submit(self, task_id: str, idempotency_key: str) -> Tuple[bool, str]:
127
+ """Returns (created_new, task_id)."""
128
+ with self._lock:
129
+ existing = self.idempotency_index.get(idempotency_key)
130
+ if existing:
131
+ return (False, existing)
132
+ self.idempotency_index[idempotency_key] = task_id
133
+ self._persist_state()
134
+ return (True, task_id)
135
+
136
+ def run_with_retry(
137
+ self,
138
+ task_id: str,
139
+ provider: str,
140
+ dispatch_key: str,
141
+ runner: Callable[[int], AttemptResult],
142
+ ) -> RunResult:
143
+ with self._lock:
144
+ if dispatch_key in self.dispatch_cache:
145
+ cached = self.dispatch_cache[dispatch_key]
146
+ return RunResult(
147
+ task_id=cached.task_id,
148
+ provider=cached.provider,
149
+ dispatch_key=cached.dispatch_key,
150
+ success=cached.success,
151
+ attempts=cached.attempts,
152
+ delays_seconds=list(cached.delays_seconds),
153
+ output=cached.output,
154
+ final_error=cached.final_error,
155
+ warnings=list(cached.warnings),
156
+ deduped_dispatch=True,
157
+ )
158
+
159
+ attempts = 0
160
+ delays: List[float] = []
161
+ all_warnings = []
162
+ final_error: Optional[ErrorKind] = None
163
+ output = None
164
+
165
+ while True:
166
+ attempts += 1
167
+ result = runner(attempts)
168
+ all_warnings.extend(result.warnings)
169
+
170
+ if result.success:
171
+ output = result.output
172
+ final = RunResult(
173
+ task_id=task_id,
174
+ provider=provider,
175
+ dispatch_key=dispatch_key,
176
+ success=True,
177
+ attempts=attempts,
178
+ delays_seconds=delays,
179
+ output=output,
180
+ final_error=None,
181
+ warnings=all_warnings,
182
+ )
183
+ with self._lock:
184
+ self.dispatch_cache[dispatch_key] = final
185
+ self._persist_state()
186
+ return final
187
+
188
+ final_error = result.error_kind or ErrorKind.NORMALIZATION_ERROR
189
+ should_retry = final_error in RETRYABLE_ERRORS and attempts <= self.retry_policy.max_retries
190
+ if not should_retry:
191
+ final = RunResult(
192
+ task_id=task_id,
193
+ provider=provider,
194
+ dispatch_key=dispatch_key,
195
+ success=False,
196
+ attempts=attempts,
197
+ delays_seconds=delays,
198
+ output=result.output,
199
+ final_error=final_error,
200
+ warnings=all_warnings,
201
+ )
202
+ with self._lock:
203
+ self.dispatch_cache[dispatch_key] = final
204
+ self._persist_state()
205
+ return final
206
+
207
+ retry_index = attempts
208
+ delays.append(self.retry_policy.compute_delay(retry_index))
209
+
210
+ def send_terminal_notification(self, task_id: str, state: TaskState, channel: str) -> bool:
211
+ with self._lock:
212
+ key = (task_id, state.value, channel)
213
+ if key in self.sent_notifications:
214
+ return False
215
+ self.sent_notifications.add(key)
216
+ self._persist_state()
217
+ return True
218
+
219
+ def evaluate_terminal_state(self, required_provider_success: Dict[str, bool]) -> TaskState:
220
+ if not required_provider_success:
221
+ return TaskState.FAILED
222
+ successes = sum(1 for ok in required_provider_success.values() if ok)
223
+ if successes == 0:
224
+ return TaskState.FAILED
225
+ if successes == len(required_provider_success):
226
+ return TaskState.COMPLETED
227
+ return TaskState.PARTIAL_SUCCESS
228
+
229
+ @staticmethod
230
+ def should_expire(
231
+ elapsed_seconds: float,
232
+ timeout_seconds: float,
233
+ grace_seconds: float,
234
+ heartbeat_age_seconds: float,
235
+ heartbeat_ttl_seconds: float,
236
+ ) -> bool:
237
+ if elapsed_seconds > (timeout_seconds + grace_seconds):
238
+ return True
239
+ if heartbeat_age_seconds > heartbeat_ttl_seconds:
240
+ return True
241
+ return False
@@ -0,0 +1,15 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+
6
+ @dataclass(frozen=True)
7
+ class RetryPolicy:
8
+ max_retries: int = 2
9
+ base_delay_seconds: float = 1.0
10
+ backoff_multiplier: float = 2.0
11
+
12
+ def compute_delay(self, retry_index: int) -> float:
13
+ # retry_index starts at 1 for the first retry.
14
+ return self.base_delay_seconds * (self.backoff_multiplier ** (retry_index - 1))
15
+