synth-ai 0.2.4.dev6__py3-none-any.whl → 0.2.4.dev8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (256) hide show
  1. synth_ai/__init__.py +18 -9
  2. synth_ai/cli/__init__.py +10 -5
  3. synth_ai/cli/balance.py +25 -32
  4. synth_ai/cli/calc.py +2 -3
  5. synth_ai/cli/demo.py +3 -5
  6. synth_ai/cli/legacy_root_backup.py +58 -32
  7. synth_ai/cli/man.py +22 -19
  8. synth_ai/cli/recent.py +9 -8
  9. synth_ai/cli/root.py +58 -13
  10. synth_ai/cli/status.py +13 -6
  11. synth_ai/cli/traces.py +45 -21
  12. synth_ai/cli/watch.py +40 -37
  13. synth_ai/config/base_url.py +47 -2
  14. synth_ai/core/experiment.py +1 -2
  15. synth_ai/environments/__init__.py +2 -6
  16. synth_ai/environments/environment/artifacts/base.py +3 -1
  17. synth_ai/environments/environment/db/sqlite.py +1 -1
  18. synth_ai/environments/environment/registry.py +19 -20
  19. synth_ai/environments/environment/resources/sqlite.py +2 -3
  20. synth_ai/environments/environment/rewards/core.py +3 -2
  21. synth_ai/environments/environment/tools/__init__.py +6 -4
  22. synth_ai/environments/examples/crafter_classic/__init__.py +1 -1
  23. synth_ai/environments/examples/crafter_classic/engine.py +13 -13
  24. synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +1 -0
  25. synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +2 -1
  26. synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +2 -1
  27. synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +3 -2
  28. synth_ai/environments/examples/crafter_classic/environment.py +16 -15
  29. synth_ai/environments/examples/crafter_classic/taskset.py +2 -2
  30. synth_ai/environments/examples/crafter_classic/trace_hooks_v3.py +2 -3
  31. synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +2 -1
  32. synth_ai/environments/examples/crafter_custom/crafter/__init__.py +2 -2
  33. synth_ai/environments/examples/crafter_custom/crafter/config.py +2 -2
  34. synth_ai/environments/examples/crafter_custom/crafter/env.py +1 -5
  35. synth_ai/environments/examples/crafter_custom/crafter/objects.py +1 -2
  36. synth_ai/environments/examples/crafter_custom/crafter/worldgen.py +1 -2
  37. synth_ai/environments/examples/crafter_custom/dataset_builder.py +5 -5
  38. synth_ai/environments/examples/crafter_custom/environment.py +13 -13
  39. synth_ai/environments/examples/crafter_custom/run_dataset.py +5 -5
  40. synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +2 -2
  41. synth_ai/environments/examples/enron/art_helpers/local_email_db.py +5 -4
  42. synth_ai/environments/examples/enron/art_helpers/types_enron.py +2 -1
  43. synth_ai/environments/examples/enron/engine.py +18 -14
  44. synth_ai/environments/examples/enron/environment.py +12 -11
  45. synth_ai/environments/examples/enron/taskset.py +7 -7
  46. synth_ai/environments/examples/minigrid/__init__.py +6 -6
  47. synth_ai/environments/examples/minigrid/engine.py +6 -6
  48. synth_ai/environments/examples/minigrid/environment.py +6 -6
  49. synth_ai/environments/examples/minigrid/puzzle_loader.py +3 -2
  50. synth_ai/environments/examples/minigrid/taskset.py +13 -13
  51. synth_ai/environments/examples/nethack/achievements.py +1 -1
  52. synth_ai/environments/examples/nethack/engine.py +8 -7
  53. synth_ai/environments/examples/nethack/environment.py +10 -9
  54. synth_ai/environments/examples/nethack/helpers/__init__.py +8 -9
  55. synth_ai/environments/examples/nethack/helpers/action_mapping.py +1 -1
  56. synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +2 -1
  57. synth_ai/environments/examples/nethack/helpers/observation_utils.py +1 -1
  58. synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +3 -4
  59. synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +6 -5
  60. synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +5 -5
  61. synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +7 -6
  62. synth_ai/environments/examples/nethack/taskset.py +5 -5
  63. synth_ai/environments/examples/red/engine.py +9 -8
  64. synth_ai/environments/examples/red/engine_helpers/reward_components.py +2 -1
  65. synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +7 -7
  66. synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +2 -1
  67. synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +2 -1
  68. synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +2 -1
  69. synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +2 -1
  70. synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +2 -1
  71. synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +2 -1
  72. synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +2 -1
  73. synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +2 -1
  74. synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +2 -1
  75. synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +2 -1
  76. synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +2 -1
  77. synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +3 -2
  78. synth_ai/environments/examples/red/engine_helpers/state_extraction.py +2 -1
  79. synth_ai/environments/examples/red/environment.py +18 -15
  80. synth_ai/environments/examples/red/taskset.py +5 -3
  81. synth_ai/environments/examples/sokoban/engine.py +16 -13
  82. synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +3 -2
  83. synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +2 -1
  84. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +1 -1
  85. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +7 -5
  86. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +1 -1
  87. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +2 -1
  88. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +5 -4
  89. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +3 -2
  90. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +2 -1
  91. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +5 -4
  92. synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +1 -1
  93. synth_ai/environments/examples/sokoban/environment.py +15 -14
  94. synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +5 -3
  95. synth_ai/environments/examples/sokoban/puzzle_loader.py +3 -2
  96. synth_ai/environments/examples/sokoban/taskset.py +13 -10
  97. synth_ai/environments/examples/tictactoe/engine.py +6 -6
  98. synth_ai/environments/examples/tictactoe/environment.py +8 -7
  99. synth_ai/environments/examples/tictactoe/taskset.py +6 -5
  100. synth_ai/environments/examples/verilog/engine.py +4 -3
  101. synth_ai/environments/examples/verilog/environment.py +11 -10
  102. synth_ai/environments/examples/verilog/taskset.py +14 -12
  103. synth_ai/environments/examples/wordle/__init__.py +5 -5
  104. synth_ai/environments/examples/wordle/engine.py +32 -25
  105. synth_ai/environments/examples/wordle/environment.py +21 -16
  106. synth_ai/environments/examples/wordle/helpers/generate_instances_wordfreq.py +6 -6
  107. synth_ai/environments/examples/wordle/taskset.py +20 -12
  108. synth_ai/environments/reproducibility/core.py +1 -1
  109. synth_ai/environments/reproducibility/tree.py +21 -21
  110. synth_ai/environments/service/app.py +3 -2
  111. synth_ai/environments/service/core_routes.py +104 -110
  112. synth_ai/environments/service/external_registry.py +1 -2
  113. synth_ai/environments/service/registry.py +1 -1
  114. synth_ai/environments/stateful/core.py +1 -2
  115. synth_ai/environments/stateful/engine.py +1 -1
  116. synth_ai/environments/tasks/api.py +4 -4
  117. synth_ai/environments/tasks/core.py +14 -12
  118. synth_ai/environments/tasks/filters.py +6 -4
  119. synth_ai/environments/tasks/utils.py +13 -11
  120. synth_ai/evals/base.py +2 -3
  121. synth_ai/experimental/synth_oss.py +4 -4
  122. synth_ai/http.py +102 -0
  123. synth_ai/inference/__init__.py +7 -0
  124. synth_ai/inference/client.py +20 -0
  125. synth_ai/jobs/client.py +246 -0
  126. synth_ai/learning/__init__.py +24 -0
  127. synth_ai/learning/client.py +149 -0
  128. synth_ai/learning/config.py +43 -0
  129. synth_ai/learning/constants.py +29 -0
  130. synth_ai/learning/ft_client.py +59 -0
  131. synth_ai/learning/gateway.py +1 -3
  132. synth_ai/learning/health.py +43 -0
  133. synth_ai/learning/jobs.py +205 -0
  134. synth_ai/learning/prompts/banking77_injection_eval.py +15 -10
  135. synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +26 -14
  136. synth_ai/learning/prompts/mipro.py +61 -52
  137. synth_ai/learning/prompts/random_search.py +42 -43
  138. synth_ai/learning/prompts/run_mipro_banking77.py +32 -20
  139. synth_ai/learning/prompts/run_random_search_banking77.py +71 -52
  140. synth_ai/learning/rl_client.py +256 -0
  141. synth_ai/learning/sse.py +58 -0
  142. synth_ai/learning/validators.py +48 -0
  143. synth_ai/lm/__init__.py +5 -5
  144. synth_ai/lm/caching/ephemeral.py +9 -9
  145. synth_ai/lm/caching/handler.py +20 -20
  146. synth_ai/lm/caching/persistent.py +10 -10
  147. synth_ai/lm/config.py +3 -3
  148. synth_ai/lm/constants.py +7 -7
  149. synth_ai/lm/core/all.py +17 -3
  150. synth_ai/lm/core/exceptions.py +0 -2
  151. synth_ai/lm/core/main.py +26 -41
  152. synth_ai/lm/core/main_v3.py +33 -10
  153. synth_ai/lm/core/synth_models.py +48 -0
  154. synth_ai/lm/core/vendor_clients.py +26 -22
  155. synth_ai/lm/injection.py +7 -8
  156. synth_ai/lm/overrides.py +21 -19
  157. synth_ai/lm/provider_support/__init__.py +1 -1
  158. synth_ai/lm/provider_support/anthropic.py +15 -15
  159. synth_ai/lm/provider_support/openai.py +23 -21
  160. synth_ai/lm/structured_outputs/handler.py +34 -32
  161. synth_ai/lm/structured_outputs/inject.py +24 -27
  162. synth_ai/lm/structured_outputs/rehabilitate.py +19 -15
  163. synth_ai/lm/tools/base.py +17 -16
  164. synth_ai/lm/unified_interface.py +17 -18
  165. synth_ai/lm/vendors/base.py +20 -18
  166. synth_ai/lm/vendors/core/anthropic_api.py +36 -27
  167. synth_ai/lm/vendors/core/gemini_api.py +31 -36
  168. synth_ai/lm/vendors/core/mistral_api.py +19 -19
  169. synth_ai/lm/vendors/core/openai_api.py +42 -13
  170. synth_ai/lm/vendors/openai_standard.py +158 -101
  171. synth_ai/lm/vendors/openai_standard_responses.py +74 -61
  172. synth_ai/lm/vendors/retries.py +9 -1
  173. synth_ai/lm/vendors/supported/custom_endpoint.py +38 -28
  174. synth_ai/lm/vendors/supported/deepseek.py +10 -10
  175. synth_ai/lm/vendors/supported/grok.py +8 -8
  176. synth_ai/lm/vendors/supported/ollama.py +2 -1
  177. synth_ai/lm/vendors/supported/openrouter.py +11 -9
  178. synth_ai/lm/vendors/synth_client.py +425 -75
  179. synth_ai/lm/warmup.py +8 -7
  180. synth_ai/rl/__init__.py +30 -0
  181. synth_ai/rl/contracts.py +32 -0
  182. synth_ai/rl/env_keys.py +137 -0
  183. synth_ai/rl/secrets.py +19 -0
  184. synth_ai/scripts/verify_rewards.py +100 -0
  185. synth_ai/task/__init__.py +10 -0
  186. synth_ai/task/contracts.py +120 -0
  187. synth_ai/task/health.py +28 -0
  188. synth_ai/task/validators.py +12 -0
  189. synth_ai/tracing/__init__.py +22 -10
  190. synth_ai/tracing_v1/__init__.py +22 -20
  191. synth_ai/tracing_v3/__init__.py +7 -7
  192. synth_ai/tracing_v3/abstractions.py +56 -52
  193. synth_ai/tracing_v3/config.py +4 -2
  194. synth_ai/tracing_v3/db_config.py +6 -8
  195. synth_ai/tracing_v3/decorators.py +29 -30
  196. synth_ai/tracing_v3/examples/basic_usage.py +12 -12
  197. synth_ai/tracing_v3/hooks.py +24 -22
  198. synth_ai/tracing_v3/llm_call_record_helpers.py +85 -98
  199. synth_ai/tracing_v3/lm_call_record_abstractions.py +2 -4
  200. synth_ai/tracing_v3/migration_helper.py +3 -5
  201. synth_ai/tracing_v3/replica_sync.py +30 -32
  202. synth_ai/tracing_v3/session_tracer.py +158 -31
  203. synth_ai/tracing_v3/storage/__init__.py +1 -1
  204. synth_ai/tracing_v3/storage/base.py +8 -7
  205. synth_ai/tracing_v3/storage/config.py +4 -4
  206. synth_ai/tracing_v3/storage/factory.py +4 -4
  207. synth_ai/tracing_v3/storage/utils.py +9 -9
  208. synth_ai/tracing_v3/turso/__init__.py +3 -3
  209. synth_ai/tracing_v3/turso/daemon.py +9 -9
  210. synth_ai/tracing_v3/turso/manager.py +278 -48
  211. synth_ai/tracing_v3/turso/models.py +77 -19
  212. synth_ai/tracing_v3/utils.py +5 -5
  213. synth_ai/v0/tracing/abstractions.py +28 -28
  214. synth_ai/v0/tracing/base_client.py +9 -9
  215. synth_ai/v0/tracing/client_manager.py +7 -7
  216. synth_ai/v0/tracing/config.py +7 -7
  217. synth_ai/v0/tracing/context.py +6 -6
  218. synth_ai/v0/tracing/decorators.py +6 -5
  219. synth_ai/v0/tracing/events/manage.py +1 -1
  220. synth_ai/v0/tracing/events/store.py +5 -4
  221. synth_ai/v0/tracing/immediate_client.py +4 -5
  222. synth_ai/v0/tracing/local.py +3 -3
  223. synth_ai/v0/tracing/log_client_base.py +4 -5
  224. synth_ai/v0/tracing/retry_queue.py +5 -6
  225. synth_ai/v0/tracing/trackers.py +25 -25
  226. synth_ai/v0/tracing/upload.py +6 -0
  227. synth_ai/v0/tracing_v1/__init__.py +1 -1
  228. synth_ai/v0/tracing_v1/abstractions.py +28 -28
  229. synth_ai/v0/tracing_v1/base_client.py +9 -9
  230. synth_ai/v0/tracing_v1/client_manager.py +7 -7
  231. synth_ai/v0/tracing_v1/config.py +7 -7
  232. synth_ai/v0/tracing_v1/context.py +6 -6
  233. synth_ai/v0/tracing_v1/decorators.py +7 -6
  234. synth_ai/v0/tracing_v1/events/manage.py +1 -1
  235. synth_ai/v0/tracing_v1/events/store.py +5 -4
  236. synth_ai/v0/tracing_v1/immediate_client.py +4 -5
  237. synth_ai/v0/tracing_v1/local.py +3 -3
  238. synth_ai/v0/tracing_v1/log_client_base.py +4 -5
  239. synth_ai/v0/tracing_v1/retry_queue.py +5 -6
  240. synth_ai/v0/tracing_v1/trackers.py +25 -25
  241. synth_ai/v0/tracing_v1/upload.py +25 -24
  242. synth_ai/zyk/__init__.py +1 -0
  243. synth_ai-0.2.4.dev8.dist-info/METADATA +635 -0
  244. synth_ai-0.2.4.dev8.dist-info/RECORD +317 -0
  245. synth_ai/tui/__init__.py +0 -1
  246. synth_ai/tui/__main__.py +0 -13
  247. synth_ai/tui/cli/__init__.py +0 -1
  248. synth_ai/tui/cli/query_experiments.py +0 -165
  249. synth_ai/tui/cli/query_experiments_v3.py +0 -165
  250. synth_ai/tui/dashboard.py +0 -329
  251. synth_ai-0.2.4.dev6.dist-info/METADATA +0 -203
  252. synth_ai-0.2.4.dev6.dist-info/RECORD +0 -299
  253. {synth_ai-0.2.4.dev6.dist-info → synth_ai-0.2.4.dev8.dist-info}/WHEEL +0 -0
  254. {synth_ai-0.2.4.dev6.dist-info → synth_ai-0.2.4.dev8.dist-info}/entry_points.txt +0 -0
  255. {synth_ai-0.2.4.dev6.dist-info → synth_ai-0.2.4.dev8.dist-info}/licenses/LICENSE +0 -0
  256. {synth_ai-0.2.4.dev6.dist-info → synth_ai-0.2.4.dev8.dist-info}/top_level.txt +0 -0
@@ -21,37 +21,50 @@ Performance Considerations:
21
21
  """
22
22
 
23
23
  import asyncio
24
+ import logging
24
25
  from contextlib import asynccontextmanager
25
26
  from datetime import datetime
26
- from typing import Any, Dict, List, Optional, Sequence, Union
27
+ from typing import Any
28
+
27
29
  import pandas as pd
28
- from sqlalchemy import select, insert, update, delete, text, and_, or_, func
29
- from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine, AsyncSession
30
- from sqlalchemy.orm import sessionmaker, selectinload, joinedload
31
- from sqlalchemy.pool import NullPool
30
+ from sqlalchemy import select, text, update
32
31
  from sqlalchemy.exc import IntegrityError
33
- import logging
32
+ from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
33
+ from sqlalchemy import event
34
+ from sqlalchemy.orm import selectinload, sessionmaker
35
+ from sqlalchemy.pool import NullPool
34
36
 
35
- from ..config import CONFIG
36
37
  from ..abstractions import (
37
- SessionTrace,
38
- SessionTimeStep,
39
- BaseEvent,
40
- LMCAISEvent,
41
38
  EnvironmentEvent,
39
+ LMCAISEvent,
42
40
  RuntimeEvent,
41
+ SessionTrace,
43
42
  )
44
- from ..utils import json_dumps
43
+ from ..config import CONFIG
45
44
  from .models import (
46
45
  Base,
47
- SessionTrace as DBSessionTrace,
48
- SessionTimestep as DBSessionTimestep,
46
+ analytics_views,
47
+ )
48
+ from .models import (
49
49
  Event as DBEvent,
50
- Message as DBMessage,
50
+ )
51
+ from .models import (
51
52
  Experiment as DBExperiment,
52
- System as DBSystem,
53
- SystemVersion as DBSystemVersion,
54
- analytics_views,
53
+ )
54
+ from .models import (
55
+ Message as DBMessage,
56
+ )
57
+ from .models import (
58
+ SessionTimestep as DBSessionTimestep,
59
+ )
60
+ from .models import (
61
+ SessionTrace as DBSessionTrace,
62
+ )
63
+ from .models import (
64
+ OutcomeReward as DBOutcomeReward,
65
+ )
66
+ from .models import (
67
+ EventReward as DBEventReward,
55
68
  )
56
69
 
57
70
  logger = logging.getLogger(__name__)
@@ -59,10 +72,10 @@ logger = logging.getLogger(__name__)
59
72
 
60
73
  class AsyncSQLTraceManager:
61
74
  """Async trace storage manager using SQLAlchemy and Turso/sqld.
62
-
75
+
63
76
  Handles all database operations for the tracing system. Designed to work
64
77
  with both local SQLite (via aiosqlite) and remote Turso databases.
65
-
78
+
66
79
  The manager handles:
67
80
  - Connection lifecycle management
68
81
  - Schema creation and verification
@@ -71,22 +84,22 @@ class AsyncSQLTraceManager:
71
84
  - Analytics view creation
72
85
  """
73
86
 
74
- def __init__(self, db_url: Optional[str] = None):
87
+ def __init__(self, db_url: str | None = None):
75
88
  self.db_url = db_url or CONFIG.db_url
76
- self.engine: Optional[AsyncEngine] = None
77
- self.SessionLocal: Optional[sessionmaker] = None
89
+ self.engine: AsyncEngine | None = None
90
+ self.SessionLocal: sessionmaker | None = None
78
91
  self._schema_lock = asyncio.Lock()
79
92
  self._schema_ready = False
80
93
 
81
94
  async def initialize(self):
82
95
  """Initialize the database connection and schema.
83
-
96
+
84
97
  This method is idempotent and thread-safe. It:
85
98
  1. Creates the async engine with appropriate settings
86
99
  2. Verifies database file exists (for SQLite)
87
100
  3. Creates schema if needed
88
101
  4. Sets up analytics views
89
-
102
+
90
103
  The schema lock ensures only one worker creates the schema in
91
104
  concurrent scenarios.
92
105
  """
@@ -119,6 +132,18 @@ class AsyncSQLTraceManager:
119
132
  connect_args=connect_args,
120
133
  echo=CONFIG.echo_sql,
121
134
  )
135
+ # Ensure PRAGMA foreign_keys=ON for every connection
136
+ try:
137
+ @event.listens_for(self.engine.sync_engine, "connect")
138
+ def _set_sqlite_pragma(dbapi_connection, connection_record): # type: ignore[no-redef]
139
+ try:
140
+ cursor = dbapi_connection.cursor()
141
+ cursor.execute("PRAGMA foreign_keys=ON")
142
+ cursor.close()
143
+ except Exception:
144
+ pass
145
+ except Exception:
146
+ pass
122
147
  else:
123
148
  connect_args = CONFIG.get_connect_args()
124
149
  engine_kwargs = CONFIG.get_engine_kwargs()
@@ -136,7 +161,7 @@ class AsyncSQLTraceManager:
136
161
 
137
162
  async def _ensure_schema(self):
138
163
  """Ensure database schema is created.
139
-
164
+
140
165
  Uses a lock to prevent race conditions when multiple workers start
141
166
  simultaneously. The checkfirst=True parameter handles cases where
142
167
  another worker already created the schema.
@@ -154,7 +179,7 @@ class AsyncSQLTraceManager:
154
179
  await conn.run_sync(
155
180
  lambda sync_conn: Base.metadata.create_all(sync_conn, checkfirst=True)
156
181
  )
157
- #logger.info("✅ Database schema created/verified successfully")
182
+ # logger.info("✅ Database schema created/verified successfully")
158
183
  except Exception as e:
159
184
  # If tables already exist, that's fine - another worker created them
160
185
  if "already exists" not in str(e):
@@ -183,7 +208,7 @@ class AsyncSQLTraceManager:
183
208
  logger.warning(f"Could not create view {view_name}: {e}")
184
209
 
185
210
  self._schema_ready = True
186
- #logger.debug("🎯 Database ready for use!")
211
+ # logger.debug("🎯 Database ready for use!")
187
212
 
188
213
  @asynccontextmanager
189
214
  async def session(self):
@@ -195,18 +220,18 @@ class AsyncSQLTraceManager:
195
220
 
196
221
  async def insert_session_trace(self, trace: SessionTrace) -> str:
197
222
  """Insert a complete session trace.
198
-
223
+
199
224
  This method handles the complex task of inserting a complete session
200
225
  with all its timesteps, events, and messages. It uses a single
201
226
  transaction for atomicity and flushes after timesteps to get their
202
227
  auto-generated IDs for foreign keys.
203
-
228
+
204
229
  Args:
205
230
  trace: The complete session trace to store
206
-
231
+
207
232
  Returns:
208
233
  The session ID
209
-
234
+
210
235
  Raises:
211
236
  IntegrityError: If session ID already exists (handled gracefully)
212
237
  """
@@ -214,7 +239,7 @@ class AsyncSQLTraceManager:
214
239
  try:
215
240
  # Convert to cents for cost storage - avoids floating point
216
241
  # precision issues and allows for integer arithmetic
217
- def to_cents(cost: Optional[float]) -> Optional[int]:
242
+ def to_cents(cost: float | None) -> int | None:
218
243
  return int(cost * 100) if cost is not None else None
219
244
 
220
245
  # Insert session
@@ -230,7 +255,7 @@ class AsyncSQLTraceManager:
230
255
 
231
256
  # Track timestep IDs for foreign keys - we need these to link
232
257
  # events and messages to their respective timesteps
233
- step_id_map: Dict[str, int] = {}
258
+ step_id_map: dict[str, int] = {}
234
259
 
235
260
  # Insert timesteps
236
261
  for step in trace.session_time_steps:
@@ -270,8 +295,9 @@ class AsyncSQLTraceManager:
270
295
  call_records_data = None
271
296
  if event.call_records:
272
297
  from dataclasses import asdict
298
+
273
299
  call_records_data = [asdict(record) for record in event.call_records]
274
-
300
+
275
301
  event_data.update(
276
302
  {
277
303
  "event_type": "cais",
@@ -340,7 +366,7 @@ class AsyncSQLTraceManager:
340
366
  return trace.session_id # Return existing ID
341
367
  raise
342
368
 
343
- async def get_session_trace(self, session_id: str) -> Optional[Dict[str, Any]]:
369
+ async def get_session_trace(self, session_id: str) -> dict[str, Any] | None:
344
370
  """Retrieve a session trace by ID."""
345
371
  async with self.session() as sess:
346
372
  result = await sess.execute(
@@ -377,7 +403,9 @@ class AsyncSQLTraceManager:
377
403
  ],
378
404
  }
379
405
 
380
- async def query_traces(self, query: str, params: Optional[Dict[str, Any]] = None) -> pd.DataFrame:
406
+ async def query_traces(
407
+ self, query: str, params: dict[str, Any] | None = None
408
+ ) -> pd.DataFrame:
381
409
  """Execute a query and return results as DataFrame."""
382
410
  async with self.session() as sess:
383
411
  result = await sess.execute(text(query), params or {})
@@ -385,7 +413,10 @@ class AsyncSQLTraceManager:
385
413
  return pd.DataFrame(rows)
386
414
 
387
415
  async def get_model_usage(
388
- self, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None, model_name: Optional[str] = None
416
+ self,
417
+ start_date: datetime | None = None,
418
+ end_date: datetime | None = None,
419
+ model_name: str | None = None,
389
420
  ) -> pd.DataFrame:
390
421
  """Get model usage statistics."""
391
422
  query = """
@@ -414,8 +445,8 @@ class AsyncSQLTraceManager:
414
445
  self,
415
446
  experiment_id: str,
416
447
  name: str,
417
- description: Optional[str] = None,
418
- configuration: Optional[Dict[str, Any]] = None,
448
+ description: str | None = None,
449
+ configuration: dict[str, Any] | None = None,
419
450
  ) -> str:
420
451
  """Create a new experiment."""
421
452
  async with self.session() as sess:
@@ -440,18 +471,18 @@ class AsyncSQLTraceManager:
440
471
  await sess.commit()
441
472
 
442
473
  async def batch_insert_sessions(
443
- self, traces: List[SessionTrace], batch_size: Optional[int] = None
444
- ) -> List[str]:
474
+ self, traces: list[SessionTrace], batch_size: int | None = None
475
+ ) -> list[str]:
445
476
  """Batch insert multiple session traces.
446
-
477
+
447
478
  Processes traces in batches to balance memory usage and performance.
448
479
  Each batch is inserted in a separate transaction to avoid holding
449
480
  locks for too long.
450
-
481
+
451
482
  Args:
452
483
  traces: List of session traces to insert
453
484
  batch_size: Number of traces per batch (defaults to config)
454
-
485
+
455
486
  Returns:
456
487
  List of inserted session IDs
457
488
  """
@@ -470,8 +501,8 @@ class AsyncSQLTraceManager:
470
501
  return inserted_ids
471
502
 
472
503
  async def get_sessions_by_experiment(
473
- self, experiment_id: str, limit: Optional[int] = None
474
- ) -> List[Dict[str, Any]]:
504
+ self, experiment_id: str, limit: int | None = None
505
+ ) -> list[dict[str, Any]]:
475
506
  """Get all sessions for an experiment."""
476
507
  async with self.session() as sess:
477
508
  query = (
@@ -515,7 +546,7 @@ class AsyncSQLTraceManager:
515
546
 
516
547
  async def close(self):
517
548
  """Close the database connection.
518
-
549
+
519
550
  Properly disposes of the engine and all connections. This is important
520
551
  for cleanup, especially with SQLite which can leave lock files.
521
552
  """
@@ -526,3 +557,202 @@ class AsyncSQLTraceManager:
526
557
  self.engine = None
527
558
  self.SessionLocal = None
528
559
  self._schema_ready = False
560
+
561
+ # -------------------------------
562
+ # Incremental insert helpers
563
+ # -------------------------------
564
+
565
+ async def ensure_session(self, session_id: str, *, created_at: datetime | None = None, metadata: dict[str, Any] | None = None):
566
+ """Ensure a DB session row exists for session_id."""
567
+ async with self.session() as sess:
568
+ result = await sess.execute(select(DBSessionTrace).where(DBSessionTrace.session_id == session_id))
569
+ existing = result.scalar_one_or_none()
570
+ if existing:
571
+ return
572
+ row = DBSessionTrace(
573
+ session_id=session_id,
574
+ created_at=created_at or datetime.utcnow(),
575
+ num_timesteps=0,
576
+ num_events=0,
577
+ num_messages=0,
578
+ session_metadata=metadata or {},
579
+ )
580
+ sess.add(row)
581
+ await sess.commit()
582
+
583
+ async def ensure_timestep(self, session_id: str, *, step_id: str, step_index: int, turn_number: int | None = None, started_at: datetime | None = None, completed_at: datetime | None = None, metadata: dict[str, Any] | None = None) -> int:
584
+ """Ensure a timestep row exists; return its DB id."""
585
+ async with self.session() as sess:
586
+ result = await sess.execute(
587
+ select(DBSessionTimestep).where(DBSessionTimestep.session_id == session_id, DBSessionTimestep.step_id == step_id)
588
+ )
589
+ row = result.scalar_one_or_none()
590
+ if row:
591
+ return row.id
592
+ row = DBSessionTimestep(
593
+ session_id=session_id,
594
+ step_id=step_id,
595
+ step_index=step_index,
596
+ turn_number=turn_number,
597
+ started_at=started_at or datetime.utcnow(),
598
+ completed_at=completed_at,
599
+ num_events=0,
600
+ num_messages=0,
601
+ step_metadata=metadata or {},
602
+ )
603
+ sess.add(row)
604
+ await sess.flush()
605
+ # increment session num_timesteps
606
+ await sess.execute(
607
+ update(DBSessionTrace)
608
+ .where(DBSessionTrace.session_id == session_id)
609
+ .values(num_timesteps=DBSessionTrace.num_timesteps + 1)
610
+ )
611
+ await sess.commit()
612
+ return row.id
613
+
614
+ async def insert_message_row(self, session_id: str, *, timestep_db_id: int | None, message_type: str, content: str, event_time: float | None = None, message_time: int | None = None, metadata: dict[str, Any] | None = None) -> int:
615
+ """Insert a message and return its id."""
616
+ async with self.session() as sess:
617
+ db_msg = DBMessage(
618
+ session_id=session_id,
619
+ timestep_id=timestep_db_id,
620
+ message_type=message_type,
621
+ content=content,
622
+ event_time=event_time,
623
+ message_time=message_time,
624
+ message_metadata=metadata or {},
625
+ )
626
+ sess.add(db_msg)
627
+ await sess.flush()
628
+ # increment session num_messages
629
+ await sess.execute(
630
+ update(DBSessionTrace)
631
+ .where(DBSessionTrace.session_id == session_id)
632
+ .values(num_messages=DBSessionTrace.num_messages + 1)
633
+ )
634
+ await sess.commit()
635
+ return db_msg.id
636
+
637
+ async def insert_event_row(self, session_id: str, *, timestep_db_id: int | None, event: EnvironmentEvent | LMCAISEvent | RuntimeEvent, metadata_override: dict[str, Any] | None = None) -> int:
638
+ """Insert an event and return its id."""
639
+ def to_cents(cost: float | None) -> int | None:
640
+ return int(cost * 100) if cost is not None else None
641
+
642
+ event_data: dict[str, Any] = {
643
+ "session_id": session_id,
644
+ "timestep_id": timestep_db_id,
645
+ "system_instance_id": event.system_instance_id,
646
+ "event_time": event.time_record.event_time,
647
+ "message_time": event.time_record.message_time,
648
+ "event_metadata_json": metadata_override or event.metadata or {},
649
+ "event_extra_metadata": getattr(event, "event_metadata", None),
650
+ }
651
+ if isinstance(event, LMCAISEvent):
652
+ call_records_data = None
653
+ if getattr(event, "call_records", None):
654
+ from dataclasses import asdict
655
+
656
+ call_records_data = [asdict(record) for record in event.call_records]
657
+ event_data.update({
658
+ "event_type": "cais",
659
+ "model_name": event.model_name,
660
+ "provider": event.provider,
661
+ "input_tokens": event.input_tokens,
662
+ "output_tokens": event.output_tokens,
663
+ "total_tokens": event.total_tokens,
664
+ "cost_usd": to_cents(event.cost_usd),
665
+ "latency_ms": event.latency_ms,
666
+ "span_id": event.span_id,
667
+ "trace_id": event.trace_id,
668
+ "system_state_before": event.system_state_before,
669
+ "system_state_after": event.system_state_after,
670
+ "call_records": call_records_data,
671
+ })
672
+ elif isinstance(event, EnvironmentEvent):
673
+ event_data.update({
674
+ "event_type": "environment",
675
+ "reward": event.reward,
676
+ "terminated": event.terminated,
677
+ "truncated": event.truncated,
678
+ "system_state_before": event.system_state_before,
679
+ "system_state_after": event.system_state_after,
680
+ })
681
+ elif isinstance(event, RuntimeEvent):
682
+ event_data.update({
683
+ "event_type": "runtime",
684
+ "event_metadata_json": {**(event.metadata or {}), "actions": event.actions},
685
+ })
686
+ else:
687
+ event_data["event_type"] = event.__class__.__name__.lower()
688
+
689
+ async with self.session() as sess:
690
+ db_event = DBEvent(**event_data)
691
+ sess.add(db_event)
692
+ await sess.flush()
693
+ # increment session num_events
694
+ await sess.execute(
695
+ update(DBSessionTrace)
696
+ .where(DBSessionTrace.session_id == session_id)
697
+ .values(num_events=DBSessionTrace.num_events + 1)
698
+ )
699
+ await sess.commit()
700
+ return db_event.id
701
+
702
+ # -------------------------------
703
+ # Reward helpers
704
+ # -------------------------------
705
+
706
+ async def insert_outcome_reward(self, session_id: str, *, total_reward: int, achievements_count: int, total_steps: int) -> int:
707
+ async with self.session() as sess:
708
+ row = DBOutcomeReward(
709
+ session_id=session_id,
710
+ total_reward=total_reward,
711
+ achievements_count=achievements_count,
712
+ total_steps=total_steps,
713
+ )
714
+ sess.add(row)
715
+ await sess.flush()
716
+ await sess.commit()
717
+ return row.id
718
+
719
+ async def insert_event_reward(self, session_id: str, *, event_id: int, message_id: int | None = None, turn_number: int | None = None, reward_value: float = 0.0, reward_type: str | None = None, key: str | None = None, annotation: dict[str, Any] | None = None, source: str | None = None) -> int:
720
+ async with self.session() as sess:
721
+ row = DBEventReward(
722
+ event_id=event_id,
723
+ session_id=session_id,
724
+ message_id=message_id,
725
+ turn_number=turn_number,
726
+ reward_value=reward_value,
727
+ reward_type=reward_type,
728
+ key=key,
729
+ annotation=annotation or {},
730
+ source=source,
731
+ )
732
+ sess.add(row)
733
+ await sess.flush()
734
+ await sess.commit()
735
+ return row.id
736
+
737
+ async def get_outcome_rewards(self) -> list[dict[str, Any]]:
738
+ async with self.session() as sess:
739
+ result = await sess.execute(select(DBOutcomeReward))
740
+ rows = result.scalars().all()
741
+ return [
742
+ {
743
+ "id": r.id,
744
+ "session_id": r.session_id,
745
+ "total_reward": r.total_reward,
746
+ "achievements_count": r.achievements_count,
747
+ "total_steps": r.total_steps,
748
+ "created_at": r.created_at,
749
+ }
750
+ for r in rows
751
+ ]
752
+
753
+ async def get_outcome_rewards_by_min_reward(self, min_reward: int) -> list[str]:
754
+ async with self.session() as sess:
755
+ result = await sess.execute(
756
+ select(DBOutcomeReward.session_id).where(DBOutcomeReward.total_reward >= min_reward)
757
+ )
758
+ return [row[0] for row in result.all()]
@@ -1,28 +1,25 @@
1
1
  """SQLAlchemy declarative models for tracing v3."""
2
2
 
3
- from datetime import datetime
4
- from typing import Optional
3
+ import json
4
+
5
5
  from sqlalchemy import (
6
+ Boolean,
7
+ CheckConstraint,
6
8
  Column,
7
- Integer,
8
- String,
9
9
  DateTime,
10
- JSON,
10
+ Float,
11
11
  ForeignKey,
12
12
  Index,
13
+ Integer,
14
+ String,
13
15
  Text,
14
- Float,
15
- Boolean,
16
- UniqueConstraint,
17
- CheckConstraint,
18
16
  TypeDecorator,
17
+ UniqueConstraint,
19
18
  )
20
19
  from sqlalchemy.ext.declarative import declarative_base
21
20
  from sqlalchemy.orm import relationship
22
21
  from sqlalchemy.sql import func
23
22
  from sqlalchemy.types import UserDefinedType
24
- import json
25
-
26
23
 
27
24
  Base = declarative_base()
28
25
 
@@ -81,10 +78,11 @@ class JSONText(TypeDecorator):
81
78
 
82
79
  class SessionTrace(Base):
83
80
  """Database model for session traces.
84
-
81
+
85
82
  Stores high-level information about tracing sessions including
86
83
  metadata, statistics, and relationships to timesteps and events.
87
84
  """
85
+
88
86
  __tablename__ = "session_traces"
89
87
 
90
88
  session_id = Column(String, primary_key=True)
@@ -114,10 +112,11 @@ class SessionTrace(Base):
114
112
 
115
113
  class SessionTimestep(Base):
116
114
  """Database model for session timesteps.
117
-
115
+
118
116
  Represents individual steps within a tracing session, with timing
119
117
  information and relationships to events and messages.
120
118
  """
119
+
121
120
  __tablename__ = "session_timesteps"
122
121
 
123
122
  id = Column(Integer, primary_key=True, autoincrement=True)
@@ -145,11 +144,12 @@ class SessionTimestep(Base):
145
144
 
146
145
  class Event(Base):
147
146
  """Database model for events.
148
-
147
+
149
148
  Stores all types of events (LM CAIS, environment, runtime) with
150
149
  type-specific fields and common metadata. Supports vector embeddings
151
150
  for similarity search.
152
151
  """
152
+
153
153
  __tablename__ = "events"
154
154
 
155
155
  id = Column(Integer, primary_key=True, autoincrement=True)
@@ -209,10 +209,11 @@ class Event(Base):
209
209
 
210
210
  class Message(Base):
211
211
  """Database model for messages.
212
-
212
+
213
213
  Stores conversational messages between users, assistants, and systems
214
214
  with support for embeddings and rich metadata.
215
215
  """
216
+
216
217
  __tablename__ = "messages"
217
218
 
218
219
  id = Column(Integer, primary_key=True, autoincrement=True)
@@ -247,10 +248,11 @@ class Message(Base):
247
248
 
248
249
  class Experiment(Base):
249
250
  """Database model for experiments.
250
-
251
+
251
252
  Groups related sessions and systems for experimental evaluation
252
253
  and comparison. Supports rich configuration and metadata.
253
254
  """
255
+
254
256
  __tablename__ = "experiments"
255
257
 
256
258
  experiment_id = Column(String, primary_key=True)
@@ -277,10 +279,11 @@ class Experiment(Base):
277
279
 
278
280
  class System(Base):
279
281
  """Database model for systems.
280
-
282
+
281
283
  Represents agents, environments, or runtime systems that participate
282
284
  in tracing sessions. Supports versioning and type classification.
283
285
  """
286
+
284
287
  __tablename__ = "systems"
285
288
 
286
289
  system_id = Column(String, primary_key=True)
@@ -302,10 +305,11 @@ class System(Base):
302
305
 
303
306
  class SystemVersion(Base):
304
307
  """Database model for system versions.
305
-
308
+
306
309
  Tracks different versions of systems with commit hashes,
307
310
  configuration changes, and relationships to experiments.
308
311
  """
312
+
309
313
  __tablename__ = "system_versions"
310
314
 
311
315
  version_id = Column(String, primary_key=True)
@@ -329,10 +333,11 @@ class SystemVersion(Base):
329
333
 
330
334
  class ExperimentalSystem(Base):
331
335
  """Database model for experiment-system relationships.
332
-
336
+
333
337
  Junction table linking experiments with specific system versions,
334
338
  allowing tracking of which systems participated in which experiments.
335
339
  """
340
+
336
341
  __tablename__ = "experimental_systems"
337
342
 
338
343
  id = Column(Integer, primary_key=True, autoincrement=True)
@@ -403,3 +408,56 @@ analytics_views = {
403
408
  GROUP BY e.experiment_id
404
409
  """,
405
410
  }
411
+
412
+
413
+ # Reward persistence tables
414
+
415
+
416
+ class OutcomeReward(Base):
417
+ """Episode-level rewards/outcomes per session.
418
+
419
+ Stores per-episode summary including total_reward (e.g., unique achievements),
420
+ achievements_count, and total_steps. Used for filtering episodes by outcome.
421
+ """
422
+
423
+ __tablename__ = "outcome_rewards"
424
+
425
+ id = Column(Integer, primary_key=True, autoincrement=True)
426
+ session_id = Column(String, ForeignKey("session_traces.session_id"), nullable=False)
427
+ total_reward = Column(Integer, nullable=False)
428
+ achievements_count = Column(Integer, nullable=False, default=0)
429
+ total_steps = Column(Integer, nullable=False, default=0)
430
+ created_at = Column(DateTime, default=func.current_timestamp(), nullable=False)
431
+
432
+ __table_args__ = (
433
+ Index("idx_outcome_rewards_session", "session_id"),
434
+ Index("idx_outcome_rewards_total", "total_reward"),
435
+ )
436
+
437
+
438
+ class EventReward(Base):
439
+ """First-class event-level rewards with annotations.
440
+
441
+ Links to an event and session. `message_id` is optional.
442
+ """
443
+
444
+ __tablename__ = "event_rewards"
445
+
446
+ id = Column(Integer, primary_key=True, autoincrement=True)
447
+ event_id = Column(Integer, ForeignKey("events.id"), nullable=False)
448
+ session_id = Column(String, ForeignKey("session_traces.session_id"), nullable=False)
449
+ message_id = Column(Integer, ForeignKey("messages.id"), nullable=True)
450
+ turn_number = Column(Integer, nullable=True)
451
+ reward_value = Column(Float, nullable=False, default=0.0)
452
+ reward_type = Column(String, nullable=True) # shaped | sparse | achievement | penalty | evaluator | human
453
+ key = Column(String, nullable=True) # e.g., achievement name
454
+ annotation = Column(JSONText) # free-form JSON
455
+ source = Column(String, nullable=True) # environment | runner | evaluator | human
456
+ created_at = Column(DateTime, default=func.current_timestamp(), nullable=False)
457
+
458
+ __table_args__ = (
459
+ Index("idx_event_rewards_session", "session_id"),
460
+ Index("idx_event_rewards_event", "event_id"),
461
+ Index("idx_event_rewards_type", "reward_type"),
462
+ Index("idx_event_rewards_key", "key"),
463
+ )