synth-ai 0.2.4.dev7__py3-none-any.whl → 0.2.4.dev8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/__init__.py +1 -1
- synth_ai/cli/balance.py +3 -15
- synth_ai/config/base_url.py +47 -0
- synth_ai/http.py +102 -0
- synth_ai/inference/__init__.py +7 -0
- synth_ai/inference/client.py +20 -0
- synth_ai/jobs/client.py +246 -0
- synth_ai/learning/__init__.py +24 -0
- synth_ai/learning/client.py +149 -0
- synth_ai/learning/config.py +43 -0
- synth_ai/learning/constants.py +29 -0
- synth_ai/learning/ft_client.py +59 -0
- synth_ai/learning/health.py +43 -0
- synth_ai/learning/jobs.py +205 -0
- synth_ai/learning/rl_client.py +256 -0
- synth_ai/learning/sse.py +58 -0
- synth_ai/learning/validators.py +48 -0
- synth_ai/lm/core/main_v3.py +13 -0
- synth_ai/lm/core/synth_models.py +48 -0
- synth_ai/lm/core/vendor_clients.py +9 -6
- synth_ai/lm/vendors/core/openai_api.py +31 -3
- synth_ai/lm/vendors/openai_standard.py +45 -14
- synth_ai/lm/vendors/supported/custom_endpoint.py +12 -2
- synth_ai/lm/vendors/synth_client.py +372 -28
- synth_ai/rl/__init__.py +30 -0
- synth_ai/rl/contracts.py +32 -0
- synth_ai/rl/env_keys.py +137 -0
- synth_ai/rl/secrets.py +19 -0
- synth_ai/scripts/verify_rewards.py +100 -0
- synth_ai/task/__init__.py +10 -0
- synth_ai/task/contracts.py +120 -0
- synth_ai/task/health.py +28 -0
- synth_ai/task/validators.py +12 -0
- synth_ai/tracing_v3/hooks.py +3 -1
- synth_ai/tracing_v3/session_tracer.py +123 -2
- synth_ai/tracing_v3/turso/manager.py +218 -0
- synth_ai/tracing_v3/turso/models.py +53 -0
- synth_ai-0.2.4.dev8.dist-info/METADATA +635 -0
- {synth_ai-0.2.4.dev7.dist-info → synth_ai-0.2.4.dev8.dist-info}/RECORD +43 -25
- synth_ai/tui/__init__.py +0 -1
- synth_ai/tui/__main__.py +0 -13
- synth_ai/tui/cli/__init__.py +0 -1
- synth_ai/tui/cli/query_experiments.py +0 -164
- synth_ai/tui/cli/query_experiments_v3.py +0 -164
- synth_ai/tui/dashboard.py +0 -340
- synth_ai-0.2.4.dev7.dist-info/METADATA +0 -193
- {synth_ai-0.2.4.dev7.dist-info → synth_ai-0.2.4.dev8.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.4.dev7.dist-info → synth_ai-0.2.4.dev8.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.4.dev7.dist-info → synth_ai-0.2.4.dev8.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.4.dev7.dist-info → synth_ai-0.2.4.dev8.dist-info}/top_level.txt +0 -0
@@ -30,6 +30,7 @@ import pandas as pd
|
|
30
30
|
from sqlalchemy import select, text, update
|
31
31
|
from sqlalchemy.exc import IntegrityError
|
32
32
|
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
|
33
|
+
from sqlalchemy import event
|
33
34
|
from sqlalchemy.orm import selectinload, sessionmaker
|
34
35
|
from sqlalchemy.pool import NullPool
|
35
36
|
|
@@ -59,6 +60,12 @@ from .models import (
|
|
59
60
|
from .models import (
|
60
61
|
SessionTrace as DBSessionTrace,
|
61
62
|
)
|
63
|
+
from .models import (
|
64
|
+
OutcomeReward as DBOutcomeReward,
|
65
|
+
)
|
66
|
+
from .models import (
|
67
|
+
EventReward as DBEventReward,
|
68
|
+
)
|
62
69
|
|
63
70
|
logger = logging.getLogger(__name__)
|
64
71
|
|
@@ -125,6 +132,18 @@ class AsyncSQLTraceManager:
|
|
125
132
|
connect_args=connect_args,
|
126
133
|
echo=CONFIG.echo_sql,
|
127
134
|
)
|
135
|
+
# Ensure PRAGMA foreign_keys=ON for every connection
|
136
|
+
try:
|
137
|
+
@event.listens_for(self.engine.sync_engine, "connect")
|
138
|
+
def _set_sqlite_pragma(dbapi_connection, connection_record): # type: ignore[no-redef]
|
139
|
+
try:
|
140
|
+
cursor = dbapi_connection.cursor()
|
141
|
+
cursor.execute("PRAGMA foreign_keys=ON")
|
142
|
+
cursor.close()
|
143
|
+
except Exception:
|
144
|
+
pass
|
145
|
+
except Exception:
|
146
|
+
pass
|
128
147
|
else:
|
129
148
|
connect_args = CONFIG.get_connect_args()
|
130
149
|
engine_kwargs = CONFIG.get_engine_kwargs()
|
@@ -538,3 +557,202 @@ class AsyncSQLTraceManager:
|
|
538
557
|
self.engine = None
|
539
558
|
self.SessionLocal = None
|
540
559
|
self._schema_ready = False
|
560
|
+
|
561
|
+
# -------------------------------
|
562
|
+
# Incremental insert helpers
|
563
|
+
# -------------------------------
|
564
|
+
|
565
|
+
async def ensure_session(self, session_id: str, *, created_at: datetime | None = None, metadata: dict[str, Any] | None = None):
|
566
|
+
"""Ensure a DB session row exists for session_id."""
|
567
|
+
async with self.session() as sess:
|
568
|
+
result = await sess.execute(select(DBSessionTrace).where(DBSessionTrace.session_id == session_id))
|
569
|
+
existing = result.scalar_one_or_none()
|
570
|
+
if existing:
|
571
|
+
return
|
572
|
+
row = DBSessionTrace(
|
573
|
+
session_id=session_id,
|
574
|
+
created_at=created_at or datetime.utcnow(),
|
575
|
+
num_timesteps=0,
|
576
|
+
num_events=0,
|
577
|
+
num_messages=0,
|
578
|
+
session_metadata=metadata or {},
|
579
|
+
)
|
580
|
+
sess.add(row)
|
581
|
+
await sess.commit()
|
582
|
+
|
583
|
+
async def ensure_timestep(self, session_id: str, *, step_id: str, step_index: int, turn_number: int | None = None, started_at: datetime | None = None, completed_at: datetime | None = None, metadata: dict[str, Any] | None = None) -> int:
|
584
|
+
"""Ensure a timestep row exists; return its DB id."""
|
585
|
+
async with self.session() as sess:
|
586
|
+
result = await sess.execute(
|
587
|
+
select(DBSessionTimestep).where(DBSessionTimestep.session_id == session_id, DBSessionTimestep.step_id == step_id)
|
588
|
+
)
|
589
|
+
row = result.scalar_one_or_none()
|
590
|
+
if row:
|
591
|
+
return row.id
|
592
|
+
row = DBSessionTimestep(
|
593
|
+
session_id=session_id,
|
594
|
+
step_id=step_id,
|
595
|
+
step_index=step_index,
|
596
|
+
turn_number=turn_number,
|
597
|
+
started_at=started_at or datetime.utcnow(),
|
598
|
+
completed_at=completed_at,
|
599
|
+
num_events=0,
|
600
|
+
num_messages=0,
|
601
|
+
step_metadata=metadata or {},
|
602
|
+
)
|
603
|
+
sess.add(row)
|
604
|
+
await sess.flush()
|
605
|
+
# increment session num_timesteps
|
606
|
+
await sess.execute(
|
607
|
+
update(DBSessionTrace)
|
608
|
+
.where(DBSessionTrace.session_id == session_id)
|
609
|
+
.values(num_timesteps=DBSessionTrace.num_timesteps + 1)
|
610
|
+
)
|
611
|
+
await sess.commit()
|
612
|
+
return row.id
|
613
|
+
|
614
|
+
async def insert_message_row(self, session_id: str, *, timestep_db_id: int | None, message_type: str, content: str, event_time: float | None = None, message_time: int | None = None, metadata: dict[str, Any] | None = None) -> int:
|
615
|
+
"""Insert a message and return its id."""
|
616
|
+
async with self.session() as sess:
|
617
|
+
db_msg = DBMessage(
|
618
|
+
session_id=session_id,
|
619
|
+
timestep_id=timestep_db_id,
|
620
|
+
message_type=message_type,
|
621
|
+
content=content,
|
622
|
+
event_time=event_time,
|
623
|
+
message_time=message_time,
|
624
|
+
message_metadata=metadata or {},
|
625
|
+
)
|
626
|
+
sess.add(db_msg)
|
627
|
+
await sess.flush()
|
628
|
+
# increment session num_messages
|
629
|
+
await sess.execute(
|
630
|
+
update(DBSessionTrace)
|
631
|
+
.where(DBSessionTrace.session_id == session_id)
|
632
|
+
.values(num_messages=DBSessionTrace.num_messages + 1)
|
633
|
+
)
|
634
|
+
await sess.commit()
|
635
|
+
return db_msg.id
|
636
|
+
|
637
|
+
async def insert_event_row(self, session_id: str, *, timestep_db_id: int | None, event: EnvironmentEvent | LMCAISEvent | RuntimeEvent, metadata_override: dict[str, Any] | None = None) -> int:
|
638
|
+
"""Insert an event and return its id."""
|
639
|
+
def to_cents(cost: float | None) -> int | None:
|
640
|
+
return int(cost * 100) if cost is not None else None
|
641
|
+
|
642
|
+
event_data: dict[str, Any] = {
|
643
|
+
"session_id": session_id,
|
644
|
+
"timestep_id": timestep_db_id,
|
645
|
+
"system_instance_id": event.system_instance_id,
|
646
|
+
"event_time": event.time_record.event_time,
|
647
|
+
"message_time": event.time_record.message_time,
|
648
|
+
"event_metadata_json": metadata_override or event.metadata or {},
|
649
|
+
"event_extra_metadata": getattr(event, "event_metadata", None),
|
650
|
+
}
|
651
|
+
if isinstance(event, LMCAISEvent):
|
652
|
+
call_records_data = None
|
653
|
+
if getattr(event, "call_records", None):
|
654
|
+
from dataclasses import asdict
|
655
|
+
|
656
|
+
call_records_data = [asdict(record) for record in event.call_records]
|
657
|
+
event_data.update({
|
658
|
+
"event_type": "cais",
|
659
|
+
"model_name": event.model_name,
|
660
|
+
"provider": event.provider,
|
661
|
+
"input_tokens": event.input_tokens,
|
662
|
+
"output_tokens": event.output_tokens,
|
663
|
+
"total_tokens": event.total_tokens,
|
664
|
+
"cost_usd": to_cents(event.cost_usd),
|
665
|
+
"latency_ms": event.latency_ms,
|
666
|
+
"span_id": event.span_id,
|
667
|
+
"trace_id": event.trace_id,
|
668
|
+
"system_state_before": event.system_state_before,
|
669
|
+
"system_state_after": event.system_state_after,
|
670
|
+
"call_records": call_records_data,
|
671
|
+
})
|
672
|
+
elif isinstance(event, EnvironmentEvent):
|
673
|
+
event_data.update({
|
674
|
+
"event_type": "environment",
|
675
|
+
"reward": event.reward,
|
676
|
+
"terminated": event.terminated,
|
677
|
+
"truncated": event.truncated,
|
678
|
+
"system_state_before": event.system_state_before,
|
679
|
+
"system_state_after": event.system_state_after,
|
680
|
+
})
|
681
|
+
elif isinstance(event, RuntimeEvent):
|
682
|
+
event_data.update({
|
683
|
+
"event_type": "runtime",
|
684
|
+
"event_metadata_json": {**(event.metadata or {}), "actions": event.actions},
|
685
|
+
})
|
686
|
+
else:
|
687
|
+
event_data["event_type"] = event.__class__.__name__.lower()
|
688
|
+
|
689
|
+
async with self.session() as sess:
|
690
|
+
db_event = DBEvent(**event_data)
|
691
|
+
sess.add(db_event)
|
692
|
+
await sess.flush()
|
693
|
+
# increment session num_events
|
694
|
+
await sess.execute(
|
695
|
+
update(DBSessionTrace)
|
696
|
+
.where(DBSessionTrace.session_id == session_id)
|
697
|
+
.values(num_events=DBSessionTrace.num_events + 1)
|
698
|
+
)
|
699
|
+
await sess.commit()
|
700
|
+
return db_event.id
|
701
|
+
|
702
|
+
# -------------------------------
|
703
|
+
# Reward helpers
|
704
|
+
# -------------------------------
|
705
|
+
|
706
|
+
async def insert_outcome_reward(self, session_id: str, *, total_reward: int, achievements_count: int, total_steps: int) -> int:
|
707
|
+
async with self.session() as sess:
|
708
|
+
row = DBOutcomeReward(
|
709
|
+
session_id=session_id,
|
710
|
+
total_reward=total_reward,
|
711
|
+
achievements_count=achievements_count,
|
712
|
+
total_steps=total_steps,
|
713
|
+
)
|
714
|
+
sess.add(row)
|
715
|
+
await sess.flush()
|
716
|
+
await sess.commit()
|
717
|
+
return row.id
|
718
|
+
|
719
|
+
async def insert_event_reward(self, session_id: str, *, event_id: int, message_id: int | None = None, turn_number: int | None = None, reward_value: float = 0.0, reward_type: str | None = None, key: str | None = None, annotation: dict[str, Any] | None = None, source: str | None = None) -> int:
|
720
|
+
async with self.session() as sess:
|
721
|
+
row = DBEventReward(
|
722
|
+
event_id=event_id,
|
723
|
+
session_id=session_id,
|
724
|
+
message_id=message_id,
|
725
|
+
turn_number=turn_number,
|
726
|
+
reward_value=reward_value,
|
727
|
+
reward_type=reward_type,
|
728
|
+
key=key,
|
729
|
+
annotation=annotation or {},
|
730
|
+
source=source,
|
731
|
+
)
|
732
|
+
sess.add(row)
|
733
|
+
await sess.flush()
|
734
|
+
await sess.commit()
|
735
|
+
return row.id
|
736
|
+
|
737
|
+
async def get_outcome_rewards(self) -> list[dict[str, Any]]:
|
738
|
+
async with self.session() as sess:
|
739
|
+
result = await sess.execute(select(DBOutcomeReward))
|
740
|
+
rows = result.scalars().all()
|
741
|
+
return [
|
742
|
+
{
|
743
|
+
"id": r.id,
|
744
|
+
"session_id": r.session_id,
|
745
|
+
"total_reward": r.total_reward,
|
746
|
+
"achievements_count": r.achievements_count,
|
747
|
+
"total_steps": r.total_steps,
|
748
|
+
"created_at": r.created_at,
|
749
|
+
}
|
750
|
+
for r in rows
|
751
|
+
]
|
752
|
+
|
753
|
+
async def get_outcome_rewards_by_min_reward(self, min_reward: int) -> list[str]:
|
754
|
+
async with self.session() as sess:
|
755
|
+
result = await sess.execute(
|
756
|
+
select(DBOutcomeReward.session_id).where(DBOutcomeReward.total_reward >= min_reward)
|
757
|
+
)
|
758
|
+
return [row[0] for row in result.all()]
|
@@ -408,3 +408,56 @@ analytics_views = {
|
|
408
408
|
GROUP BY e.experiment_id
|
409
409
|
""",
|
410
410
|
}
|
411
|
+
|
412
|
+
|
413
|
+
# Reward persistence tables
|
414
|
+
|
415
|
+
|
416
|
+
class OutcomeReward(Base):
|
417
|
+
"""Episode-level rewards/outcomes per session.
|
418
|
+
|
419
|
+
Stores per-episode summary including total_reward (e.g., unique achievements),
|
420
|
+
achievements_count, and total_steps. Used for filtering episodes by outcome.
|
421
|
+
"""
|
422
|
+
|
423
|
+
__tablename__ = "outcome_rewards"
|
424
|
+
|
425
|
+
id = Column(Integer, primary_key=True, autoincrement=True)
|
426
|
+
session_id = Column(String, ForeignKey("session_traces.session_id"), nullable=False)
|
427
|
+
total_reward = Column(Integer, nullable=False)
|
428
|
+
achievements_count = Column(Integer, nullable=False, default=0)
|
429
|
+
total_steps = Column(Integer, nullable=False, default=0)
|
430
|
+
created_at = Column(DateTime, default=func.current_timestamp(), nullable=False)
|
431
|
+
|
432
|
+
__table_args__ = (
|
433
|
+
Index("idx_outcome_rewards_session", "session_id"),
|
434
|
+
Index("idx_outcome_rewards_total", "total_reward"),
|
435
|
+
)
|
436
|
+
|
437
|
+
|
438
|
+
class EventReward(Base):
|
439
|
+
"""First-class event-level rewards with annotations.
|
440
|
+
|
441
|
+
Links to an event and session. `message_id` is optional.
|
442
|
+
"""
|
443
|
+
|
444
|
+
__tablename__ = "event_rewards"
|
445
|
+
|
446
|
+
id = Column(Integer, primary_key=True, autoincrement=True)
|
447
|
+
event_id = Column(Integer, ForeignKey("events.id"), nullable=False)
|
448
|
+
session_id = Column(String, ForeignKey("session_traces.session_id"), nullable=False)
|
449
|
+
message_id = Column(Integer, ForeignKey("messages.id"), nullable=True)
|
450
|
+
turn_number = Column(Integer, nullable=True)
|
451
|
+
reward_value = Column(Float, nullable=False, default=0.0)
|
452
|
+
reward_type = Column(String, nullable=True) # shaped | sparse | achievement | penalty | evaluator | human
|
453
|
+
key = Column(String, nullable=True) # e.g., achievement name
|
454
|
+
annotation = Column(JSONText) # free-form JSON
|
455
|
+
source = Column(String, nullable=True) # environment | runner | evaluator | human
|
456
|
+
created_at = Column(DateTime, default=func.current_timestamp(), nullable=False)
|
457
|
+
|
458
|
+
__table_args__ = (
|
459
|
+
Index("idx_event_rewards_session", "session_id"),
|
460
|
+
Index("idx_event_rewards_event", "event_id"),
|
461
|
+
Index("idx_event_rewards_type", "reward_type"),
|
462
|
+
Index("idx_event_rewards_key", "key"),
|
463
|
+
)
|