synth-ai 0.2.9.dev4__py3-none-any.whl → 0.2.9.dev6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (353) hide show
  1. examples/__init__.py +16 -0
  2. examples/crafter_debug_render.py +23 -17
  3. examples/qwen_coder/README.md +102 -0
  4. examples/qwen_coder/_shared.py +113 -0
  5. examples/qwen_coder/configs/coder_lora_30b.toml +61 -0
  6. examples/qwen_coder/configs/coder_lora_4b.toml +57 -0
  7. examples/qwen_coder/configs/coder_lora_small.toml +58 -0
  8. examples/qwen_coder/generate_dataset.py +98 -0
  9. examples/qwen_coder/infer_ft_smoke.py +64 -0
  10. examples/qwen_coder/infer_prod_proxy.py +73 -0
  11. examples/qwen_coder/infer_via_synth.py +87 -0
  12. examples/qwen_coder/scripts/infer_coder.sh +18 -0
  13. examples/qwen_coder/scripts/train_coder_30b.sh +21 -0
  14. examples/qwen_coder/sft_full_17b.py +103 -0
  15. examples/qwen_coder/sft_lora_30b.py +110 -0
  16. examples/qwen_coder/subset_jsonl.py +38 -0
  17. examples/qwen_coder/validate_jsonl.py +59 -0
  18. examples/rl/configs/eval_base_qwen.toml +1 -1
  19. examples/rl/configs/rl_from_base_qwen17.toml +1 -1
  20. examples/rl/download_dataset.py +26 -10
  21. examples/rl/run_eval.py +53 -52
  22. examples/rl/run_rl_and_save.py +29 -12
  23. examples/rl/task_app/math_single_step.py +180 -41
  24. examples/rl/task_app/math_task_app.py +14 -6
  25. examples/sft/README.md +139 -0
  26. examples/sft/configs/crafter_fft_qwen0p6b.toml +44 -0
  27. examples/sft/configs/crafter_lora_qwen0p6b.toml +45 -0
  28. examples/sft/evaluate.py +117 -0
  29. examples/sft/export_dataset.py +117 -0
  30. examples/sft/generate_traces.py +162 -0
  31. examples/swe/__init__.py +12 -0
  32. examples/swe/task_app/README.md +105 -0
  33. examples/swe/task_app/__init__.py +2 -0
  34. examples/swe/task_app/grpo_swe_mini.py +571 -0
  35. examples/swe/task_app/grpo_swe_mini_task_app.py +136 -0
  36. examples/swe/task_app/hosted/README.md +173 -0
  37. examples/swe/task_app/hosted/__init__.py +5 -0
  38. examples/swe/task_app/hosted/branching.py +143 -0
  39. examples/swe/task_app/hosted/environment_routes.py +1289 -0
  40. examples/swe/task_app/hosted/envs/__init__.py +1 -0
  41. examples/swe/task_app/hosted/envs/crafter/__init__.py +6 -0
  42. examples/swe/task_app/hosted/envs/crafter/app.py +1 -0
  43. examples/swe/task_app/hosted/envs/crafter/environment.py +522 -0
  44. examples/swe/task_app/hosted/envs/crafter/policy.py +478 -0
  45. examples/swe/task_app/hosted/envs/crafter/react_agent.py +108 -0
  46. examples/swe/task_app/hosted/envs/crafter/shared.py +305 -0
  47. examples/swe/task_app/hosted/envs/crafter/tools.py +47 -0
  48. examples/swe/task_app/hosted/envs/mini_swe/__init__.py +8 -0
  49. examples/swe/task_app/hosted/envs/mini_swe/environment.py +1164 -0
  50. examples/swe/task_app/hosted/envs/mini_swe/policy.py +355 -0
  51. examples/swe/task_app/hosted/envs/mini_swe/shared.py +83 -0
  52. examples/swe/task_app/hosted/envs/mini_swe/tools.py +96 -0
  53. examples/swe/task_app/hosted/hosted_app.py +204 -0
  54. examples/swe/task_app/hosted/inference/__init__.py +5 -0
  55. examples/swe/task_app/hosted/inference/openai_client.py +618 -0
  56. examples/swe/task_app/hosted/main.py +100 -0
  57. examples/swe/task_app/hosted/policy_routes.py +1079 -0
  58. examples/swe/task_app/hosted/registry.py +195 -0
  59. examples/swe/task_app/hosted/rollout.py +1869 -0
  60. examples/swe/task_app/hosted/storage/__init__.py +5 -0
  61. examples/swe/task_app/hosted/storage/volume.py +211 -0
  62. examples/swe/task_app/hosted/test_agents.py +161 -0
  63. examples/swe/task_app/hosted/test_service.py +137 -0
  64. examples/swe/task_app/hosted/utils.py +62 -0
  65. examples/vlm/README.md +68 -0
  66. examples/vlm/configs/crafter_vlm_gpt4o.toml +44 -0
  67. examples/vlm/crafter_image_only_agent.py +207 -0
  68. examples/vlm/crafter_openai_vlm_agent.py +277 -0
  69. examples/vlm/filter_image_rows.py +63 -0
  70. examples/vlm/run_crafter_vlm_benchmark.py +316 -0
  71. examples/warming_up_to_rl/analyze_trace_db.py +12 -10
  72. examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +11 -1
  73. examples/warming_up_to_rl/export_trace_sft.py +218 -36
  74. examples/warming_up_to_rl/groq_test.py +15 -8
  75. examples/warming_up_to_rl/manage_secrets.py +29 -25
  76. examples/warming_up_to_rl/readme.md +9 -2
  77. examples/warming_up_to_rl/run_eval.py +137 -61
  78. examples/warming_up_to_rl/run_fft_and_save.py +131 -60
  79. examples/warming_up_to_rl/run_local_rollout.py +88 -39
  80. examples/warming_up_to_rl/run_local_rollout_modal.py +114 -28
  81. examples/warming_up_to_rl/run_local_rollout_parallel.py +81 -20
  82. examples/warming_up_to_rl/run_local_rollout_traced.py +126 -23
  83. examples/warming_up_to_rl/run_rl_and_save.py +35 -12
  84. examples/warming_up_to_rl/run_rollout_remote.py +44 -19
  85. examples/warming_up_to_rl/task_app/README.md +6 -2
  86. examples/warming_up_to_rl/task_app/grpo_crafter.py +319 -57
  87. examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +11 -30
  88. examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +1 -1
  89. examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +9 -11
  90. examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +137 -182
  91. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/__init__.py +1 -1
  92. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/__init__.py +1 -1
  93. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/app.py +1 -1
  94. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +150 -57
  95. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +105 -69
  96. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +19 -7
  97. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +45 -42
  98. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/tools.py +1 -1
  99. examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +47 -45
  100. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/__init__.py +1 -1
  101. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +198 -92
  102. examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +0 -2
  103. examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +361 -263
  104. examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +21 -23
  105. examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +394 -274
  106. examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/__init__.py +1 -1
  107. examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +56 -62
  108. examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +1 -0
  109. examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +6 -15
  110. examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +4 -3
  111. synth/__init__.py +14 -0
  112. synth_ai/__init__.py +20 -4
  113. synth_ai/api/models/supported.py +376 -0
  114. synth_ai/api/train/builders.py +157 -26
  115. synth_ai/api/train/cli.py +213 -57
  116. synth_ai/api/train/config_finder.py +65 -5
  117. synth_ai/api/train/env_resolver.py +33 -15
  118. synth_ai/api/train/pollers.py +13 -4
  119. synth_ai/api/train/supported_algos.py +139 -0
  120. synth_ai/api/train/task_app.py +5 -3
  121. synth_ai/api/train/utils.py +33 -48
  122. synth_ai/cli/__init__.py +19 -4
  123. synth_ai/cli/_modal_wrapper.py +28 -0
  124. synth_ai/cli/_typer_patch.py +49 -0
  125. synth_ai/cli/balance.py +2 -3
  126. synth_ai/cli/calc.py +1 -1
  127. synth_ai/cli/demo.py +21 -6
  128. synth_ai/cli/recent.py +2 -2
  129. synth_ai/cli/rl_demo.py +77 -17
  130. synth_ai/cli/root.py +116 -39
  131. synth_ai/cli/status.py +2 -2
  132. synth_ai/cli/task_apps.py +1709 -243
  133. synth_ai/cli/traces.py +7 -4
  134. synth_ai/cli/turso.py +73 -0
  135. synth_ai/cli/watch.py +12 -18
  136. synth_ai/core/experiment.py +0 -2
  137. synth_ai/demo_registry.py +68 -31
  138. synth_ai/demos/core/cli.py +516 -194
  139. synth_ai/demos/demo_task_apps/__init__.py +3 -3
  140. synth_ai/demos/demo_task_apps/core.py +64 -28
  141. synth_ai/demos/demo_task_apps/crafter/configs/crafter_fft_4b.toml +2 -3
  142. synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +37 -30
  143. synth_ai/demos/demo_task_apps/math/_common.py +1 -2
  144. synth_ai/demos/demo_task_apps/math/app.py +2 -1
  145. synth_ai/demos/demo_task_apps/math/deploy_modal.py +3 -6
  146. synth_ai/demos/demo_task_apps/math/modal_task_app.py +183 -82
  147. synth_ai/demos/demo_task_apps/math/task_app_entry.py +0 -2
  148. synth_ai/environments/examples/bandit/engine.py +12 -4
  149. synth_ai/environments/examples/bandit/taskset.py +4 -4
  150. synth_ai/environments/examples/crafter_classic/environment.py +76 -1
  151. synth_ai/environments/reproducibility/tree.py +5 -6
  152. synth_ai/environments/service/app.py +11 -12
  153. synth_ai/environments/service/core_routes.py +10 -9
  154. synth_ai/environments/stateful/engine.py +1 -1
  155. synth_ai/environments/tasks/core.py +1 -0
  156. synth_ai/environments/tasks/filters.py +5 -6
  157. synth_ai/environments/tasks/utils.py +4 -5
  158. synth_ai/evals/base.py +0 -2
  159. synth_ai/handshake.py +11 -9
  160. synth_ai/http.py +1 -1
  161. synth_ai/http_client.py +43 -11
  162. synth_ai/inference/__init__.py +0 -2
  163. synth_ai/inference/client.py +20 -6
  164. synth_ai/jobs/client.py +103 -78
  165. synth_ai/learning/__init__.py +41 -6
  166. synth_ai/learning/algorithms.py +14 -0
  167. synth_ai/learning/client.py +121 -29
  168. synth_ai/learning/config.py +2 -40
  169. synth_ai/learning/constants.py +0 -2
  170. synth_ai/learning/ft_client.py +4 -56
  171. synth_ai/learning/health.py +13 -7
  172. synth_ai/learning/jobs.py +43 -47
  173. synth_ai/{rl → learning/rl}/__init__.py +14 -5
  174. synth_ai/learning/rl/client.py +267 -0
  175. synth_ai/learning/rl/config.py +31 -0
  176. synth_ai/{rl → learning/rl}/contracts.py +5 -10
  177. synth_ai/{rl → learning/rl}/env_keys.py +45 -16
  178. synth_ai/learning/rl/secrets.py +13 -0
  179. synth_ai/learning/rl_client.py +2 -253
  180. synth_ai/learning/sft/__init__.py +29 -0
  181. synth_ai/learning/sft/client.py +68 -0
  182. synth_ai/learning/sft/config.py +270 -0
  183. synth_ai/learning/sft/data.py +295 -0
  184. synth_ai/learning/sse.py +25 -26
  185. synth_ai/learning/validators.py +25 -24
  186. synth_ai/lm/__init__.py +21 -47
  187. synth_ai/task/__init__.py +26 -27
  188. synth_ai/task/apps/__init__.py +18 -19
  189. synth_ai/task/auth.py +35 -23
  190. synth_ai/task/client.py +15 -13
  191. synth_ai/task/contracts.py +37 -35
  192. synth_ai/task/datasets.py +9 -6
  193. synth_ai/task/errors.py +11 -10
  194. synth_ai/task/health.py +17 -11
  195. synth_ai/task/json.py +58 -24
  196. synth_ai/task/proxy.py +15 -14
  197. synth_ai/task/rubrics.py +22 -15
  198. synth_ai/task/server.py +43 -17
  199. synth_ai/task/tracing_utils.py +12 -7
  200. synth_ai/task/validators.py +0 -1
  201. synth_ai/task/vendors.py +5 -7
  202. synth_ai/tracing_v3/__init__.py +2 -0
  203. synth_ai/tracing_v3/abstractions.py +21 -4
  204. synth_ai/tracing_v3/db_config.py +26 -1
  205. synth_ai/tracing_v3/decorators.py +18 -15
  206. synth_ai/tracing_v3/examples/basic_usage.py +3 -2
  207. synth_ai/tracing_v3/hooks.py +6 -4
  208. synth_ai/tracing_v3/llm_call_record_helpers.py +6 -6
  209. synth_ai/tracing_v3/replica_sync.py +1 -0
  210. synth_ai/tracing_v3/session_tracer.py +63 -16
  211. synth_ai/tracing_v3/storage/base.py +89 -1
  212. synth_ai/tracing_v3/storage/config.py +21 -8
  213. synth_ai/tracing_v3/storage/factory.py +10 -8
  214. synth_ai/tracing_v3/storage/utils.py +4 -2
  215. synth_ai/tracing_v3/turso/daemon.py +7 -2
  216. synth_ai/tracing_v3/turso/models.py +5 -2
  217. synth_ai/tracing_v3/turso/native_manager.py +1173 -0
  218. synth_ai/tracing_v3/utils.py +4 -3
  219. synth_ai/v0/api/__init__.py +8 -0
  220. synth_ai/v0/api/models/__init__.py +8 -0
  221. synth_ai/v0/api/models/supported.py +8 -0
  222. synth_ai/v0/config/__init__.py +15 -0
  223. synth_ai/v0/config/base_url.py +12 -0
  224. synth_ai/v0/lm/__init__.py +51 -0
  225. synth_ai/{lm → v0/lm}/caching/ephemeral.py +3 -5
  226. synth_ai/{lm → v0/lm}/caching/handler.py +4 -4
  227. synth_ai/{lm → v0/lm}/caching/initialize.py +1 -1
  228. synth_ai/{lm → v0/lm}/caching/persistent.py +1 -1
  229. synth_ai/{lm → v0/lm}/config.py +6 -1
  230. synth_ai/{lm → v0/lm}/core/all.py +9 -9
  231. synth_ai/{lm → v0/lm}/core/exceptions.py +0 -2
  232. synth_ai/{lm → v0/lm}/core/main.py +19 -7
  233. synth_ai/{lm → v0/lm}/core/main_v3.py +10 -10
  234. synth_ai/{lm → v0/lm}/core/synth_models.py +2 -15
  235. synth_ai/{lm → v0/lm}/core/vendor_clients.py +6 -4
  236. synth_ai/{lm → v0/lm}/overrides.py +4 -4
  237. synth_ai/{lm → v0/lm}/provider_support/anthropic.py +4 -4
  238. synth_ai/{lm → v0/lm}/provider_support/openai.py +5 -5
  239. synth_ai/{lm → v0/lm}/structured_outputs/handler.py +5 -5
  240. synth_ai/{lm → v0/lm}/structured_outputs/rehabilitate.py +1 -1
  241. synth_ai/{lm → v0/lm}/vendors/core/anthropic_api.py +16 -16
  242. synth_ai/{lm → v0/lm}/vendors/core/gemini_api.py +5 -5
  243. synth_ai/{lm → v0/lm}/vendors/core/mistral_api.py +5 -5
  244. synth_ai/{lm → v0/lm}/vendors/core/openai_api.py +12 -10
  245. synth_ai/{lm → v0/lm}/vendors/openai_standard.py +11 -9
  246. synth_ai/{lm → v0/lm}/vendors/openai_standard_responses.py +8 -5
  247. synth_ai/{lm → v0/lm}/vendors/supported/custom_endpoint.py +4 -6
  248. synth_ai/{lm → v0/lm}/vendors/supported/deepseek.py +2 -2
  249. synth_ai/{lm → v0/lm}/vendors/supported/grok.py +2 -2
  250. synth_ai/{lm → v0/lm}/vendors/supported/groq.py +1 -1
  251. synth_ai/{lm → v0/lm}/vendors/supported/ollama.py +1 -1
  252. synth_ai/{lm → v0/lm}/vendors/supported/openrouter.py +3 -3
  253. synth_ai/{lm → v0/lm}/vendors/supported/together.py +1 -1
  254. synth_ai/{lm → v0/lm}/vendors/synth_client.py +38 -11
  255. synth_ai/v0/tracing/upload.py +32 -135
  256. synth_ai/v0/tracing_v3/__init__.py +10 -0
  257. synth_ai/v0/tracing_v3/abstractions.py +3 -0
  258. synth_ai/v0/tracing_v3/decorators.py +3 -0
  259. synth_ai/v0/tracing_v3/llm_call_record_helpers.py +3 -0
  260. synth_ai/v0/tracing_v3/session_tracer.py +3 -0
  261. synth_ai-0.2.9.dev6.dist-info/METADATA +191 -0
  262. {synth_ai-0.2.9.dev4.dist-info → synth_ai-0.2.9.dev6.dist-info}/RECORD +291 -264
  263. {synth_ai-0.2.9.dev4.dist-info → synth_ai-0.2.9.dev6.dist-info}/top_level.txt +1 -0
  264. examples/common_old/backend.py +0 -21
  265. examples/evals_old/README.md +0 -98
  266. examples/evals_old/__init__.py +0 -6
  267. examples/evals_old/compare_models.py +0 -1037
  268. examples/evals_old/example_log.md +0 -145
  269. examples/evals_old/run_demo.sh +0 -126
  270. examples/evals_old/trace_analysis.py +0 -270
  271. examples/finetuning_old/_backup_synth_qwen/config.toml +0 -29
  272. examples/finetuning_old/_backup_synth_qwen/example_log.md +0 -324
  273. examples/finetuning_old/_backup_synth_qwen/filter_traces.py +0 -60
  274. examples/finetuning_old/_backup_synth_qwen/filter_traces_achievements.py +0 -239
  275. examples/finetuning_old/_backup_synth_qwen/purge_v3_traces.py +0 -109
  276. examples/finetuning_old/_backup_synth_qwen/react_agent_lm.py +0 -1924
  277. examples/finetuning_old/_backup_synth_qwen/readme.md +0 -49
  278. examples/finetuning_old/_backup_synth_qwen/run_crafter_qwen4b.py +0 -114
  279. examples/finetuning_old/_backup_synth_qwen/run_demo.sh +0 -195
  280. examples/finetuning_old/_backup_synth_qwen/sft_kickoff.py +0 -118
  281. examples/finetuning_old/synth_qwen_v1/README.md +0 -68
  282. examples/finetuning_old/synth_qwen_v1/filter_traces.py +0 -60
  283. examples/finetuning_old/synth_qwen_v1/filter_traces_achievements.py +0 -239
  284. examples/finetuning_old/synth_qwen_v1/finetune.py +0 -46
  285. examples/finetuning_old/synth_qwen_v1/hello_ft_model.py +0 -71
  286. examples/finetuning_old/synth_qwen_v1/infer.py +0 -37
  287. examples/finetuning_old/synth_qwen_v1/poll.py +0 -44
  288. examples/finetuning_old/synth_qwen_v1/prepare_data.py +0 -35
  289. examples/finetuning_old/synth_qwen_v1/purge_v3_traces.py +0 -109
  290. examples/finetuning_old/synth_qwen_v1/react_agent_lm.py +0 -1932
  291. examples/finetuning_old/synth_qwen_v1/run_crafter_sft_job.py +0 -207
  292. examples/finetuning_old/synth_qwen_v1/run_ft_job.py +0 -232
  293. examples/finetuning_old/synth_qwen_v1/upload_data.py +0 -34
  294. examples/finetuning_old/synth_qwen_v1/util.py +0 -147
  295. examples/rl_old/task_app.py +0 -962
  296. examples/warming_up_to_rl/old/event_rewards.md +0 -234
  297. examples/warming_up_to_rl/old/notes.md +0 -73
  298. examples/warming_up_to_rl/task_app/synth_envs_hosted/test_stepwise_rewards.py +0 -58
  299. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_traces_sft_turso.py +0 -738
  300. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/filter_traces_sft_turso.py +0 -580
  301. synth_ai/environments/examples/sokoban/units/astar_common.py +0 -95
  302. synth_ai/experimental/synth_oss.py +0 -446
  303. synth_ai/install_sqld.sh +0 -40
  304. synth_ai/learning/filtering.py +0 -0
  305. synth_ai/learning/offline/dpo.py +0 -0
  306. synth_ai/learning/offline/providers.py +0 -7
  307. synth_ai/learning/offline/sft.py +0 -0
  308. synth_ai/learning/offline/shared.py +0 -0
  309. synth_ai/learning/online/grpo.py +0 -0
  310. synth_ai/learning/online/irft.py +0 -0
  311. synth_ai/learning/prompts/banking77_injection_eval.py +0 -168
  312. synth_ai/learning/prompts/gepa.py +0 -0
  313. synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +0 -213
  314. synth_ai/learning/prompts/mipro.py +0 -289
  315. synth_ai/learning/prompts/random_search.py +0 -246
  316. synth_ai/learning/prompts/run_mipro_banking77.py +0 -172
  317. synth_ai/learning/prompts/run_random_search_banking77.py +0 -324
  318. synth_ai/rl/secrets.py +0 -19
  319. synth_ai/scripts/verify_rewards.py +0 -100
  320. synth_ai/tracing/__init__.py +0 -30
  321. synth_ai/tracing_v1/__init__.py +0 -33
  322. synth_ai/tracing_v3/turso/__init__.py +0 -25
  323. synth_ai/tracing_v3/turso/manager.py +0 -774
  324. synth_ai/zyk/__init__.py +0 -30
  325. synth_ai-0.2.9.dev4.dist-info/METADATA +0 -131
  326. /synth_ai/{lm → v0/lm}/caching/__init__.py +0 -0
  327. /synth_ai/{lm → v0/lm}/caching/constants.py +0 -0
  328. /synth_ai/{lm → v0/lm}/caching/dbs.py +0 -0
  329. /synth_ai/{lm → v0/lm}/constants.py +0 -0
  330. /synth_ai/{lm → v0/lm}/core/__init__.py +0 -0
  331. /synth_ai/{lm → v0/lm}/cost/__init__.py +0 -0
  332. /synth_ai/{lm → v0/lm}/cost/monitor.py +0 -0
  333. /synth_ai/{lm → v0/lm}/cost/statefulness.py +0 -0
  334. /synth_ai/{lm → v0/lm}/injection.py +0 -0
  335. /synth_ai/{lm → v0/lm}/provider_support/__init__.py +0 -0
  336. /synth_ai/{lm → v0/lm}/provider_support/suppress_logging.py +0 -0
  337. /synth_ai/{lm → v0/lm}/structured_outputs/__init__.py +0 -0
  338. /synth_ai/{lm → v0/lm}/structured_outputs/inject.py +0 -0
  339. /synth_ai/{lm → v0/lm}/tools/__init__.py +0 -0
  340. /synth_ai/{lm → v0/lm}/tools/base.py +0 -0
  341. /synth_ai/{lm → v0/lm}/unified_interface.py +0 -0
  342. /synth_ai/{lm → v0/lm}/vendors/__init__.py +0 -0
  343. /synth_ai/{lm → v0/lm}/vendors/base.py +0 -0
  344. /synth_ai/{lm → v0/lm}/vendors/core/__init__.py +0 -0
  345. /synth_ai/{lm → v0/lm}/vendors/core/synth_dev_api.py +0 -0
  346. /synth_ai/{lm → v0/lm}/vendors/local/__init__.py +0 -0
  347. /synth_ai/{lm → v0/lm}/vendors/local/ollama.py +0 -0
  348. /synth_ai/{lm → v0/lm}/vendors/retries.py +0 -0
  349. /synth_ai/{lm → v0/lm}/vendors/supported/__init__.py +0 -0
  350. /synth_ai/{lm → v0/lm}/warmup.py +0 -0
  351. {synth_ai-0.2.9.dev4.dist-info → synth_ai-0.2.9.dev6.dist-info}/WHEEL +0 -0
  352. {synth_ai-0.2.9.dev4.dist-info → synth_ai-0.2.9.dev6.dist-info}/entry_points.txt +0 -0
  353. {synth_ai-0.2.9.dev4.dist-info → synth_ai-0.2.9.dev6.dist-info}/licenses/LICENSE +0 -0
@@ -1,774 +0,0 @@
1
- from __future__ import annotations
2
- """Async SQLAlchemy-based trace manager for Turso/sqld.
3
-
4
- This module provides the database interface for the tracing system using
5
- async SQLAlchemy with a Turso/sqld backend. It handles all database operations
6
- including schema creation, session storage, and analytics queries.
7
-
8
- Key Features:
9
- ------------
10
- - Async-first design using aiosqlite for local SQLite
11
- - Automatic schema creation and migration
12
- - Batch insert capabilities for high-throughput scenarios
13
- - Analytics views for efficient querying
14
- - Connection pooling and retry logic
15
-
16
- Performance Considerations:
17
- --------------------------
18
- - Uses NullPool for SQLite to avoid connection issues
19
- - Implements busy timeout for concurrent access
20
- - Batches inserts to reduce transaction overhead
21
- - Creates indexes for common query patterns
22
- """
23
-
24
- import asyncio
25
- import logging
26
- from contextlib import asynccontextmanager
27
- from datetime import datetime
28
- from typing import Any
29
-
30
- # Optional pandas import: fall back to records (list[dict]) if unavailable
31
- try: # pragma: no cover - exercised in environments without pandas
32
- import pandas as pd # type: ignore
33
- except Exception: # pragma: no cover
34
- pd = None # type: ignore[assignment]
35
- from sqlalchemy import select, text, update
36
- from sqlalchemy.exc import IntegrityError
37
- from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
38
- from sqlalchemy import event
39
- from sqlalchemy.orm import selectinload, sessionmaker
40
- from sqlalchemy.pool import NullPool
41
-
42
- from ..abstractions import (
43
- EnvironmentEvent,
44
- LMCAISEvent,
45
- RuntimeEvent,
46
- SessionTrace,
47
- )
48
- from ..config import CONFIG
49
- from .models import (
50
- Base,
51
- analytics_views,
52
- )
53
- from .models import (
54
- Event as DBEvent,
55
- )
56
- from .models import (
57
- Experiment as DBExperiment,
58
- )
59
- from .models import (
60
- Message as DBMessage,
61
- )
62
- from .models import (
63
- SessionTimestep as DBSessionTimestep,
64
- )
65
- from .models import (
66
- SessionTrace as DBSessionTrace,
67
- )
68
- from .models import (
69
- OutcomeReward as DBOutcomeReward,
70
- )
71
- from .models import (
72
- EventReward as DBEventReward,
73
- )
74
-
75
- logger = logging.getLogger(__name__)
76
-
77
-
78
- class AsyncSQLTraceManager:
79
- """Async trace storage manager using SQLAlchemy and Turso/sqld.
80
-
81
- Handles all database operations for the tracing system. Designed to work
82
- with both local SQLite (via aiosqlite) and remote Turso databases.
83
-
84
- The manager handles:
85
- - Connection lifecycle management
86
- - Schema creation and verification
87
- - Transaction management
88
- - Batch operations for efficiency
89
- - Analytics view creation
90
- """
91
-
92
- def __init__(self, db_url: str | None = None):
93
- self.db_url = db_url or CONFIG.db_url
94
- self.engine: AsyncEngine | None = None
95
- self.SessionLocal: sessionmaker | None = None
96
- self._schema_lock = asyncio.Lock()
97
- self._schema_ready = False
98
-
99
- async def initialize(self):
100
- """Initialize the database connection and schema.
101
-
102
- This method is idempotent and thread-safe. It:
103
- 1. Creates the async engine with appropriate settings
104
- 2. Verifies database file exists (for SQLite)
105
- 3. Creates schema if needed
106
- 4. Sets up analytics views
107
-
108
- The schema lock ensures only one worker creates the schema in
109
- concurrent scenarios.
110
- """
111
- if self.engine is None:
112
- logger.debug(f"🔗 Initializing database connection to: {self.db_url}")
113
-
114
- # For SQLite, use NullPool to avoid connection pool issues
115
- # SQLite doesn't handle concurrent connections well, so we create
116
- # a new connection for each operation
117
- if self.db_url.startswith("sqlite"):
118
- # Extract the file path from the URL
119
- db_path = self.db_url.replace("sqlite+aiosqlite:///", "")
120
- import os
121
-
122
- # Check if database file exists
123
- if not os.path.exists(db_path):
124
- logger.debug(f"⚠️ Database file not found: {db_path}")
125
- logger.debug(
126
- "🔧 Make sure './serve.sh' is running to start the turso/sqld service"
127
- )
128
- else:
129
- logger.debug(f"✅ Found database file: {db_path}")
130
-
131
- # Set a high busy timeout to handle concurrent access
132
- # This allows SQLite to wait instead of immediately failing
133
- connect_args = {"timeout": 30.0} # 30 second busy timeout
134
- self.engine = create_async_engine(
135
- self.db_url, # Use instance db_url, not CONFIG
136
- poolclass=NullPool, # No connection pooling for SQLite
137
- connect_args=connect_args,
138
- echo=CONFIG.echo_sql,
139
- )
140
- # Ensure PRAGMA foreign_keys=ON for every connection
141
- try:
142
- @event.listens_for(self.engine.sync_engine, "connect")
143
- def _set_sqlite_pragma(dbapi_connection, connection_record): # type: ignore[no-redef]
144
- try:
145
- cursor = dbapi_connection.cursor()
146
- cursor.execute("PRAGMA foreign_keys=ON")
147
- cursor.close()
148
- except Exception:
149
- pass
150
- except Exception:
151
- pass
152
- else:
153
- connect_args = CONFIG.get_connect_args()
154
- engine_kwargs = CONFIG.get_engine_kwargs()
155
- self.engine = create_async_engine(
156
- self.db_url, # Use instance db_url, not CONFIG
157
- connect_args=connect_args,
158
- **engine_kwargs,
159
- )
160
-
161
- self.SessionLocal = sessionmaker(
162
- self.engine, class_=AsyncSession, expire_on_commit=False
163
- )
164
-
165
- await self._ensure_schema()
166
-
167
- async def _ensure_schema(self):
168
- """Ensure database schema is created.
169
-
170
- Uses a lock to prevent race conditions when multiple workers start
171
- simultaneously. The checkfirst=True parameter handles cases where
172
- another worker already created the schema.
173
- """
174
- async with self._schema_lock:
175
- if self._schema_ready:
176
- return
177
-
178
- logger.debug("📊 Initializing database schema...")
179
-
180
- async with self.engine.begin() as conn:
181
- # Use a transaction to ensure atomic schema creation
182
- # checkfirst=True prevents errors if tables already exist
183
- try:
184
- await conn.run_sync(
185
- lambda sync_conn: Base.metadata.create_all(sync_conn, checkfirst=True)
186
- )
187
- # logger.info("✅ Database schema created/verified successfully")
188
- except Exception as e:
189
- # If tables already exist, that's fine - another worker created them
190
- if "already exists" not in str(e):
191
- logger.error(f"❌ Failed to create database schema: {e}")
192
- raise
193
- else:
194
- logger.debug("✅ Database schema already exists")
195
-
196
- # Enable foreign keys for SQLite - critical for data integrity
197
- # This must be done for each connection in SQLite
198
- if CONFIG.foreign_keys:
199
- await conn.execute(text("PRAGMA foreign_keys = ON"))
200
-
201
- # Set journal mode
202
- if CONFIG.journal_mode:
203
- await conn.execute(text(f"PRAGMA journal_mode = {CONFIG.journal_mode}"))
204
-
205
- # Create analytics views for efficient querying
206
- # These are materialized as views to avoid recalculation
207
- for view_name, view_sql in analytics_views.items():
208
- try:
209
- await conn.execute(text(view_sql))
210
- except Exception as e:
211
- # Views might already exist from another worker
212
- if "already exists" not in str(e):
213
- logger.warning(f"Could not create view {view_name}: {e}")
214
-
215
- self._schema_ready = True
216
- # logger.debug("🎯 Database ready for use!")
217
-
218
- @asynccontextmanager
219
- async def session(self):
220
- """Get an async database session."""
221
- if not self.SessionLocal:
222
- await self.initialize()
223
- async with self.SessionLocal() as session:
224
- yield session
225
-
226
- async def insert_session_trace(self, trace: SessionTrace) -> str:
227
- """Insert a complete session trace.
228
-
229
- This method handles the complex task of inserting a complete session
230
- with all its timesteps, events, and messages. It uses a single
231
- transaction for atomicity and flushes after timesteps to get their
232
- auto-generated IDs for foreign keys.
233
-
234
- Args:
235
- trace: The complete session trace to store
236
-
237
- Returns:
238
- The session ID
239
-
240
- Raises:
241
- IntegrityError: If session ID already exists (handled gracefully)
242
- """
243
- async with self.session() as sess:
244
- try:
245
- # Convert to cents for cost storage - avoids floating point
246
- # precision issues and allows for integer arithmetic
247
- def to_cents(cost: float | None) -> int | None:
248
- return int(cost * 100) if cost is not None else None
249
-
250
- # Insert session
251
- db_session = DBSessionTrace(
252
- session_id=trace.session_id,
253
- created_at=trace.created_at,
254
- num_timesteps=len(trace.session_time_steps),
255
- num_events=len(trace.event_history),
256
- num_messages=len(trace.markov_blanket_message_history),
257
- session_metadata=trace.metadata or {},
258
- )
259
- sess.add(db_session)
260
-
261
- # Track timestep IDs for foreign keys - we need these to link
262
- # events and messages to their respective timesteps
263
- step_id_map: dict[str, int] = {}
264
-
265
- # Insert timesteps
266
- for step in trace.session_time_steps:
267
- db_step = DBSessionTimestep(
268
- session_id=trace.session_id,
269
- step_id=step.step_id,
270
- step_index=step.step_index,
271
- turn_number=step.turn_number,
272
- started_at=step.timestamp,
273
- completed_at=step.completed_at,
274
- num_events=len(step.events),
275
- num_messages=len(step.markov_blanket_messages),
276
- step_metadata=step.step_metadata or {},
277
- )
278
- sess.add(db_step)
279
- # Flush to get the auto-generated ID without committing
280
- # This allows us to use the ID for foreign keys while
281
- # maintaining transaction atomicity
282
- await sess.flush() # Get the auto-generated ID
283
- step_id_map[step.step_id] = db_step.id
284
-
285
- # Insert events - handle different event types with their
286
- # specific fields while maintaining a unified storage model
287
- for event in trace.event_history:
288
- event_data = {
289
- "session_id": trace.session_id,
290
- "timestep_id": step_id_map.get(event.metadata.get("step_id")),
291
- "system_instance_id": event.system_instance_id,
292
- "event_time": event.time_record.event_time,
293
- "message_time": event.time_record.message_time,
294
- "event_metadata_json": event.metadata or {},
295
- "event_extra_metadata": event.event_metadata,
296
- }
297
-
298
- if isinstance(event, LMCAISEvent):
299
- # Serialize call_records if present
300
- call_records_data = None
301
- if event.call_records:
302
- from dataclasses import asdict
303
-
304
- call_records_data = [asdict(record) for record in event.call_records]
305
-
306
- event_data.update(
307
- {
308
- "event_type": "cais",
309
- "model_name": event.model_name,
310
- "provider": event.provider,
311
- "input_tokens": event.input_tokens,
312
- "output_tokens": event.output_tokens,
313
- "total_tokens": event.total_tokens,
314
- "cost_usd": to_cents(event.cost_usd),
315
- "latency_ms": event.latency_ms,
316
- "span_id": event.span_id,
317
- "trace_id": event.trace_id,
318
- "system_state_before": event.system_state_before,
319
- "system_state_after": event.system_state_after,
320
- "call_records": call_records_data, # Store in the proper column
321
- }
322
- )
323
- elif isinstance(event, EnvironmentEvent):
324
- event_data.update(
325
- {
326
- "event_type": "environment",
327
- "reward": event.reward,
328
- "terminated": event.terminated,
329
- "truncated": event.truncated,
330
- "system_state_before": event.system_state_before,
331
- "system_state_after": event.system_state_after,
332
- }
333
- )
334
- elif isinstance(event, RuntimeEvent):
335
- event_data.update(
336
- {
337
- "event_type": "runtime",
338
- "event_metadata_json": {**event.metadata, "actions": event.actions},
339
- }
340
- )
341
- else:
342
- event_data["event_type"] = event.__class__.__name__.lower()
343
-
344
- db_event = DBEvent(**event_data)
345
- sess.add(db_event)
346
-
347
- # Insert messages
348
- for msg in trace.markov_blanket_message_history:
349
- db_msg = DBMessage(
350
- session_id=trace.session_id,
351
- timestep_id=step_id_map.get(msg.metadata.get("step_id"))
352
- if hasattr(msg, "metadata")
353
- else None,
354
- message_type=msg.message_type,
355
- content=msg.content,
356
- event_time=msg.time_record.event_time,
357
- message_time=msg.time_record.message_time,
358
- message_metadata=msg.metadata if hasattr(msg, "metadata") else {},
359
- )
360
- sess.add(db_msg)
361
-
362
- # Commit the entire transaction atomically
363
- await sess.commit()
364
- return trace.session_id
365
- except IntegrityError as e:
366
- # Handle duplicate session IDs gracefully - this can happen
367
- # in distributed systems or retries. We return the existing
368
- # ID to maintain idempotency
369
- if "UNIQUE constraint failed: session_traces.session_id" in str(e):
370
- await sess.rollback()
371
- return trace.session_id # Return existing ID
372
- raise
373
-
374
- async def get_session_trace(self, session_id: str) -> dict[str, Any] | None:
375
- """Retrieve a session trace by ID."""
376
- async with self.session() as sess:
377
- result = await sess.execute(
378
- select(DBSessionTrace)
379
- .options(
380
- selectinload(DBSessionTrace.timesteps),
381
- selectinload(DBSessionTrace.events),
382
- selectinload(DBSessionTrace.messages),
383
- )
384
- .where(DBSessionTrace.session_id == session_id)
385
- )
386
- session = result.scalar_one_or_none()
387
-
388
- if not session:
389
- return None
390
-
391
- return {
392
- "session_id": session.session_id,
393
- "created_at": session.created_at,
394
- "num_timesteps": session.num_timesteps,
395
- "num_events": session.num_events,
396
- "num_messages": session.num_messages,
397
- "metadata": session.session_metadata,
398
- "timesteps": [
399
- {
400
- "step_id": step.step_id,
401
- "step_index": step.step_index,
402
- "turn_number": step.turn_number,
403
- "started_at": step.started_at,
404
- "completed_at": step.completed_at,
405
- "metadata": step.step_metadata,
406
- }
407
- for step in sorted(session.timesteps, key=lambda s: s.step_index)
408
- ],
409
- }
410
-
411
- async def query_traces(
412
- self, query: str, params: dict[str, Any] | None = None
413
- ) -> Any:
414
- """Execute a query and return results.
415
-
416
- Returns a pandas DataFrame when pandas is available; otherwise a
417
- list of dict records. Callers should handle both.
418
- """
419
- async with self.session() as sess:
420
- result = await sess.execute(text(query), params or {})
421
- rows = result.mappings().all()
422
- if pd is not None:
423
- return pd.DataFrame(rows)
424
- return [dict(r) for r in rows]
425
-
426
- async def get_model_usage(
427
- self,
428
- start_date: datetime | None = None,
429
- end_date: datetime | None = None,
430
- model_name: str | None = None,
431
- ) -> Any:
432
- """Get model usage statistics.
433
-
434
- Returns a pandas DataFrame when pandas is available; otherwise a list
435
- of dict records.
436
- """
437
- query = """
438
- SELECT * FROM model_usage_stats
439
- WHERE 1=1
440
- """
441
- params = {}
442
-
443
- if start_date:
444
- query += " AND last_used >= :start_date"
445
- params["start_date"] = start_date
446
-
447
- if end_date:
448
- query += " AND first_used <= :end_date"
449
- params["end_date"] = end_date
450
-
451
- if model_name:
452
- query += " AND model_name = :model_name"
453
- params["model_name"] = model_name
454
-
455
- query += " ORDER BY usage_count DESC"
456
-
457
- return await self.query_traces(query, params)
458
-
459
- async def create_experiment(
460
- self,
461
- experiment_id: str,
462
- name: str,
463
- description: str | None = None,
464
- configuration: dict[str, Any] | None = None,
465
- ) -> str:
466
- """Create a new experiment."""
467
- async with self.session() as sess:
468
- experiment = DBExperiment(
469
- experiment_id=experiment_id,
470
- name=name,
471
- description=description,
472
- configuration=configuration or {},
473
- )
474
- sess.add(experiment)
475
- await sess.commit()
476
- return experiment_id
477
-
478
- async def link_session_to_experiment(self, session_id: str, experiment_id: str):
479
- """Link a session to an experiment."""
480
- async with self.session() as sess:
481
- await sess.execute(
482
- update(DBSessionTrace)
483
- .where(DBSessionTrace.session_id == session_id)
484
- .values(experiment_id=experiment_id)
485
- )
486
- await sess.commit()
487
-
488
- async def batch_insert_sessions(
489
- self, traces: list[SessionTrace], batch_size: int | None = None
490
- ) -> list[str]:
491
- """Batch insert multiple session traces.
492
-
493
- Processes traces in batches to balance memory usage and performance.
494
- Each batch is inserted in a separate transaction to avoid holding
495
- locks for too long.
496
-
497
- Args:
498
- traces: List of session traces to insert
499
- batch_size: Number of traces per batch (defaults to config)
500
-
501
- Returns:
502
- List of inserted session IDs
503
- """
504
- batch_size = batch_size or CONFIG.batch_size
505
- inserted_ids = []
506
-
507
- # Process in chunks to avoid memory issues with large datasets
508
- for i in range(0, len(traces), batch_size):
509
- batch = traces[i : i + batch_size]
510
- # Insert each trace in the batch - could be optimized further
511
- # with bulk inserts if needed
512
- for trace in batch:
513
- session_id = await self.insert_session_trace(trace)
514
- inserted_ids.append(session_id)
515
-
516
- return inserted_ids
517
-
518
- async def get_sessions_by_experiment(
519
- self, experiment_id: str, limit: int | None = None
520
- ) -> list[dict[str, Any]]:
521
- """Get all sessions for an experiment."""
522
- async with self.session() as sess:
523
- query = (
524
- select(DBSessionTrace)
525
- .where(DBSessionTrace.experiment_id == experiment_id)
526
- .order_by(DBSessionTrace.created_at.desc())
527
- )
528
-
529
- if limit:
530
- query = query.limit(limit)
531
-
532
- result = await sess.execute(query)
533
- sessions = result.scalars().all()
534
-
535
- return [
536
- {
537
- "session_id": s.session_id,
538
- "created_at": s.created_at,
539
- "num_timesteps": s.num_timesteps,
540
- "num_events": s.num_events,
541
- "num_messages": s.num_messages,
542
- "metadata": s.metadata,
543
- }
544
- for s in sessions
545
- ]
546
-
547
- async def delete_session(self, session_id: str) -> bool:
548
- """Delete a session and all related data."""
549
- async with self.session() as sess:
550
- # Get the session object to trigger cascade deletes
551
- result = await sess.execute(
552
- select(DBSessionTrace).where(DBSessionTrace.session_id == session_id)
553
- )
554
- session = result.scalar_one_or_none()
555
-
556
- if session:
557
- await sess.delete(session)
558
- await sess.commit()
559
- return True
560
- return False
561
-
562
- async def close(self):
563
- """Close the database connection.
564
-
565
- Properly disposes of the engine and all connections. This is important
566
- for cleanup, especially with SQLite which can leave lock files.
567
- """
568
- if self.engine:
569
- # Dispose of all connections in the pool
570
- await self.engine.dispose()
571
- # Clear all state to allow re-initialization if needed
572
- self.engine = None
573
- self.SessionLocal = None
574
- self._schema_ready = False
575
-
576
- # -------------------------------
577
- # Incremental insert helpers
578
- # -------------------------------
579
-
580
- async def ensure_session(self, session_id: str, *, created_at: datetime | None = None, metadata: dict[str, Any] | None = None):
581
- """Ensure a DB session row exists for session_id."""
582
- async with self.session() as sess:
583
- result = await sess.execute(select(DBSessionTrace).where(DBSessionTrace.session_id == session_id))
584
- existing = result.scalar_one_or_none()
585
- if existing:
586
- return
587
- row = DBSessionTrace(
588
- session_id=session_id,
589
- created_at=created_at or datetime.utcnow(),
590
- num_timesteps=0,
591
- num_events=0,
592
- num_messages=0,
593
- session_metadata=metadata or {},
594
- )
595
- sess.add(row)
596
- await sess.commit()
597
-
598
- async def ensure_timestep(self, session_id: str, *, step_id: str, step_index: int, turn_number: int | None = None, started_at: datetime | None = None, completed_at: datetime | None = None, metadata: dict[str, Any] | None = None) -> int:
599
- """Ensure a timestep row exists; return its DB id."""
600
- async with self.session() as sess:
601
- result = await sess.execute(
602
- select(DBSessionTimestep).where(DBSessionTimestep.session_id == session_id, DBSessionTimestep.step_id == step_id)
603
- )
604
- row = result.scalar_one_or_none()
605
- if row:
606
- return row.id
607
- row = DBSessionTimestep(
608
- session_id=session_id,
609
- step_id=step_id,
610
- step_index=step_index,
611
- turn_number=turn_number,
612
- started_at=started_at or datetime.utcnow(),
613
- completed_at=completed_at,
614
- num_events=0,
615
- num_messages=0,
616
- step_metadata=metadata or {},
617
- )
618
- sess.add(row)
619
- await sess.flush()
620
- # increment session num_timesteps
621
- await sess.execute(
622
- update(DBSessionTrace)
623
- .where(DBSessionTrace.session_id == session_id)
624
- .values(num_timesteps=DBSessionTrace.num_timesteps + 1)
625
- )
626
- await sess.commit()
627
- return row.id
628
-
629
- async def insert_message_row(self, session_id: str, *, timestep_db_id: int | None, message_type: str, content: str, event_time: float | None = None, message_time: int | None = None, metadata: dict[str, Any] | None = None) -> int:
630
- """Insert a message and return its id."""
631
- async with self.session() as sess:
632
- db_msg = DBMessage(
633
- session_id=session_id,
634
- timestep_id=timestep_db_id,
635
- message_type=message_type,
636
- content=content,
637
- event_time=event_time,
638
- message_time=message_time,
639
- message_metadata=metadata or {},
640
- )
641
- sess.add(db_msg)
642
- await sess.flush()
643
- # increment session num_messages
644
- await sess.execute(
645
- update(DBSessionTrace)
646
- .where(DBSessionTrace.session_id == session_id)
647
- .values(num_messages=DBSessionTrace.num_messages + 1)
648
- )
649
- await sess.commit()
650
- return db_msg.id
651
-
652
- async def insert_event_row(self, session_id: str, *, timestep_db_id: int | None, event: EnvironmentEvent | LMCAISEvent | RuntimeEvent, metadata_override: dict[str, Any] | None = None) -> int:
653
- """Insert an event and return its id."""
654
- def to_cents(cost: float | None) -> int | None:
655
- return int(cost * 100) if cost is not None else None
656
-
657
- event_data: dict[str, Any] = {
658
- "session_id": session_id,
659
- "timestep_id": timestep_db_id,
660
- "system_instance_id": event.system_instance_id,
661
- "event_time": event.time_record.event_time,
662
- "message_time": event.time_record.message_time,
663
- "event_metadata_json": metadata_override or event.metadata or {},
664
- "event_extra_metadata": getattr(event, "event_metadata", None),
665
- }
666
- if isinstance(event, LMCAISEvent):
667
- call_records_data = None
668
- if getattr(event, "call_records", None):
669
- from dataclasses import asdict
670
-
671
- call_records_data = [asdict(record) for record in event.call_records]
672
- event_data.update({
673
- "event_type": "cais",
674
- "model_name": event.model_name,
675
- "provider": event.provider,
676
- "input_tokens": event.input_tokens,
677
- "output_tokens": event.output_tokens,
678
- "total_tokens": event.total_tokens,
679
- "cost_usd": to_cents(event.cost_usd),
680
- "latency_ms": event.latency_ms,
681
- "span_id": event.span_id,
682
- "trace_id": event.trace_id,
683
- "system_state_before": event.system_state_before,
684
- "system_state_after": event.system_state_after,
685
- "call_records": call_records_data,
686
- })
687
- elif isinstance(event, EnvironmentEvent):
688
- event_data.update({
689
- "event_type": "environment",
690
- "reward": event.reward,
691
- "terminated": event.terminated,
692
- "truncated": event.truncated,
693
- "system_state_before": event.system_state_before,
694
- "system_state_after": event.system_state_after,
695
- })
696
- elif isinstance(event, RuntimeEvent):
697
- event_data.update({
698
- "event_type": "runtime",
699
- "event_metadata_json": {**(event.metadata or {}), "actions": event.actions},
700
- })
701
- else:
702
- event_data["event_type"] = event.__class__.__name__.lower()
703
-
704
- async with self.session() as sess:
705
- db_event = DBEvent(**event_data)
706
- sess.add(db_event)
707
- await sess.flush()
708
- # increment session num_events
709
- await sess.execute(
710
- update(DBSessionTrace)
711
- .where(DBSessionTrace.session_id == session_id)
712
- .values(num_events=DBSessionTrace.num_events + 1)
713
- )
714
- await sess.commit()
715
- return db_event.id
716
-
717
- # -------------------------------
718
- # Reward helpers
719
- # -------------------------------
720
-
721
- async def insert_outcome_reward(self, session_id: str, *, total_reward: int, achievements_count: int, total_steps: int, reward_metadata: dict | None = None) -> int:
722
- async with self.session() as sess:
723
- row = DBOutcomeReward(
724
- session_id=session_id,
725
- total_reward=total_reward,
726
- achievements_count=achievements_count,
727
- total_steps=total_steps,
728
- reward_metadata=reward_metadata or {},
729
- )
730
- sess.add(row)
731
- await sess.flush()
732
- await sess.commit()
733
- return row.id
734
-
735
- async def insert_event_reward(self, session_id: str, *, event_id: int, message_id: int | None = None, turn_number: int | None = None, reward_value: float = 0.0, reward_type: str | None = None, key: str | None = None, annotation: dict[str, Any] | None = None, source: str | None = None) -> int:
736
- async with self.session() as sess:
737
- row = DBEventReward(
738
- event_id=event_id,
739
- session_id=session_id,
740
- message_id=message_id,
741
- turn_number=turn_number,
742
- reward_value=reward_value,
743
- reward_type=reward_type,
744
- key=key,
745
- annotation=annotation or {},
746
- source=source,
747
- )
748
- sess.add(row)
749
- await sess.flush()
750
- await sess.commit()
751
- return row.id
752
-
753
- async def get_outcome_rewards(self) -> list[dict[str, Any]]:
754
- async with self.session() as sess:
755
- result = await sess.execute(select(DBOutcomeReward))
756
- rows = result.scalars().all()
757
- return [
758
- {
759
- "id": r.id,
760
- "session_id": r.session_id,
761
- "total_reward": r.total_reward,
762
- "achievements_count": r.achievements_count,
763
- "total_steps": r.total_steps,
764
- "created_at": r.created_at,
765
- }
766
- for r in rows
767
- ]
768
-
769
- async def get_outcome_rewards_by_min_reward(self, min_reward: int) -> list[str]:
770
- async with self.session() as sess:
771
- result = await sess.execute(
772
- select(DBOutcomeReward.session_id).where(DBOutcomeReward.total_reward >= min_reward)
773
- )
774
- return [row[0] for row in result.all()]