synth-ai 0.2.9.dev5__py3-none-any.whl → 0.2.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/__init__.py +16 -0
- examples/crafter_debug_render.py +23 -17
- examples/dev/qwen3_32b_qlora_4xh100.toml +40 -0
- examples/multi_step/crafter_rl_lora.md +29 -0
- examples/qwen_coder/README.md +102 -0
- examples/qwen_coder/_shared.py +113 -0
- examples/qwen_coder/configs/coder_lora_30b.toml +61 -0
- examples/qwen_coder/configs/coder_lora_4b.toml +57 -0
- examples/qwen_coder/configs/coder_lora_small.toml +58 -0
- examples/qwen_coder/generate_dataset.py +98 -0
- examples/qwen_coder/infer_ft_smoke.py +65 -0
- examples/qwen_coder/infer_prod_proxy.py +73 -0
- examples/qwen_coder/infer_via_synth.py +87 -0
- examples/qwen_coder/scripts/infer_coder.sh +19 -0
- examples/qwen_coder/scripts/train_coder_30b.sh +22 -0
- examples/qwen_coder/sft_full_17b.py +103 -0
- examples/qwen_coder/sft_lora_30b.py +110 -0
- examples/qwen_coder/subset_jsonl.py +39 -0
- examples/qwen_coder/todos.md +38 -0
- examples/qwen_coder/validate_jsonl.py +60 -0
- examples/rl/configs/eval_base_qwen.toml +1 -1
- examples/rl/configs/rl_from_base_qwen17.toml +1 -1
- examples/rl/download_dataset.py +26 -10
- examples/rl/run_eval.py +53 -52
- examples/rl/run_rl_and_save.py +29 -12
- examples/rl/task_app/math_single_step.py +180 -41
- examples/rl/task_app/math_task_app.py +14 -6
- examples/sft/README.md +139 -0
- examples/sft/configs/crafter_fft_qwen0p6b.toml +44 -0
- examples/sft/configs/crafter_lora_qwen0p6b.toml +45 -0
- examples/sft/evaluate.py +117 -0
- examples/sft/export_dataset.py +117 -0
- examples/sft/generate_traces.py +162 -0
- examples/swe/__init__.py +12 -0
- examples/swe/task_app/README.md +105 -0
- examples/swe/task_app/__init__.py +2 -0
- examples/swe/task_app/grpo_swe_mini.py +571 -0
- examples/swe/task_app/grpo_swe_mini_task_app.py +136 -0
- examples/swe/task_app/hosted/README.md +173 -0
- examples/swe/task_app/hosted/__init__.py +5 -0
- examples/swe/task_app/hosted/branching.py +143 -0
- examples/swe/task_app/hosted/environment_routes.py +1289 -0
- examples/swe/task_app/hosted/envs/__init__.py +1 -0
- examples/swe/task_app/hosted/envs/crafter/__init__.py +6 -0
- examples/swe/task_app/hosted/envs/crafter/app.py +1 -0
- examples/swe/task_app/hosted/envs/crafter/environment.py +522 -0
- examples/swe/task_app/hosted/envs/crafter/policy.py +478 -0
- examples/swe/task_app/hosted/envs/crafter/react_agent.py +108 -0
- examples/swe/task_app/hosted/envs/crafter/shared.py +305 -0
- examples/swe/task_app/hosted/envs/crafter/tools.py +47 -0
- examples/swe/task_app/hosted/envs/mini_swe/__init__.py +8 -0
- examples/swe/task_app/hosted/envs/mini_swe/environment.py +1164 -0
- examples/swe/task_app/hosted/envs/mini_swe/policy.py +355 -0
- examples/swe/task_app/hosted/envs/mini_swe/shared.py +83 -0
- examples/swe/task_app/hosted/envs/mini_swe/tools.py +96 -0
- examples/swe/task_app/hosted/hosted_app.py +204 -0
- examples/swe/task_app/hosted/inference/__init__.py +5 -0
- examples/swe/task_app/hosted/inference/openai_client.py +618 -0
- examples/swe/task_app/hosted/main.py +100 -0
- examples/swe/task_app/hosted/policy_routes.py +1079 -0
- examples/swe/task_app/hosted/registry.py +195 -0
- examples/swe/task_app/hosted/rollout.py +1869 -0
- examples/swe/task_app/hosted/storage/__init__.py +5 -0
- examples/swe/task_app/hosted/storage/volume.py +211 -0
- examples/swe/task_app/hosted/test_agents.py +161 -0
- examples/swe/task_app/hosted/test_service.py +137 -0
- examples/swe/task_app/hosted/utils.py +62 -0
- examples/vlm/PROPOSAL.md +53 -0
- examples/vlm/README.md +68 -0
- examples/vlm/configs/crafter_vlm_gpt4o.toml +44 -0
- examples/vlm/crafter_image_only_agent.py +207 -0
- examples/vlm/crafter_openai_vlm_agent.py +277 -0
- examples/vlm/filter_image_rows.py +63 -0
- examples/vlm/run_crafter_vlm_benchmark.py +316 -0
- examples/warming_up_to_rl/analyze_trace_db.py +12 -10
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +11 -1
- examples/warming_up_to_rl/export_trace_sft.py +218 -36
- examples/warming_up_to_rl/groq_test.py +15 -8
- examples/warming_up_to_rl/manage_secrets.py +29 -25
- examples/warming_up_to_rl/readme.md +9 -2
- examples/warming_up_to_rl/run_eval.py +137 -61
- examples/warming_up_to_rl/run_fft_and_save.py +131 -60
- examples/warming_up_to_rl/run_local_rollout.py +88 -39
- examples/warming_up_to_rl/run_local_rollout_modal.py +114 -28
- examples/warming_up_to_rl/run_local_rollout_parallel.py +81 -20
- examples/warming_up_to_rl/run_local_rollout_traced.py +126 -23
- examples/warming_up_to_rl/run_rl_and_save.py +35 -12
- examples/warming_up_to_rl/run_rollout_remote.py +44 -19
- examples/warming_up_to_rl/task_app/README.md +6 -2
- examples/warming_up_to_rl/task_app/grpo_crafter.py +319 -57
- examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +11 -30
- examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +9 -11
- examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +137 -182
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/app.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +150 -57
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +105 -69
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +19 -7
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +45 -42
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/tools.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +47 -45
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +198 -92
- examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +0 -2
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +361 -263
- examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +21 -23
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +394 -274
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +56 -62
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +6 -15
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +4 -3
- synth_ai/__init__.py +1 -0
- synth_ai/api/models/supported.py +376 -0
- synth_ai/api/train/builders.py +157 -26
- synth_ai/api/train/cli.py +213 -57
- synth_ai/api/train/config_finder.py +65 -5
- synth_ai/api/train/env_resolver.py +33 -15
- synth_ai/api/train/pollers.py +13 -4
- synth_ai/api/train/supported_algos.py +139 -0
- synth_ai/api/train/task_app.py +5 -3
- synth_ai/api/train/utils.py +33 -48
- synth_ai/cli/__init__.py +19 -4
- synth_ai/cli/_modal_wrapper.py +28 -0
- synth_ai/cli/_typer_patch.py +49 -0
- synth_ai/cli/balance.py +2 -3
- synth_ai/cli/calc.py +1 -1
- synth_ai/cli/demo.py +21 -6
- synth_ai/cli/recent.py +2 -2
- synth_ai/cli/rl_demo.py +77 -17
- synth_ai/cli/root.py +116 -39
- synth_ai/cli/status.py +2 -2
- synth_ai/cli/task_apps.py +1699 -259
- synth_ai/cli/traces.py +7 -4
- synth_ai/cli/turso.py +73 -0
- synth_ai/cli/watch.py +12 -18
- synth_ai/core/experiment.py +0 -2
- synth_ai/demo_registry.py +68 -31
- synth_ai/demos/core/cli.py +516 -194
- synth_ai/demos/demo_task_apps/__init__.py +3 -3
- synth_ai/demos/demo_task_apps/core.py +64 -28
- synth_ai/demos/demo_task_apps/crafter/configs/crafter_fft_4b.toml +2 -3
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +37 -30
- synth_ai/demos/demo_task_apps/math/_common.py +1 -2
- synth_ai/demos/demo_task_apps/math/app.py +2 -1
- synth_ai/demos/demo_task_apps/math/deploy_modal.py +3 -6
- synth_ai/demos/demo_task_apps/math/modal_task_app.py +183 -82
- synth_ai/demos/demo_task_apps/math/task_app_entry.py +0 -2
- synth_ai/environments/examples/bandit/engine.py +12 -4
- synth_ai/environments/examples/bandit/taskset.py +4 -4
- synth_ai/environments/examples/crafter_classic/environment.py +76 -1
- synth_ai/environments/reproducibility/tree.py +5 -6
- synth_ai/environments/service/app.py +11 -12
- synth_ai/environments/service/core_routes.py +10 -9
- synth_ai/environments/stateful/engine.py +1 -1
- synth_ai/environments/tasks/core.py +1 -0
- synth_ai/environments/tasks/filters.py +5 -6
- synth_ai/environments/tasks/utils.py +4 -5
- synth_ai/evals/base.py +0 -2
- synth_ai/handshake.py +11 -9
- synth_ai/http.py +1 -1
- synth_ai/http_client.py +43 -11
- synth_ai/inference/__init__.py +0 -2
- synth_ai/inference/client.py +20 -6
- synth_ai/jobs/client.py +103 -78
- synth_ai/learning/__init__.py +41 -6
- synth_ai/learning/algorithms.py +14 -0
- synth_ai/learning/client.py +121 -29
- synth_ai/learning/config.py +2 -40
- synth_ai/learning/constants.py +0 -2
- synth_ai/learning/ft_client.py +4 -56
- synth_ai/learning/health.py +13 -7
- synth_ai/learning/jobs.py +43 -47
- synth_ai/{rl → learning/rl}/__init__.py +14 -5
- synth_ai/learning/rl/client.py +267 -0
- synth_ai/learning/rl/config.py +31 -0
- synth_ai/{rl → learning/rl}/contracts.py +5 -10
- synth_ai/{rl → learning/rl}/env_keys.py +45 -16
- synth_ai/learning/rl/secrets.py +13 -0
- synth_ai/learning/rl_client.py +2 -253
- synth_ai/learning/sft/__init__.py +29 -0
- synth_ai/learning/sft/client.py +68 -0
- synth_ai/learning/sft/config.py +270 -0
- synth_ai/learning/sft/data.py +295 -0
- synth_ai/learning/sse.py +25 -26
- synth_ai/learning/validators.py +25 -24
- synth_ai/lm/__init__.py +21 -47
- synth_ai/task/__init__.py +26 -27
- synth_ai/task/apps/__init__.py +18 -19
- synth_ai/task/auth.py +35 -23
- synth_ai/task/client.py +15 -13
- synth_ai/task/contracts.py +37 -35
- synth_ai/task/datasets.py +9 -6
- synth_ai/task/errors.py +11 -10
- synth_ai/task/health.py +17 -11
- synth_ai/task/json.py +58 -24
- synth_ai/task/proxy.py +15 -14
- synth_ai/task/rubrics.py +22 -15
- synth_ai/task/server.py +43 -17
- synth_ai/task/tracing_utils.py +12 -7
- synth_ai/task/validators.py +0 -1
- synth_ai/task/vendors.py +5 -7
- synth_ai/tracing_v3/__init__.py +2 -0
- synth_ai/tracing_v3/abstractions.py +21 -4
- synth_ai/tracing_v3/db_config.py +26 -1
- synth_ai/tracing_v3/decorators.py +18 -15
- synth_ai/tracing_v3/examples/basic_usage.py +3 -2
- synth_ai/tracing_v3/hooks.py +6 -4
- synth_ai/tracing_v3/llm_call_record_helpers.py +6 -6
- synth_ai/tracing_v3/replica_sync.py +1 -0
- synth_ai/tracing_v3/session_tracer.py +63 -16
- synth_ai/tracing_v3/storage/base.py +89 -1
- synth_ai/tracing_v3/storage/config.py +21 -8
- synth_ai/tracing_v3/storage/factory.py +10 -8
- synth_ai/tracing_v3/storage/utils.py +4 -2
- synth_ai/tracing_v3/turso/daemon.py +7 -2
- synth_ai/tracing_v3/turso/models.py +5 -2
- synth_ai/tracing_v3/turso/native_manager.py +1173 -0
- synth_ai/tracing_v3/utils.py +4 -3
- synth_ai/v0/api/__init__.py +8 -0
- synth_ai/v0/api/models/__init__.py +8 -0
- synth_ai/v0/api/models/supported.py +8 -0
- synth_ai/v0/config/__init__.py +15 -0
- synth_ai/v0/config/base_url.py +12 -0
- synth_ai/v0/lm/__init__.py +51 -0
- synth_ai/{lm → v0/lm}/caching/ephemeral.py +3 -5
- synth_ai/{lm → v0/lm}/caching/handler.py +4 -4
- synth_ai/{lm → v0/lm}/caching/initialize.py +1 -1
- synth_ai/{lm → v0/lm}/caching/persistent.py +1 -1
- synth_ai/{lm → v0/lm}/config.py +6 -1
- synth_ai/{lm → v0/lm}/core/all.py +9 -9
- synth_ai/{lm → v0/lm}/core/exceptions.py +0 -2
- synth_ai/{lm → v0/lm}/core/main.py +19 -7
- synth_ai/{lm → v0/lm}/core/main_v3.py +10 -10
- synth_ai/{lm → v0/lm}/core/synth_models.py +2 -15
- synth_ai/{lm → v0/lm}/core/vendor_clients.py +6 -4
- synth_ai/{lm → v0/lm}/overrides.py +4 -4
- synth_ai/{lm → v0/lm}/provider_support/anthropic.py +4 -4
- synth_ai/{lm → v0/lm}/provider_support/openai.py +5 -5
- synth_ai/{lm → v0/lm}/structured_outputs/handler.py +5 -5
- synth_ai/{lm → v0/lm}/structured_outputs/rehabilitate.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/core/anthropic_api.py +16 -16
- synth_ai/{lm → v0/lm}/vendors/core/gemini_api.py +5 -5
- synth_ai/{lm → v0/lm}/vendors/core/mistral_api.py +5 -5
- synth_ai/{lm → v0/lm}/vendors/core/openai_api.py +12 -10
- synth_ai/{lm → v0/lm}/vendors/openai_standard.py +11 -9
- synth_ai/{lm → v0/lm}/vendors/openai_standard_responses.py +8 -5
- synth_ai/{lm → v0/lm}/vendors/supported/custom_endpoint.py +4 -6
- synth_ai/{lm → v0/lm}/vendors/supported/deepseek.py +2 -2
- synth_ai/{lm → v0/lm}/vendors/supported/grok.py +2 -2
- synth_ai/{lm → v0/lm}/vendors/supported/groq.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/supported/ollama.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/supported/openrouter.py +3 -3
- synth_ai/{lm → v0/lm}/vendors/supported/together.py +1 -1
- synth_ai/{lm → v0/lm}/vendors/synth_client.py +38 -11
- synth_ai/v0/tracing/upload.py +32 -135
- synth_ai/v0/tracing_v3/__init__.py +10 -0
- synth_ai/v0/tracing_v3/abstractions.py +3 -0
- synth_ai/v0/tracing_v3/decorators.py +3 -0
- synth_ai/v0/tracing_v3/llm_call_record_helpers.py +3 -0
- synth_ai/v0/tracing_v3/session_tracer.py +3 -0
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.10.dist-info}/METADATA +10 -7
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.10.dist-info}/RECORD +294 -258
- examples/common_old/backend.py +0 -21
- examples/evals_old/README.md +0 -98
- examples/evals_old/__init__.py +0 -6
- examples/evals_old/compare_models.py +0 -1037
- examples/evals_old/example_log.md +0 -145
- examples/evals_old/run_demo.sh +0 -126
- examples/evals_old/trace_analysis.py +0 -270
- examples/finetuning_old/_backup_synth_qwen/config.toml +0 -29
- examples/finetuning_old/_backup_synth_qwen/example_log.md +0 -324
- examples/finetuning_old/_backup_synth_qwen/filter_traces.py +0 -60
- examples/finetuning_old/_backup_synth_qwen/filter_traces_achievements.py +0 -239
- examples/finetuning_old/_backup_synth_qwen/purge_v3_traces.py +0 -109
- examples/finetuning_old/_backup_synth_qwen/react_agent_lm.py +0 -1924
- examples/finetuning_old/_backup_synth_qwen/readme.md +0 -49
- examples/finetuning_old/_backup_synth_qwen/run_crafter_qwen4b.py +0 -114
- examples/finetuning_old/_backup_synth_qwen/run_demo.sh +0 -195
- examples/finetuning_old/_backup_synth_qwen/sft_kickoff.py +0 -118
- examples/finetuning_old/synth_qwen_v1/README.md +0 -68
- examples/finetuning_old/synth_qwen_v1/filter_traces.py +0 -60
- examples/finetuning_old/synth_qwen_v1/filter_traces_achievements.py +0 -239
- examples/finetuning_old/synth_qwen_v1/finetune.py +0 -46
- examples/finetuning_old/synth_qwen_v1/hello_ft_model.py +0 -71
- examples/finetuning_old/synth_qwen_v1/infer.py +0 -37
- examples/finetuning_old/synth_qwen_v1/poll.py +0 -44
- examples/finetuning_old/synth_qwen_v1/prepare_data.py +0 -35
- examples/finetuning_old/synth_qwen_v1/purge_v3_traces.py +0 -109
- examples/finetuning_old/synth_qwen_v1/react_agent_lm.py +0 -1932
- examples/finetuning_old/synth_qwen_v1/run_crafter_sft_job.py +0 -207
- examples/finetuning_old/synth_qwen_v1/run_ft_job.py +0 -232
- examples/finetuning_old/synth_qwen_v1/upload_data.py +0 -34
- examples/finetuning_old/synth_qwen_v1/util.py +0 -147
- examples/rl_old/task_app.py +0 -962
- synth_ai/experimental/synth_oss.py +0 -446
- synth_ai/install_sqld.sh +0 -40
- synth_ai/learning/filtering.py +0 -0
- synth_ai/learning/offline/dpo.py +0 -0
- synth_ai/learning/offline/providers.py +0 -7
- synth_ai/learning/offline/sft.py +0 -0
- synth_ai/learning/offline/shared.py +0 -0
- synth_ai/learning/online/grpo.py +0 -0
- synth_ai/learning/online/irft.py +0 -0
- synth_ai/learning/prompts/banking77_injection_eval.py +0 -168
- synth_ai/learning/prompts/gepa.py +0 -0
- synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +0 -213
- synth_ai/learning/prompts/mipro.py +0 -289
- synth_ai/learning/prompts/random_search.py +0 -246
- synth_ai/learning/prompts/run_mipro_banking77.py +0 -172
- synth_ai/learning/prompts/run_random_search_banking77.py +0 -324
- synth_ai/rl/secrets.py +0 -19
- synth_ai/scripts/verify_rewards.py +0 -100
- synth_ai/tracing/__init__.py +0 -30
- synth_ai/tracing_v1/__init__.py +0 -33
- synth_ai/tracing_v3/turso/__init__.py +0 -25
- synth_ai/tracing_v3/turso/manager.py +0 -774
- synth_ai/zyk/__init__.py +0 -30
- /synth_ai/{lm → v0/lm}/caching/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/caching/constants.py +0 -0
- /synth_ai/{lm → v0/lm}/caching/dbs.py +0 -0
- /synth_ai/{lm → v0/lm}/constants.py +0 -0
- /synth_ai/{lm → v0/lm}/core/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/monitor.py +0 -0
- /synth_ai/{lm → v0/lm}/cost/statefulness.py +0 -0
- /synth_ai/{lm → v0/lm}/injection.py +0 -0
- /synth_ai/{lm → v0/lm}/provider_support/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/provider_support/suppress_logging.py +0 -0
- /synth_ai/{lm → v0/lm}/structured_outputs/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/structured_outputs/inject.py +0 -0
- /synth_ai/{lm → v0/lm}/tools/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/tools/base.py +0 -0
- /synth_ai/{lm → v0/lm}/unified_interface.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/base.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/core/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/core/synth_dev_api.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/local/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/local/ollama.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/retries.py +0 -0
- /synth_ai/{lm → v0/lm}/vendors/supported/__init__.py +0 -0
- /synth_ai/{lm → v0/lm}/warmup.py +0 -0
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.10.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.10.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.10.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.10.dist-info}/top_level.txt +0 -0
|
@@ -1,774 +0,0 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
"""Async SQLAlchemy-based trace manager for Turso/sqld.
|
|
3
|
-
|
|
4
|
-
This module provides the database interface for the tracing system using
|
|
5
|
-
async SQLAlchemy with a Turso/sqld backend. It handles all database operations
|
|
6
|
-
including schema creation, session storage, and analytics queries.
|
|
7
|
-
|
|
8
|
-
Key Features:
|
|
9
|
-
------------
|
|
10
|
-
- Async-first design using aiosqlite for local SQLite
|
|
11
|
-
- Automatic schema creation and migration
|
|
12
|
-
- Batch insert capabilities for high-throughput scenarios
|
|
13
|
-
- Analytics views for efficient querying
|
|
14
|
-
- Connection pooling and retry logic
|
|
15
|
-
|
|
16
|
-
Performance Considerations:
|
|
17
|
-
--------------------------
|
|
18
|
-
- Uses NullPool for SQLite to avoid connection issues
|
|
19
|
-
- Implements busy timeout for concurrent access
|
|
20
|
-
- Batches inserts to reduce transaction overhead
|
|
21
|
-
- Creates indexes for common query patterns
|
|
22
|
-
"""
|
|
23
|
-
|
|
24
|
-
import asyncio
|
|
25
|
-
import logging
|
|
26
|
-
from contextlib import asynccontextmanager
|
|
27
|
-
from datetime import datetime
|
|
28
|
-
from typing import Any
|
|
29
|
-
|
|
30
|
-
# Optional pandas import: fall back to records (list[dict]) if unavailable
|
|
31
|
-
try: # pragma: no cover - exercised in environments without pandas
|
|
32
|
-
import pandas as pd # type: ignore
|
|
33
|
-
except Exception: # pragma: no cover
|
|
34
|
-
pd = None # type: ignore[assignment]
|
|
35
|
-
from sqlalchemy import select, text, update
|
|
36
|
-
from sqlalchemy.exc import IntegrityError
|
|
37
|
-
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
|
|
38
|
-
from sqlalchemy import event
|
|
39
|
-
from sqlalchemy.orm import selectinload, sessionmaker
|
|
40
|
-
from sqlalchemy.pool import NullPool
|
|
41
|
-
|
|
42
|
-
from ..abstractions import (
|
|
43
|
-
EnvironmentEvent,
|
|
44
|
-
LMCAISEvent,
|
|
45
|
-
RuntimeEvent,
|
|
46
|
-
SessionTrace,
|
|
47
|
-
)
|
|
48
|
-
from ..config import CONFIG
|
|
49
|
-
from .models import (
|
|
50
|
-
Base,
|
|
51
|
-
analytics_views,
|
|
52
|
-
)
|
|
53
|
-
from .models import (
|
|
54
|
-
Event as DBEvent,
|
|
55
|
-
)
|
|
56
|
-
from .models import (
|
|
57
|
-
Experiment as DBExperiment,
|
|
58
|
-
)
|
|
59
|
-
from .models import (
|
|
60
|
-
Message as DBMessage,
|
|
61
|
-
)
|
|
62
|
-
from .models import (
|
|
63
|
-
SessionTimestep as DBSessionTimestep,
|
|
64
|
-
)
|
|
65
|
-
from .models import (
|
|
66
|
-
SessionTrace as DBSessionTrace,
|
|
67
|
-
)
|
|
68
|
-
from .models import (
|
|
69
|
-
OutcomeReward as DBOutcomeReward,
|
|
70
|
-
)
|
|
71
|
-
from .models import (
|
|
72
|
-
EventReward as DBEventReward,
|
|
73
|
-
)
|
|
74
|
-
|
|
75
|
-
logger = logging.getLogger(__name__)
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
class AsyncSQLTraceManager:
|
|
79
|
-
"""Async trace storage manager using SQLAlchemy and Turso/sqld.
|
|
80
|
-
|
|
81
|
-
Handles all database operations for the tracing system. Designed to work
|
|
82
|
-
with both local SQLite (via aiosqlite) and remote Turso databases.
|
|
83
|
-
|
|
84
|
-
The manager handles:
|
|
85
|
-
- Connection lifecycle management
|
|
86
|
-
- Schema creation and verification
|
|
87
|
-
- Transaction management
|
|
88
|
-
- Batch operations for efficiency
|
|
89
|
-
- Analytics view creation
|
|
90
|
-
"""
|
|
91
|
-
|
|
92
|
-
def __init__(self, db_url: str | None = None):
|
|
93
|
-
self.db_url = db_url or CONFIG.db_url
|
|
94
|
-
self.engine: AsyncEngine | None = None
|
|
95
|
-
self.SessionLocal: sessionmaker | None = None
|
|
96
|
-
self._schema_lock = asyncio.Lock()
|
|
97
|
-
self._schema_ready = False
|
|
98
|
-
|
|
99
|
-
async def initialize(self):
|
|
100
|
-
"""Initialize the database connection and schema.
|
|
101
|
-
|
|
102
|
-
This method is idempotent and thread-safe. It:
|
|
103
|
-
1. Creates the async engine with appropriate settings
|
|
104
|
-
2. Verifies database file exists (for SQLite)
|
|
105
|
-
3. Creates schema if needed
|
|
106
|
-
4. Sets up analytics views
|
|
107
|
-
|
|
108
|
-
The schema lock ensures only one worker creates the schema in
|
|
109
|
-
concurrent scenarios.
|
|
110
|
-
"""
|
|
111
|
-
if self.engine is None:
|
|
112
|
-
logger.debug(f"🔗 Initializing database connection to: {self.db_url}")
|
|
113
|
-
|
|
114
|
-
# For SQLite, use NullPool to avoid connection pool issues
|
|
115
|
-
# SQLite doesn't handle concurrent connections well, so we create
|
|
116
|
-
# a new connection for each operation
|
|
117
|
-
if self.db_url.startswith("sqlite"):
|
|
118
|
-
# Extract the file path from the URL
|
|
119
|
-
db_path = self.db_url.replace("sqlite+aiosqlite:///", "")
|
|
120
|
-
import os
|
|
121
|
-
|
|
122
|
-
# Check if database file exists
|
|
123
|
-
if not os.path.exists(db_path):
|
|
124
|
-
logger.debug(f"⚠️ Database file not found: {db_path}")
|
|
125
|
-
logger.debug(
|
|
126
|
-
"🔧 Make sure './serve.sh' is running to start the turso/sqld service"
|
|
127
|
-
)
|
|
128
|
-
else:
|
|
129
|
-
logger.debug(f"✅ Found database file: {db_path}")
|
|
130
|
-
|
|
131
|
-
# Set a high busy timeout to handle concurrent access
|
|
132
|
-
# This allows SQLite to wait instead of immediately failing
|
|
133
|
-
connect_args = {"timeout": 30.0} # 30 second busy timeout
|
|
134
|
-
self.engine = create_async_engine(
|
|
135
|
-
self.db_url, # Use instance db_url, not CONFIG
|
|
136
|
-
poolclass=NullPool, # No connection pooling for SQLite
|
|
137
|
-
connect_args=connect_args,
|
|
138
|
-
echo=CONFIG.echo_sql,
|
|
139
|
-
)
|
|
140
|
-
# Ensure PRAGMA foreign_keys=ON for every connection
|
|
141
|
-
try:
|
|
142
|
-
@event.listens_for(self.engine.sync_engine, "connect")
|
|
143
|
-
def _set_sqlite_pragma(dbapi_connection, connection_record): # type: ignore[no-redef]
|
|
144
|
-
try:
|
|
145
|
-
cursor = dbapi_connection.cursor()
|
|
146
|
-
cursor.execute("PRAGMA foreign_keys=ON")
|
|
147
|
-
cursor.close()
|
|
148
|
-
except Exception:
|
|
149
|
-
pass
|
|
150
|
-
except Exception:
|
|
151
|
-
pass
|
|
152
|
-
else:
|
|
153
|
-
connect_args = CONFIG.get_connect_args()
|
|
154
|
-
engine_kwargs = CONFIG.get_engine_kwargs()
|
|
155
|
-
self.engine = create_async_engine(
|
|
156
|
-
self.db_url, # Use instance db_url, not CONFIG
|
|
157
|
-
connect_args=connect_args,
|
|
158
|
-
**engine_kwargs,
|
|
159
|
-
)
|
|
160
|
-
|
|
161
|
-
self.SessionLocal = sessionmaker(
|
|
162
|
-
self.engine, class_=AsyncSession, expire_on_commit=False
|
|
163
|
-
)
|
|
164
|
-
|
|
165
|
-
await self._ensure_schema()
|
|
166
|
-
|
|
167
|
-
async def _ensure_schema(self):
|
|
168
|
-
"""Ensure database schema is created.
|
|
169
|
-
|
|
170
|
-
Uses a lock to prevent race conditions when multiple workers start
|
|
171
|
-
simultaneously. The checkfirst=True parameter handles cases where
|
|
172
|
-
another worker already created the schema.
|
|
173
|
-
"""
|
|
174
|
-
async with self._schema_lock:
|
|
175
|
-
if self._schema_ready:
|
|
176
|
-
return
|
|
177
|
-
|
|
178
|
-
logger.debug("📊 Initializing database schema...")
|
|
179
|
-
|
|
180
|
-
async with self.engine.begin() as conn:
|
|
181
|
-
# Use a transaction to ensure atomic schema creation
|
|
182
|
-
# checkfirst=True prevents errors if tables already exist
|
|
183
|
-
try:
|
|
184
|
-
await conn.run_sync(
|
|
185
|
-
lambda sync_conn: Base.metadata.create_all(sync_conn, checkfirst=True)
|
|
186
|
-
)
|
|
187
|
-
# logger.info("✅ Database schema created/verified successfully")
|
|
188
|
-
except Exception as e:
|
|
189
|
-
# If tables already exist, that's fine - another worker created them
|
|
190
|
-
if "already exists" not in str(e):
|
|
191
|
-
logger.error(f"❌ Failed to create database schema: {e}")
|
|
192
|
-
raise
|
|
193
|
-
else:
|
|
194
|
-
logger.debug("✅ Database schema already exists")
|
|
195
|
-
|
|
196
|
-
# Enable foreign keys for SQLite - critical for data integrity
|
|
197
|
-
# This must be done for each connection in SQLite
|
|
198
|
-
if CONFIG.foreign_keys:
|
|
199
|
-
await conn.execute(text("PRAGMA foreign_keys = ON"))
|
|
200
|
-
|
|
201
|
-
# Set journal mode
|
|
202
|
-
if CONFIG.journal_mode:
|
|
203
|
-
await conn.execute(text(f"PRAGMA journal_mode = {CONFIG.journal_mode}"))
|
|
204
|
-
|
|
205
|
-
# Create analytics views for efficient querying
|
|
206
|
-
# These are materialized as views to avoid recalculation
|
|
207
|
-
for view_name, view_sql in analytics_views.items():
|
|
208
|
-
try:
|
|
209
|
-
await conn.execute(text(view_sql))
|
|
210
|
-
except Exception as e:
|
|
211
|
-
# Views might already exist from another worker
|
|
212
|
-
if "already exists" not in str(e):
|
|
213
|
-
logger.warning(f"Could not create view {view_name}: {e}")
|
|
214
|
-
|
|
215
|
-
self._schema_ready = True
|
|
216
|
-
# logger.debug("🎯 Database ready for use!")
|
|
217
|
-
|
|
218
|
-
@asynccontextmanager
|
|
219
|
-
async def session(self):
|
|
220
|
-
"""Get an async database session."""
|
|
221
|
-
if not self.SessionLocal:
|
|
222
|
-
await self.initialize()
|
|
223
|
-
async with self.SessionLocal() as session:
|
|
224
|
-
yield session
|
|
225
|
-
|
|
226
|
-
async def insert_session_trace(self, trace: SessionTrace) -> str:
|
|
227
|
-
"""Insert a complete session trace.
|
|
228
|
-
|
|
229
|
-
This method handles the complex task of inserting a complete session
|
|
230
|
-
with all its timesteps, events, and messages. It uses a single
|
|
231
|
-
transaction for atomicity and flushes after timesteps to get their
|
|
232
|
-
auto-generated IDs for foreign keys.
|
|
233
|
-
|
|
234
|
-
Args:
|
|
235
|
-
trace: The complete session trace to store
|
|
236
|
-
|
|
237
|
-
Returns:
|
|
238
|
-
The session ID
|
|
239
|
-
|
|
240
|
-
Raises:
|
|
241
|
-
IntegrityError: If session ID already exists (handled gracefully)
|
|
242
|
-
"""
|
|
243
|
-
async with self.session() as sess:
|
|
244
|
-
try:
|
|
245
|
-
# Convert to cents for cost storage - avoids floating point
|
|
246
|
-
# precision issues and allows for integer arithmetic
|
|
247
|
-
def to_cents(cost: float | None) -> int | None:
|
|
248
|
-
return int(cost * 100) if cost is not None else None
|
|
249
|
-
|
|
250
|
-
# Insert session
|
|
251
|
-
db_session = DBSessionTrace(
|
|
252
|
-
session_id=trace.session_id,
|
|
253
|
-
created_at=trace.created_at,
|
|
254
|
-
num_timesteps=len(trace.session_time_steps),
|
|
255
|
-
num_events=len(trace.event_history),
|
|
256
|
-
num_messages=len(trace.markov_blanket_message_history),
|
|
257
|
-
session_metadata=trace.metadata or {},
|
|
258
|
-
)
|
|
259
|
-
sess.add(db_session)
|
|
260
|
-
|
|
261
|
-
# Track timestep IDs for foreign keys - we need these to link
|
|
262
|
-
# events and messages to their respective timesteps
|
|
263
|
-
step_id_map: dict[str, int] = {}
|
|
264
|
-
|
|
265
|
-
# Insert timesteps
|
|
266
|
-
for step in trace.session_time_steps:
|
|
267
|
-
db_step = DBSessionTimestep(
|
|
268
|
-
session_id=trace.session_id,
|
|
269
|
-
step_id=step.step_id,
|
|
270
|
-
step_index=step.step_index,
|
|
271
|
-
turn_number=step.turn_number,
|
|
272
|
-
started_at=step.timestamp,
|
|
273
|
-
completed_at=step.completed_at,
|
|
274
|
-
num_events=len(step.events),
|
|
275
|
-
num_messages=len(step.markov_blanket_messages),
|
|
276
|
-
step_metadata=step.step_metadata or {},
|
|
277
|
-
)
|
|
278
|
-
sess.add(db_step)
|
|
279
|
-
# Flush to get the auto-generated ID without committing
|
|
280
|
-
# This allows us to use the ID for foreign keys while
|
|
281
|
-
# maintaining transaction atomicity
|
|
282
|
-
await sess.flush() # Get the auto-generated ID
|
|
283
|
-
step_id_map[step.step_id] = db_step.id
|
|
284
|
-
|
|
285
|
-
# Insert events - handle different event types with their
|
|
286
|
-
# specific fields while maintaining a unified storage model
|
|
287
|
-
for event in trace.event_history:
|
|
288
|
-
event_data = {
|
|
289
|
-
"session_id": trace.session_id,
|
|
290
|
-
"timestep_id": step_id_map.get(event.metadata.get("step_id")),
|
|
291
|
-
"system_instance_id": event.system_instance_id,
|
|
292
|
-
"event_time": event.time_record.event_time,
|
|
293
|
-
"message_time": event.time_record.message_time,
|
|
294
|
-
"event_metadata_json": event.metadata or {},
|
|
295
|
-
"event_extra_metadata": event.event_metadata,
|
|
296
|
-
}
|
|
297
|
-
|
|
298
|
-
if isinstance(event, LMCAISEvent):
|
|
299
|
-
# Serialize call_records if present
|
|
300
|
-
call_records_data = None
|
|
301
|
-
if event.call_records:
|
|
302
|
-
from dataclasses import asdict
|
|
303
|
-
|
|
304
|
-
call_records_data = [asdict(record) for record in event.call_records]
|
|
305
|
-
|
|
306
|
-
event_data.update(
|
|
307
|
-
{
|
|
308
|
-
"event_type": "cais",
|
|
309
|
-
"model_name": event.model_name,
|
|
310
|
-
"provider": event.provider,
|
|
311
|
-
"input_tokens": event.input_tokens,
|
|
312
|
-
"output_tokens": event.output_tokens,
|
|
313
|
-
"total_tokens": event.total_tokens,
|
|
314
|
-
"cost_usd": to_cents(event.cost_usd),
|
|
315
|
-
"latency_ms": event.latency_ms,
|
|
316
|
-
"span_id": event.span_id,
|
|
317
|
-
"trace_id": event.trace_id,
|
|
318
|
-
"system_state_before": event.system_state_before,
|
|
319
|
-
"system_state_after": event.system_state_after,
|
|
320
|
-
"call_records": call_records_data, # Store in the proper column
|
|
321
|
-
}
|
|
322
|
-
)
|
|
323
|
-
elif isinstance(event, EnvironmentEvent):
|
|
324
|
-
event_data.update(
|
|
325
|
-
{
|
|
326
|
-
"event_type": "environment",
|
|
327
|
-
"reward": event.reward,
|
|
328
|
-
"terminated": event.terminated,
|
|
329
|
-
"truncated": event.truncated,
|
|
330
|
-
"system_state_before": event.system_state_before,
|
|
331
|
-
"system_state_after": event.system_state_after,
|
|
332
|
-
}
|
|
333
|
-
)
|
|
334
|
-
elif isinstance(event, RuntimeEvent):
|
|
335
|
-
event_data.update(
|
|
336
|
-
{
|
|
337
|
-
"event_type": "runtime",
|
|
338
|
-
"event_metadata_json": {**event.metadata, "actions": event.actions},
|
|
339
|
-
}
|
|
340
|
-
)
|
|
341
|
-
else:
|
|
342
|
-
event_data["event_type"] = event.__class__.__name__.lower()
|
|
343
|
-
|
|
344
|
-
db_event = DBEvent(**event_data)
|
|
345
|
-
sess.add(db_event)
|
|
346
|
-
|
|
347
|
-
# Insert messages
|
|
348
|
-
for msg in trace.markov_blanket_message_history:
|
|
349
|
-
db_msg = DBMessage(
|
|
350
|
-
session_id=trace.session_id,
|
|
351
|
-
timestep_id=step_id_map.get(msg.metadata.get("step_id"))
|
|
352
|
-
if hasattr(msg, "metadata")
|
|
353
|
-
else None,
|
|
354
|
-
message_type=msg.message_type,
|
|
355
|
-
content=msg.content,
|
|
356
|
-
event_time=msg.time_record.event_time,
|
|
357
|
-
message_time=msg.time_record.message_time,
|
|
358
|
-
message_metadata=msg.metadata if hasattr(msg, "metadata") else {},
|
|
359
|
-
)
|
|
360
|
-
sess.add(db_msg)
|
|
361
|
-
|
|
362
|
-
# Commit the entire transaction atomically
|
|
363
|
-
await sess.commit()
|
|
364
|
-
return trace.session_id
|
|
365
|
-
except IntegrityError as e:
|
|
366
|
-
# Handle duplicate session IDs gracefully - this can happen
|
|
367
|
-
# in distributed systems or retries. We return the existing
|
|
368
|
-
# ID to maintain idempotency
|
|
369
|
-
if "UNIQUE constraint failed: session_traces.session_id" in str(e):
|
|
370
|
-
await sess.rollback()
|
|
371
|
-
return trace.session_id # Return existing ID
|
|
372
|
-
raise
|
|
373
|
-
|
|
374
|
-
async def get_session_trace(self, session_id: str) -> dict[str, Any] | None:
|
|
375
|
-
"""Retrieve a session trace by ID."""
|
|
376
|
-
async with self.session() as sess:
|
|
377
|
-
result = await sess.execute(
|
|
378
|
-
select(DBSessionTrace)
|
|
379
|
-
.options(
|
|
380
|
-
selectinload(DBSessionTrace.timesteps),
|
|
381
|
-
selectinload(DBSessionTrace.events),
|
|
382
|
-
selectinload(DBSessionTrace.messages),
|
|
383
|
-
)
|
|
384
|
-
.where(DBSessionTrace.session_id == session_id)
|
|
385
|
-
)
|
|
386
|
-
session = result.scalar_one_or_none()
|
|
387
|
-
|
|
388
|
-
if not session:
|
|
389
|
-
return None
|
|
390
|
-
|
|
391
|
-
return {
|
|
392
|
-
"session_id": session.session_id,
|
|
393
|
-
"created_at": session.created_at,
|
|
394
|
-
"num_timesteps": session.num_timesteps,
|
|
395
|
-
"num_events": session.num_events,
|
|
396
|
-
"num_messages": session.num_messages,
|
|
397
|
-
"metadata": session.session_metadata,
|
|
398
|
-
"timesteps": [
|
|
399
|
-
{
|
|
400
|
-
"step_id": step.step_id,
|
|
401
|
-
"step_index": step.step_index,
|
|
402
|
-
"turn_number": step.turn_number,
|
|
403
|
-
"started_at": step.started_at,
|
|
404
|
-
"completed_at": step.completed_at,
|
|
405
|
-
"metadata": step.step_metadata,
|
|
406
|
-
}
|
|
407
|
-
for step in sorted(session.timesteps, key=lambda s: s.step_index)
|
|
408
|
-
],
|
|
409
|
-
}
|
|
410
|
-
|
|
411
|
-
async def query_traces(
|
|
412
|
-
self, query: str, params: dict[str, Any] | None = None
|
|
413
|
-
) -> Any:
|
|
414
|
-
"""Execute a query and return results.
|
|
415
|
-
|
|
416
|
-
Returns a pandas DataFrame when pandas is available; otherwise a
|
|
417
|
-
list of dict records. Callers should handle both.
|
|
418
|
-
"""
|
|
419
|
-
async with self.session() as sess:
|
|
420
|
-
result = await sess.execute(text(query), params or {})
|
|
421
|
-
rows = result.mappings().all()
|
|
422
|
-
if pd is not None:
|
|
423
|
-
return pd.DataFrame(rows)
|
|
424
|
-
return [dict(r) for r in rows]
|
|
425
|
-
|
|
426
|
-
async def get_model_usage(
|
|
427
|
-
self,
|
|
428
|
-
start_date: datetime | None = None,
|
|
429
|
-
end_date: datetime | None = None,
|
|
430
|
-
model_name: str | None = None,
|
|
431
|
-
) -> Any:
|
|
432
|
-
"""Get model usage statistics.
|
|
433
|
-
|
|
434
|
-
Returns a pandas DataFrame when pandas is available; otherwise a list
|
|
435
|
-
of dict records.
|
|
436
|
-
"""
|
|
437
|
-
query = """
|
|
438
|
-
SELECT * FROM model_usage_stats
|
|
439
|
-
WHERE 1=1
|
|
440
|
-
"""
|
|
441
|
-
params = {}
|
|
442
|
-
|
|
443
|
-
if start_date:
|
|
444
|
-
query += " AND last_used >= :start_date"
|
|
445
|
-
params["start_date"] = start_date
|
|
446
|
-
|
|
447
|
-
if end_date:
|
|
448
|
-
query += " AND first_used <= :end_date"
|
|
449
|
-
params["end_date"] = end_date
|
|
450
|
-
|
|
451
|
-
if model_name:
|
|
452
|
-
query += " AND model_name = :model_name"
|
|
453
|
-
params["model_name"] = model_name
|
|
454
|
-
|
|
455
|
-
query += " ORDER BY usage_count DESC"
|
|
456
|
-
|
|
457
|
-
return await self.query_traces(query, params)
|
|
458
|
-
|
|
459
|
-
async def create_experiment(
|
|
460
|
-
self,
|
|
461
|
-
experiment_id: str,
|
|
462
|
-
name: str,
|
|
463
|
-
description: str | None = None,
|
|
464
|
-
configuration: dict[str, Any] | None = None,
|
|
465
|
-
) -> str:
|
|
466
|
-
"""Create a new experiment."""
|
|
467
|
-
async with self.session() as sess:
|
|
468
|
-
experiment = DBExperiment(
|
|
469
|
-
experiment_id=experiment_id,
|
|
470
|
-
name=name,
|
|
471
|
-
description=description,
|
|
472
|
-
configuration=configuration or {},
|
|
473
|
-
)
|
|
474
|
-
sess.add(experiment)
|
|
475
|
-
await sess.commit()
|
|
476
|
-
return experiment_id
|
|
477
|
-
|
|
478
|
-
async def link_session_to_experiment(self, session_id: str, experiment_id: str):
|
|
479
|
-
"""Link a session to an experiment."""
|
|
480
|
-
async with self.session() as sess:
|
|
481
|
-
await sess.execute(
|
|
482
|
-
update(DBSessionTrace)
|
|
483
|
-
.where(DBSessionTrace.session_id == session_id)
|
|
484
|
-
.values(experiment_id=experiment_id)
|
|
485
|
-
)
|
|
486
|
-
await sess.commit()
|
|
487
|
-
|
|
488
|
-
async def batch_insert_sessions(
|
|
489
|
-
self, traces: list[SessionTrace], batch_size: int | None = None
|
|
490
|
-
) -> list[str]:
|
|
491
|
-
"""Batch insert multiple session traces.
|
|
492
|
-
|
|
493
|
-
Processes traces in batches to balance memory usage and performance.
|
|
494
|
-
Each batch is inserted in a separate transaction to avoid holding
|
|
495
|
-
locks for too long.
|
|
496
|
-
|
|
497
|
-
Args:
|
|
498
|
-
traces: List of session traces to insert
|
|
499
|
-
batch_size: Number of traces per batch (defaults to config)
|
|
500
|
-
|
|
501
|
-
Returns:
|
|
502
|
-
List of inserted session IDs
|
|
503
|
-
"""
|
|
504
|
-
batch_size = batch_size or CONFIG.batch_size
|
|
505
|
-
inserted_ids = []
|
|
506
|
-
|
|
507
|
-
# Process in chunks to avoid memory issues with large datasets
|
|
508
|
-
for i in range(0, len(traces), batch_size):
|
|
509
|
-
batch = traces[i : i + batch_size]
|
|
510
|
-
# Insert each trace in the batch - could be optimized further
|
|
511
|
-
# with bulk inserts if needed
|
|
512
|
-
for trace in batch:
|
|
513
|
-
session_id = await self.insert_session_trace(trace)
|
|
514
|
-
inserted_ids.append(session_id)
|
|
515
|
-
|
|
516
|
-
return inserted_ids
|
|
517
|
-
|
|
518
|
-
async def get_sessions_by_experiment(
|
|
519
|
-
self, experiment_id: str, limit: int | None = None
|
|
520
|
-
) -> list[dict[str, Any]]:
|
|
521
|
-
"""Get all sessions for an experiment."""
|
|
522
|
-
async with self.session() as sess:
|
|
523
|
-
query = (
|
|
524
|
-
select(DBSessionTrace)
|
|
525
|
-
.where(DBSessionTrace.experiment_id == experiment_id)
|
|
526
|
-
.order_by(DBSessionTrace.created_at.desc())
|
|
527
|
-
)
|
|
528
|
-
|
|
529
|
-
if limit:
|
|
530
|
-
query = query.limit(limit)
|
|
531
|
-
|
|
532
|
-
result = await sess.execute(query)
|
|
533
|
-
sessions = result.scalars().all()
|
|
534
|
-
|
|
535
|
-
return [
|
|
536
|
-
{
|
|
537
|
-
"session_id": s.session_id,
|
|
538
|
-
"created_at": s.created_at,
|
|
539
|
-
"num_timesteps": s.num_timesteps,
|
|
540
|
-
"num_events": s.num_events,
|
|
541
|
-
"num_messages": s.num_messages,
|
|
542
|
-
"metadata": s.metadata,
|
|
543
|
-
}
|
|
544
|
-
for s in sessions
|
|
545
|
-
]
|
|
546
|
-
|
|
547
|
-
async def delete_session(self, session_id: str) -> bool:
|
|
548
|
-
"""Delete a session and all related data."""
|
|
549
|
-
async with self.session() as sess:
|
|
550
|
-
# Get the session object to trigger cascade deletes
|
|
551
|
-
result = await sess.execute(
|
|
552
|
-
select(DBSessionTrace).where(DBSessionTrace.session_id == session_id)
|
|
553
|
-
)
|
|
554
|
-
session = result.scalar_one_or_none()
|
|
555
|
-
|
|
556
|
-
if session:
|
|
557
|
-
await sess.delete(session)
|
|
558
|
-
await sess.commit()
|
|
559
|
-
return True
|
|
560
|
-
return False
|
|
561
|
-
|
|
562
|
-
async def close(self):
|
|
563
|
-
"""Close the database connection.
|
|
564
|
-
|
|
565
|
-
Properly disposes of the engine and all connections. This is important
|
|
566
|
-
for cleanup, especially with SQLite which can leave lock files.
|
|
567
|
-
"""
|
|
568
|
-
if self.engine:
|
|
569
|
-
# Dispose of all connections in the pool
|
|
570
|
-
await self.engine.dispose()
|
|
571
|
-
# Clear all state to allow re-initialization if needed
|
|
572
|
-
self.engine = None
|
|
573
|
-
self.SessionLocal = None
|
|
574
|
-
self._schema_ready = False
|
|
575
|
-
|
|
576
|
-
# -------------------------------
|
|
577
|
-
# Incremental insert helpers
|
|
578
|
-
# -------------------------------
|
|
579
|
-
|
|
580
|
-
async def ensure_session(self, session_id: str, *, created_at: datetime | None = None, metadata: dict[str, Any] | None = None):
|
|
581
|
-
"""Ensure a DB session row exists for session_id."""
|
|
582
|
-
async with self.session() as sess:
|
|
583
|
-
result = await sess.execute(select(DBSessionTrace).where(DBSessionTrace.session_id == session_id))
|
|
584
|
-
existing = result.scalar_one_or_none()
|
|
585
|
-
if existing:
|
|
586
|
-
return
|
|
587
|
-
row = DBSessionTrace(
|
|
588
|
-
session_id=session_id,
|
|
589
|
-
created_at=created_at or datetime.utcnow(),
|
|
590
|
-
num_timesteps=0,
|
|
591
|
-
num_events=0,
|
|
592
|
-
num_messages=0,
|
|
593
|
-
session_metadata=metadata or {},
|
|
594
|
-
)
|
|
595
|
-
sess.add(row)
|
|
596
|
-
await sess.commit()
|
|
597
|
-
|
|
598
|
-
async def ensure_timestep(self, session_id: str, *, step_id: str, step_index: int, turn_number: int | None = None, started_at: datetime | None = None, completed_at: datetime | None = None, metadata: dict[str, Any] | None = None) -> int:
|
|
599
|
-
"""Ensure a timestep row exists; return its DB id."""
|
|
600
|
-
async with self.session() as sess:
|
|
601
|
-
result = await sess.execute(
|
|
602
|
-
select(DBSessionTimestep).where(DBSessionTimestep.session_id == session_id, DBSessionTimestep.step_id == step_id)
|
|
603
|
-
)
|
|
604
|
-
row = result.scalar_one_or_none()
|
|
605
|
-
if row:
|
|
606
|
-
return row.id
|
|
607
|
-
row = DBSessionTimestep(
|
|
608
|
-
session_id=session_id,
|
|
609
|
-
step_id=step_id,
|
|
610
|
-
step_index=step_index,
|
|
611
|
-
turn_number=turn_number,
|
|
612
|
-
started_at=started_at or datetime.utcnow(),
|
|
613
|
-
completed_at=completed_at,
|
|
614
|
-
num_events=0,
|
|
615
|
-
num_messages=0,
|
|
616
|
-
step_metadata=metadata or {},
|
|
617
|
-
)
|
|
618
|
-
sess.add(row)
|
|
619
|
-
await sess.flush()
|
|
620
|
-
# increment session num_timesteps
|
|
621
|
-
await sess.execute(
|
|
622
|
-
update(DBSessionTrace)
|
|
623
|
-
.where(DBSessionTrace.session_id == session_id)
|
|
624
|
-
.values(num_timesteps=DBSessionTrace.num_timesteps + 1)
|
|
625
|
-
)
|
|
626
|
-
await sess.commit()
|
|
627
|
-
return row.id
|
|
628
|
-
|
|
629
|
-
async def insert_message_row(self, session_id: str, *, timestep_db_id: int | None, message_type: str, content: str, event_time: float | None = None, message_time: int | None = None, metadata: dict[str, Any] | None = None) -> int:
|
|
630
|
-
"""Insert a message and return its id."""
|
|
631
|
-
async with self.session() as sess:
|
|
632
|
-
db_msg = DBMessage(
|
|
633
|
-
session_id=session_id,
|
|
634
|
-
timestep_id=timestep_db_id,
|
|
635
|
-
message_type=message_type,
|
|
636
|
-
content=content,
|
|
637
|
-
event_time=event_time,
|
|
638
|
-
message_time=message_time,
|
|
639
|
-
message_metadata=metadata or {},
|
|
640
|
-
)
|
|
641
|
-
sess.add(db_msg)
|
|
642
|
-
await sess.flush()
|
|
643
|
-
# increment session num_messages
|
|
644
|
-
await sess.execute(
|
|
645
|
-
update(DBSessionTrace)
|
|
646
|
-
.where(DBSessionTrace.session_id == session_id)
|
|
647
|
-
.values(num_messages=DBSessionTrace.num_messages + 1)
|
|
648
|
-
)
|
|
649
|
-
await sess.commit()
|
|
650
|
-
return db_msg.id
|
|
651
|
-
|
|
652
|
-
async def insert_event_row(self, session_id: str, *, timestep_db_id: int | None, event: EnvironmentEvent | LMCAISEvent | RuntimeEvent, metadata_override: dict[str, Any] | None = None) -> int:
|
|
653
|
-
"""Insert an event and return its id."""
|
|
654
|
-
def to_cents(cost: float | None) -> int | None:
|
|
655
|
-
return int(cost * 100) if cost is not None else None
|
|
656
|
-
|
|
657
|
-
event_data: dict[str, Any] = {
|
|
658
|
-
"session_id": session_id,
|
|
659
|
-
"timestep_id": timestep_db_id,
|
|
660
|
-
"system_instance_id": event.system_instance_id,
|
|
661
|
-
"event_time": event.time_record.event_time,
|
|
662
|
-
"message_time": event.time_record.message_time,
|
|
663
|
-
"event_metadata_json": metadata_override or event.metadata or {},
|
|
664
|
-
"event_extra_metadata": getattr(event, "event_metadata", None),
|
|
665
|
-
}
|
|
666
|
-
if isinstance(event, LMCAISEvent):
|
|
667
|
-
call_records_data = None
|
|
668
|
-
if getattr(event, "call_records", None):
|
|
669
|
-
from dataclasses import asdict
|
|
670
|
-
|
|
671
|
-
call_records_data = [asdict(record) for record in event.call_records]
|
|
672
|
-
event_data.update({
|
|
673
|
-
"event_type": "cais",
|
|
674
|
-
"model_name": event.model_name,
|
|
675
|
-
"provider": event.provider,
|
|
676
|
-
"input_tokens": event.input_tokens,
|
|
677
|
-
"output_tokens": event.output_tokens,
|
|
678
|
-
"total_tokens": event.total_tokens,
|
|
679
|
-
"cost_usd": to_cents(event.cost_usd),
|
|
680
|
-
"latency_ms": event.latency_ms,
|
|
681
|
-
"span_id": event.span_id,
|
|
682
|
-
"trace_id": event.trace_id,
|
|
683
|
-
"system_state_before": event.system_state_before,
|
|
684
|
-
"system_state_after": event.system_state_after,
|
|
685
|
-
"call_records": call_records_data,
|
|
686
|
-
})
|
|
687
|
-
elif isinstance(event, EnvironmentEvent):
|
|
688
|
-
event_data.update({
|
|
689
|
-
"event_type": "environment",
|
|
690
|
-
"reward": event.reward,
|
|
691
|
-
"terminated": event.terminated,
|
|
692
|
-
"truncated": event.truncated,
|
|
693
|
-
"system_state_before": event.system_state_before,
|
|
694
|
-
"system_state_after": event.system_state_after,
|
|
695
|
-
})
|
|
696
|
-
elif isinstance(event, RuntimeEvent):
|
|
697
|
-
event_data.update({
|
|
698
|
-
"event_type": "runtime",
|
|
699
|
-
"event_metadata_json": {**(event.metadata or {}), "actions": event.actions},
|
|
700
|
-
})
|
|
701
|
-
else:
|
|
702
|
-
event_data["event_type"] = event.__class__.__name__.lower()
|
|
703
|
-
|
|
704
|
-
async with self.session() as sess:
|
|
705
|
-
db_event = DBEvent(**event_data)
|
|
706
|
-
sess.add(db_event)
|
|
707
|
-
await sess.flush()
|
|
708
|
-
# increment session num_events
|
|
709
|
-
await sess.execute(
|
|
710
|
-
update(DBSessionTrace)
|
|
711
|
-
.where(DBSessionTrace.session_id == session_id)
|
|
712
|
-
.values(num_events=DBSessionTrace.num_events + 1)
|
|
713
|
-
)
|
|
714
|
-
await sess.commit()
|
|
715
|
-
return db_event.id
|
|
716
|
-
|
|
717
|
-
# -------------------------------
|
|
718
|
-
# Reward helpers
|
|
719
|
-
# -------------------------------
|
|
720
|
-
|
|
721
|
-
async def insert_outcome_reward(self, session_id: str, *, total_reward: int, achievements_count: int, total_steps: int, reward_metadata: dict | None = None) -> int:
|
|
722
|
-
async with self.session() as sess:
|
|
723
|
-
row = DBOutcomeReward(
|
|
724
|
-
session_id=session_id,
|
|
725
|
-
total_reward=total_reward,
|
|
726
|
-
achievements_count=achievements_count,
|
|
727
|
-
total_steps=total_steps,
|
|
728
|
-
reward_metadata=reward_metadata or {},
|
|
729
|
-
)
|
|
730
|
-
sess.add(row)
|
|
731
|
-
await sess.flush()
|
|
732
|
-
await sess.commit()
|
|
733
|
-
return row.id
|
|
734
|
-
|
|
735
|
-
async def insert_event_reward(self, session_id: str, *, event_id: int, message_id: int | None = None, turn_number: int | None = None, reward_value: float = 0.0, reward_type: str | None = None, key: str | None = None, annotation: dict[str, Any] | None = None, source: str | None = None) -> int:
|
|
736
|
-
async with self.session() as sess:
|
|
737
|
-
row = DBEventReward(
|
|
738
|
-
event_id=event_id,
|
|
739
|
-
session_id=session_id,
|
|
740
|
-
message_id=message_id,
|
|
741
|
-
turn_number=turn_number,
|
|
742
|
-
reward_value=reward_value,
|
|
743
|
-
reward_type=reward_type,
|
|
744
|
-
key=key,
|
|
745
|
-
annotation=annotation or {},
|
|
746
|
-
source=source,
|
|
747
|
-
)
|
|
748
|
-
sess.add(row)
|
|
749
|
-
await sess.flush()
|
|
750
|
-
await sess.commit()
|
|
751
|
-
return row.id
|
|
752
|
-
|
|
753
|
-
async def get_outcome_rewards(self) -> list[dict[str, Any]]:
|
|
754
|
-
async with self.session() as sess:
|
|
755
|
-
result = await sess.execute(select(DBOutcomeReward))
|
|
756
|
-
rows = result.scalars().all()
|
|
757
|
-
return [
|
|
758
|
-
{
|
|
759
|
-
"id": r.id,
|
|
760
|
-
"session_id": r.session_id,
|
|
761
|
-
"total_reward": r.total_reward,
|
|
762
|
-
"achievements_count": r.achievements_count,
|
|
763
|
-
"total_steps": r.total_steps,
|
|
764
|
-
"created_at": r.created_at,
|
|
765
|
-
}
|
|
766
|
-
for r in rows
|
|
767
|
-
]
|
|
768
|
-
|
|
769
|
-
async def get_outcome_rewards_by_min_reward(self, min_reward: int) -> list[str]:
|
|
770
|
-
async with self.session() as sess:
|
|
771
|
-
result = await sess.execute(
|
|
772
|
-
select(DBOutcomeReward.session_id).where(DBOutcomeReward.total_reward >= min_reward)
|
|
773
|
-
)
|
|
774
|
-
return [row[0] for row in result.all()]
|