synth-ai 0.2.9.dev7__py3-none-any.whl → 0.2.9.dev8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (327) hide show
  1. examples/__init__.py +16 -0
  2. examples/crafter_debug_render.py +8 -11
  3. examples/qwen_coder/README.md +102 -0
  4. examples/qwen_coder/_shared.py +113 -0
  5. examples/qwen_coder/configs/coder_lora_30b.toml +61 -0
  6. examples/qwen_coder/configs/coder_lora_4b.toml +57 -0
  7. examples/qwen_coder/configs/coder_lora_small.toml +58 -0
  8. examples/qwen_coder/generate_dataset.py +98 -0
  9. examples/qwen_coder/infer_ft_smoke.py +64 -0
  10. examples/qwen_coder/infer_prod_proxy.py +73 -0
  11. examples/qwen_coder/infer_via_synth.py +87 -0
  12. examples/qwen_coder/scripts/infer_coder.sh +18 -0
  13. examples/qwen_coder/scripts/train_coder_30b.sh +21 -0
  14. examples/qwen_coder/sft_full_17b.py +103 -0
  15. examples/qwen_coder/sft_lora_30b.py +110 -0
  16. examples/qwen_coder/subset_jsonl.py +38 -0
  17. examples/qwen_coder/validate_jsonl.py +59 -0
  18. examples/rl/run_eval.py +36 -37
  19. examples/rl/run_rl_and_save.py +5 -5
  20. examples/rl/task_app/math_single_step.py +65 -43
  21. examples/rl/task_app/math_task_app.py +3 -3
  22. examples/sft/README.md +139 -0
  23. examples/sft/configs/crafter_fft_qwen0p6b.toml +44 -0
  24. examples/sft/configs/crafter_lora_qwen0p6b.toml +45 -0
  25. examples/sft/evaluate.py +117 -0
  26. examples/sft/export_dataset.py +117 -0
  27. examples/sft/generate_traces.py +162 -0
  28. examples/swe/__init__.py +12 -0
  29. examples/swe/task_app/README.md +105 -0
  30. examples/swe/task_app/__init__.py +2 -0
  31. examples/swe/task_app/grpo_swe_mini.py +571 -0
  32. examples/swe/task_app/grpo_swe_mini_task_app.py +136 -0
  33. examples/swe/task_app/hosted/README.md +173 -0
  34. examples/swe/task_app/hosted/__init__.py +5 -0
  35. examples/swe/task_app/hosted/branching.py +143 -0
  36. examples/swe/task_app/hosted/environment_routes.py +1289 -0
  37. examples/swe/task_app/hosted/envs/__init__.py +1 -0
  38. examples/swe/task_app/hosted/envs/crafter/__init__.py +6 -0
  39. examples/swe/task_app/hosted/envs/crafter/app.py +1 -0
  40. examples/swe/task_app/hosted/envs/crafter/environment.py +522 -0
  41. examples/swe/task_app/hosted/envs/crafter/policy.py +478 -0
  42. examples/swe/task_app/hosted/envs/crafter/react_agent.py +108 -0
  43. examples/swe/task_app/hosted/envs/crafter/shared.py +305 -0
  44. examples/swe/task_app/hosted/envs/crafter/tools.py +47 -0
  45. examples/swe/task_app/hosted/envs/mini_swe/__init__.py +8 -0
  46. examples/swe/task_app/hosted/envs/mini_swe/environment.py +1164 -0
  47. examples/swe/task_app/hosted/envs/mini_swe/policy.py +355 -0
  48. examples/swe/task_app/hosted/envs/mini_swe/shared.py +83 -0
  49. examples/swe/task_app/hosted/envs/mini_swe/tools.py +96 -0
  50. examples/swe/task_app/hosted/hosted_app.py +204 -0
  51. examples/swe/task_app/hosted/inference/__init__.py +5 -0
  52. examples/swe/task_app/hosted/inference/openai_client.py +618 -0
  53. examples/swe/task_app/hosted/main.py +100 -0
  54. examples/swe/task_app/hosted/policy_routes.py +1079 -0
  55. examples/swe/task_app/hosted/registry.py +195 -0
  56. examples/swe/task_app/hosted/rollout.py +1869 -0
  57. examples/swe/task_app/hosted/storage/__init__.py +5 -0
  58. examples/swe/task_app/hosted/storage/volume.py +211 -0
  59. examples/swe/task_app/hosted/test_agents.py +161 -0
  60. examples/swe/task_app/hosted/test_service.py +137 -0
  61. examples/swe/task_app/hosted/utils.py +62 -0
  62. examples/vlm/README.md +68 -0
  63. examples/vlm/configs/crafter_vlm_gpt4o.toml +44 -0
  64. examples/vlm/crafter_image_only_agent.py +207 -0
  65. examples/vlm/crafter_openai_vlm_agent.py +277 -0
  66. examples/vlm/filter_image_rows.py +63 -0
  67. examples/vlm/run_crafter_vlm_benchmark.py +316 -0
  68. examples/warming_up_to_rl/analyze_trace_db.py +5 -5
  69. examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +11 -1
  70. examples/warming_up_to_rl/export_trace_sft.py +78 -21
  71. examples/warming_up_to_rl/groq_test.py +4 -4
  72. examples/warming_up_to_rl/manage_secrets.py +13 -18
  73. examples/warming_up_to_rl/run_eval.py +42 -44
  74. examples/warming_up_to_rl/run_fft_and_save.py +11 -16
  75. examples/warming_up_to_rl/run_local_rollout.py +1 -3
  76. examples/warming_up_to_rl/run_local_rollout_modal.py +2 -4
  77. examples/warming_up_to_rl/run_local_rollout_parallel.py +1 -4
  78. examples/warming_up_to_rl/run_local_rollout_traced.py +3 -5
  79. examples/warming_up_to_rl/run_rl_and_save.py +5 -6
  80. examples/warming_up_to_rl/run_rollout_remote.py +8 -10
  81. examples/warming_up_to_rl/task_app/README.md +6 -2
  82. examples/warming_up_to_rl/task_app/grpo_crafter.py +234 -35
  83. examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +2 -3
  84. examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +1 -1
  85. examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +9 -11
  86. examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +131 -114
  87. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +101 -41
  88. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +73 -51
  89. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +14 -6
  90. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +16 -16
  91. examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +32 -34
  92. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +94 -31
  93. examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +0 -2
  94. examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +303 -203
  95. examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +21 -23
  96. examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +328 -225
  97. examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +13 -13
  98. examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +1 -0
  99. examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +1 -0
  100. examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +4 -3
  101. synth/__init__.py +14 -0
  102. synth_ai/__init__.py +26 -4
  103. synth_ai/api/models/supported.py +376 -0
  104. synth_ai/api/train/builders.py +128 -21
  105. synth_ai/api/train/cli.py +80 -64
  106. synth_ai/api/train/config_finder.py +7 -2
  107. synth_ai/api/train/env_resolver.py +1 -1
  108. synth_ai/api/train/pollers.py +2 -1
  109. synth_ai/api/train/supported_algos.py +139 -0
  110. synth_ai/api/train/task_app.py +1 -2
  111. synth_ai/api/train/utils.py +13 -44
  112. synth_ai/cli/__init__.py +8 -0
  113. synth_ai/cli/_modal_wrapper.py +28 -0
  114. synth_ai/cli/_typer_patch.py +49 -0
  115. synth_ai/cli/balance.py +1 -2
  116. synth_ai/cli/calc.py +1 -1
  117. synth_ai/cli/demo.py +2 -1
  118. synth_ai/cli/recent.py +2 -2
  119. synth_ai/cli/rl_demo.py +2 -1
  120. synth_ai/cli/root.py +11 -13
  121. synth_ai/cli/status.py +2 -2
  122. synth_ai/cli/task_apps.py +529 -179
  123. synth_ai/cli/traces.py +6 -4
  124. synth_ai/cli/watch.py +12 -18
  125. synth_ai/demo_registry.py +1 -1
  126. synth_ai/demos/core/cli.py +36 -43
  127. synth_ai/demos/demo_task_apps/__init__.py +3 -3
  128. synth_ai/demos/demo_task_apps/core.py +17 -25
  129. synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +3 -4
  130. synth_ai/demos/demo_task_apps/math/app.py +2 -1
  131. synth_ai/demos/demo_task_apps/math/deploy_modal.py +3 -4
  132. synth_ai/demos/demo_task_apps/math/modal_task_app.py +16 -18
  133. synth_ai/demos/demo_task_apps/math/task_app_entry.py +0 -1
  134. synth_ai/environments/examples/crafter_classic/environment.py +76 -1
  135. synth_ai/environments/reproducibility/tree.py +2 -5
  136. synth_ai/environments/service/app.py +11 -12
  137. synth_ai/environments/service/core_routes.py +4 -7
  138. synth_ai/environments/stateful/engine.py +1 -1
  139. synth_ai/environments/tasks/core.py +1 -0
  140. synth_ai/environments/tasks/filters.py +5 -6
  141. synth_ai/environments/tasks/utils.py +4 -5
  142. synth_ai/handshake.py +9 -9
  143. synth_ai/http.py +1 -1
  144. synth_ai/http_client.py +18 -10
  145. synth_ai/inference/client.py +15 -5
  146. synth_ai/jobs/client.py +78 -83
  147. synth_ai/learning/__init__.py +41 -6
  148. synth_ai/learning/algorithms.py +14 -0
  149. synth_ai/learning/client.py +91 -24
  150. synth_ai/learning/config.py +2 -38
  151. synth_ai/learning/ft_client.py +4 -59
  152. synth_ai/learning/health.py +5 -6
  153. synth_ai/learning/jobs.py +31 -47
  154. synth_ai/{rl → learning/rl}/__init__.py +14 -4
  155. synth_ai/learning/rl/client.py +267 -0
  156. synth_ai/learning/rl/config.py +31 -0
  157. synth_ai/{rl → learning/rl}/contracts.py +5 -8
  158. synth_ai/{rl → learning/rl}/env_keys.py +39 -15
  159. synth_ai/learning/rl/secrets.py +13 -0
  160. synth_ai/learning/rl_client.py +2 -281
  161. synth_ai/learning/sft/__init__.py +29 -0
  162. synth_ai/learning/sft/client.py +68 -0
  163. synth_ai/learning/sft/config.py +270 -0
  164. synth_ai/learning/sft/data.py +295 -0
  165. synth_ai/learning/sse.py +25 -24
  166. synth_ai/learning/validators.py +25 -28
  167. synth_ai/lm/__init__.py +21 -47
  168. synth_ai/main.py +4 -0
  169. synth_ai/task/__init__.py +25 -27
  170. synth_ai/task/apps/__init__.py +7 -8
  171. synth_ai/task/auth.py +8 -8
  172. synth_ai/task/client.py +14 -14
  173. synth_ai/task/contracts.py +36 -35
  174. synth_ai/task/datasets.py +6 -5
  175. synth_ai/task/errors.py +10 -10
  176. synth_ai/task/health.py +17 -9
  177. synth_ai/task/json.py +58 -23
  178. synth_ai/task/proxy.py +13 -9
  179. synth_ai/task/rubrics.py +16 -15
  180. synth_ai/task/server.py +12 -12
  181. synth_ai/task/tracing_utils.py +4 -4
  182. synth_ai/task/vendors.py +5 -6
  183. synth_ai/tracing_v3/__init__.py +2 -0
  184. synth_ai/tracing_v3/abstractions.py +21 -4
  185. synth_ai/tracing_v3/decorators.py +18 -16
  186. synth_ai/tracing_v3/hooks.py +5 -5
  187. synth_ai/tracing_v3/llm_call_record_helpers.py +6 -6
  188. synth_ai/tracing_v3/session_tracer.py +40 -14
  189. synth_ai/tracing_v3/storage/base.py +85 -0
  190. synth_ai/tracing_v3/storage/config.py +21 -8
  191. synth_ai/tracing_v3/storage/factory.py +10 -7
  192. synth_ai/tracing_v3/storage/utils.py +4 -2
  193. synth_ai/tracing_v3/turso/daemon.py +7 -2
  194. synth_ai/tracing_v3/turso/models.py +2 -2
  195. synth_ai/tracing_v3/turso/native_manager.py +1173 -0
  196. synth_ai/tracing_v3/utils.py +4 -4
  197. synth_ai/v0/api/__init__.py +8 -0
  198. synth_ai/v0/api/models/__init__.py +8 -0
  199. synth_ai/v0/api/models/supported.py +8 -0
  200. synth_ai/v0/config/__init__.py +15 -0
  201. synth_ai/v0/config/base_url.py +12 -0
  202. synth_ai/v0/lm/__init__.py +51 -0
  203. synth_ai/{lm → v0/lm}/caching/ephemeral.py +2 -2
  204. synth_ai/{lm → v0/lm}/caching/handler.py +4 -4
  205. synth_ai/{lm → v0/lm}/caching/initialize.py +1 -1
  206. synth_ai/{lm → v0/lm}/caching/persistent.py +1 -1
  207. synth_ai/{lm → v0/lm}/config.py +6 -1
  208. synth_ai/{lm → v0/lm}/core/all.py +9 -9
  209. synth_ai/{lm → v0/lm}/core/main.py +6 -6
  210. synth_ai/{lm → v0/lm}/core/main_v3.py +10 -10
  211. synth_ai/{lm → v0/lm}/core/synth_models.py +2 -14
  212. synth_ai/{lm → v0/lm}/core/vendor_clients.py +2 -2
  213. synth_ai/{lm → v0/lm}/overrides.py +2 -2
  214. synth_ai/{lm → v0/lm}/provider_support/anthropic.py +4 -4
  215. synth_ai/{lm → v0/lm}/provider_support/openai.py +5 -5
  216. synth_ai/{lm → v0/lm}/structured_outputs/handler.py +5 -5
  217. synth_ai/{lm → v0/lm}/structured_outputs/rehabilitate.py +1 -1
  218. synth_ai/{lm → v0/lm}/vendors/core/anthropic_api.py +9 -9
  219. synth_ai/{lm → v0/lm}/vendors/core/gemini_api.py +5 -5
  220. synth_ai/{lm → v0/lm}/vendors/core/mistral_api.py +5 -5
  221. synth_ai/{lm → v0/lm}/vendors/core/openai_api.py +10 -10
  222. synth_ai/{lm → v0/lm}/vendors/openai_standard.py +8 -8
  223. synth_ai/{lm → v0/lm}/vendors/openai_standard_responses.py +2 -2
  224. synth_ai/{lm → v0/lm}/vendors/supported/custom_endpoint.py +3 -3
  225. synth_ai/{lm → v0/lm}/vendors/supported/deepseek.py +2 -2
  226. synth_ai/{lm → v0/lm}/vendors/supported/grok.py +2 -2
  227. synth_ai/{lm → v0/lm}/vendors/supported/groq.py +1 -1
  228. synth_ai/{lm → v0/lm}/vendors/supported/ollama.py +1 -1
  229. synth_ai/{lm → v0/lm}/vendors/supported/openrouter.py +3 -3
  230. synth_ai/{lm → v0/lm}/vendors/supported/together.py +1 -1
  231. synth_ai/{lm → v0/lm}/vendors/synth_client.py +1 -1
  232. synth_ai/v0/tracing_v3/__init__.py +10 -0
  233. synth_ai/v0/tracing_v3/abstractions.py +3 -0
  234. synth_ai/v0/tracing_v3/decorators.py +3 -0
  235. synth_ai/v0/tracing_v3/llm_call_record_helpers.py +3 -0
  236. synth_ai/v0/tracing_v3/session_tracer.py +3 -0
  237. synth_ai-0.2.9.dev8.dist-info/METADATA +191 -0
  238. {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.9.dev8.dist-info}/RECORD +268 -238
  239. {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.9.dev8.dist-info}/top_level.txt +1 -0
  240. examples/common_old/backend.py +0 -20
  241. examples/evals_old/README.md +0 -98
  242. examples/evals_old/__init__.py +0 -6
  243. examples/evals_old/compare_models.py +0 -1038
  244. examples/evals_old/example_log.md +0 -145
  245. examples/evals_old/run_demo.sh +0 -126
  246. examples/evals_old/trace_analysis.py +0 -270
  247. examples/finetuning_old/_backup_synth_qwen/config.toml +0 -29
  248. examples/finetuning_old/_backup_synth_qwen/example_log.md +0 -324
  249. examples/finetuning_old/_backup_synth_qwen/filter_traces.py +0 -60
  250. examples/finetuning_old/_backup_synth_qwen/filter_traces_achievements.py +0 -243
  251. examples/finetuning_old/_backup_synth_qwen/purge_v3_traces.py +0 -109
  252. examples/finetuning_old/_backup_synth_qwen/react_agent_lm.py +0 -1924
  253. examples/finetuning_old/_backup_synth_qwen/readme.md +0 -49
  254. examples/finetuning_old/_backup_synth_qwen/run_crafter_qwen4b.py +0 -114
  255. examples/finetuning_old/_backup_synth_qwen/run_demo.sh +0 -195
  256. examples/finetuning_old/_backup_synth_qwen/sft_kickoff.py +0 -119
  257. examples/finetuning_old/synth_qwen_v1/README.md +0 -68
  258. examples/finetuning_old/synth_qwen_v1/filter_traces.py +0 -60
  259. examples/finetuning_old/synth_qwen_v1/filter_traces_achievements.py +0 -243
  260. examples/finetuning_old/synth_qwen_v1/finetune.py +0 -46
  261. examples/finetuning_old/synth_qwen_v1/hello_ft_model.py +0 -71
  262. examples/finetuning_old/synth_qwen_v1/infer.py +0 -36
  263. examples/finetuning_old/synth_qwen_v1/poll.py +0 -46
  264. examples/finetuning_old/synth_qwen_v1/prepare_data.py +0 -35
  265. examples/finetuning_old/synth_qwen_v1/purge_v3_traces.py +0 -109
  266. examples/finetuning_old/synth_qwen_v1/react_agent_lm.py +0 -1933
  267. examples/finetuning_old/synth_qwen_v1/run_crafter_sft_job.py +0 -210
  268. examples/finetuning_old/synth_qwen_v1/run_ft_job.py +0 -237
  269. examples/finetuning_old/synth_qwen_v1/upload_data.py +0 -34
  270. examples/finetuning_old/synth_qwen_v1/util.py +0 -152
  271. examples/rl_old/task_app.py +0 -1131
  272. examples/warming_up_to_rl/old/event_rewards.md +0 -234
  273. examples/warming_up_to_rl/old/notes.md +0 -73
  274. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_traces_sft_turso.py +0 -738
  275. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/filter_traces_sft_turso.py +0 -580
  276. synth_ai/experimental/synth_oss.py +0 -445
  277. synth_ai/learning/filtering.py +0 -0
  278. synth_ai/learning/offline/dpo.py +0 -0
  279. synth_ai/learning/offline/providers.py +0 -7
  280. synth_ai/learning/offline/sft.py +0 -0
  281. synth_ai/learning/offline/shared.py +0 -0
  282. synth_ai/learning/online/grpo.py +0 -0
  283. synth_ai/learning/online/irft.py +0 -0
  284. synth_ai/learning/prompts/banking77_injection_eval.py +0 -168
  285. synth_ai/learning/prompts/gepa.py +0 -0
  286. synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +0 -211
  287. synth_ai/learning/prompts/mipro.py +0 -289
  288. synth_ai/learning/prompts/random_search.py +0 -249
  289. synth_ai/learning/prompts/run_mipro_banking77.py +0 -172
  290. synth_ai/learning/prompts/run_random_search_banking77.py +0 -329
  291. synth_ai/rl/secrets.py +0 -19
  292. synth_ai/scripts/verify_rewards.py +0 -100
  293. synth_ai/tracing/__init__.py +0 -30
  294. synth_ai/tracing_v1/__init__.py +0 -33
  295. synth_ai/tracing_v3/turso/__init__.py +0 -25
  296. synth_ai/tracing_v3/turso/manager.py +0 -838
  297. synth_ai/zyk/__init__.py +0 -30
  298. synth_ai-0.2.9.dev7.dist-info/METADATA +0 -131
  299. /synth_ai/{lm → v0/lm}/caching/__init__.py +0 -0
  300. /synth_ai/{lm → v0/lm}/caching/constants.py +0 -0
  301. /synth_ai/{lm → v0/lm}/caching/dbs.py +0 -0
  302. /synth_ai/{lm → v0/lm}/constants.py +0 -0
  303. /synth_ai/{lm → v0/lm}/core/__init__.py +0 -0
  304. /synth_ai/{lm → v0/lm}/core/exceptions.py +0 -0
  305. /synth_ai/{lm → v0/lm}/cost/__init__.py +0 -0
  306. /synth_ai/{lm → v0/lm}/cost/monitor.py +0 -0
  307. /synth_ai/{lm → v0/lm}/cost/statefulness.py +0 -0
  308. /synth_ai/{lm → v0/lm}/injection.py +0 -0
  309. /synth_ai/{lm → v0/lm}/provider_support/__init__.py +0 -0
  310. /synth_ai/{lm → v0/lm}/provider_support/suppress_logging.py +0 -0
  311. /synth_ai/{lm → v0/lm}/structured_outputs/__init__.py +0 -0
  312. /synth_ai/{lm → v0/lm}/structured_outputs/inject.py +0 -0
  313. /synth_ai/{lm → v0/lm}/tools/__init__.py +0 -0
  314. /synth_ai/{lm → v0/lm}/tools/base.py +0 -0
  315. /synth_ai/{lm → v0/lm}/unified_interface.py +0 -0
  316. /synth_ai/{lm → v0/lm}/vendors/__init__.py +0 -0
  317. /synth_ai/{lm → v0/lm}/vendors/base.py +0 -0
  318. /synth_ai/{lm → v0/lm}/vendors/core/__init__.py +0 -0
  319. /synth_ai/{lm → v0/lm}/vendors/core/synth_dev_api.py +0 -0
  320. /synth_ai/{lm → v0/lm}/vendors/local/__init__.py +0 -0
  321. /synth_ai/{lm → v0/lm}/vendors/local/ollama.py +0 -0
  322. /synth_ai/{lm → v0/lm}/vendors/retries.py +0 -0
  323. /synth_ai/{lm → v0/lm}/vendors/supported/__init__.py +0 -0
  324. /synth_ai/{lm → v0/lm}/warmup.py +0 -0
  325. {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.9.dev8.dist-info}/WHEEL +0 -0
  326. {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.9.dev8.dist-info}/entry_points.txt +0 -0
  327. {synth_ai-0.2.9.dev7.dist-info → synth_ai-0.2.9.dev8.dist-info}/licenses/LICENSE +0 -0
@@ -1,838 +0,0 @@
1
- from __future__ import annotations
2
-
3
- """Async SQLAlchemy-based trace manager for Turso/sqld.
4
-
5
- This module provides the database interface for the tracing system using
6
- async SQLAlchemy with a Turso/sqld backend. It handles all database operations
7
- including schema creation, session storage, and analytics queries.
8
-
9
- Key Features:
10
- ------------
11
- - Async-first design using aiosqlite for local SQLite
12
- - Automatic schema creation and migration
13
- - Batch insert capabilities for high-throughput scenarios
14
- - Analytics views for efficient querying
15
- - Connection pooling and retry logic
16
-
17
- Performance Considerations:
18
- --------------------------
19
- - Uses NullPool for SQLite to avoid connection issues
20
- - Implements busy timeout for concurrent access
21
- - Batches inserts to reduce transaction overhead
22
- - Creates indexes for common query patterns
23
- """
24
-
25
- import asyncio
26
- import logging
27
- from contextlib import asynccontextmanager
28
- from datetime import datetime
29
- from typing import Any
30
-
31
- # Optional pandas import: fall back to records (list[dict]) if unavailable
32
- try: # pragma: no cover - exercised in environments without pandas
33
- import pandas as pd # type: ignore
34
- except Exception: # pragma: no cover
35
- pd = None # type: ignore[assignment]
36
- from sqlalchemy import select, text, update
37
- from sqlalchemy.exc import IntegrityError
38
- from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
39
- from sqlalchemy import event
40
- from sqlalchemy.orm import selectinload, sessionmaker
41
- from sqlalchemy.pool import NullPool
42
-
43
- from ..abstractions import (
44
- EnvironmentEvent,
45
- LMCAISEvent,
46
- RuntimeEvent,
47
- SessionTrace,
48
- )
49
- from ..config import CONFIG
50
- from .models import (
51
- Base,
52
- analytics_views,
53
- )
54
- from .models import (
55
- Event as DBEvent,
56
- )
57
- from .models import (
58
- Experiment as DBExperiment,
59
- )
60
- from .models import (
61
- Message as DBMessage,
62
- )
63
- from .models import (
64
- SessionTimestep as DBSessionTimestep,
65
- )
66
- from .models import (
67
- SessionTrace as DBSessionTrace,
68
- )
69
- from .models import (
70
- OutcomeReward as DBOutcomeReward,
71
- )
72
- from .models import (
73
- EventReward as DBEventReward,
74
- )
75
-
76
- logger = logging.getLogger(__name__)
77
-
78
-
79
- class AsyncSQLTraceManager:
80
- """Async trace storage manager using SQLAlchemy and Turso/sqld.
81
-
82
- Handles all database operations for the tracing system. Designed to work
83
- with both local SQLite (via aiosqlite) and remote Turso databases.
84
-
85
- The manager handles:
86
- - Connection lifecycle management
87
- - Schema creation and verification
88
- - Transaction management
89
- - Batch operations for efficiency
90
- - Analytics view creation
91
- """
92
-
93
- def __init__(self, db_url: str | None = None):
94
- self.db_url = db_url or CONFIG.db_url
95
- self.engine: AsyncEngine | None = None
96
- self.SessionLocal: sessionmaker | None = None
97
- self._schema_lock = asyncio.Lock()
98
- self._schema_ready = False
99
-
100
- async def initialize(self):
101
- """Initialize the database connection and schema.
102
-
103
- This method is idempotent and thread-safe. It:
104
- 1. Creates the async engine with appropriate settings
105
- 2. Verifies database file exists (for SQLite)
106
- 3. Creates schema if needed
107
- 4. Sets up analytics views
108
-
109
- The schema lock ensures only one worker creates the schema in
110
- concurrent scenarios.
111
- """
112
- if self.engine is None:
113
- logger.debug(f"🔗 Initializing database connection to: {self.db_url}")
114
-
115
- # For SQLite, use NullPool to avoid connection pool issues
116
- # SQLite doesn't handle concurrent connections well, so we create
117
- # a new connection for each operation
118
- if self.db_url.startswith("sqlite"):
119
- # Extract the file path from the URL
120
- db_path = self.db_url.replace("sqlite+aiosqlite:///", "")
121
- import os
122
-
123
- # Check if database file exists
124
- if not os.path.exists(db_path):
125
- logger.debug(f"⚠️ Database file not found: {db_path}")
126
- logger.debug(
127
- "🔧 Make sure './serve.sh' is running to start the turso/sqld service"
128
- )
129
- else:
130
- logger.debug(f"✅ Found database file: {db_path}")
131
-
132
- # Set a high busy timeout to handle concurrent access
133
- # This allows SQLite to wait instead of immediately failing
134
- connect_args = {"timeout": 30.0} # 30 second busy timeout
135
- self.engine = create_async_engine(
136
- self.db_url, # Use instance db_url, not CONFIG
137
- poolclass=NullPool, # No connection pooling for SQLite
138
- connect_args=connect_args,
139
- echo=CONFIG.echo_sql,
140
- )
141
- # Ensure PRAGMA foreign_keys=ON for every connection
142
- try:
143
-
144
- @event.listens_for(self.engine.sync_engine, "connect")
145
- def _set_sqlite_pragma(dbapi_connection, connection_record): # type: ignore[no-redef]
146
- try:
147
- cursor = dbapi_connection.cursor()
148
- cursor.execute("PRAGMA foreign_keys=ON")
149
- cursor.close()
150
- except Exception:
151
- pass
152
- except Exception:
153
- pass
154
- else:
155
- connect_args = CONFIG.get_connect_args()
156
- engine_kwargs = CONFIG.get_engine_kwargs()
157
- self.engine = create_async_engine(
158
- self.db_url, # Use instance db_url, not CONFIG
159
- connect_args=connect_args,
160
- **engine_kwargs,
161
- )
162
-
163
- self.SessionLocal = sessionmaker(
164
- self.engine, class_=AsyncSession, expire_on_commit=False
165
- )
166
-
167
- await self._ensure_schema()
168
-
169
- async def _ensure_schema(self):
170
- """Ensure database schema is created.
171
-
172
- Uses a lock to prevent race conditions when multiple workers start
173
- simultaneously. The checkfirst=True parameter handles cases where
174
- another worker already created the schema.
175
- """
176
- async with self._schema_lock:
177
- if self._schema_ready:
178
- return
179
-
180
- logger.debug("📊 Initializing database schema...")
181
-
182
- async with self.engine.begin() as conn:
183
- # Use a transaction to ensure atomic schema creation
184
- # checkfirst=True prevents errors if tables already exist
185
- try:
186
- await conn.run_sync(
187
- lambda sync_conn: Base.metadata.create_all(sync_conn, checkfirst=True)
188
- )
189
- # logger.info("✅ Database schema created/verified successfully")
190
- except Exception as e:
191
- # If tables already exist, that's fine - another worker created them
192
- if "already exists" not in str(e):
193
- logger.error(f"❌ Failed to create database schema: {e}")
194
- raise
195
- else:
196
- logger.debug("✅ Database schema already exists")
197
-
198
- # Enable foreign keys for SQLite - critical for data integrity
199
- # This must be done for each connection in SQLite
200
- if CONFIG.foreign_keys:
201
- await conn.execute(text("PRAGMA foreign_keys = ON"))
202
-
203
- # Set journal mode
204
- if CONFIG.journal_mode:
205
- await conn.execute(text(f"PRAGMA journal_mode = {CONFIG.journal_mode}"))
206
-
207
- # Create analytics views for efficient querying
208
- # These are materialized as views to avoid recalculation
209
- for view_name, view_sql in analytics_views.items():
210
- try:
211
- await conn.execute(text(view_sql))
212
- except Exception as e:
213
- # Views might already exist from another worker
214
- if "already exists" not in str(e):
215
- logger.warning(f"Could not create view {view_name}: {e}")
216
-
217
- self._schema_ready = True
218
- # logger.debug("🎯 Database ready for use!")
219
-
220
- @asynccontextmanager
221
- async def session(self):
222
- """Get an async database session."""
223
- if not self.SessionLocal:
224
- await self.initialize()
225
- async with self.SessionLocal() as session:
226
- yield session
227
-
228
- async def insert_session_trace(self, trace: SessionTrace) -> str:
229
- """Insert a complete session trace.
230
-
231
- This method handles the complex task of inserting a complete session
232
- with all its timesteps, events, and messages. It uses a single
233
- transaction for atomicity and flushes after timesteps to get their
234
- auto-generated IDs for foreign keys.
235
-
236
- Args:
237
- trace: The complete session trace to store
238
-
239
- Returns:
240
- The session ID
241
-
242
- Raises:
243
- IntegrityError: If session ID already exists (handled gracefully)
244
- """
245
- async with self.session() as sess:
246
- try:
247
- # Convert to cents for cost storage - avoids floating point
248
- # precision issues and allows for integer arithmetic
249
- def to_cents(cost: float | None) -> int | None:
250
- return int(cost * 100) if cost is not None else None
251
-
252
- # Insert session
253
- db_session = DBSessionTrace(
254
- session_id=trace.session_id,
255
- created_at=trace.created_at,
256
- num_timesteps=len(trace.session_time_steps),
257
- num_events=len(trace.event_history),
258
- num_messages=len(trace.markov_blanket_message_history),
259
- session_metadata=trace.metadata or {},
260
- )
261
- sess.add(db_session)
262
-
263
- # Track timestep IDs for foreign keys - we need these to link
264
- # events and messages to their respective timesteps
265
- step_id_map: dict[str, int] = {}
266
-
267
- # Insert timesteps
268
- for step in trace.session_time_steps:
269
- db_step = DBSessionTimestep(
270
- session_id=trace.session_id,
271
- step_id=step.step_id,
272
- step_index=step.step_index,
273
- turn_number=step.turn_number,
274
- started_at=step.timestamp,
275
- completed_at=step.completed_at,
276
- num_events=len(step.events),
277
- num_messages=len(step.markov_blanket_messages),
278
- step_metadata=step.step_metadata or {},
279
- )
280
- sess.add(db_step)
281
- # Flush to get the auto-generated ID without committing
282
- # This allows us to use the ID for foreign keys while
283
- # maintaining transaction atomicity
284
- await sess.flush() # Get the auto-generated ID
285
- step_id_map[step.step_id] = db_step.id
286
-
287
- # Insert events - handle different event types with their
288
- # specific fields while maintaining a unified storage model
289
- for event in trace.event_history:
290
- event_data = {
291
- "session_id": trace.session_id,
292
- "timestep_id": step_id_map.get(event.metadata.get("step_id")),
293
- "system_instance_id": event.system_instance_id,
294
- "event_time": event.time_record.event_time,
295
- "message_time": event.time_record.message_time,
296
- "event_metadata_json": event.metadata or {},
297
- "event_extra_metadata": event.event_metadata,
298
- }
299
-
300
- if isinstance(event, LMCAISEvent):
301
- # Serialize call_records if present
302
- call_records_data = None
303
- if event.call_records:
304
- from dataclasses import asdict
305
-
306
- call_records_data = [asdict(record) for record in event.call_records]
307
-
308
- event_data.update(
309
- {
310
- "event_type": "cais",
311
- "model_name": event.model_name,
312
- "provider": event.provider,
313
- "input_tokens": event.input_tokens,
314
- "output_tokens": event.output_tokens,
315
- "total_tokens": event.total_tokens,
316
- "cost_usd": to_cents(event.cost_usd),
317
- "latency_ms": event.latency_ms,
318
- "span_id": event.span_id,
319
- "trace_id": event.trace_id,
320
- "system_state_before": event.system_state_before,
321
- "system_state_after": event.system_state_after,
322
- "call_records": call_records_data, # Store in the proper column
323
- }
324
- )
325
- elif isinstance(event, EnvironmentEvent):
326
- event_data.update(
327
- {
328
- "event_type": "environment",
329
- "reward": event.reward,
330
- "terminated": event.terminated,
331
- "truncated": event.truncated,
332
- "system_state_before": event.system_state_before,
333
- "system_state_after": event.system_state_after,
334
- }
335
- )
336
- elif isinstance(event, RuntimeEvent):
337
- event_data.update(
338
- {
339
- "event_type": "runtime",
340
- "event_metadata_json": {**event.metadata, "actions": event.actions},
341
- }
342
- )
343
- else:
344
- event_data["event_type"] = event.__class__.__name__.lower()
345
-
346
- db_event = DBEvent(**event_data)
347
- sess.add(db_event)
348
-
349
- # Insert messages
350
- for msg in trace.markov_blanket_message_history:
351
- db_msg = DBMessage(
352
- session_id=trace.session_id,
353
- timestep_id=step_id_map.get(msg.metadata.get("step_id"))
354
- if hasattr(msg, "metadata")
355
- else None,
356
- message_type=msg.message_type,
357
- content=msg.content,
358
- event_time=msg.time_record.event_time,
359
- message_time=msg.time_record.message_time,
360
- message_metadata=msg.metadata if hasattr(msg, "metadata") else {},
361
- )
362
- sess.add(db_msg)
363
-
364
- # Commit the entire transaction atomically
365
- await sess.commit()
366
- return trace.session_id
367
- except IntegrityError as e:
368
- # Handle duplicate session IDs gracefully - this can happen
369
- # in distributed systems or retries. We return the existing
370
- # ID to maintain idempotency
371
- if "UNIQUE constraint failed: session_traces.session_id" in str(e):
372
- await sess.rollback()
373
- return trace.session_id # Return existing ID
374
- raise
375
-
376
- async def get_session_trace(self, session_id: str) -> dict[str, Any] | None:
377
- """Retrieve a session trace by ID."""
378
- async with self.session() as sess:
379
- result = await sess.execute(
380
- select(DBSessionTrace)
381
- .options(
382
- selectinload(DBSessionTrace.timesteps),
383
- selectinload(DBSessionTrace.events),
384
- selectinload(DBSessionTrace.messages),
385
- )
386
- .where(DBSessionTrace.session_id == session_id)
387
- )
388
- session = result.scalar_one_or_none()
389
-
390
- if not session:
391
- return None
392
-
393
- return {
394
- "session_id": session.session_id,
395
- "created_at": session.created_at,
396
- "num_timesteps": session.num_timesteps,
397
- "num_events": session.num_events,
398
- "num_messages": session.num_messages,
399
- "metadata": session.session_metadata,
400
- "timesteps": [
401
- {
402
- "step_id": step.step_id,
403
- "step_index": step.step_index,
404
- "turn_number": step.turn_number,
405
- "started_at": step.started_at,
406
- "completed_at": step.completed_at,
407
- "metadata": step.step_metadata,
408
- }
409
- for step in sorted(session.timesteps, key=lambda s: s.step_index)
410
- ],
411
- }
412
-
413
- async def query_traces(self, query: str, params: dict[str, Any] | None = None) -> Any:
414
- """Execute a query and return results.
415
-
416
- Returns a pandas DataFrame when pandas is available; otherwise a
417
- list of dict records. Callers should handle both.
418
- """
419
- async with self.session() as sess:
420
- result = await sess.execute(text(query), params or {})
421
- rows = result.mappings().all()
422
- if pd is not None:
423
- return pd.DataFrame(rows)
424
- return [dict(r) for r in rows]
425
-
426
- async def get_model_usage(
427
- self,
428
- start_date: datetime | None = None,
429
- end_date: datetime | None = None,
430
- model_name: str | None = None,
431
- ) -> Any:
432
- """Get model usage statistics.
433
-
434
- Returns a pandas DataFrame when pandas is available; otherwise a list
435
- of dict records.
436
- """
437
- query = """
438
- SELECT * FROM model_usage_stats
439
- WHERE 1=1
440
- """
441
- params = {}
442
-
443
- if start_date:
444
- query += " AND last_used >= :start_date"
445
- params["start_date"] = start_date
446
-
447
- if end_date:
448
- query += " AND first_used <= :end_date"
449
- params["end_date"] = end_date
450
-
451
- if model_name:
452
- query += " AND model_name = :model_name"
453
- params["model_name"] = model_name
454
-
455
- query += " ORDER BY usage_count DESC"
456
-
457
- return await self.query_traces(query, params)
458
-
459
- async def create_experiment(
460
- self,
461
- experiment_id: str,
462
- name: str,
463
- description: str | None = None,
464
- configuration: dict[str, Any] | None = None,
465
- ) -> str:
466
- """Create a new experiment."""
467
- async with self.session() as sess:
468
- experiment = DBExperiment(
469
- experiment_id=experiment_id,
470
- name=name,
471
- description=description,
472
- configuration=configuration or {},
473
- )
474
- sess.add(experiment)
475
- await sess.commit()
476
- return experiment_id
477
-
478
- async def link_session_to_experiment(self, session_id: str, experiment_id: str):
479
- """Link a session to an experiment."""
480
- async with self.session() as sess:
481
- await sess.execute(
482
- update(DBSessionTrace)
483
- .where(DBSessionTrace.session_id == session_id)
484
- .values(experiment_id=experiment_id)
485
- )
486
- await sess.commit()
487
-
488
- async def batch_insert_sessions(
489
- self, traces: list[SessionTrace], batch_size: int | None = None
490
- ) -> list[str]:
491
- """Batch insert multiple session traces.
492
-
493
- Processes traces in batches to balance memory usage and performance.
494
- Each batch is inserted in a separate transaction to avoid holding
495
- locks for too long.
496
-
497
- Args:
498
- traces: List of session traces to insert
499
- batch_size: Number of traces per batch (defaults to config)
500
-
501
- Returns:
502
- List of inserted session IDs
503
- """
504
- batch_size = batch_size or CONFIG.batch_size
505
- inserted_ids = []
506
-
507
- # Process in chunks to avoid memory issues with large datasets
508
- for i in range(0, len(traces), batch_size):
509
- batch = traces[i : i + batch_size]
510
- # Insert each trace in the batch - could be optimized further
511
- # with bulk inserts if needed
512
- for trace in batch:
513
- session_id = await self.insert_session_trace(trace)
514
- inserted_ids.append(session_id)
515
-
516
- return inserted_ids
517
-
518
- async def get_sessions_by_experiment(
519
- self, experiment_id: str, limit: int | None = None
520
- ) -> list[dict[str, Any]]:
521
- """Get all sessions for an experiment."""
522
- async with self.session() as sess:
523
- query = (
524
- select(DBSessionTrace)
525
- .where(DBSessionTrace.experiment_id == experiment_id)
526
- .order_by(DBSessionTrace.created_at.desc())
527
- )
528
-
529
- if limit:
530
- query = query.limit(limit)
531
-
532
- result = await sess.execute(query)
533
- sessions = result.scalars().all()
534
-
535
- return [
536
- {
537
- "session_id": s.session_id,
538
- "created_at": s.created_at,
539
- "num_timesteps": s.num_timesteps,
540
- "num_events": s.num_events,
541
- "num_messages": s.num_messages,
542
- "metadata": s.metadata,
543
- }
544
- for s in sessions
545
- ]
546
-
547
- async def delete_session(self, session_id: str) -> bool:
548
- """Delete a session and all related data."""
549
- async with self.session() as sess:
550
- # Get the session object to trigger cascade deletes
551
- result = await sess.execute(
552
- select(DBSessionTrace).where(DBSessionTrace.session_id == session_id)
553
- )
554
- session = result.scalar_one_or_none()
555
-
556
- if session:
557
- await sess.delete(session)
558
- await sess.commit()
559
- return True
560
- return False
561
-
562
- async def close(self):
563
- """Close the database connection.
564
-
565
- Properly disposes of the engine and all connections. This is important
566
- for cleanup, especially with SQLite which can leave lock files.
567
- """
568
- if self.engine:
569
- # Dispose of all connections in the pool
570
- await self.engine.dispose()
571
- # Clear all state to allow re-initialization if needed
572
- self.engine = None
573
- self.SessionLocal = None
574
- self._schema_ready = False
575
-
576
- # -------------------------------
577
- # Incremental insert helpers
578
- # -------------------------------
579
-
580
- async def ensure_session(
581
- self,
582
- session_id: str,
583
- *,
584
- created_at: datetime | None = None,
585
- metadata: dict[str, Any] | None = None,
586
- ):
587
- """Ensure a DB session row exists for session_id."""
588
- async with self.session() as sess:
589
- result = await sess.execute(
590
- select(DBSessionTrace).where(DBSessionTrace.session_id == session_id)
591
- )
592
- existing = result.scalar_one_or_none()
593
- if existing:
594
- return
595
- row = DBSessionTrace(
596
- session_id=session_id,
597
- created_at=created_at or datetime.utcnow(),
598
- num_timesteps=0,
599
- num_events=0,
600
- num_messages=0,
601
- session_metadata=metadata or {},
602
- )
603
- sess.add(row)
604
- await sess.commit()
605
-
606
- async def ensure_timestep(
607
- self,
608
- session_id: str,
609
- *,
610
- step_id: str,
611
- step_index: int,
612
- turn_number: int | None = None,
613
- started_at: datetime | None = None,
614
- completed_at: datetime | None = None,
615
- metadata: dict[str, Any] | None = None,
616
- ) -> int:
617
- """Ensure a timestep row exists; return its DB id."""
618
- async with self.session() as sess:
619
- result = await sess.execute(
620
- select(DBSessionTimestep).where(
621
- DBSessionTimestep.session_id == session_id, DBSessionTimestep.step_id == step_id
622
- )
623
- )
624
- row = result.scalar_one_or_none()
625
- if row:
626
- return row.id
627
- row = DBSessionTimestep(
628
- session_id=session_id,
629
- step_id=step_id,
630
- step_index=step_index,
631
- turn_number=turn_number,
632
- started_at=started_at or datetime.utcnow(),
633
- completed_at=completed_at,
634
- num_events=0,
635
- num_messages=0,
636
- step_metadata=metadata or {},
637
- )
638
- sess.add(row)
639
- await sess.flush()
640
- # increment session num_timesteps
641
- await sess.execute(
642
- update(DBSessionTrace)
643
- .where(DBSessionTrace.session_id == session_id)
644
- .values(num_timesteps=DBSessionTrace.num_timesteps + 1)
645
- )
646
- await sess.commit()
647
- return row.id
648
-
649
- async def insert_message_row(
650
- self,
651
- session_id: str,
652
- *,
653
- timestep_db_id: int | None,
654
- message_type: str,
655
- content: str,
656
- event_time: float | None = None,
657
- message_time: int | None = None,
658
- metadata: dict[str, Any] | None = None,
659
- ) -> int:
660
- """Insert a message and return its id."""
661
- async with self.session() as sess:
662
- db_msg = DBMessage(
663
- session_id=session_id,
664
- timestep_id=timestep_db_id,
665
- message_type=message_type,
666
- content=content,
667
- event_time=event_time,
668
- message_time=message_time,
669
- message_metadata=metadata or {},
670
- )
671
- sess.add(db_msg)
672
- await sess.flush()
673
- # increment session num_messages
674
- await sess.execute(
675
- update(DBSessionTrace)
676
- .where(DBSessionTrace.session_id == session_id)
677
- .values(num_messages=DBSessionTrace.num_messages + 1)
678
- )
679
- await sess.commit()
680
- return db_msg.id
681
-
682
- async def insert_event_row(
683
- self,
684
- session_id: str,
685
- *,
686
- timestep_db_id: int | None,
687
- event: EnvironmentEvent | LMCAISEvent | RuntimeEvent,
688
- metadata_override: dict[str, Any] | None = None,
689
- ) -> int:
690
- """Insert an event and return its id."""
691
-
692
- def to_cents(cost: float | None) -> int | None:
693
- return int(cost * 100) if cost is not None else None
694
-
695
- event_data: dict[str, Any] = {
696
- "session_id": session_id,
697
- "timestep_id": timestep_db_id,
698
- "system_instance_id": event.system_instance_id,
699
- "event_time": event.time_record.event_time,
700
- "message_time": event.time_record.message_time,
701
- "event_metadata_json": metadata_override or event.metadata or {},
702
- "event_extra_metadata": getattr(event, "event_metadata", None),
703
- }
704
- if isinstance(event, LMCAISEvent):
705
- call_records_data = None
706
- if getattr(event, "call_records", None):
707
- from dataclasses import asdict
708
-
709
- call_records_data = [asdict(record) for record in event.call_records]
710
- event_data.update(
711
- {
712
- "event_type": "cais",
713
- "model_name": event.model_name,
714
- "provider": event.provider,
715
- "input_tokens": event.input_tokens,
716
- "output_tokens": event.output_tokens,
717
- "total_tokens": event.total_tokens,
718
- "cost_usd": to_cents(event.cost_usd),
719
- "latency_ms": event.latency_ms,
720
- "span_id": event.span_id,
721
- "trace_id": event.trace_id,
722
- "system_state_before": event.system_state_before,
723
- "system_state_after": event.system_state_after,
724
- "call_records": call_records_data,
725
- }
726
- )
727
- elif isinstance(event, EnvironmentEvent):
728
- event_data.update(
729
- {
730
- "event_type": "environment",
731
- "reward": event.reward,
732
- "terminated": event.terminated,
733
- "truncated": event.truncated,
734
- "system_state_before": event.system_state_before,
735
- "system_state_after": event.system_state_after,
736
- }
737
- )
738
- elif isinstance(event, RuntimeEvent):
739
- event_data.update(
740
- {
741
- "event_type": "runtime",
742
- "event_metadata_json": {**(event.metadata or {}), "actions": event.actions},
743
- }
744
- )
745
- else:
746
- event_data["event_type"] = event.__class__.__name__.lower()
747
-
748
- async with self.session() as sess:
749
- db_event = DBEvent(**event_data)
750
- sess.add(db_event)
751
- await sess.flush()
752
- # increment session num_events
753
- await sess.execute(
754
- update(DBSessionTrace)
755
- .where(DBSessionTrace.session_id == session_id)
756
- .values(num_events=DBSessionTrace.num_events + 1)
757
- )
758
- await sess.commit()
759
- return db_event.id
760
-
761
- # -------------------------------
762
- # Reward helpers
763
- # -------------------------------
764
-
765
- async def insert_outcome_reward(
766
- self,
767
- session_id: str,
768
- *,
769
- total_reward: int,
770
- achievements_count: int,
771
- total_steps: int,
772
- reward_metadata: dict | None = None,
773
- ) -> int:
774
- async with self.session() as sess:
775
- row = DBOutcomeReward(
776
- session_id=session_id,
777
- total_reward=total_reward,
778
- achievements_count=achievements_count,
779
- total_steps=total_steps,
780
- reward_metadata=reward_metadata or {},
781
- )
782
- sess.add(row)
783
- await sess.flush()
784
- await sess.commit()
785
- return row.id
786
-
787
- async def insert_event_reward(
788
- self,
789
- session_id: str,
790
- *,
791
- event_id: int,
792
- message_id: int | None = None,
793
- turn_number: int | None = None,
794
- reward_value: float = 0.0,
795
- reward_type: str | None = None,
796
- key: str | None = None,
797
- annotation: dict[str, Any] | None = None,
798
- source: str | None = None,
799
- ) -> int:
800
- async with self.session() as sess:
801
- row = DBEventReward(
802
- event_id=event_id,
803
- session_id=session_id,
804
- message_id=message_id,
805
- turn_number=turn_number,
806
- reward_value=reward_value,
807
- reward_type=reward_type,
808
- key=key,
809
- annotation=annotation or {},
810
- source=source,
811
- )
812
- sess.add(row)
813
- await sess.flush()
814
- await sess.commit()
815
- return row.id
816
-
817
- async def get_outcome_rewards(self) -> list[dict[str, Any]]:
818
- async with self.session() as sess:
819
- result = await sess.execute(select(DBOutcomeReward))
820
- rows = result.scalars().all()
821
- return [
822
- {
823
- "id": r.id,
824
- "session_id": r.session_id,
825
- "total_reward": r.total_reward,
826
- "achievements_count": r.achievements_count,
827
- "total_steps": r.total_steps,
828
- "created_at": r.created_at,
829
- }
830
- for r in rows
831
- ]
832
-
833
- async def get_outcome_rewards_by_min_reward(self, min_reward: int) -> list[str]:
834
- async with self.session() as sess:
835
- result = await sess.execute(
836
- select(DBOutcomeReward.session_id).where(DBOutcomeReward.total_reward >= min_reward)
837
- )
838
- return [row[0] for row in result.all()]