synth-ai 0.2.4.dev7__py3-none-any.whl → 0.2.4.dev8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. synth_ai/__init__.py +1 -1
  2. synth_ai/cli/balance.py +3 -15
  3. synth_ai/config/base_url.py +47 -0
  4. synth_ai/http.py +102 -0
  5. synth_ai/inference/__init__.py +7 -0
  6. synth_ai/inference/client.py +20 -0
  7. synth_ai/jobs/client.py +246 -0
  8. synth_ai/learning/__init__.py +24 -0
  9. synth_ai/learning/client.py +149 -0
  10. synth_ai/learning/config.py +43 -0
  11. synth_ai/learning/constants.py +29 -0
  12. synth_ai/learning/ft_client.py +59 -0
  13. synth_ai/learning/health.py +43 -0
  14. synth_ai/learning/jobs.py +205 -0
  15. synth_ai/learning/rl_client.py +256 -0
  16. synth_ai/learning/sse.py +58 -0
  17. synth_ai/learning/validators.py +48 -0
  18. synth_ai/lm/core/main_v3.py +13 -0
  19. synth_ai/lm/core/synth_models.py +48 -0
  20. synth_ai/lm/core/vendor_clients.py +9 -6
  21. synth_ai/lm/vendors/core/openai_api.py +31 -3
  22. synth_ai/lm/vendors/openai_standard.py +45 -14
  23. synth_ai/lm/vendors/supported/custom_endpoint.py +12 -2
  24. synth_ai/lm/vendors/synth_client.py +372 -28
  25. synth_ai/rl/__init__.py +30 -0
  26. synth_ai/rl/contracts.py +32 -0
  27. synth_ai/rl/env_keys.py +137 -0
  28. synth_ai/rl/secrets.py +19 -0
  29. synth_ai/scripts/verify_rewards.py +100 -0
  30. synth_ai/task/__init__.py +10 -0
  31. synth_ai/task/contracts.py +120 -0
  32. synth_ai/task/health.py +28 -0
  33. synth_ai/task/validators.py +12 -0
  34. synth_ai/tracing_v3/hooks.py +3 -1
  35. synth_ai/tracing_v3/session_tracer.py +123 -2
  36. synth_ai/tracing_v3/turso/manager.py +218 -0
  37. synth_ai/tracing_v3/turso/models.py +53 -0
  38. synth_ai-0.2.4.dev8.dist-info/METADATA +635 -0
  39. {synth_ai-0.2.4.dev7.dist-info → synth_ai-0.2.4.dev8.dist-info}/RECORD +43 -25
  40. synth_ai/tui/__init__.py +0 -1
  41. synth_ai/tui/__main__.py +0 -13
  42. synth_ai/tui/cli/__init__.py +0 -1
  43. synth_ai/tui/cli/query_experiments.py +0 -164
  44. synth_ai/tui/cli/query_experiments_v3.py +0 -164
  45. synth_ai/tui/dashboard.py +0 -340
  46. synth_ai-0.2.4.dev7.dist-info/METADATA +0 -193
  47. {synth_ai-0.2.4.dev7.dist-info → synth_ai-0.2.4.dev8.dist-info}/WHEEL +0 -0
  48. {synth_ai-0.2.4.dev7.dist-info → synth_ai-0.2.4.dev8.dist-info}/entry_points.txt +0 -0
  49. {synth_ai-0.2.4.dev7.dist-info → synth_ai-0.2.4.dev8.dist-info}/licenses/LICENSE +0 -0
  50. {synth_ai-0.2.4.dev7.dist-info → synth_ai-0.2.4.dev8.dist-info}/top_level.txt +0 -0
@@ -30,6 +30,7 @@ import pandas as pd
30
30
  from sqlalchemy import select, text, update
31
31
  from sqlalchemy.exc import IntegrityError
32
32
  from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
33
+ from sqlalchemy import event
33
34
  from sqlalchemy.orm import selectinload, sessionmaker
34
35
  from sqlalchemy.pool import NullPool
35
36
 
@@ -59,6 +60,12 @@ from .models import (
59
60
  from .models import (
60
61
  SessionTrace as DBSessionTrace,
61
62
  )
63
+ from .models import (
64
+ OutcomeReward as DBOutcomeReward,
65
+ )
66
+ from .models import (
67
+ EventReward as DBEventReward,
68
+ )
62
69
 
63
70
  logger = logging.getLogger(__name__)
64
71
 
@@ -125,6 +132,18 @@ class AsyncSQLTraceManager:
125
132
  connect_args=connect_args,
126
133
  echo=CONFIG.echo_sql,
127
134
  )
135
+ # Ensure PRAGMA foreign_keys=ON for every connection
136
+ try:
137
+ @event.listens_for(self.engine.sync_engine, "connect")
138
+ def _set_sqlite_pragma(dbapi_connection, connection_record): # type: ignore[no-redef]
139
+ try:
140
+ cursor = dbapi_connection.cursor()
141
+ cursor.execute("PRAGMA foreign_keys=ON")
142
+ cursor.close()
143
+ except Exception:
144
+ pass
145
+ except Exception:
146
+ pass
128
147
  else:
129
148
  connect_args = CONFIG.get_connect_args()
130
149
  engine_kwargs = CONFIG.get_engine_kwargs()
@@ -538,3 +557,202 @@ class AsyncSQLTraceManager:
538
557
  self.engine = None
539
558
  self.SessionLocal = None
540
559
  self._schema_ready = False
560
+
561
+ # -------------------------------
562
+ # Incremental insert helpers
563
+ # -------------------------------
564
+
565
+ async def ensure_session(self, session_id: str, *, created_at: datetime | None = None, metadata: dict[str, Any] | None = None):
566
+ """Ensure a DB session row exists for session_id."""
567
+ async with self.session() as sess:
568
+ result = await sess.execute(select(DBSessionTrace).where(DBSessionTrace.session_id == session_id))
569
+ existing = result.scalar_one_or_none()
570
+ if existing:
571
+ return
572
+ row = DBSessionTrace(
573
+ session_id=session_id,
574
+ created_at=created_at or datetime.utcnow(),
575
+ num_timesteps=0,
576
+ num_events=0,
577
+ num_messages=0,
578
+ session_metadata=metadata or {},
579
+ )
580
+ sess.add(row)
581
+ await sess.commit()
582
+
583
+ async def ensure_timestep(self, session_id: str, *, step_id: str, step_index: int, turn_number: int | None = None, started_at: datetime | None = None, completed_at: datetime | None = None, metadata: dict[str, Any] | None = None) -> int:
584
+ """Ensure a timestep row exists; return its DB id."""
585
+ async with self.session() as sess:
586
+ result = await sess.execute(
587
+ select(DBSessionTimestep).where(DBSessionTimestep.session_id == session_id, DBSessionTimestep.step_id == step_id)
588
+ )
589
+ row = result.scalar_one_or_none()
590
+ if row:
591
+ return row.id
592
+ row = DBSessionTimestep(
593
+ session_id=session_id,
594
+ step_id=step_id,
595
+ step_index=step_index,
596
+ turn_number=turn_number,
597
+ started_at=started_at or datetime.utcnow(),
598
+ completed_at=completed_at,
599
+ num_events=0,
600
+ num_messages=0,
601
+ step_metadata=metadata or {},
602
+ )
603
+ sess.add(row)
604
+ await sess.flush()
605
+ # increment session num_timesteps
606
+ await sess.execute(
607
+ update(DBSessionTrace)
608
+ .where(DBSessionTrace.session_id == session_id)
609
+ .values(num_timesteps=DBSessionTrace.num_timesteps + 1)
610
+ )
611
+ await sess.commit()
612
+ return row.id
613
+
614
+ async def insert_message_row(self, session_id: str, *, timestep_db_id: int | None, message_type: str, content: str, event_time: float | None = None, message_time: int | None = None, metadata: dict[str, Any] | None = None) -> int:
615
+ """Insert a message and return its id."""
616
+ async with self.session() as sess:
617
+ db_msg = DBMessage(
618
+ session_id=session_id,
619
+ timestep_id=timestep_db_id,
620
+ message_type=message_type,
621
+ content=content,
622
+ event_time=event_time,
623
+ message_time=message_time,
624
+ message_metadata=metadata or {},
625
+ )
626
+ sess.add(db_msg)
627
+ await sess.flush()
628
+ # increment session num_messages
629
+ await sess.execute(
630
+ update(DBSessionTrace)
631
+ .where(DBSessionTrace.session_id == session_id)
632
+ .values(num_messages=DBSessionTrace.num_messages + 1)
633
+ )
634
+ await sess.commit()
635
+ return db_msg.id
636
+
637
+ async def insert_event_row(self, session_id: str, *, timestep_db_id: int | None, event: EnvironmentEvent | LMCAISEvent | RuntimeEvent, metadata_override: dict[str, Any] | None = None) -> int:
638
+ """Insert an event and return its id."""
639
+ def to_cents(cost: float | None) -> int | None:
640
+ return int(cost * 100) if cost is not None else None
641
+
642
+ event_data: dict[str, Any] = {
643
+ "session_id": session_id,
644
+ "timestep_id": timestep_db_id,
645
+ "system_instance_id": event.system_instance_id,
646
+ "event_time": event.time_record.event_time,
647
+ "message_time": event.time_record.message_time,
648
+ "event_metadata_json": metadata_override or event.metadata or {},
649
+ "event_extra_metadata": getattr(event, "event_metadata", None),
650
+ }
651
+ if isinstance(event, LMCAISEvent):
652
+ call_records_data = None
653
+ if getattr(event, "call_records", None):
654
+ from dataclasses import asdict
655
+
656
+ call_records_data = [asdict(record) for record in event.call_records]
657
+ event_data.update({
658
+ "event_type": "cais",
659
+ "model_name": event.model_name,
660
+ "provider": event.provider,
661
+ "input_tokens": event.input_tokens,
662
+ "output_tokens": event.output_tokens,
663
+ "total_tokens": event.total_tokens,
664
+ "cost_usd": to_cents(event.cost_usd),
665
+ "latency_ms": event.latency_ms,
666
+ "span_id": event.span_id,
667
+ "trace_id": event.trace_id,
668
+ "system_state_before": event.system_state_before,
669
+ "system_state_after": event.system_state_after,
670
+ "call_records": call_records_data,
671
+ })
672
+ elif isinstance(event, EnvironmentEvent):
673
+ event_data.update({
674
+ "event_type": "environment",
675
+ "reward": event.reward,
676
+ "terminated": event.terminated,
677
+ "truncated": event.truncated,
678
+ "system_state_before": event.system_state_before,
679
+ "system_state_after": event.system_state_after,
680
+ })
681
+ elif isinstance(event, RuntimeEvent):
682
+ event_data.update({
683
+ "event_type": "runtime",
684
+ "event_metadata_json": {**(event.metadata or {}), "actions": event.actions},
685
+ })
686
+ else:
687
+ event_data["event_type"] = event.__class__.__name__.lower()
688
+
689
+ async with self.session() as sess:
690
+ db_event = DBEvent(**event_data)
691
+ sess.add(db_event)
692
+ await sess.flush()
693
+ # increment session num_events
694
+ await sess.execute(
695
+ update(DBSessionTrace)
696
+ .where(DBSessionTrace.session_id == session_id)
697
+ .values(num_events=DBSessionTrace.num_events + 1)
698
+ )
699
+ await sess.commit()
700
+ return db_event.id
701
+
702
+ # -------------------------------
703
+ # Reward helpers
704
+ # -------------------------------
705
+
706
+ async def insert_outcome_reward(self, session_id: str, *, total_reward: int, achievements_count: int, total_steps: int) -> int:
707
+ async with self.session() as sess:
708
+ row = DBOutcomeReward(
709
+ session_id=session_id,
710
+ total_reward=total_reward,
711
+ achievements_count=achievements_count,
712
+ total_steps=total_steps,
713
+ )
714
+ sess.add(row)
715
+ await sess.flush()
716
+ await sess.commit()
717
+ return row.id
718
+
719
+ async def insert_event_reward(self, session_id: str, *, event_id: int, message_id: int | None = None, turn_number: int | None = None, reward_value: float = 0.0, reward_type: str | None = None, key: str | None = None, annotation: dict[str, Any] | None = None, source: str | None = None) -> int:
720
+ async with self.session() as sess:
721
+ row = DBEventReward(
722
+ event_id=event_id,
723
+ session_id=session_id,
724
+ message_id=message_id,
725
+ turn_number=turn_number,
726
+ reward_value=reward_value,
727
+ reward_type=reward_type,
728
+ key=key,
729
+ annotation=annotation or {},
730
+ source=source,
731
+ )
732
+ sess.add(row)
733
+ await sess.flush()
734
+ await sess.commit()
735
+ return row.id
736
+
737
+ async def get_outcome_rewards(self) -> list[dict[str, Any]]:
738
+ async with self.session() as sess:
739
+ result = await sess.execute(select(DBOutcomeReward))
740
+ rows = result.scalars().all()
741
+ return [
742
+ {
743
+ "id": r.id,
744
+ "session_id": r.session_id,
745
+ "total_reward": r.total_reward,
746
+ "achievements_count": r.achievements_count,
747
+ "total_steps": r.total_steps,
748
+ "created_at": r.created_at,
749
+ }
750
+ for r in rows
751
+ ]
752
+
753
+ async def get_outcome_rewards_by_min_reward(self, min_reward: int) -> list[str]:
754
+ async with self.session() as sess:
755
+ result = await sess.execute(
756
+ select(DBOutcomeReward.session_id).where(DBOutcomeReward.total_reward >= min_reward)
757
+ )
758
+ return [row[0] for row in result.all()]
@@ -408,3 +408,56 @@ analytics_views = {
408
408
  GROUP BY e.experiment_id
409
409
  """,
410
410
  }
411
+
412
+
413
+ # Reward persistence tables
414
+
415
+
416
+ class OutcomeReward(Base):
417
+ """Episode-level rewards/outcomes per session.
418
+
419
+ Stores per-episode summary including total_reward (e.g., unique achievements),
420
+ achievements_count, and total_steps. Used for filtering episodes by outcome.
421
+ """
422
+
423
+ __tablename__ = "outcome_rewards"
424
+
425
+ id = Column(Integer, primary_key=True, autoincrement=True)
426
+ session_id = Column(String, ForeignKey("session_traces.session_id"), nullable=False)
427
+ total_reward = Column(Integer, nullable=False)
428
+ achievements_count = Column(Integer, nullable=False, default=0)
429
+ total_steps = Column(Integer, nullable=False, default=0)
430
+ created_at = Column(DateTime, default=func.current_timestamp(), nullable=False)
431
+
432
+ __table_args__ = (
433
+ Index("idx_outcome_rewards_session", "session_id"),
434
+ Index("idx_outcome_rewards_total", "total_reward"),
435
+ )
436
+
437
+
438
+ class EventReward(Base):
439
+ """First-class event-level rewards with annotations.
440
+
441
+ Links to an event and session. `message_id` is optional.
442
+ """
443
+
444
+ __tablename__ = "event_rewards"
445
+
446
+ id = Column(Integer, primary_key=True, autoincrement=True)
447
+ event_id = Column(Integer, ForeignKey("events.id"), nullable=False)
448
+ session_id = Column(String, ForeignKey("session_traces.session_id"), nullable=False)
449
+ message_id = Column(Integer, ForeignKey("messages.id"), nullable=True)
450
+ turn_number = Column(Integer, nullable=True)
451
+ reward_value = Column(Float, nullable=False, default=0.0)
452
+ reward_type = Column(String, nullable=True) # shaped | sparse | achievement | penalty | evaluator | human
453
+ key = Column(String, nullable=True) # e.g., achievement name
454
+ annotation = Column(JSONText) # free-form JSON
455
+ source = Column(String, nullable=True) # environment | runner | evaluator | human
456
+ created_at = Column(DateTime, default=func.current_timestamp(), nullable=False)
457
+
458
+ __table_args__ = (
459
+ Index("idx_event_rewards_session", "session_id"),
460
+ Index("idx_event_rewards_event", "event_id"),
461
+ Index("idx_event_rewards_type", "reward_type"),
462
+ Index("idx_event_rewards_key", "key"),
463
+ )