synth-ai 0.2.4.dev7__py3-none-any.whl → 0.2.4.dev8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. synth_ai/__init__.py +1 -1
  2. synth_ai/cli/balance.py +3 -15
  3. synth_ai/config/base_url.py +47 -0
  4. synth_ai/http.py +102 -0
  5. synth_ai/inference/__init__.py +7 -0
  6. synth_ai/inference/client.py +20 -0
  7. synth_ai/jobs/client.py +246 -0
  8. synth_ai/learning/__init__.py +24 -0
  9. synth_ai/learning/client.py +149 -0
  10. synth_ai/learning/config.py +43 -0
  11. synth_ai/learning/constants.py +29 -0
  12. synth_ai/learning/ft_client.py +59 -0
  13. synth_ai/learning/health.py +43 -0
  14. synth_ai/learning/jobs.py +205 -0
  15. synth_ai/learning/rl_client.py +256 -0
  16. synth_ai/learning/sse.py +58 -0
  17. synth_ai/learning/validators.py +48 -0
  18. synth_ai/lm/core/main_v3.py +13 -0
  19. synth_ai/lm/core/synth_models.py +48 -0
  20. synth_ai/lm/core/vendor_clients.py +9 -6
  21. synth_ai/lm/vendors/core/openai_api.py +31 -3
  22. synth_ai/lm/vendors/openai_standard.py +45 -14
  23. synth_ai/lm/vendors/supported/custom_endpoint.py +12 -2
  24. synth_ai/lm/vendors/synth_client.py +372 -28
  25. synth_ai/rl/__init__.py +30 -0
  26. synth_ai/rl/contracts.py +32 -0
  27. synth_ai/rl/env_keys.py +137 -0
  28. synth_ai/rl/secrets.py +19 -0
  29. synth_ai/scripts/verify_rewards.py +100 -0
  30. synth_ai/task/__init__.py +10 -0
  31. synth_ai/task/contracts.py +120 -0
  32. synth_ai/task/health.py +28 -0
  33. synth_ai/task/validators.py +12 -0
  34. synth_ai/tracing_v3/hooks.py +3 -1
  35. synth_ai/tracing_v3/session_tracer.py +123 -2
  36. synth_ai/tracing_v3/turso/manager.py +218 -0
  37. synth_ai/tracing_v3/turso/models.py +53 -0
  38. synth_ai-0.2.4.dev8.dist-info/METADATA +635 -0
  39. {synth_ai-0.2.4.dev7.dist-info → synth_ai-0.2.4.dev8.dist-info}/RECORD +43 -25
  40. synth_ai/tui/__init__.py +0 -1
  41. synth_ai/tui/__main__.py +0 -13
  42. synth_ai/tui/cli/__init__.py +0 -1
  43. synth_ai/tui/cli/query_experiments.py +0 -164
  44. synth_ai/tui/cli/query_experiments_v3.py +0 -164
  45. synth_ai/tui/dashboard.py +0 -340
  46. synth_ai-0.2.4.dev7.dist-info/METADATA +0 -193
  47. {synth_ai-0.2.4.dev7.dist-info → synth_ai-0.2.4.dev8.dist-info}/WHEEL +0 -0
  48. {synth_ai-0.2.4.dev7.dist-info → synth_ai-0.2.4.dev8.dist-info}/entry_points.txt +0 -0
  49. {synth_ai-0.2.4.dev7.dist-info → synth_ai-0.2.4.dev8.dist-info}/licenses/LICENSE +0 -0
  50. {synth_ai-0.2.4.dev7.dist-info → synth_ai-0.2.4.dev8.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,100 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Verify reward persistence in a traces database.
4
+
5
+ Usage:
6
+ uv run python -m synth_ai.scripts.verify_rewards --db /path/to/db.sqlite --min-reward 1
7
+ """
8
+
9
+ import argparse
10
+ import asyncio
11
+ import os
12
+ from typing import Dict
13
+
14
+ from sqlalchemy import text
15
+
16
+ from synth_ai.tracing_v3.turso.manager import AsyncSQLTraceManager
17
+
18
+
19
+ async def verify(db_path: str, min_reward: int) -> int:
20
+ db_url = db_path
21
+ if not db_url.startswith("sqlite+aiosqlite:///"):
22
+ db_url = f"sqlite+aiosqlite:///{os.path.abspath(db_path)}"
23
+
24
+ mgr = AsyncSQLTraceManager(db_url=db_url)
25
+ await mgr.initialize()
26
+
27
+ try:
28
+ async with mgr.session() as session:
29
+ # Sessions with outcome_rewards
30
+ q_good = text(
31
+ """
32
+ SELECT session_id, MAX(total_reward) as total_reward
33
+ FROM outcome_rewards
34
+ GROUP BY session_id
35
+ """
36
+ )
37
+ res = await session.execute(q_good)
38
+ outcomes = {row[0]: int(row[1]) for row in res.fetchall()}
39
+
40
+ # Sessions without outcome_rewards
41
+ q_missing = text(
42
+ """
43
+ SELECT s.session_id
44
+ FROM session_traces s
45
+ LEFT JOIN outcome_rewards o ON s.session_id = o.session_id
46
+ WHERE o.session_id IS NULL
47
+ """
48
+ )
49
+ res2 = await session.execute(q_missing)
50
+ missing = [row[0] for row in res2.fetchall()]
51
+
52
+ # Aggregate event_rewards per session (informational)
53
+ q_event = text(
54
+ """
55
+ SELECT session_id, COALESCE(SUM(reward_value), 0.0) as sum_rewards
56
+ FROM event_rewards
57
+ GROUP BY session_id
58
+ """
59
+ )
60
+ res3 = await session.execute(q_event)
61
+ event_sums: Dict[str, float] = {row[0]: float(row[1]) for row in res3.fetchall()}
62
+
63
+ print(f"Sessions with outcome_rewards: {len(outcomes)}")
64
+ print(f"Sessions missing outcome_rewards: {len(missing)}")
65
+ if missing:
66
+ print("Missing session_ids:", ", ".join(missing[:10]) + (" ..." if len(missing) > 10 else ""))
67
+
68
+ # Threshold check
69
+ qualifying = {sid: r for sid, r in outcomes.items() if r >= min_reward}
70
+ print(f"Sessions with total_reward >= {min_reward}: {len(qualifying)}")
71
+
72
+ # Show a small comparison snapshot
73
+ sample = list(qualifying.items())[:5]
74
+ for sid, tot in sample:
75
+ er = event_sums.get(sid, 0.0)
76
+ print(f" {sid}: outcome={tot}, sum(event_rewards)={er:.2f}")
77
+
78
+ # Exit non-zero if any sessions are missing outcome rewards
79
+ if missing:
80
+ return 2
81
+ if min_reward > 0 and not qualifying:
82
+ return 3
83
+ return 0
84
+ finally:
85
+ await mgr.close()
86
+
87
+
88
+ def main() -> int:
89
+ ap = argparse.ArgumentParser(description="Verify reward persistence in traces DB")
90
+ ap.add_argument("--db", required=True, help="Path to traces SQLite DB (aiosqlite)")
91
+ ap.add_argument("--min-reward", type=int, default=0, help="Minimum total_reward to consider qualifying")
92
+ args = ap.parse_args()
93
+
94
+ return asyncio.run(verify(args.db, args.min_reward))
95
+
96
+
97
+ if __name__ == "__main__":
98
+ raise SystemExit(main())
99
+
100
+
@@ -0,0 +1,10 @@
1
+ from .validators import validate_task_app_url
2
+ from .health import task_app_health
3
+ from .contracts import TaskAppContract, TaskAppEndpoints
4
+
5
+ __all__ = [
6
+ "validate_task_app_url",
7
+ "task_app_health",
8
+ "TaskAppContract",
9
+ "TaskAppEndpoints",
10
+ ]
@@ -0,0 +1,120 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Optional, Any, Dict, List
5
+ from pydantic import BaseModel
6
+
7
+
8
+ @dataclass(frozen=True)
9
+ class TaskAppEndpoints:
10
+ """Canonical Task App endpoint shapes used by RL trainers.
11
+
12
+ The Task App is an HTTP service (often deployed on Modal) that exposes:
13
+ - Health: GET /health
14
+ • Requires header X-API-Key (when ENVIRONMENT_API_KEY is configured)
15
+ • Returns { healthy: true }
16
+ - Environment lifecycle:
17
+ • POST /env/{env_name}/initialize → { env_id, observation }
18
+ • POST /env/{env_name}/step → { observation, reward, done, info }
19
+ • POST /env/{env_name}/terminate → { ok: true }
20
+ - Rollout (optional, unified schema):
21
+ • POST /rollout → { run_id, trajectories[], metrics, ... }
22
+ - Proxy (optional):
23
+ • POST /proxy/v1/chat/completions (for direct OpenAI calls from Task App)
24
+ """
25
+
26
+ health: str = "/health"
27
+ rollout: str = "/rollout"
28
+ proxy_chat_completions: str = "/proxy/v1/chat/completions"
29
+ env_initialize: str = "/env/{env_name}/initialize"
30
+ env_step: str = "/env/{env_name}/step"
31
+ env_terminate: str = "/env/{env_name}/terminate"
32
+
33
+
34
+ @dataclass(frozen=True)
35
+ class TaskAppContract:
36
+ """Requirements and expectations for a Task App used by RL trainers.
37
+
38
+ - Auth: ENVIRONMENT_API_KEY must be set in the Task App environment; requests include X-API-Key.
39
+ - Health: /health returns 200 and JSON; may verify X-API-Key header.
40
+ - Env API: initialize/step/terminate are present for the target env (e.g., CrafterClassic).
41
+ - Rollout API: optional; provides a single-call rollout for convenience/testing.
42
+ - Inference routing: policy config passes an inference_url (Synth backend or OpenAI proxy).
43
+ - URL: base must be reachable via HTTPS and should be under .modal.run in production.
44
+ """
45
+
46
+ base_url: str
47
+ env_name: Optional[str] = None
48
+ requires_api_key_header: bool = True
49
+
50
+
51
+ # --- Unified rollout schema used by Task App services and SDK utilities ---
52
+
53
+ class RolloutEnvSpec(BaseModel):
54
+ env_id: Optional[str] = None
55
+ env_name: Optional[str] = None
56
+ config: Dict[str, Any] = {}
57
+ seed: Optional[int] = None
58
+
59
+
60
+ class RolloutPolicySpec(BaseModel):
61
+ policy_id: Optional[str] = None
62
+ policy_name: Optional[str] = None
63
+ config: Dict[str, Any] = {}
64
+
65
+
66
+ class RolloutRecordConfig(BaseModel):
67
+ trajectories: bool = True
68
+ logprobs: bool = False
69
+ value: bool = False
70
+
71
+
72
+ class RolloutSafetyConfig(BaseModel):
73
+ max_ops: int = 100000
74
+ max_time_s: float = 3600.0
75
+
76
+
77
+ class RolloutRequest(BaseModel):
78
+ run_id: str
79
+ env: RolloutEnvSpec
80
+ policy: RolloutPolicySpec
81
+ ops: List[Dict[str, Any]] | List[str]
82
+ record: RolloutRecordConfig = RolloutRecordConfig()
83
+ on_done: str = "reset"
84
+ safety: RolloutSafetyConfig = RolloutSafetyConfig()
85
+ training_session_id: Optional[str] = None
86
+ synth_base_url: Optional[str] = None
87
+
88
+
89
+ class RolloutStep(BaseModel):
90
+ obs: Dict[str, Any]
91
+ tool_calls: List[Dict[str, Any]]
92
+ reward: Optional[float] = None
93
+ done: bool = False
94
+ truncated: Optional[bool] = None
95
+ info: Optional[Dict[str, Any]] = None
96
+
97
+
98
+ class RolloutTrajectory(BaseModel):
99
+ env_id: str
100
+ policy_id: str
101
+ steps: List[RolloutStep]
102
+ final: Optional[Dict[str, Any]] = None
103
+ length: int
104
+
105
+
106
+ class RolloutMetrics(BaseModel):
107
+ episode_returns: List[float]
108
+ mean_return: float
109
+ num_steps: int
110
+ num_episodes: int = 0
111
+
112
+
113
+ class RolloutResponse(BaseModel):
114
+ run_id: str
115
+ trajectories: List[RolloutTrajectory]
116
+ branches: Dict[str, List[str]] = {}
117
+ metrics: RolloutMetrics
118
+ aborted: bool = False
119
+ ops_executed: int = 0
120
+
@@ -0,0 +1,28 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Dict
4
+ import aiohttp
5
+
6
+
7
+ async def task_app_health(task_app_url: str) -> Dict[str, Any]:
8
+ """Probe a Task App base URL for basic reachability.
9
+
10
+ Behavior:
11
+ - Try HEAD first (follows redirects)
12
+ - Fallback to GET if HEAD is unsupported
13
+ - Returns {ok: bool, status?: int, error?: str}
14
+ """
15
+ try:
16
+ async with aiohttp.ClientSession() as session:
17
+ async with session.head(task_app_url, allow_redirects=True) as r:
18
+ if 200 <= r.status < 400:
19
+ return {"ok": True, "status": r.status}
20
+ async with aiohttp.ClientSession() as session:
21
+ async with session.get(task_app_url, allow_redirects=True) as r2:
22
+ if 200 <= r2.status < 400:
23
+ return {"ok": True, "status": r2.status}
24
+ return {"ok": False, "status": None}
25
+ except Exception as e:
26
+ return {"ok": False, "error": f"{type(e).__name__}: {e}"}
27
+
28
+
@@ -0,0 +1,12 @@
1
+ from __future__ import annotations
2
+
3
+ from urllib.parse import urlparse
4
+
5
+
6
+ def validate_task_app_url(url: str, *, name: str = "TASK_APP_BASE_URL") -> None:
7
+ """Validate a Task App base URL (scheme + host present)."""
8
+
9
+ p = urlparse(url)
10
+ if p.scheme not in ("http", "https") or not p.netloc:
11
+ raise ValueError(f"Invalid {name}: malformed: {url}")
12
+
@@ -200,7 +200,9 @@ def create_default_hooks() -> HookManager:
200
200
 
201
201
  # Example: Log session starts - useful for debugging and monitoring
202
202
  async def log_session_start(session_id: str, metadata: dict[str, Any]):
203
- print(f"Session started: {session_id}")
203
+ import os
204
+ if os.getenv("SYNTH_TRACE_VERBOSE", "0") in ("1", "true", "True"):
205
+ print(f"Session started: {session_id}")
204
206
 
205
207
  # Example: Validate events before recording - ensures data quality
206
208
  def validate_event(event_obj: BaseEvent) -> bool:
@@ -107,6 +107,10 @@ class SessionTracer:
107
107
  if self.auto_save and self.db is None:
108
108
  await self.initialize()
109
109
 
110
+ # Ensure session row exists for incremental writes
111
+ if self.db:
112
+ await self.db.ensure_session(session_id, created_at=self._current_trace.created_at, metadata=metadata or {})
113
+
110
114
  # Trigger hooks
111
115
  await self.hooks.trigger(
112
116
  "session_start", session_id=session_id, metadata=metadata or {}
@@ -152,6 +156,17 @@ class SessionTracer:
152
156
  "timestep_start", step=step, session_id=self._current_trace.session_id
153
157
  )
154
158
 
159
+ # Ensure timestep row exists in DB for incremental linkage
160
+ if self.db:
161
+ await self.db.ensure_timestep(
162
+ self._current_trace.session_id,
163
+ step_id=step.step_id,
164
+ step_index=step.step_index,
165
+ turn_number=turn_number,
166
+ started_at=step.timestamp,
167
+ metadata=metadata or {},
168
+ )
169
+
155
170
  return step
156
171
 
157
172
  async def end_timestep(self, step_id: str | None = None):
@@ -180,7 +195,7 @@ class SessionTracer:
180
195
  if step == self._current_step:
181
196
  self._current_step = None
182
197
 
183
- async def record_event(self, event: BaseEvent):
198
+ async def record_event(self, event: BaseEvent) -> int | None:
184
199
  """Record an event.
185
200
 
186
201
  Args:
@@ -201,6 +216,46 @@ class SessionTracer:
201
216
  if self._current_step:
202
217
  self._current_step.events.append(event)
203
218
 
219
+ # Persist incrementally if DB is available; return DB event id
220
+ if self.db:
221
+ timestep_db_id = None
222
+ if self._current_step:
223
+ # ensure timestep exists and get id
224
+ timestep_db_id = await self.db.ensure_timestep(
225
+ self._current_trace.session_id,
226
+ step_id=self._current_step.step_id,
227
+ step_index=self._current_step.step_index,
228
+ turn_number=self._current_step.turn_number,
229
+ started_at=self._current_step.timestamp,
230
+ completed_at=self._current_step.completed_at,
231
+ metadata=self._current_step.step_metadata,
232
+ )
233
+ event_id = await self.db.insert_event_row(
234
+ self._current_trace.session_id,
235
+ timestep_db_id=timestep_db_id,
236
+ event=event,
237
+ )
238
+ # Auto-insert an event reward if EnvironmentEvent carries reward
239
+ try:
240
+ from .abstractions import EnvironmentEvent # local import to avoid cycles
241
+
242
+ if isinstance(event, EnvironmentEvent) and event.reward is not None:
243
+ await self.record_event_reward(
244
+ event_id=event_id,
245
+ message_id=None,
246
+ turn_number=self._current_step.turn_number if self._current_step else None,
247
+ reward_value=float(event.reward),
248
+ reward_type="sparse",
249
+ key=None,
250
+ annotation=getattr(event, "event_metadata", None),
251
+ source="environment",
252
+ )
253
+ except Exception:
254
+ # Do not fail tracing if reward recording fails
255
+ pass
256
+ return event_id
257
+ return None
258
+
204
259
  async def record_message(
205
260
  self,
206
261
  content: str,
@@ -208,7 +263,7 @@ class SessionTracer:
208
263
  event_time: float | None = None,
209
264
  message_time: int | None = None,
210
265
  metadata: dict[str, Any] | None = None,
211
- ):
266
+ ) -> int | None:
212
267
  """Record a message.
213
268
 
214
269
  Args:
@@ -242,6 +297,31 @@ class SessionTracer:
242
297
  if self._current_step:
243
298
  self._current_step.markov_blanket_messages.append(msg)
244
299
 
300
+ # Persist incrementally and return DB message id
301
+ if self.db:
302
+ timestep_db_id = None
303
+ if self._current_step:
304
+ timestep_db_id = await self.db.ensure_timestep(
305
+ self._current_trace.session_id,
306
+ step_id=self._current_step.step_id,
307
+ step_index=self._current_step.step_index,
308
+ turn_number=self._current_step.turn_number,
309
+ started_at=self._current_step.timestamp,
310
+ completed_at=self._current_step.completed_at,
311
+ metadata=self._current_step.step_metadata,
312
+ )
313
+ message_id = await self.db.insert_message_row(
314
+ self._current_trace.session_id,
315
+ timestep_db_id=timestep_db_id,
316
+ message_type=message_type,
317
+ content=content,
318
+ event_time=msg.time_record.event_time,
319
+ message_time=msg.time_record.message_time,
320
+ metadata=msg.metadata,
321
+ )
322
+ return message_id
323
+ return None
324
+
245
325
  async def end_session(self, save: bool = None) -> SessionTrace:
246
326
  """End the current session.
247
327
 
@@ -341,3 +421,44 @@ class SessionTracer:
341
421
  if self.db:
342
422
  await self.db.close()
343
423
  self.db = None
424
+
425
+ # -------------------------------
426
+ # Reward recording helpers
427
+ # -------------------------------
428
+
429
+ async def record_outcome_reward(self, *, total_reward: int, achievements_count: int, total_steps: int) -> int | None:
430
+ """Record an episode-level outcome reward for the current session."""
431
+ if self._current_trace is None:
432
+ raise RuntimeError("No active session")
433
+ if self.db is None:
434
+ await self.initialize()
435
+ if self.db:
436
+ return await self.db.insert_outcome_reward(
437
+ self._current_trace.session_id,
438
+ total_reward=total_reward,
439
+ achievements_count=achievements_count,
440
+ total_steps=total_steps,
441
+ )
442
+ return None
443
+
444
+ # StepMetrics removed in favor of event_rewards; use record_event_reward for per-turn shaped values
445
+
446
+ async def record_event_reward(self, *, event_id: int, message_id: int | None = None, turn_number: int | None = None, reward_value: float = 0.0, reward_type: str | None = None, key: str | None = None, annotation: dict[str, Any] | None = None, source: str | None = None) -> int | None:
447
+ """Record a first-class event-level reward with optional annotations."""
448
+ if self._current_trace is None:
449
+ raise RuntimeError("No active session")
450
+ if self.db is None:
451
+ await self.initialize()
452
+ if self.db:
453
+ return await self.db.insert_event_reward(
454
+ self._current_trace.session_id,
455
+ event_id=event_id,
456
+ message_id=message_id,
457
+ turn_number=turn_number,
458
+ reward_value=reward_value,
459
+ reward_type=reward_type,
460
+ key=key,
461
+ annotation=annotation,
462
+ source=source,
463
+ )
464
+ return None