synth-ai 0.2.13.dev2__py3-none-any.whl → 0.2.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/README.md +1 -0
- examples/multi_step/SFT_README.md +147 -0
- examples/multi_step/configs/README_verilog_rl.md +77 -0
- examples/multi_step/configs/VERILOG_REWARDS.md +90 -0
- examples/multi_step/configs/VERILOG_RL_CHECKLIST.md +183 -0
- examples/multi_step/configs/crafter_eval_synth_qwen4b.toml +35 -0
- examples/multi_step/configs/crafter_eval_text_only_groq_qwen32b.toml +36 -0
- examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +12 -11
- examples/multi_step/configs/crafter_sft_qwen30b_lora.toml +62 -0
- examples/multi_step/configs/crafter_synth_backend.md +40 -0
- examples/multi_step/configs/verilog_eval_groq_qwen32b.toml +31 -0
- examples/multi_step/configs/verilog_eval_synth_qwen8b.toml +33 -0
- examples/multi_step/configs/verilog_rl_lora.toml +190 -0
- examples/multi_step/convert_traces_to_sft.py +84 -0
- examples/multi_step/judges/crafter_backend_judge.py +220 -0
- examples/multi_step/judges/verilog_backend_judge.py +234 -0
- examples/multi_step/readme.md +48 -0
- examples/multi_step/run_sft_qwen30b.sh +45 -0
- examples/multi_step/verilog_rl_lora.md +218 -0
- examples/qwen_coder/configs/coder_lora_30b.toml +3 -2
- examples/qwen_coder/configs/coder_lora_4b.toml +2 -1
- examples/qwen_coder/configs/coder_lora_small.toml +2 -1
- examples/qwen_vl/BUGS_AND_FIXES.md +232 -0
- examples/qwen_vl/IMAGE_VALIDATION_COMPLETE.md +271 -0
- examples/qwen_vl/IMAGE_VALIDATION_SUMMARY.md +260 -0
- examples/qwen_vl/INFERENCE_SFT_TESTS.md +412 -0
- examples/qwen_vl/NEXT_STEPS_2B.md +325 -0
- examples/qwen_vl/QUICKSTART.md +327 -0
- examples/qwen_vl/QUICKSTART_RL_VISION.md +110 -0
- examples/qwen_vl/README.md +154 -0
- examples/qwen_vl/RL_VISION_COMPLETE.md +475 -0
- examples/qwen_vl/RL_VISION_TESTING.md +333 -0
- examples/qwen_vl/SDK_VISION_INTEGRATION.md +328 -0
- examples/qwen_vl/SETUP_COMPLETE.md +275 -0
- examples/qwen_vl/VISION_TESTS_COMPLETE.md +490 -0
- examples/qwen_vl/VLM_PIPELINE_COMPLETE.md +242 -0
- examples/qwen_vl/__init__.py +2 -0
- examples/qwen_vl/collect_data_via_cli.md +423 -0
- examples/qwen_vl/collect_vision_traces.py +368 -0
- examples/qwen_vl/configs/crafter_rl_vision_qwen3vl4b.toml +127 -0
- examples/qwen_vl/configs/crafter_vlm_sft_example.toml +60 -0
- examples/qwen_vl/configs/eval_gpt4o_mini_vision.toml +43 -0
- examples/qwen_vl/configs/eval_gpt4o_vision_proper.toml +29 -0
- examples/qwen_vl/configs/eval_gpt5nano_vision.toml +45 -0
- examples/qwen_vl/configs/eval_qwen2vl_vision.toml +44 -0
- examples/qwen_vl/configs/filter_qwen2vl_sft.toml +50 -0
- examples/qwen_vl/configs/filter_vision_sft.toml +53 -0
- examples/qwen_vl/configs/filter_vision_test.toml +8 -0
- examples/qwen_vl/configs/sft_qwen3_vl_2b_test.toml +54 -0
- examples/qwen_vl/crafter_gpt5nano_agent.py +308 -0
- examples/qwen_vl/crafter_qwen_vl_agent.py +300 -0
- examples/qwen_vl/run_vision_comparison.sh +62 -0
- examples/qwen_vl/run_vision_sft_pipeline.sh +175 -0
- examples/qwen_vl/test_image_validation.py +201 -0
- examples/qwen_vl/test_sft_vision_data.py +110 -0
- examples/rl/README.md +1 -1
- examples/rl/configs/eval_base_qwen.toml +17 -0
- examples/rl/configs/eval_rl_qwen.toml +13 -0
- examples/rl/configs/rl_from_base_qwen.toml +37 -0
- examples/rl/configs/rl_from_base_qwen17.toml +76 -0
- examples/rl/configs/rl_from_ft_qwen.toml +37 -0
- examples/rl/run_eval.py +436 -0
- examples/rl/run_rl_and_save.py +111 -0
- examples/rl/task_app/README.md +22 -0
- examples/rl/task_app/math_single_step.py +990 -0
- examples/rl/task_app/math_task_app.py +111 -0
- examples/sft/README.md +5 -5
- examples/sft/configs/crafter_fft_qwen0p6b.toml +4 -2
- examples/sft/configs/crafter_lora_qwen0p6b.toml +4 -3
- examples/sft/evaluate.py +4 -4
- examples/sft/export_dataset.py +7 -4
- examples/sft/generate_traces.py +2 -0
- examples/swe/task_app/README.md +1 -1
- examples/swe/task_app/grpo_swe_mini.py +1 -1
- examples/swe/task_app/grpo_swe_mini_task_app.py +0 -12
- examples/swe/task_app/hosted/envs/mini_swe/environment.py +13 -13
- examples/swe/task_app/hosted/policy_routes.py +0 -2
- examples/swe/task_app/hosted/rollout.py +2 -8
- examples/task_apps/IMAGE_ONLY_EVAL_QUICKSTART.md +258 -0
- examples/task_apps/crafter/CREATE_SFT_DATASET.md +273 -0
- examples/task_apps/crafter/EVAL_IMAGE_ONLY_RESULTS.md +152 -0
- examples/task_apps/crafter/FILTER_COMMAND_STATUS.md +174 -0
- examples/task_apps/crafter/FILTER_COMMAND_SUCCESS.md +268 -0
- examples/task_apps/crafter/QUERY_EXAMPLES.md +203 -0
- examples/task_apps/crafter/README_IMAGE_ONLY_EVAL.md +316 -0
- examples/task_apps/crafter/eval_image_only_gpt4o.toml +28 -0
- examples/task_apps/crafter/eval_text_only_groq_llama.toml +36 -0
- examples/task_apps/crafter/filter_sft_dataset.toml +16 -0
- examples/task_apps/crafter/task_app/__init__.py +3 -0
- examples/task_apps/crafter/task_app/grpo_crafter.py +309 -14
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/environment.py +10 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/policy.py +75 -4
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/react_agent.py +17 -2
- examples/task_apps/crafter/task_app/synth_envs_hosted/inference/openai_client.py +55 -3
- examples/task_apps/crafter/task_app/synth_envs_hosted/policy_routes.py +114 -32
- examples/task_apps/crafter/task_app/synth_envs_hosted/rollout.py +127 -27
- examples/task_apps/crafter/task_app/synth_envs_hosted/utils.py +156 -0
- examples/task_apps/enron/__init__.py +1 -0
- examples/task_apps/enron/filter_sft.toml +5 -0
- examples/task_apps/enron/tests/__init__.py +2 -0
- examples/task_apps/enron/tests/integration/__init__.py +2 -0
- examples/task_apps/enron/tests/integration/test_enron_eval.py +2 -0
- examples/task_apps/enron/tests/unit/__init__.py +2 -0
- examples/task_apps/pokemon_red/EVAL_IMAGE_ONLY_COMPLETE.md +283 -0
- examples/task_apps/pokemon_red/EVAL_IMAGE_ONLY_STATUS.md +155 -0
- examples/task_apps/pokemon_red/README_IMAGE_ONLY_EVAL.md +415 -0
- examples/task_apps/pokemon_red/eval_image_only_gpt4o.toml +29 -0
- examples/task_apps/pokemon_red/pallet_town_rl_config.toml +2 -0
- examples/task_apps/pokemon_red/task_app.py +199 -6
- examples/task_apps/pokemon_red/test_pallet_town_rewards.py +2 -0
- examples/task_apps/sokoban/filter_sft.toml +5 -0
- examples/task_apps/sokoban/tests/__init__.py +2 -0
- examples/task_apps/sokoban/tests/integration/__init__.py +2 -0
- examples/task_apps/sokoban/tests/unit/__init__.py +2 -0
- examples/task_apps/verilog/eval_groq_qwen32b.toml +8 -4
- examples/task_apps/verilog/filter_sft.toml +5 -0
- examples/task_apps/verilog/task_app/grpo_verilog.py +258 -23
- examples/task_apps/verilog/tests/__init__.py +2 -0
- examples/task_apps/verilog/tests/integration/__init__.py +2 -0
- examples/task_apps/verilog/tests/integration/test_verilog_eval.py +2 -0
- examples/task_apps/verilog/tests/unit/__init__.py +2 -0
- examples/vlm/README.md +3 -3
- examples/vlm/configs/crafter_vlm_gpt4o.toml +2 -0
- examples/vlm/crafter_openai_vlm_agent.py +3 -5
- examples/vlm/filter_image_rows.py +1 -1
- examples/vlm/run_crafter_vlm_benchmark.py +2 -2
- examples/warming_up_to_rl/_utils.py +92 -0
- examples/warming_up_to_rl/analyze_trace_db.py +1 -1
- examples/warming_up_to_rl/configs/crafter_fft.toml +2 -0
- examples/warming_up_to_rl/configs/crafter_fft_4b.toml +2 -0
- examples/warming_up_to_rl/configs/eval_fft_qwen4b.toml +2 -0
- examples/warming_up_to_rl/configs/eval_groq_qwen32b.toml +2 -0
- examples/warming_up_to_rl/configs/eval_modal_qwen4b.toml +2 -1
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +2 -1
- examples/warming_up_to_rl/configs/rl_from_ft.toml +2 -0
- examples/warming_up_to_rl/export_trace_sft.py +174 -60
- examples/warming_up_to_rl/groq_test.py +2 -0
- examples/warming_up_to_rl/readme.md +63 -132
- examples/warming_up_to_rl/run_fft_and_save.py +1 -1
- examples/warming_up_to_rl/run_local_rollout.py +2 -0
- examples/warming_up_to_rl/run_local_rollout_modal.py +2 -0
- examples/warming_up_to_rl/run_local_rollout_parallel.py +2 -0
- examples/warming_up_to_rl/run_local_rollout_traced.py +2 -0
- examples/warming_up_to_rl/run_rl_and_save.py +1 -1
- examples/warming_up_to_rl/run_rollout_remote.py +2 -0
- examples/warming_up_to_rl/task_app/README.md +42 -0
- examples/warming_up_to_rl/task_app/grpo_crafter.py +696 -0
- examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +135 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/README.md +173 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +143 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +1226 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/__init__.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/__init__.py +6 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/app.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +522 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +478 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +108 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +305 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/tools.py +47 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +204 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +618 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +100 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +1081 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +195 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +1861 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +211 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +161 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +137 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +62 -0
- synth_ai/__init__.py +44 -30
- synth_ai/_utils/__init__.py +47 -0
- synth_ai/_utils/base_url.py +10 -0
- synth_ai/_utils/http.py +10 -0
- synth_ai/_utils/prompts.py +10 -0
- synth_ai/_utils/task_app_state.py +12 -0
- synth_ai/_utils/user_config.py +10 -0
- synth_ai/api/models/supported.py +145 -7
- synth_ai/api/train/__init__.py +13 -1
- synth_ai/api/train/cli.py +30 -7
- synth_ai/api/train/config_finder.py +18 -11
- synth_ai/api/train/env_resolver.py +13 -10
- synth_ai/cli/__init__.py +66 -49
- synth_ai/cli/_modal_wrapper.py +9 -6
- synth_ai/cli/_typer_patch.py +0 -2
- synth_ai/cli/_validate_task_app.py +22 -4
- synth_ai/cli/legacy_root_backup.py +3 -1
- synth_ai/cli/lib/__init__.py +10 -0
- synth_ai/cli/lib/task_app_discovery.py +7 -0
- synth_ai/cli/lib/task_app_env.py +518 -0
- synth_ai/cli/recent.py +1 -0
- synth_ai/cli/setup.py +266 -0
- synth_ai/cli/task_app_deploy.py +16 -0
- synth_ai/cli/task_app_list.py +25 -0
- synth_ai/cli/task_app_modal_serve.py +16 -0
- synth_ai/cli/task_app_serve.py +18 -0
- synth_ai/cli/task_apps.py +392 -141
- synth_ai/cli/train.py +18 -0
- synth_ai/cli/tui.py +62 -0
- synth_ai/demos/__init__.py +10 -0
- synth_ai/demos/core/__init__.py +28 -1
- synth_ai/demos/crafter/__init__.py +1 -0
- synth_ai/demos/crafter/crafter_fft_4b.toml +55 -0
- synth_ai/demos/crafter/grpo_crafter_task_app.py +185 -0
- synth_ai/demos/crafter/rl_from_base_qwen4b.toml +74 -0
- synth_ai/demos/demo_registry.py +176 -0
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +1 -1
- synth_ai/demos/math/__init__.py +1 -0
- synth_ai/demos/math/_common.py +16 -0
- synth_ai/demos/math/app.py +38 -0
- synth_ai/demos/math/config.toml +76 -0
- synth_ai/demos/math/deploy_modal.py +54 -0
- synth_ai/demos/math/modal_task_app.py +702 -0
- synth_ai/demos/math/task_app_entry.py +51 -0
- synth_ai/environments/environment/core.py +7 -1
- synth_ai/environments/examples/bandit/engine.py +0 -1
- synth_ai/environments/examples/bandit/environment.py +0 -1
- synth_ai/environments/examples/crafter_classic/environment.py +1 -1
- synth_ai/environments/examples/verilog/engine.py +76 -10
- synth_ai/environments/examples/wordle/environment.py +0 -1
- synth_ai/evals/base.py +16 -5
- synth_ai/evals/client.py +1 -1
- synth_ai/inference/client.py +1 -1
- synth_ai/learning/client.py +1 -1
- synth_ai/learning/health.py +1 -1
- synth_ai/learning/jobs.py +1 -1
- synth_ai/learning/rl/client.py +1 -1
- synth_ai/learning/rl/env_keys.py +1 -1
- synth_ai/learning/rl/secrets.py +1 -1
- synth_ai/learning/sft/client.py +1 -1
- synth_ai/learning/sft/data.py +407 -4
- synth_ai/learning/validators.py +4 -1
- synth_ai/task/__init__.py +11 -1
- synth_ai/task/apps/__init__.py +5 -2
- synth_ai/task/config.py +259 -0
- synth_ai/task/contracts.py +15 -2
- synth_ai/task/rubrics/__init__.py +4 -2
- synth_ai/task/rubrics/loaders.py +27 -4
- synth_ai/task/rubrics/scoring.py +3 -0
- synth_ai/task/rubrics.py +219 -0
- synth_ai/task/trace_correlation_helpers.py +328 -0
- synth_ai/task/tracing_utils.py +14 -3
- synth_ai/task/validators.py +145 -2
- synth_ai/tracing_v3/config.py +15 -13
- synth_ai/tracing_v3/constants.py +21 -0
- synth_ai/tracing_v3/db_config.py +3 -1
- synth_ai/tracing_v3/decorators.py +10 -7
- synth_ai/tracing_v3/session_tracer.py +10 -0
- synth_ai/tracing_v3/turso/daemon.py +2 -2
- synth_ai/tracing_v3/turso/native_manager.py +108 -77
- synth_ai/tracing_v3/utils.py +1 -1
- synth_ai/tui/__init__.py +5 -0
- synth_ai/tui/__main__.py +13 -0
- synth_ai/tui/cli/__init__.py +1 -0
- synth_ai/tui/cli/query_experiments.py +164 -0
- synth_ai/tui/cli/query_experiments_v3.py +164 -0
- synth_ai/tui/dashboard.py +911 -0
- synth_ai/utils/__init__.py +101 -0
- synth_ai/utils/base_url.py +94 -0
- synth_ai/utils/cli.py +131 -0
- synth_ai/utils/env.py +287 -0
- synth_ai/utils/http.py +169 -0
- synth_ai/utils/modal.py +308 -0
- synth_ai/utils/process.py +212 -0
- synth_ai/utils/prompts.py +39 -0
- synth_ai/utils/sqld.py +122 -0
- synth_ai/utils/task_app_discovery.py +882 -0
- synth_ai/utils/task_app_env.py +186 -0
- synth_ai/utils/task_app_state.py +318 -0
- synth_ai/utils/user_config.py +137 -0
- synth_ai/v0/config/__init__.py +1 -5
- synth_ai/v0/config/base_url.py +1 -7
- synth_ai/v0/tracing/config.py +1 -1
- synth_ai/v0/tracing/decorators.py +1 -1
- synth_ai/v0/tracing/upload.py +1 -1
- synth_ai/v0/tracing_v1/config.py +1 -1
- synth_ai/v0/tracing_v1/decorators.py +1 -1
- synth_ai/v0/tracing_v1/upload.py +1 -1
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.16.dist-info}/METADATA +85 -31
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.16.dist-info}/RECORD +286 -135
- synth_ai/cli/man.py +0 -106
- synth_ai/compound/cais.py +0 -0
- synth_ai/core/experiment.py +0 -13
- synth_ai/core/system.py +0 -15
- synth_ai/demo_registry.py +0 -295
- synth_ai/handshake.py +0 -109
- synth_ai/http.py +0 -26
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.16.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.16.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.16.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.16.dist-info}/top_level.txt +0 -0
synth_ai/task/config.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
1
|
+
"""Configuration dataclasses for task app CLI commands (eval, filter)."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, Literal
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass(slots=True)
|
|
11
|
+
class EvalConfig:
|
|
12
|
+
"""Configuration for 'synth-ai eval' command.
|
|
13
|
+
|
|
14
|
+
Validates and provides defaults for evaluation runs against task apps.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
# Required: Task app identifier
|
|
18
|
+
app_id: str
|
|
19
|
+
|
|
20
|
+
# Required: Model to evaluate
|
|
21
|
+
model: str
|
|
22
|
+
|
|
23
|
+
# Required: Seeds to run
|
|
24
|
+
seeds: list[int]
|
|
25
|
+
|
|
26
|
+
# Optional: Task app URL (None = spawn in-process)
|
|
27
|
+
task_app_url: str | None = None
|
|
28
|
+
|
|
29
|
+
# Optional: Data split to use
|
|
30
|
+
split: str = "train"
|
|
31
|
+
|
|
32
|
+
# Optional: Maximum turns/steps per episode
|
|
33
|
+
max_turns: int | None = None
|
|
34
|
+
|
|
35
|
+
# Optional: Maximum LLM calls per episode
|
|
36
|
+
max_llm_calls: int = 10
|
|
37
|
+
|
|
38
|
+
# Optional: Concurrency for parallel rollouts
|
|
39
|
+
concurrency: int = 1
|
|
40
|
+
|
|
41
|
+
# Optional: Environment name
|
|
42
|
+
env_name: str | None = None
|
|
43
|
+
|
|
44
|
+
# Optional: Policy name
|
|
45
|
+
policy_name: str | None = None
|
|
46
|
+
|
|
47
|
+
# Optional: Trace format ("compact", "full", "structured")
|
|
48
|
+
trace_format: Literal["compact", "full", "structured"] = "compact"
|
|
49
|
+
|
|
50
|
+
# Optional: Whether to return traces in response
|
|
51
|
+
return_trace: bool = False
|
|
52
|
+
|
|
53
|
+
# Optional: Operations sequence (if not provided, generates default)
|
|
54
|
+
ops: list[str] | None = None
|
|
55
|
+
|
|
56
|
+
# Optional: Environment config overrides
|
|
57
|
+
env_config: dict[str, Any] = field(default_factory=dict)
|
|
58
|
+
|
|
59
|
+
# Optional: Policy config overrides
|
|
60
|
+
policy_config: dict[str, Any] = field(default_factory=dict)
|
|
61
|
+
|
|
62
|
+
# Optional: Metadata for traces
|
|
63
|
+
metadata: dict[str, str] = field(default_factory=dict)
|
|
64
|
+
|
|
65
|
+
# Optional: SQL query for metadata filtering
|
|
66
|
+
metadata_sql: str | None = None
|
|
67
|
+
|
|
68
|
+
def __post_init__(self):
|
|
69
|
+
"""Validate configuration after initialization."""
|
|
70
|
+
if not self.app_id:
|
|
71
|
+
raise ValueError("app_id is required")
|
|
72
|
+
|
|
73
|
+
if not self.model:
|
|
74
|
+
raise ValueError("model is required")
|
|
75
|
+
|
|
76
|
+
if not self.seeds:
|
|
77
|
+
raise ValueError("seeds list cannot be empty")
|
|
78
|
+
|
|
79
|
+
if not isinstance(self.seeds, list):
|
|
80
|
+
raise ValueError("seeds must be a list of integers")
|
|
81
|
+
|
|
82
|
+
if self.concurrency < 1:
|
|
83
|
+
raise ValueError("concurrency must be >= 1")
|
|
84
|
+
|
|
85
|
+
if self.max_llm_calls < 1:
|
|
86
|
+
raise ValueError("max_llm_calls must be >= 1")
|
|
87
|
+
|
|
88
|
+
if self.max_turns is not None and self.max_turns < 1:
|
|
89
|
+
raise ValueError("max_turns must be >= 1")
|
|
90
|
+
|
|
91
|
+
if self.trace_format not in ("compact", "full", "structured"):
|
|
92
|
+
raise ValueError(f"trace_format must be 'compact', 'full', or 'structured', got: {self.trace_format}")
|
|
93
|
+
|
|
94
|
+
@classmethod
|
|
95
|
+
def from_dict(cls, data: dict[str, Any]) -> EvalConfig:
|
|
96
|
+
"""Create EvalConfig from a dictionary (e.g. from TOML).
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
data: Dictionary with eval configuration
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
Validated EvalConfig instance
|
|
103
|
+
"""
|
|
104
|
+
# Extract known fields
|
|
105
|
+
config_dict = {
|
|
106
|
+
"app_id": data.get("app_id"),
|
|
107
|
+
"model": data.get("model"),
|
|
108
|
+
"seeds": data.get("seeds", []),
|
|
109
|
+
"task_app_url": data.get("task_app_url"),
|
|
110
|
+
"split": data.get("split", "train"),
|
|
111
|
+
"max_turns": data.get("max_turns"),
|
|
112
|
+
"max_llm_calls": data.get("max_llm_calls", 10),
|
|
113
|
+
"concurrency": data.get("concurrency", 1),
|
|
114
|
+
"env_name": data.get("env_name"),
|
|
115
|
+
"policy_name": data.get("policy_name"),
|
|
116
|
+
"trace_format": data.get("trace_format", "compact"),
|
|
117
|
+
"return_trace": data.get("return_trace", False),
|
|
118
|
+
"ops": data.get("ops"),
|
|
119
|
+
"env_config": data.get("env_config", {}),
|
|
120
|
+
"policy_config": data.get("policy_config", {}),
|
|
121
|
+
"metadata": data.get("metadata", {}),
|
|
122
|
+
"metadata_sql": data.get("metadata_sql"),
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
return cls(**config_dict)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
@dataclass(slots=True)
|
|
129
|
+
class FilterConfig:
|
|
130
|
+
"""Configuration for 'synth-ai filter' command.
|
|
131
|
+
|
|
132
|
+
Validates and provides defaults for filtering traces into SFT datasets.
|
|
133
|
+
"""
|
|
134
|
+
|
|
135
|
+
# Required: Database path or URL
|
|
136
|
+
db: str
|
|
137
|
+
|
|
138
|
+
# Required: Output JSONL path
|
|
139
|
+
output: str
|
|
140
|
+
|
|
141
|
+
# Optional: Filter by data splits
|
|
142
|
+
splits: list[str] = field(default_factory=list)
|
|
143
|
+
|
|
144
|
+
# Optional: Filter by task IDs
|
|
145
|
+
task_ids: list[str] = field(default_factory=list)
|
|
146
|
+
|
|
147
|
+
# Optional: Filter by models
|
|
148
|
+
models: list[str] = field(default_factory=list)
|
|
149
|
+
|
|
150
|
+
# Optional: Minimum official score threshold
|
|
151
|
+
min_official_score: float | None = None
|
|
152
|
+
|
|
153
|
+
# Optional: Maximum official score threshold
|
|
154
|
+
max_official_score: float | None = None
|
|
155
|
+
|
|
156
|
+
# Optional: Minimum judge scores (judge_name -> min_score)
|
|
157
|
+
min_judge_scores: dict[str, float] = field(default_factory=dict)
|
|
158
|
+
|
|
159
|
+
# Optional: Maximum judge scores (judge_name -> max_score)
|
|
160
|
+
max_judge_scores: dict[str, float] = field(default_factory=dict)
|
|
161
|
+
|
|
162
|
+
# Optional: Limit number of examples
|
|
163
|
+
limit: int | None = None
|
|
164
|
+
|
|
165
|
+
# Optional: Offset for pagination
|
|
166
|
+
offset: int | None = None
|
|
167
|
+
|
|
168
|
+
# Optional: Whether to shuffle results
|
|
169
|
+
shuffle: bool = False
|
|
170
|
+
|
|
171
|
+
# Optional: Random seed for shuffling
|
|
172
|
+
shuffle_seed: int | None = None
|
|
173
|
+
|
|
174
|
+
def __post_init__(self):
|
|
175
|
+
"""Validate configuration after initialization."""
|
|
176
|
+
if not self.db:
|
|
177
|
+
raise ValueError("db (database path or URL) is required")
|
|
178
|
+
|
|
179
|
+
if not self.output:
|
|
180
|
+
raise ValueError("output (JSONL file path) is required")
|
|
181
|
+
|
|
182
|
+
# Validate output has .jsonl extension
|
|
183
|
+
output_path = Path(self.output)
|
|
184
|
+
if output_path.suffix.lower() not in (".jsonl", ".json"):
|
|
185
|
+
raise ValueError(f"output must be a .jsonl or .json file, got: {self.output}")
|
|
186
|
+
|
|
187
|
+
# Validate score thresholds
|
|
188
|
+
if (
|
|
189
|
+
self.min_official_score is not None
|
|
190
|
+
and self.max_official_score is not None
|
|
191
|
+
and self.min_official_score > self.max_official_score
|
|
192
|
+
):
|
|
193
|
+
raise ValueError("min_official_score cannot be greater than max_official_score")
|
|
194
|
+
|
|
195
|
+
# Validate limit/offset
|
|
196
|
+
if self.limit is not None and self.limit < 1:
|
|
197
|
+
raise ValueError("limit must be >= 1")
|
|
198
|
+
|
|
199
|
+
if self.offset is not None and self.offset < 0:
|
|
200
|
+
raise ValueError("offset must be >= 0")
|
|
201
|
+
|
|
202
|
+
# Validate shuffle seed requires shuffle
|
|
203
|
+
if self.shuffle_seed is not None and not self.shuffle:
|
|
204
|
+
raise ValueError("shuffle_seed requires shuffle=true")
|
|
205
|
+
|
|
206
|
+
@classmethod
|
|
207
|
+
def from_dict(cls, data: dict[str, Any]) -> FilterConfig:
|
|
208
|
+
"""Create FilterConfig from a dictionary (e.g. from TOML).
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
data: Dictionary with filter configuration
|
|
212
|
+
|
|
213
|
+
Returns:
|
|
214
|
+
Validated FilterConfig instance
|
|
215
|
+
"""
|
|
216
|
+
# Extract known fields
|
|
217
|
+
config_dict = {
|
|
218
|
+
"db": data.get("db"),
|
|
219
|
+
"output": data.get("output"),
|
|
220
|
+
"splits": data.get("splits", []),
|
|
221
|
+
"task_ids": data.get("task_ids", []),
|
|
222
|
+
"models": data.get("models", []),
|
|
223
|
+
"min_official_score": data.get("min_official_score"),
|
|
224
|
+
"max_official_score": data.get("max_official_score"),
|
|
225
|
+
"min_judge_scores": data.get("min_judge_scores", {}),
|
|
226
|
+
"max_judge_scores": data.get("max_judge_scores", {}),
|
|
227
|
+
"limit": data.get("limit"),
|
|
228
|
+
"offset": data.get("offset"),
|
|
229
|
+
"shuffle": data.get("shuffle", False),
|
|
230
|
+
"shuffle_seed": data.get("shuffle_seed"),
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
return cls(**config_dict)
|
|
234
|
+
|
|
235
|
+
def get_db_url(self) -> str:
|
|
236
|
+
"""Convert db path to proper SQLite URL if needed.
|
|
237
|
+
|
|
238
|
+
Returns:
|
|
239
|
+
Database URL suitable for SQLAlchemy/aiosqlite
|
|
240
|
+
"""
|
|
241
|
+
db_value = self.db.strip()
|
|
242
|
+
if "://" in db_value:
|
|
243
|
+
return db_value
|
|
244
|
+
else:
|
|
245
|
+
db_path = Path(db_value).expanduser().resolve()
|
|
246
|
+
# Ensure parent directory exists
|
|
247
|
+
db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
248
|
+
return f"sqlite+aiosqlite:///{db_path}"
|
|
249
|
+
|
|
250
|
+
def get_output_path(self) -> Path:
|
|
251
|
+
"""Get resolved output path with parent directory created.
|
|
252
|
+
|
|
253
|
+
Returns:
|
|
254
|
+
Resolved Path object with parent directory created
|
|
255
|
+
"""
|
|
256
|
+
output_path = Path(self.output).expanduser().resolve()
|
|
257
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
258
|
+
return output_path
|
|
259
|
+
|
synth_ai/task/contracts.py
CHANGED
|
@@ -1,11 +1,18 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
from dataclasses import dataclass
|
|
4
|
+
from enum import Enum
|
|
4
5
|
from typing import Any, Literal
|
|
5
6
|
|
|
6
7
|
from pydantic import BaseModel, ConfigDict, Field
|
|
7
8
|
|
|
8
9
|
|
|
10
|
+
class RolloutMode(str, Enum):
|
|
11
|
+
"""Mode controls how rollout infrastructure processes inference URLs."""
|
|
12
|
+
RL = "rl"
|
|
13
|
+
EVAL = "eval"
|
|
14
|
+
|
|
15
|
+
|
|
9
16
|
@dataclass(frozen=True)
|
|
10
17
|
class TaskAppEndpoints:
|
|
11
18
|
"""Required Task App endpoints used by RL trainers and clients.
|
|
@@ -43,7 +50,7 @@ class RolloutRecordConfig(BaseModel):
|
|
|
43
50
|
logprobs: bool = False
|
|
44
51
|
value: bool = False
|
|
45
52
|
return_trace: bool = False
|
|
46
|
-
trace_format: Literal["compact", "full"] = "compact"
|
|
53
|
+
trace_format: Literal["compact", "full", "structured"] = "compact"
|
|
47
54
|
|
|
48
55
|
|
|
49
56
|
class RolloutSafetyConfig(BaseModel):
|
|
@@ -61,6 +68,7 @@ class RolloutRequest(BaseModel):
|
|
|
61
68
|
safety: RolloutSafetyConfig = RolloutSafetyConfig()
|
|
62
69
|
training_session_id: str | None = None
|
|
63
70
|
synth_base_url: str | None = None
|
|
71
|
+
mode: RolloutMode # Required: explicit RL vs EVAL mode
|
|
64
72
|
|
|
65
73
|
|
|
66
74
|
class RolloutStep(BaseModel):
|
|
@@ -110,7 +118,7 @@ class RolloutTrajectory(BaseModel):
|
|
|
110
118
|
|
|
111
119
|
# Required for trace correlation with inference mesh (optional initially for backward compat)
|
|
112
120
|
# See: monorepo/INFERENCE_URL_REQUIREMENT_PLAN.md and trace_creation_and_judgement.txt
|
|
113
|
-
inference_url: str
|
|
121
|
+
inference_url: str
|
|
114
122
|
|
|
115
123
|
decision_samples: list[dict[str, Any]] | None = None
|
|
116
124
|
|
|
@@ -143,10 +151,15 @@ class RolloutResponse(BaseModel):
|
|
|
143
151
|
aborted: bool = False
|
|
144
152
|
ops_executed: int = 0
|
|
145
153
|
|
|
154
|
+
# OPTIONAL: correlation ID for linking rollout to inference traces
|
|
155
|
+
# If not provided, trainer will infer it from trajectory.inference_url ?cid=... parameter
|
|
156
|
+
trace_correlation_id: str | None = None
|
|
157
|
+
|
|
146
158
|
# PREFERRED: v3 trace format (SessionTrace). This is the single source of truth
|
|
147
159
|
# for rollout data and should be used by all new code. Contains richer data than
|
|
148
160
|
# trajectories including token IDs, logprobs, timing, and multimodal content.
|
|
149
161
|
trace: dict[str, Any] | None = None
|
|
162
|
+
pipeline_metadata: dict[str, Any] = Field(default_factory=dict)
|
|
150
163
|
|
|
151
164
|
|
|
152
165
|
class _ExtraAllowModel(BaseModel):
|
|
@@ -9,10 +9,9 @@ This module provides:
|
|
|
9
9
|
"""
|
|
10
10
|
|
|
11
11
|
# Core models (flexible validation)
|
|
12
|
-
from .models import Criterion, Rubric
|
|
13
|
-
|
|
14
12
|
# Loading and blending
|
|
15
13
|
from .loaders import blend_rubrics, load_rubric
|
|
14
|
+
from .models import Criterion, Rubric
|
|
16
15
|
|
|
17
16
|
# Scoring
|
|
18
17
|
from .scoring import score_events_against_rubric, score_outcome_against_rubric
|
|
@@ -51,3 +50,6 @@ __all__ = [
|
|
|
51
50
|
RubricCriterion = StrictCriterion
|
|
52
51
|
RubricSpec = StrictRubric
|
|
53
52
|
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
|
synth_ai/task/rubrics/loaders.py
CHANGED
|
@@ -60,15 +60,39 @@ def load_rubric(source: str | dict[str, Any] | Rubric | None) -> Rubric | None:
|
|
|
60
60
|
|
|
61
61
|
Returns:
|
|
62
62
|
Parsed Rubric instance or None if source is None
|
|
63
|
+
|
|
64
|
+
Raises:
|
|
65
|
+
ValueError: If the rubric format is incorrect (e.g., backend judge format)
|
|
66
|
+
ValidationError: If the rubric fails schema validation
|
|
63
67
|
"""
|
|
64
68
|
if source is None:
|
|
65
69
|
return None
|
|
66
70
|
if isinstance(source, Rubric):
|
|
67
71
|
return source
|
|
72
|
+
|
|
73
|
+
# Load and parse the data
|
|
68
74
|
if isinstance(source, dict):
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
75
|
+
data = source
|
|
76
|
+
else:
|
|
77
|
+
text, suffix = _load_text(str(source))
|
|
78
|
+
data = _parse_structured(text, suffix)
|
|
79
|
+
|
|
80
|
+
# Check if this looks like a backend judge rubric (wrong format)
|
|
81
|
+
if (
|
|
82
|
+
isinstance(data, dict)
|
|
83
|
+
and "event" in data
|
|
84
|
+
and "outcome" in data
|
|
85
|
+
and "version" not in data
|
|
86
|
+
and "goal_text" not in data
|
|
87
|
+
and "criteria" not in data
|
|
88
|
+
):
|
|
89
|
+
source_hint = f" ({source})" if isinstance(source, str) else ""
|
|
90
|
+
raise ValueError(
|
|
91
|
+
f"Rubric appears to be in backend judge format (has 'event'/'outcome' keys){source_hint}. "
|
|
92
|
+
f"Task apps require rubrics with 'version', 'goal_text', and 'criteria' fields. "
|
|
93
|
+
f"Backend judge rubrics should be named '*_backend_judge.json' and loaded by judge functions."
|
|
94
|
+
)
|
|
95
|
+
|
|
72
96
|
return Rubric.model_validate(data)
|
|
73
97
|
|
|
74
98
|
|
|
@@ -130,4 +154,3 @@ def blend_rubrics(base: Rubric | None, override: Rubric | None) -> Rubric | None
|
|
|
130
154
|
criteria=merged,
|
|
131
155
|
aggregation=aggregation,
|
|
132
156
|
)
|
|
133
|
-
|
synth_ai/task/rubrics/scoring.py
CHANGED
synth_ai/task/rubrics.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
"""Rubric schema, loading, and scoring helpers for Task Apps."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
from collections.abc import Iterable
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
from pydantic import BaseModel, Field, field_validator
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Criterion(BaseModel):
|
|
14
|
+
id: str
|
|
15
|
+
description: str
|
|
16
|
+
weight: float = 1.0
|
|
17
|
+
required: bool = False
|
|
18
|
+
|
|
19
|
+
@field_validator("weight")
|
|
20
|
+
@classmethod
|
|
21
|
+
def _validate_weight(cls, value: float) -> float:
|
|
22
|
+
if value <= 0:
|
|
23
|
+
raise ValueError("criterion weight must be positive")
|
|
24
|
+
return value
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class Rubric(BaseModel):
|
|
28
|
+
version: str
|
|
29
|
+
goal_text: str | None = None
|
|
30
|
+
criteria: list[Criterion] = Field(default_factory=list)
|
|
31
|
+
aggregation: str = "weighted_sum"
|
|
32
|
+
|
|
33
|
+
@field_validator("aggregation")
|
|
34
|
+
@classmethod
|
|
35
|
+
def _validate_aggregation(cls, value: str) -> str:
|
|
36
|
+
allowed = {"sum", "weighted_sum", "custom", "inherit"}
|
|
37
|
+
if value not in allowed:
|
|
38
|
+
raise ValueError(f"aggregation must be one of {sorted(allowed)}")
|
|
39
|
+
return value
|
|
40
|
+
|
|
41
|
+
@field_validator("criteria")
|
|
42
|
+
@classmethod
|
|
43
|
+
def _validate_criteria(cls, criteria: list[Criterion]) -> list[Criterion]:
|
|
44
|
+
seen = set()
|
|
45
|
+
for criterion in criteria:
|
|
46
|
+
if criterion.id in seen:
|
|
47
|
+
raise ValueError(f"duplicate criterion id: {criterion.id}")
|
|
48
|
+
seen.add(criterion.id)
|
|
49
|
+
return criteria
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _load_text(source: str) -> tuple[str, str | None]:
|
|
53
|
+
path = Path(source)
|
|
54
|
+
if path.exists():
|
|
55
|
+
return path.read_text(encoding="utf-8"), path.suffix.lower()
|
|
56
|
+
return source, None
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _parse_structured(text: str, suffix: str | None) -> dict[str, Any]:
|
|
60
|
+
text = text.strip()
|
|
61
|
+
if not text:
|
|
62
|
+
raise ValueError("Rubric source is empty")
|
|
63
|
+
if suffix in (".yaml", ".yml"):
|
|
64
|
+
try:
|
|
65
|
+
import yaml # type: ignore
|
|
66
|
+
except Exception as exc: # pragma: no cover - optional dependency
|
|
67
|
+
raise RuntimeError("PyYAML is required to load YAML rubrics") from exc
|
|
68
|
+
data = yaml.safe_load(text)
|
|
69
|
+
if not isinstance(data, dict):
|
|
70
|
+
raise ValueError("Rubric YAML must produce a mapping") from None
|
|
71
|
+
return data
|
|
72
|
+
if text.startswith("{"):
|
|
73
|
+
return json.loads(text)
|
|
74
|
+
if text.startswith("http://") or text.startswith("https://"):
|
|
75
|
+
import requests # type: ignore
|
|
76
|
+
|
|
77
|
+
response = requests.get(text, timeout=15)
|
|
78
|
+
response.raise_for_status()
|
|
79
|
+
return _parse_structured(response.text, suffix)
|
|
80
|
+
try:
|
|
81
|
+
return json.loads(text)
|
|
82
|
+
except json.JSONDecodeError:
|
|
83
|
+
try:
|
|
84
|
+
import yaml # type: ignore
|
|
85
|
+
except Exception as exc: # pragma: no cover - optional dependency
|
|
86
|
+
raise RuntimeError("PyYAML is required to load rubric text") from exc
|
|
87
|
+
data = yaml.safe_load(text)
|
|
88
|
+
if not isinstance(data, dict):
|
|
89
|
+
raise ValueError("Rubric text must decode to a mapping") from None
|
|
90
|
+
return data
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def load_rubric(source: str | dict[str, Any] | Rubric | None) -> Rubric | None:
|
|
94
|
+
if source is None:
|
|
95
|
+
return None
|
|
96
|
+
if isinstance(source, Rubric):
|
|
97
|
+
return source
|
|
98
|
+
if isinstance(source, dict):
|
|
99
|
+
return Rubric.model_validate(source)
|
|
100
|
+
text, suffix = _load_text(str(source))
|
|
101
|
+
data = _parse_structured(text, suffix)
|
|
102
|
+
return Rubric.model_validate(data)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def _merge_weights(base: Criterion, override: Criterion) -> float:
|
|
106
|
+
if override.weight != 1.0 and base.weight != 1.0:
|
|
107
|
+
return base.weight * override.weight
|
|
108
|
+
if override.weight != 1.0:
|
|
109
|
+
return override.weight
|
|
110
|
+
return base.weight
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def blend_rubrics(base: Rubric | None, override: Rubric | None) -> Rubric | None:
|
|
114
|
+
if override is None and base is None:
|
|
115
|
+
return None
|
|
116
|
+
if base is None:
|
|
117
|
+
return override
|
|
118
|
+
if override is None:
|
|
119
|
+
return base
|
|
120
|
+
|
|
121
|
+
base_map = {criterion.id: criterion for criterion in base.criteria}
|
|
122
|
+
merged: list[Criterion] = []
|
|
123
|
+
|
|
124
|
+
for ov in override.criteria:
|
|
125
|
+
if ov.id in base_map:
|
|
126
|
+
existing = base_map.pop(ov.id)
|
|
127
|
+
merged.append(
|
|
128
|
+
Criterion(
|
|
129
|
+
id=ov.id,
|
|
130
|
+
description=ov.description or existing.description,
|
|
131
|
+
weight=_merge_weights(existing, ov),
|
|
132
|
+
required=ov.required if ov.required is not None else existing.required,
|
|
133
|
+
)
|
|
134
|
+
)
|
|
135
|
+
else:
|
|
136
|
+
merged.append(ov)
|
|
137
|
+
|
|
138
|
+
merged.extend(base_map.values())
|
|
139
|
+
|
|
140
|
+
aggregation = override.aggregation
|
|
141
|
+
if aggregation == "inherit":
|
|
142
|
+
aggregation = base.aggregation
|
|
143
|
+
|
|
144
|
+
return Rubric(
|
|
145
|
+
version=override.version or base.version,
|
|
146
|
+
goal_text=override.goal_text or base.goal_text,
|
|
147
|
+
criteria=merged,
|
|
148
|
+
aggregation=aggregation,
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def _as_float(value: Any) -> float | None:
|
|
153
|
+
try:
|
|
154
|
+
return float(value)
|
|
155
|
+
except Exception:
|
|
156
|
+
return None
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def _score(
|
|
160
|
+
criteria: Iterable[Criterion], values: dict[str, float], aggregation: str
|
|
161
|
+
) -> dict[str, Any]:
|
|
162
|
+
if aggregation == "inherit":
|
|
163
|
+
aggregation = "weighted_sum"
|
|
164
|
+
per_criterion: dict[str, dict[str, Any]] = {}
|
|
165
|
+
total = 0.0
|
|
166
|
+
total_weight = 0.0
|
|
167
|
+
for criterion in criteria:
|
|
168
|
+
score = values.get(criterion.id, 0.0)
|
|
169
|
+
per_criterion[criterion.id] = {
|
|
170
|
+
"score": score,
|
|
171
|
+
"weight": criterion.weight,
|
|
172
|
+
"required": criterion.required,
|
|
173
|
+
}
|
|
174
|
+
if aggregation == "sum":
|
|
175
|
+
total += score
|
|
176
|
+
elif aggregation == "weighted_sum":
|
|
177
|
+
total += score * criterion.weight
|
|
178
|
+
total_weight += criterion.weight
|
|
179
|
+
if aggregation == "weighted_sum" and total_weight > 0:
|
|
180
|
+
total = total / total_weight
|
|
181
|
+
if aggregation == "custom":
|
|
182
|
+
total = None # type: ignore[assignment]
|
|
183
|
+
return {
|
|
184
|
+
"aggregation": aggregation,
|
|
185
|
+
"score": total,
|
|
186
|
+
"per_criterion": per_criterion,
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def score_events_against_rubric(
|
|
191
|
+
events: list[dict[str, Any]], rubric: Rubric | None
|
|
192
|
+
) -> dict[str, Any]:
|
|
193
|
+
if rubric is None:
|
|
194
|
+
return {"aggregation": "none", "score": None, "per_criterion": {}}
|
|
195
|
+
values: dict[str, float] = {}
|
|
196
|
+
for event in events or []:
|
|
197
|
+
if not isinstance(event, dict):
|
|
198
|
+
continue
|
|
199
|
+
cid = event.get("criterion_id") or event.get("id") or event.get("criterion")
|
|
200
|
+
score = _as_float(event.get("score"))
|
|
201
|
+
if cid and score is not None:
|
|
202
|
+
values[str(cid)] = score
|
|
203
|
+
return _score(rubric.criteria, values, rubric.aggregation)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def score_outcome_against_rubric(outcome: dict[str, Any], rubric: Rubric | None) -> dict[str, Any]:
|
|
207
|
+
if rubric is None:
|
|
208
|
+
return {"aggregation": "none", "score": None, "per_criterion": {}}
|
|
209
|
+
values: dict[str, float] = {}
|
|
210
|
+
if isinstance(outcome, dict):
|
|
211
|
+
candidates = (
|
|
212
|
+
outcome.get("criteria") if isinstance(outcome.get("criteria"), dict) else outcome
|
|
213
|
+
)
|
|
214
|
+
if isinstance(candidates, dict):
|
|
215
|
+
for key, value in candidates.items():
|
|
216
|
+
score = _as_float(value)
|
|
217
|
+
if score is not None:
|
|
218
|
+
values[str(key)] = score
|
|
219
|
+
return _score(rubric.criteria, values, rubric.aggregation)
|