synth-ai 0.2.4.dev7__py3-none-any.whl → 0.2.4.dev8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. synth_ai/__init__.py +1 -1
  2. synth_ai/cli/balance.py +3 -15
  3. synth_ai/config/base_url.py +47 -0
  4. synth_ai/http.py +102 -0
  5. synth_ai/inference/__init__.py +7 -0
  6. synth_ai/inference/client.py +20 -0
  7. synth_ai/jobs/client.py +246 -0
  8. synth_ai/learning/__init__.py +24 -0
  9. synth_ai/learning/client.py +149 -0
  10. synth_ai/learning/config.py +43 -0
  11. synth_ai/learning/constants.py +29 -0
  12. synth_ai/learning/ft_client.py +59 -0
  13. synth_ai/learning/health.py +43 -0
  14. synth_ai/learning/jobs.py +205 -0
  15. synth_ai/learning/rl_client.py +256 -0
  16. synth_ai/learning/sse.py +58 -0
  17. synth_ai/learning/validators.py +48 -0
  18. synth_ai/lm/core/main_v3.py +13 -0
  19. synth_ai/lm/core/synth_models.py +48 -0
  20. synth_ai/lm/core/vendor_clients.py +9 -6
  21. synth_ai/lm/vendors/core/openai_api.py +31 -3
  22. synth_ai/lm/vendors/openai_standard.py +45 -14
  23. synth_ai/lm/vendors/supported/custom_endpoint.py +12 -2
  24. synth_ai/lm/vendors/synth_client.py +372 -28
  25. synth_ai/rl/__init__.py +30 -0
  26. synth_ai/rl/contracts.py +32 -0
  27. synth_ai/rl/env_keys.py +137 -0
  28. synth_ai/rl/secrets.py +19 -0
  29. synth_ai/scripts/verify_rewards.py +100 -0
  30. synth_ai/task/__init__.py +10 -0
  31. synth_ai/task/contracts.py +120 -0
  32. synth_ai/task/health.py +28 -0
  33. synth_ai/task/validators.py +12 -0
  34. synth_ai/tracing_v3/hooks.py +3 -1
  35. synth_ai/tracing_v3/session_tracer.py +123 -2
  36. synth_ai/tracing_v3/turso/manager.py +218 -0
  37. synth_ai/tracing_v3/turso/models.py +53 -0
  38. synth_ai-0.2.4.dev8.dist-info/METADATA +635 -0
  39. {synth_ai-0.2.4.dev7.dist-info → synth_ai-0.2.4.dev8.dist-info}/RECORD +43 -25
  40. synth_ai/tui/__init__.py +0 -1
  41. synth_ai/tui/__main__.py +0 -13
  42. synth_ai/tui/cli/__init__.py +0 -1
  43. synth_ai/tui/cli/query_experiments.py +0 -164
  44. synth_ai/tui/cli/query_experiments_v3.py +0 -164
  45. synth_ai/tui/dashboard.py +0 -340
  46. synth_ai-0.2.4.dev7.dist-info/METADATA +0 -193
  47. {synth_ai-0.2.4.dev7.dist-info → synth_ai-0.2.4.dev8.dist-info}/WHEEL +0 -0
  48. {synth_ai-0.2.4.dev7.dist-info → synth_ai-0.2.4.dev8.dist-info}/entry_points.txt +0 -0
  49. {synth_ai-0.2.4.dev7.dist-info → synth_ai-0.2.4.dev8.dist-info}/licenses/LICENSE +0 -0
  50. {synth_ai-0.2.4.dev7.dist-info → synth_ai-0.2.4.dev8.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,43 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Dict, Optional
5
+
6
+
7
+ @dataclass
8
+ class FTJobConfig:
9
+ model: str
10
+ training_file_id: str
11
+ n_epochs: int = 1
12
+ batch_size: int = 1
13
+ upload_to_wasabi: bool = True
14
+
15
+ def hyperparameters(self) -> Dict[str, Any]:
16
+ if self.n_epochs < 1:
17
+ raise ValueError("n_epochs must be >= 1")
18
+ if self.batch_size < 1:
19
+ raise ValueError("batch_size must be >= 1")
20
+ return {"n_epochs": int(self.n_epochs), "batch_size": int(self.batch_size)}
21
+
22
+ def metadata(self) -> Dict[str, Any]: # type: ignore[override]
23
+ return {"upload_to_wasabi": bool(self.upload_to_wasabi)}
24
+
25
+
26
+ @dataclass
27
+ class RLJobConfig:
28
+ model: str
29
+ task_app_url: str
30
+ trainer_id: str
31
+ batch_size: int = 1
32
+ group_size: int = 2
33
+ job_config_id: Optional[str] = None
34
+ inline_config: Optional[Dict[str, Any]] = None
35
+
36
+ def trainer_dict(self) -> Dict[str, Any]:
37
+ if self.batch_size < 1:
38
+ raise ValueError("batch_size must be >= 1")
39
+ if self.group_size < 2:
40
+ raise ValueError("group_size must be >= 2")
41
+ return {"batch_size": int(self.batch_size), "group_size": int(self.group_size)}
42
+
43
+
@@ -0,0 +1,29 @@
1
+ from __future__ import annotations
2
+
3
+ # Terminal statuses normalized across FT and RL
4
+ TERMINAL_STATUSES = {
5
+ "succeeded",
6
+ "failed",
7
+ "cancelled",
8
+ "canceled",
9
+ "error",
10
+ "completed",
11
+ }
12
+
13
+ # Terminal event types (success/failure) across FT and RL
14
+ TERMINAL_EVENT_SUCCESS = {
15
+ "sft.completed",
16
+ "sft.workflow.completed",
17
+ "rl.job.completed",
18
+ "rl.train.completed",
19
+ "workflow.completed",
20
+ }
21
+
22
+ TERMINAL_EVENT_FAILURE = {
23
+ "sft.failed",
24
+ "sft.workflow.failed",
25
+ "rl.job.failed",
26
+ "workflow.failed",
27
+ }
28
+
29
+
@@ -0,0 +1,59 @@
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import Any, Dict, Optional
5
+
6
+ from ..http import AsyncHttpClient, HTTPError
7
+
8
+
9
+ class FtClient:
10
+ def __init__(self, base_url: str, api_key: str, *, timeout: float = 30.0) -> None:
11
+ self._base_url = base_url.rstrip("/")
12
+ self._api_key = api_key
13
+ self._timeout = timeout
14
+
15
+ async def upload_training_file(self, path: str | Path, *, purpose: str = "fine-tune") -> str:
16
+ p = Path(path)
17
+ content = p.read_bytes()
18
+ async with AsyncHttpClient(self._base_url, self._api_key, timeout=self._timeout) as http:
19
+ data = {"purpose": purpose}
20
+ files = {"file": (p.name, content, _infer_content_type(p.name))}
21
+ js = await http.post_multipart("/api/learning/files", data=data, files=files)
22
+ if not isinstance(js, dict) or "id" not in js:
23
+ raise HTTPError(status=500, url="/api/learning/files", message="invalid_upload_response", body_snippet=str(js)[:200])
24
+ return str(js["id"])
25
+
26
+ async def create_sft_job(
27
+ self,
28
+ *,
29
+ model: str,
30
+ training_file_id: str,
31
+ hyperparameters: Dict[str, Any],
32
+ metadata: Optional[Dict[str, Any]] = None,
33
+ ) -> Dict[str, Any]:
34
+ body = {
35
+ "training_type": "sft_offline",
36
+ "model": model,
37
+ "training_file_id": training_file_id,
38
+ "hyperparameters": dict(hyperparameters or {}),
39
+ "metadata": dict(metadata or {}),
40
+ }
41
+ async with AsyncHttpClient(self._base_url, self._api_key, timeout=self._timeout) as http:
42
+ return await http.post_json("/api/learning/jobs", json=body)
43
+
44
+ async def start_job(self, job_id: str) -> Dict[str, Any]:
45
+ async with AsyncHttpClient(self._base_url, self._api_key, timeout=self._timeout) as http:
46
+ return await http.post_json(f"/api/learning/jobs/{job_id}/start", json={})
47
+
48
+
49
+ def _infer_content_type(filename: str) -> str:
50
+ name = filename.lower()
51
+ if name.endswith(".jsonl"):
52
+ return "application/jsonl"
53
+ if name.endswith(".json"):
54
+ return "application/json"
55
+ if name.endswith(".txt"):
56
+ return "text/plain"
57
+ return "application/octet-stream"
58
+
59
+
@@ -0,0 +1,43 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Dict, Optional
4
+ import aiohttp
5
+
6
+ from ..http import AsyncHttpClient
7
+
8
+
9
+ def _api_base(b: str) -> str:
10
+ b = (b or "").rstrip("/")
11
+ return b if b.endswith("/api") else f"{b}/api"
12
+
13
+
14
+ async def backend_health(base_url: str, api_key: str) -> Dict[str, Any]:
15
+ async with AsyncHttpClient(base_url, api_key, timeout=15.0) as http:
16
+ js = await http.get(f"{_api_base(base_url)}/health")
17
+ return {"ok": True, "raw": js}
18
+
19
+
20
+ async def task_app_health(task_app_url: str) -> Dict[str, Any]:
21
+ # Delegate to central task module for consistency
22
+ from synth_ai.task.health import task_app_health as _th
23
+
24
+ return await _th(task_app_url)
25
+
26
+
27
+ async def pricing_preflight(base_url: str, api_key: str, *, job_type: str, gpu_type: str, estimated_seconds: float, container_count: int) -> Dict[str, Any]:
28
+ body = {
29
+ "job_type": job_type,
30
+ "gpu_type": gpu_type,
31
+ "estimated_seconds": float(estimated_seconds or 0.0),
32
+ "container_count": int(container_count or 1),
33
+ }
34
+ async with AsyncHttpClient(base_url, api_key, timeout=30.0) as http:
35
+ js = await http.post_json(f"{_api_base(base_url)}/v1/pricing/preflight", json=body)
36
+ return js if isinstance(js, dict) else {"raw": js}
37
+
38
+
39
+ async def balance_autumn_normalized(base_url: str, api_key: str) -> Dict[str, Any]:
40
+ async with AsyncHttpClient(base_url, api_key, timeout=30.0) as http:
41
+ js = await http.get(f"{_api_base(base_url)}/v1/balance/autumn-normalized")
42
+ return js if isinstance(js, dict) else {"raw": js}
43
+
@@ -0,0 +1,205 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Callable, Dict, List, Optional
4
+ import time
5
+
6
+ from .constants import TERMINAL_EVENT_FAILURE, TERMINAL_EVENT_SUCCESS, TERMINAL_STATUSES
7
+ from ..http import AsyncHttpClient, sleep
8
+
9
+
10
+ def _api_base(b: str) -> str:
11
+ b = (b or "").rstrip("/")
12
+ return b if b.endswith("/api") else f"{b}/api"
13
+
14
+
15
+ class JobsApiResolver:
16
+ def __init__(self, base_url: str, *, strict: bool) -> None:
17
+ self._base = _api_base(base_url)
18
+ self._strict = strict
19
+
20
+ def status_urls(self, job_id: str) -> List[str]:
21
+ if self._strict:
22
+ return [f"{self._base}/learning/jobs/{job_id}"]
23
+ return [
24
+ f"{self._base}/learning/jobs/{job_id}",
25
+ f"{self._base}/rl/jobs/{job_id}",
26
+ f"{self._base}/orchestration/jobs/{job_id}",
27
+ ]
28
+
29
+ def events_urls(self, job_id: str, since: int) -> List[str]:
30
+ if self._strict:
31
+ return [f"{self._base}/learning/jobs/{job_id}/events?since_seq={since}&limit=200"]
32
+ return [
33
+ f"{self._base}/learning/jobs/{job_id}/events?since_seq={since}&limit=200",
34
+ f"{self._base}/orchestration/jobs/{job_id}/events?since_seq={since}&limit=200",
35
+ # RL /jobs/{id}/events is SSE in backend; avoid in JSON poller
36
+ ]
37
+
38
+ def metrics_url(self, job_id: str, after_step: int) -> str:
39
+ return f"{self._base}/learning/jobs/{job_id}/metrics?after_step={after_step}&limit=200"
40
+
41
+
42
+ class JobHandle:
43
+ def __init__(self, base_url: str, api_key: str, job_id: str, *, strict: bool = True, timeout: float = 600.0) -> None:
44
+ self.base_url = base_url.rstrip("/")
45
+ self.api_key = api_key
46
+ self.job_id = job_id
47
+ self.strict = strict
48
+ self.timeout = timeout
49
+
50
+ async def poll_until_terminal(
51
+ self,
52
+ *,
53
+ interval_seconds: float = 2.0,
54
+ max_seconds: float | None = None,
55
+ empty_polls_threshold: int = 5,
56
+ startup_deadline_s: int = 45,
57
+ on_event: Optional[Callable[[Dict[str, Any]], None]] = None,
58
+ on_metric: Optional[Callable[[Dict[str, Any]], None]] = None,
59
+ ) -> Dict[str, Any]:
60
+ last_seq_by_stream: Dict[str, int] = {}
61
+ events_job_id: Optional[str] = None
62
+ last_status: Optional[str] = None
63
+ last_step_by_name: Dict[str, int] = {}
64
+ empty_polls = 0
65
+ saw_any_event = False
66
+ start_t = time.time()
67
+ resolver = JobsApiResolver(self.base_url, strict=self.strict)
68
+ detected_fine_tuned_model: Optional[str] = None
69
+
70
+ async with AsyncHttpClient(self.base_url, self.api_key, timeout=self.timeout) as http:
71
+ while True:
72
+ # Status
73
+ status_data: Optional[Dict[str, Any]] = None
74
+ for su in resolver.status_urls(self.job_id):
75
+ try:
76
+ status_data = await http.get(su)
77
+ if isinstance(status_data, dict):
78
+ break
79
+ except Exception:
80
+ continue
81
+ status = str((status_data or {}).get("status") or "").lower()
82
+ if status_data:
83
+ linked = status_data.get("linked_job_id")
84
+ if isinstance(linked, str) and linked and linked != events_job_id:
85
+ events_job_id = linked
86
+ # Capture fine_tuned_model if already present on status
87
+ if not detected_fine_tuned_model:
88
+ ftm = status_data.get("fine_tuned_model")
89
+ if isinstance(ftm, str) and ftm:
90
+ detected_fine_tuned_model = ftm
91
+ if status and status != last_status:
92
+ last_status = status
93
+ if on_event:
94
+ try:
95
+ on_event({"type": "job.status", "message": status})
96
+ except Exception:
97
+ pass
98
+
99
+ # Events
100
+ stream_ids = [self.job_id]
101
+ if events_job_id and events_job_id not in stream_ids:
102
+ stream_ids.append(events_job_id)
103
+ total_events_this_cycle = 0
104
+ terminal_event_seen = False
105
+ terminal_event_status: Optional[str] = None
106
+ for ev_id in stream_ids:
107
+ since = last_seq_by_stream.get(ev_id, 0)
108
+ for eu in resolver.events_urls(ev_id, since):
109
+ try:
110
+ ev_js = await http.get(eu)
111
+ except Exception:
112
+ continue
113
+ try:
114
+ events = (ev_js or {}).get("events") or (ev_js or {}).get("data") or []
115
+ if not isinstance(events, list):
116
+ events = []
117
+ except Exception:
118
+ events = []
119
+ total_events_this_cycle += len(events)
120
+ if events:
121
+ saw_any_event = True
122
+ for e in events:
123
+ seq_val = int(e.get("seq") or 0)
124
+ if seq_val <= last_seq_by_stream.get(ev_id, 0):
125
+ continue
126
+ last_seq_by_stream[ev_id] = seq_val
127
+ if on_event:
128
+ try:
129
+ on_event(e)
130
+ except Exception:
131
+ pass
132
+ et = str(e.get("type") or e.get("event_type") or "").lower()
133
+ # Capture fine_tuned_model from event data when available
134
+ if not detected_fine_tuned_model:
135
+ try:
136
+ data_obj = e.get("data") or {}
137
+ ftm = data_obj.get("fine_tuned_model") if isinstance(data_obj, dict) else None
138
+ if isinstance(ftm, str) and ftm:
139
+ detected_fine_tuned_model = ftm
140
+ except Exception:
141
+ pass
142
+ if et in TERMINAL_EVENT_SUCCESS:
143
+ terminal_event_seen = True
144
+ terminal_event_status = "succeeded"
145
+ elif et in TERMINAL_EVENT_FAILURE:
146
+ terminal_event_seen = True
147
+ terminal_event_status = "failed"
148
+
149
+ # Metrics
150
+ try:
151
+ after = max(last_step_by_name.values()) if last_step_by_name else -1
152
+ mu = resolver.metrics_url(self.job_id, after)
153
+ md = await http.get(mu)
154
+ for p in (md or {}).get("points", []):
155
+ name = str(p.get("name") or "")
156
+ step = int(p.get("step") or -1)
157
+ if step <= last_step_by_name.get(name, -1):
158
+ continue
159
+ last_step_by_name[name] = step
160
+ if on_metric:
161
+ try:
162
+ on_metric(p)
163
+ except Exception:
164
+ pass
165
+ except Exception:
166
+ pass
167
+
168
+ # Terminal decisions
169
+ if terminal_event_seen or (status and status in TERMINAL_STATUSES):
170
+ # Best-effort enrichment of final result with fine_tuned_model
171
+ result_status = terminal_event_status or status or "completed"
172
+ final_res: Dict[str, Any] = {"status": result_status, "job_id": self.job_id}
173
+ if not detected_fine_tuned_model:
174
+ # Briefly try to re-fetch status to see if fine_tuned_model is persisted
175
+ try:
176
+ for su in resolver.status_urls(self.job_id):
177
+ try:
178
+ final_status = await http.get(su)
179
+ if isinstance(final_status, dict):
180
+ ftm2 = final_status.get("fine_tuned_model")
181
+ if isinstance(ftm2, str) and ftm2:
182
+ detected_fine_tuned_model = ftm2
183
+ break
184
+ except Exception:
185
+ continue
186
+ except Exception:
187
+ pass
188
+ if detected_fine_tuned_model:
189
+ final_res["fine_tuned_model"] = detected_fine_tuned_model
190
+ return final_res
191
+
192
+ # Guards (relaxed): do not abort on consecutive empty polls
193
+ if total_events_this_cycle == 0:
194
+ empty_polls += 1
195
+ else:
196
+ empty_polls = 0
197
+ if not saw_any_event and (time.time() - start_t) > int(startup_deadline_s):
198
+ raise AssertionError(
199
+ f"No events observed within startup window ({startup_deadline_s}s). Investigate event streaming."
200
+ )
201
+ await sleep(interval_seconds)
202
+ if max_seconds is not None and (time.time() - start_t) >= max_seconds:
203
+ raise TimeoutError(f"Polling timed out after {max_seconds}s for job {self.job_id}")
204
+
205
+
@@ -0,0 +1,256 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Dict, List, Optional, Callable
4
+ import os
5
+ import time
6
+
7
+ from ..http import AsyncHttpClient, HTTPError, sleep
8
+
9
+
10
+ def _api_base(b: str) -> str:
11
+ b = (b or "").rstrip("/")
12
+ return b if b.endswith("/api") else f"{b}/api"
13
+
14
+
15
+ class RlClient:
16
+ """Lightweight RL client for provider-agnostic job control.
17
+
18
+ Notes:
19
+ - Uses learning/* for status/events/metrics and rl/* for creation/start.
20
+ - Trainer endpoints are resolved server-side via trainer_id.
21
+ """
22
+
23
+ def __init__(self, base_url: str, api_key: str, *, timeout: float = 600.0) -> None:
24
+ self._base_url = base_url.rstrip("/")
25
+ self._api_key = api_key
26
+ self._timeout = timeout
27
+
28
+ async def resolve_trainer_start_url(self, trainer_id: str) -> str:
29
+ """GET /api/rl/services/{id} → { training_start_url }"""
30
+ path = f"/api/rl/services/{trainer_id}"
31
+ async with AsyncHttpClient(self._base_url, self._api_key, timeout=30.0) as http:
32
+ js = await http.get(path)
33
+ if not isinstance(js, dict):
34
+ raise HTTPError(status=500, url=path, message="invalid_service_response", body_snippet=str(js)[:200])
35
+ start_url = js.get("training_start_url")
36
+ if not isinstance(start_url, str) or not start_url:
37
+ raise HTTPError(status=500, url=path, message="missing_training_start_url", body_snippet=str(js)[:200])
38
+ return start_url
39
+
40
+ async def create_job(
41
+ self,
42
+ *,
43
+ model: str,
44
+ task_app_url: str,
45
+ trainer: Dict[str, Any],
46
+ trainer_id: Optional[str] = None,
47
+ job_config_id: Optional[str] = None,
48
+ inline_config: Optional[Dict[str, Any]] = None,
49
+ ) -> Dict[str, Any]:
50
+ body = {
51
+ "job_type": "rl",
52
+ "data": {
53
+ "model": model,
54
+ "endpoint_base_url": task_app_url,
55
+ **({"job_config_id": job_config_id} if job_config_id else {}),
56
+ **({"config": inline_config} if inline_config else {}),
57
+ "trainer": {
58
+ "batch_size": int(trainer.get("batch_size", 1)),
59
+ "group_size": max(2, int(trainer.get("group_size", 2))),
60
+ },
61
+ },
62
+ }
63
+ async with AsyncHttpClient(self._base_url, self._api_key, timeout=self._timeout) as http:
64
+ js = await http.post_json(f"{_api_base(self._base_url)}/rl/jobs", json=body)
65
+ if not isinstance(js, dict):
66
+ raise HTTPError(status=500, url="/api/rl/jobs", message="invalid_create_response", body_snippet=str(js)[:200])
67
+ return js
68
+
69
+ async def start_job_if_supported(self, job_id: str) -> Optional[Dict[str, Any]]:
70
+ path = f"{_api_base(self._base_url)}/rl/jobs/{job_id}/start"
71
+ try:
72
+ async with AsyncHttpClient(self._base_url, self._api_key, timeout=30.0) as http:
73
+ return await http.post_json(path, json={})
74
+ except HTTPError as he: # noqa: PERF203
75
+ if he.status == 404:
76
+ return None
77
+ raise
78
+
79
+ async def get_job(self, job_id: str) -> Dict[str, Any]:
80
+ async with AsyncHttpClient(self._base_url, self._api_key, timeout=30.0) as http:
81
+ return await http.get(f"{_api_base(self._base_url)}/learning/jobs/{job_id}")
82
+
83
+ async def get_events(self, job_id: str, *, since_seq: int = 0, limit: int = 200) -> List[Dict[str, Any]]:
84
+ params = {"since_seq": since_seq, "limit": limit}
85
+ async with AsyncHttpClient(self._base_url, self._api_key, timeout=30.0) as http:
86
+ try:
87
+ js = await http.get(f"{_api_base(self._base_url)}/learning/jobs/{job_id}/events", params=params)
88
+ except HTTPError as he:
89
+ try:
90
+ print(
91
+ f"[poll] events HTTPError status={he.status} url={he.url} since_seq={since_seq} body={(he.body_snippet or '')[:200]}"
92
+ )
93
+ except Exception:
94
+ pass
95
+ raise
96
+ if isinstance(js, dict):
97
+ evs = js.get("events") or js.get("data")
98
+ if isinstance(evs, list):
99
+ return evs
100
+ return []
101
+
102
+ async def get_metrics(self, job_id: str, *, after_step: int = -1, limit: int = 200) -> List[Dict[str, Any]]:
103
+ params = {"after_step": after_step, "limit": limit}
104
+ async with AsyncHttpClient(self._base_url, self._api_key, timeout=30.0) as http:
105
+ js = await http.get(f"{_api_base(self._base_url)}/learning/jobs/{job_id}/metrics", params=params)
106
+ if isinstance(js, dict) and isinstance(js.get("points"), list):
107
+ return js["points"]
108
+ return []
109
+
110
+ async def poll_until_terminal(
111
+ self,
112
+ job_id: str,
113
+ *,
114
+ interval_seconds: float = 2.0,
115
+ max_seconds: float | None = None,
116
+ empty_polls_threshold: int = 5,
117
+ startup_deadline_s: int = 45,
118
+ on_event: Optional[Callable[[Dict[str, Any]], None]] = None,
119
+ on_metric: Optional[Callable[[Dict[str, Any]], None]] = None,
120
+ ) -> Dict[str, Any]:
121
+ last_seq_by_stream: Dict[str, int] = {}
122
+ events_job_id: Optional[str] = None
123
+ last_status: Optional[str] = None
124
+ last_step_by_name: Dict[str, int] = {}
125
+ empty_polls = 0
126
+ saw_any_event = False
127
+ start_t = time.time()
128
+ terminal = {"succeeded", "failed", "cancelled", "canceled", "error", "completed"}
129
+
130
+ while True:
131
+ status_data: Optional[Dict[str, Any]] = None
132
+ try:
133
+ status_data = await self.get_job(job_id)
134
+ except Exception:
135
+ status_data = None
136
+ if status_data is None:
137
+ try:
138
+ print(f"[poll] get_job returned None base={self._base_url} job_id={job_id}")
139
+ except Exception:
140
+ pass
141
+ status = str((status_data or {}).get("status") or "").lower()
142
+ if status_data:
143
+ linked = status_data.get("linked_job_id")
144
+ if isinstance(linked, str) and linked and linked != events_job_id:
145
+ events_job_id = linked
146
+ try:
147
+ print(f"[poll] discovered linked_job_id stream={events_job_id}")
148
+ except Exception:
149
+ pass
150
+ if status and status != last_status:
151
+ last_status = status
152
+ # Status transitions only to avoid log spam
153
+ if on_event:
154
+ try:
155
+ on_event({"type": "rl.status", "message": status})
156
+ except Exception:
157
+ pass
158
+
159
+ # Events
160
+ stream_ids = [job_id]
161
+ if events_job_id and events_job_id not in stream_ids:
162
+ stream_ids.append(events_job_id)
163
+ try:
164
+ print(f"[poll] streams={stream_ids} intervals={interval_seconds}s since_map={last_seq_by_stream} empty_polls={empty_polls}")
165
+ except Exception:
166
+ pass
167
+ total_events_this_cycle = 0
168
+ terminal_event_seen = False
169
+ terminal_event_status: Optional[str] = None
170
+ for ev_id in stream_ids:
171
+ since = last_seq_by_stream.get(ev_id, 0)
172
+ try:
173
+ events = await self.get_events(ev_id, since_seq=since, limit=200)
174
+ except HTTPError as he:
175
+ try:
176
+ print(f"[poll] get_events error status={he.status} url={he.url} since={since} body={(he.body_snippet or '')[:200]}")
177
+ except Exception:
178
+ pass
179
+ events = []
180
+ except Exception as e:
181
+ try:
182
+ print(f"[poll] get_events unexpected error ev_id={ev_id} since={since} err={type(e).__name__}: {e}")
183
+ except Exception:
184
+ pass
185
+ events = []
186
+ total_events_this_cycle += len(events)
187
+ if events:
188
+ saw_any_event = True
189
+ for e in events:
190
+ seq_val = int(e.get("seq") or 0)
191
+ if seq_val <= last_seq_by_stream.get(ev_id, 0):
192
+ continue
193
+ last_seq_by_stream[ev_id] = seq_val
194
+ if on_event:
195
+ try:
196
+ on_event(e)
197
+ except Exception:
198
+ pass
199
+ et = str(e.get("type") or e.get("event_type") or "").lower()
200
+ if et in ("rl.job.completed", "workflow.completed", "rl.train.completed"):
201
+ terminal_event_seen = True
202
+ terminal_event_status = "succeeded"
203
+ elif et in ("rl.job.failed", "workflow.failed"):
204
+ terminal_event_seen = True
205
+ terminal_event_status = "failed"
206
+
207
+ # Metrics
208
+ try:
209
+ after = max(last_step_by_name.values()) if last_step_by_name else -1
210
+ points = await self.get_metrics(job_id, after_step=after, limit=200)
211
+ for p in points:
212
+ name = str(p.get("name") or "")
213
+ step = int(p.get("step") or -1)
214
+ if step <= last_step_by_name.get(name, -1):
215
+ continue
216
+ last_step_by_name[name] = step
217
+ if on_metric:
218
+ try:
219
+ on_metric(p)
220
+ except Exception:
221
+ pass
222
+ except Exception:
223
+ pass
224
+
225
+ if terminal_event_seen:
226
+ return {"status": terminal_event_status or status or "completed", "job_id": job_id}
227
+ if status and status in terminal:
228
+ return {"status": status, "job_id": job_id}
229
+
230
+ if total_events_this_cycle == 0:
231
+ empty_polls += 1
232
+ else:
233
+ empty_polls = 0
234
+ if empty_polls >= max(1, int(empty_polls_threshold)):
235
+ try:
236
+ print(
237
+ f"[poll] threshold hit: empty_polls={empty_polls} >= {empty_polls_threshold} streams={stream_ids} last_seq_map={last_seq_by_stream}"
238
+ )
239
+ except Exception:
240
+ pass
241
+ raise AssertionError(f"No new events detected for {empty_polls_threshold} consecutive polls. Check event ingestion.")
242
+
243
+ if not saw_any_event and (time.time() - start_t) > int(startup_deadline_s):
244
+ try:
245
+ print(
246
+ f"[poll] startup window exceeded: {startup_deadline_s}s base={self._base_url} job={job_id} streams={stream_ids} last_seq_map={last_seq_by_stream}"
247
+ )
248
+ except Exception:
249
+ pass
250
+ raise AssertionError(f"No events observed within startup window ({startup_deadline_s}s). Investigate event streaming.")
251
+
252
+ await sleep(interval_seconds)
253
+ if max_seconds is not None and (time.time() - start_t) >= max_seconds:
254
+ raise TimeoutError(f"Polling timed out after {max_seconds}s for job {job_id}")
255
+
256
+