synth-ai 0.2.17__py3-none-any.whl → 0.2.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/baseline/banking77_baseline.py +204 -0
- examples/baseline/crafter_baseline.py +407 -0
- examples/baseline/pokemon_red_baseline.py +326 -0
- examples/baseline/simple_baseline.py +56 -0
- examples/baseline/warming_up_to_rl_baseline.py +239 -0
- examples/blog_posts/gepa/README.md +355 -0
- examples/blog_posts/gepa/configs/banking77_gepa_local.toml +95 -0
- examples/blog_posts/gepa/configs/banking77_gepa_test.toml +82 -0
- examples/blog_posts/gepa/configs/banking77_mipro_local.toml +52 -0
- examples/blog_posts/gepa/configs/hotpotqa_gepa_local.toml +59 -0
- examples/blog_posts/gepa/configs/hotpotqa_gepa_qwen.toml +36 -0
- examples/blog_posts/gepa/configs/hotpotqa_mipro_local.toml +53 -0
- examples/blog_posts/gepa/configs/hover_gepa_local.toml +59 -0
- examples/blog_posts/gepa/configs/hover_gepa_qwen.toml +36 -0
- examples/blog_posts/gepa/configs/hover_mipro_local.toml +53 -0
- examples/blog_posts/gepa/configs/ifbench_gepa_local.toml +59 -0
- examples/blog_posts/gepa/configs/ifbench_gepa_qwen.toml +36 -0
- examples/blog_posts/gepa/configs/ifbench_mipro_local.toml +53 -0
- examples/blog_posts/gepa/configs/pupa_gepa_local.toml +60 -0
- examples/blog_posts/gepa/configs/pupa_mipro_local.toml +54 -0
- examples/blog_posts/gepa/deploy_banking77_task_app.sh +41 -0
- examples/blog_posts/gepa/gepa_baseline.py +204 -0
- examples/blog_posts/gepa/query_prompts_example.py +97 -0
- examples/blog_posts/gepa/run_gepa_banking77.sh +87 -0
- examples/blog_posts/gepa/task_apps.py +105 -0
- examples/blog_posts/gepa/test_gepa_local.sh +67 -0
- examples/blog_posts/gepa/verify_banking77_setup.sh +123 -0
- examples/blog_posts/pokemon_vl/configs/eval_gpt5nano.toml +26 -0
- examples/blog_posts/pokemon_vl/configs/eval_qwen3_vl.toml +12 -10
- examples/blog_posts/pokemon_vl/configs/train_rl_from_sft.toml +1 -0
- examples/blog_posts/pokemon_vl/extract_images.py +239 -0
- examples/blog_posts/pokemon_vl/pokemon_vl_baseline.py +326 -0
- examples/blog_posts/pokemon_vl/run_eval_extract_images.py +209 -0
- examples/blog_posts/pokemon_vl/run_qwen_eval_extract_images.py +212 -0
- examples/blog_posts/pokemon_vl/text_box_analysis.md +106 -0
- examples/blog_posts/warming_up_to_rl/ARCHITECTURE.md +195 -0
- examples/blog_posts/warming_up_to_rl/FINAL_TEST_RESULTS.md +127 -0
- examples/blog_posts/warming_up_to_rl/INFERENCE_SUCCESS.md +132 -0
- examples/blog_posts/warming_up_to_rl/SMOKE_TESTING.md +164 -0
- examples/blog_posts/warming_up_to_rl/SMOKE_TEST_COMPLETE.md +253 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_baseline_qwen32b_10x20.toml +25 -0
- examples/blog_posts/warming_up_to_rl/configs/eval_ft_qwen4b_10x20.toml +26 -0
- examples/blog_posts/warming_up_to_rl/configs/filter_high_reward_dataset.toml +1 -1
- examples/blog_posts/warming_up_to_rl/configs/smoke_test.toml +75 -0
- examples/blog_posts/warming_up_to_rl/configs/train_rl_from_sft.toml +60 -10
- examples/blog_posts/warming_up_to_rl/configs/train_sft_qwen4b.toml +1 -1
- examples/blog_posts/warming_up_to_rl/warming_up_to_rl_baseline.py +187 -0
- examples/multi_step/configs/VERILOG_REWARDS.md +4 -0
- examples/multi_step/configs/VERILOG_RL_CHECKLIST.md +4 -0
- examples/multi_step/configs/crafter_rl_outcome.toml +1 -0
- examples/multi_step/configs/crafter_rl_stepwise_shaped.toml +1 -0
- examples/multi_step/configs/crafter_rl_stepwise_simple.toml +1 -0
- examples/rl/configs/rl_from_base_qwen17.toml +1 -0
- examples/swe/task_app/hosted/inference/openai_client.py +0 -34
- examples/swe/task_app/hosted/policy_routes.py +17 -0
- examples/swe/task_app/hosted/rollout.py +4 -2
- examples/task_apps/banking77/__init__.py +6 -0
- examples/task_apps/banking77/banking77_task_app.py +841 -0
- examples/task_apps/banking77/deploy_wrapper.py +46 -0
- examples/task_apps/crafter/CREATE_SFT_DATASET.md +4 -0
- examples/task_apps/crafter/FILTER_COMMAND_STATUS.md +4 -0
- examples/task_apps/crafter/FILTER_COMMAND_SUCCESS.md +4 -0
- examples/task_apps/crafter/task_app/grpo_crafter.py +24 -2
- examples/task_apps/crafter/task_app/synth_envs_hosted/hosted_app.py +49 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/inference/openai_client.py +355 -58
- examples/task_apps/crafter/task_app/synth_envs_hosted/policy_routes.py +68 -7
- examples/task_apps/crafter/task_app/synth_envs_hosted/rollout.py +78 -21
- examples/task_apps/crafter/task_app/synth_envs_hosted/utils.py +194 -1
- examples/task_apps/gepa_benchmarks/__init__.py +7 -0
- examples/task_apps/gepa_benchmarks/common.py +260 -0
- examples/task_apps/gepa_benchmarks/hotpotqa_task_app.py +507 -0
- examples/task_apps/gepa_benchmarks/hover_task_app.py +436 -0
- examples/task_apps/gepa_benchmarks/ifbench_task_app.py +563 -0
- examples/task_apps/gepa_benchmarks/pupa_task_app.py +460 -0
- examples/task_apps/pokemon_red/README_IMAGE_ONLY_EVAL.md +4 -0
- examples/task_apps/pokemon_red/task_app.py +254 -36
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +1 -0
- examples/warming_up_to_rl/task_app/grpo_crafter.py +53 -4
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +49 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +152 -41
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +31 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +33 -3
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +67 -0
- examples/workflows/math_rl/configs/rl_from_base_qwen17.toml +1 -0
- synth_ai/api/train/builders.py +90 -1
- synth_ai/api/train/cli.py +396 -21
- synth_ai/api/train/config_finder.py +13 -2
- synth_ai/api/train/configs/__init__.py +15 -1
- synth_ai/api/train/configs/prompt_learning.py +442 -0
- synth_ai/api/train/configs/rl.py +29 -0
- synth_ai/api/train/task_app.py +1 -1
- synth_ai/api/train/validators.py +277 -0
- synth_ai/baseline/__init__.py +25 -0
- synth_ai/baseline/config.py +209 -0
- synth_ai/baseline/discovery.py +214 -0
- synth_ai/baseline/execution.py +146 -0
- synth_ai/cli/__init__.py +85 -17
- synth_ai/cli/__main__.py +0 -0
- synth_ai/cli/claude.py +70 -0
- synth_ai/cli/codex.py +84 -0
- synth_ai/cli/commands/__init__.py +1 -0
- synth_ai/cli/commands/baseline/__init__.py +12 -0
- synth_ai/cli/commands/baseline/core.py +637 -0
- synth_ai/cli/commands/baseline/list.py +93 -0
- synth_ai/cli/commands/eval/core.py +13 -10
- synth_ai/cli/commands/filter/core.py +53 -17
- synth_ai/cli/commands/help/core.py +0 -1
- synth_ai/cli/commands/smoke/__init__.py +7 -0
- synth_ai/cli/commands/smoke/core.py +1436 -0
- synth_ai/cli/commands/status/subcommands/pricing.py +22 -0
- synth_ai/cli/commands/status/subcommands/usage.py +203 -0
- synth_ai/cli/commands/train/judge_schemas.py +1 -0
- synth_ai/cli/commands/train/judge_validation.py +1 -0
- synth_ai/cli/commands/train/validation.py +0 -57
- synth_ai/cli/demo.py +35 -3
- synth_ai/cli/deploy/__init__.py +40 -25
- synth_ai/cli/deploy.py +162 -0
- synth_ai/cli/legacy_root_backup.py +14 -8
- synth_ai/cli/opencode.py +107 -0
- synth_ai/cli/root.py +9 -5
- synth_ai/cli/task_app_deploy.py +1 -1
- synth_ai/cli/task_apps.py +53 -53
- synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +7 -4
- synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +9 -5
- synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +4 -3
- synth_ai/judge_schemas.py +1 -0
- synth_ai/learning/__init__.py +10 -0
- synth_ai/learning/prompt_learning_client.py +276 -0
- synth_ai/learning/prompt_learning_types.py +184 -0
- synth_ai/pricing/__init__.py +2 -0
- synth_ai/pricing/model_pricing.py +57 -0
- synth_ai/streaming/handlers.py +53 -4
- synth_ai/streaming/streamer.py +19 -0
- synth_ai/task/apps/__init__.py +1 -0
- synth_ai/task/config.py +2 -0
- synth_ai/task/tracing_utils.py +25 -25
- synth_ai/task/validators.py +44 -8
- synth_ai/task_app_cfgs.py +21 -0
- synth_ai/tracing_v3/config.py +162 -19
- synth_ai/tracing_v3/constants.py +1 -1
- synth_ai/tracing_v3/db_config.py +24 -38
- synth_ai/tracing_v3/storage/config.py +47 -13
- synth_ai/tracing_v3/storage/factory.py +3 -3
- synth_ai/tracing_v3/turso/daemon.py +113 -11
- synth_ai/tracing_v3/turso/native_manager.py +92 -16
- synth_ai/types.py +8 -0
- synth_ai/urls.py +11 -0
- synth_ai/utils/__init__.py +30 -1
- synth_ai/utils/agents.py +74 -0
- synth_ai/utils/bin.py +39 -0
- synth_ai/utils/cli.py +149 -5
- synth_ai/utils/env.py +17 -17
- synth_ai/utils/json.py +72 -0
- synth_ai/utils/modal.py +283 -1
- synth_ai/utils/paths.py +48 -0
- synth_ai/utils/uvicorn.py +113 -0
- {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/METADATA +102 -4
- {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/RECORD +162 -88
- synth_ai/cli/commands/deploy/__init__.py +0 -23
- synth_ai/cli/commands/deploy/core.py +0 -614
- synth_ai/cli/commands/deploy/errors.py +0 -72
- synth_ai/cli/commands/deploy/validation.py +0 -11
- synth_ai/cli/deploy/core.py +0 -5
- synth_ai/cli/deploy/errors.py +0 -23
- synth_ai/cli/deploy/validation.py +0 -5
- {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.17.dist-info → synth_ai-0.2.19.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,507 @@
|
|
|
1
|
+
"""HotpotQA multi-hop QA task app for Synth prompt learning benchmarks."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import contextlib
|
|
6
|
+
import os
|
|
7
|
+
import uuid
|
|
8
|
+
from collections.abc import Iterable, Sequence
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Any, Mapping, cast
|
|
11
|
+
|
|
12
|
+
from datasets import load_dataset
|
|
13
|
+
from fastapi import APIRouter, HTTPException, Request
|
|
14
|
+
|
|
15
|
+
from synth_ai.task.apps import ModalDeploymentConfig, TaskAppEntry, register_task_app
|
|
16
|
+
from synth_ai.task.contracts import (
|
|
17
|
+
RolloutMetrics,
|
|
18
|
+
RolloutRequest,
|
|
19
|
+
RolloutResponse,
|
|
20
|
+
RolloutStep,
|
|
21
|
+
RolloutTrajectory,
|
|
22
|
+
TaskInfo,
|
|
23
|
+
)
|
|
24
|
+
from synth_ai.task.datasets import TaskDatasetRegistry, TaskDatasetSpec
|
|
25
|
+
from synth_ai.task.rubrics import Rubric, load_rubric
|
|
26
|
+
from synth_ai.task.server import ProxyConfig, RubricBundle, TaskAppConfig
|
|
27
|
+
from synth_ai.task.vendors import normalize_vendor_keys
|
|
28
|
+
|
|
29
|
+
from .common import call_chat_completion, normalise_answer
|
|
30
|
+
|
|
31
|
+
REPO_ROOT = Path(__file__).resolve().parents[3]
|
|
32
|
+
|
|
33
|
+
HOTPOTQA_DATASET = "hotpot_qa"
|
|
34
|
+
HOTPOTQA_CONFIG = "distractor"
|
|
35
|
+
DEFAULT_SPLIT = "validation"
|
|
36
|
+
AVAILABLE_SPLITS: tuple[str, ...] = ("train", "validation")
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
hotpotqa_router = APIRouter()
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
HOTPOTQA_DATASET_SPEC = TaskDatasetSpec(
|
|
43
|
+
id="hotpotqa",
|
|
44
|
+
name="HotpotQA Multi-Hop Question Answering",
|
|
45
|
+
version="1.0.0",
|
|
46
|
+
splits=list(AVAILABLE_SPLITS),
|
|
47
|
+
default_split=DEFAULT_SPLIT,
|
|
48
|
+
description="HotpotQA question answering with multi-hop supporting facts.",
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class HotpotQADataset:
|
|
53
|
+
"""Lazy loader and sampler for the HotpotQA dataset."""
|
|
54
|
+
|
|
55
|
+
def __init__(self) -> None:
|
|
56
|
+
self._splits: dict[str, Any] = {}
|
|
57
|
+
|
|
58
|
+
def _load_split(self, split: str):
|
|
59
|
+
if split not in AVAILABLE_SPLITS:
|
|
60
|
+
raise ValueError(f"Unknown split '{split}'. Available: {AVAILABLE_SPLITS}")
|
|
61
|
+
if split not in self._splits:
|
|
62
|
+
try:
|
|
63
|
+
self._splits[split] = load_dataset(HOTPOTQA_DATASET, HOTPOTQA_CONFIG, split=split)
|
|
64
|
+
except Exception as exc: # pragma: no cover - network/dataset errors
|
|
65
|
+
raise RuntimeError(
|
|
66
|
+
f"Failed to download HotpotQA split '{split}'. "
|
|
67
|
+
f"Ensure network access to Hugging Face datasets."
|
|
68
|
+
) from exc
|
|
69
|
+
return self._splits[split]
|
|
70
|
+
|
|
71
|
+
def ensure_ready(self, required_splits: Sequence[str]) -> None:
|
|
72
|
+
for split in required_splits:
|
|
73
|
+
self._load_split(split)
|
|
74
|
+
|
|
75
|
+
def size(self, split: str) -> int:
|
|
76
|
+
dataset = self._load_split(split)
|
|
77
|
+
return len(dataset)
|
|
78
|
+
|
|
79
|
+
@staticmethod
|
|
80
|
+
def _format_context(context: Any) -> tuple[str, list[str]]:
|
|
81
|
+
"""Convert HotpotQA context paragraphs into display text and titles."""
|
|
82
|
+
|
|
83
|
+
lines: list[str] = []
|
|
84
|
+
titles: list[str] = []
|
|
85
|
+
|
|
86
|
+
if isinstance(context, Mapping):
|
|
87
|
+
title_list = context.get("title") or []
|
|
88
|
+
sentences_list = context.get("sentences") or []
|
|
89
|
+
for title, sentences in zip(title_list, sentences_list):
|
|
90
|
+
title_str = str(title)
|
|
91
|
+
titles.append(title_str)
|
|
92
|
+
lines.append(f"### {title_str}")
|
|
93
|
+
for sentence in sentences or []:
|
|
94
|
+
lines.append(str(sentence))
|
|
95
|
+
lines.append("")
|
|
96
|
+
elif isinstance(context, Sequence):
|
|
97
|
+
for entry in context:
|
|
98
|
+
if not isinstance(entry, Sequence) or len(entry) != 2:
|
|
99
|
+
continue
|
|
100
|
+
title_str = str(entry[0])
|
|
101
|
+
sentences = entry[1] if isinstance(entry[1], Sequence) else []
|
|
102
|
+
titles.append(title_str)
|
|
103
|
+
lines.append(f"### {title_str}")
|
|
104
|
+
for sentence in sentences:
|
|
105
|
+
lines.append(str(sentence))
|
|
106
|
+
lines.append("")
|
|
107
|
+
|
|
108
|
+
return "\n".join(lines).strip(), titles
|
|
109
|
+
|
|
110
|
+
def sample(self, *, split: str, index: int) -> dict[str, Any]:
|
|
111
|
+
dataset = self._load_split(split)
|
|
112
|
+
size = len(dataset)
|
|
113
|
+
if size == 0:
|
|
114
|
+
raise RuntimeError(f"HotpotQA split '{split}' is empty")
|
|
115
|
+
idx = int(index) % size
|
|
116
|
+
row = dataset[int(idx)]
|
|
117
|
+
|
|
118
|
+
context_text, context_titles = self._format_context(row.get("context") or [])
|
|
119
|
+
supporting = row.get("supporting_facts") or []
|
|
120
|
+
supporting_titles: list[str] = []
|
|
121
|
+
if isinstance(supporting, Mapping):
|
|
122
|
+
supporting_titles = [str(title) for title in (supporting.get("title") or [])]
|
|
123
|
+
elif isinstance(supporting, Sequence):
|
|
124
|
+
supporting_titles = [
|
|
125
|
+
str(entry[0]) for entry in supporting if isinstance(entry, Sequence) and entry
|
|
126
|
+
]
|
|
127
|
+
supporting_titles = sorted(set(supporting_titles))
|
|
128
|
+
|
|
129
|
+
return {
|
|
130
|
+
"index": idx,
|
|
131
|
+
"split": split,
|
|
132
|
+
"question": str(row.get("question") or ""),
|
|
133
|
+
"answer": str(row.get("answer") or ""),
|
|
134
|
+
"context_text": context_text,
|
|
135
|
+
"context_titles": context_titles,
|
|
136
|
+
"supporting_titles": supporting_titles,
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def _parse_answer(response_text: str) -> tuple[str, str]:
|
|
141
|
+
"""Parse response text into (answer, support) segments."""
|
|
142
|
+
|
|
143
|
+
answer = ""
|
|
144
|
+
support = ""
|
|
145
|
+
if not response_text:
|
|
146
|
+
return answer, support
|
|
147
|
+
|
|
148
|
+
lower = response_text.lower()
|
|
149
|
+
if "answer:" in lower:
|
|
150
|
+
parts = lower.split("answer:", 1)[1]
|
|
151
|
+
answer_section = parts.split("support:", 1)[0] if "support:" in parts else parts
|
|
152
|
+
answer = answer_section.strip()
|
|
153
|
+
else:
|
|
154
|
+
answer = response_text.strip()
|
|
155
|
+
|
|
156
|
+
if "support:" in lower:
|
|
157
|
+
support_section = lower.split("support:", 1)[1]
|
|
158
|
+
support = support_section.strip()
|
|
159
|
+
|
|
160
|
+
# Use original casing if possible.
|
|
161
|
+
if answer:
|
|
162
|
+
match_index = response_text.lower().find(answer)
|
|
163
|
+
if match_index >= 0:
|
|
164
|
+
answer = response_text[match_index : match_index + len(answer)].strip()
|
|
165
|
+
if support:
|
|
166
|
+
match_index = response_text.lower().find(support)
|
|
167
|
+
if match_index >= 0:
|
|
168
|
+
support = response_text[match_index : match_index + len(support)].strip()
|
|
169
|
+
return answer.strip(), support.strip()
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
async def rollout_executor(request: RolloutRequest, fastapi_request: Request) -> RolloutResponse:
|
|
173
|
+
dataset: HotpotQADataset = fastapi_request.app.state.hotpotqa_dataset
|
|
174
|
+
|
|
175
|
+
split = str(((request.env.config or {}).get("split")) or DEFAULT_SPLIT)
|
|
176
|
+
seed = request.env.seed or 0
|
|
177
|
+
|
|
178
|
+
sample = dataset.sample(split=split, index=seed)
|
|
179
|
+
observation = {
|
|
180
|
+
"question": sample["question"],
|
|
181
|
+
"context": sample["context_text"],
|
|
182
|
+
"supporting_titles": sample["supporting_titles"],
|
|
183
|
+
"index": sample["index"],
|
|
184
|
+
"split": sample["split"],
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
placeholders = {
|
|
188
|
+
"question": sample["question"],
|
|
189
|
+
"context": sample["context_text"],
|
|
190
|
+
}
|
|
191
|
+
|
|
192
|
+
default_messages = [
|
|
193
|
+
{
|
|
194
|
+
"role": "system",
|
|
195
|
+
"pattern": (
|
|
196
|
+
"You are a research assistant that answers multi-hop questions. "
|
|
197
|
+
"Read the passages carefully and respond in the format:\n"
|
|
198
|
+
"Answer: <short answer>\nSupport: <brief justification citing passages>."
|
|
199
|
+
),
|
|
200
|
+
},
|
|
201
|
+
{
|
|
202
|
+
"role": "user",
|
|
203
|
+
"pattern": "Question: {question}\n\nPassages:\n{context}\n\nProvide the final answer.",
|
|
204
|
+
},
|
|
205
|
+
]
|
|
206
|
+
|
|
207
|
+
tool_calls: list[dict[str, Any]] = []
|
|
208
|
+
response_json: dict[str, Any] | None = None
|
|
209
|
+
response_text = ""
|
|
210
|
+
error_info: dict[str, Any] = {}
|
|
211
|
+
|
|
212
|
+
try:
|
|
213
|
+
response_text, response_json, _ = await call_chat_completion(
|
|
214
|
+
request.policy.config or {},
|
|
215
|
+
placeholders,
|
|
216
|
+
default_messages,
|
|
217
|
+
)
|
|
218
|
+
except HTTPException as http_err: # pragma: no cover - passthrough to metrics
|
|
219
|
+
error_info = {"error": str(http_err.detail), "code": http_err.status_code}
|
|
220
|
+
except Exception as exc: # pragma: no cover - defensive logging
|
|
221
|
+
error_info = {"error": str(exc)}
|
|
222
|
+
|
|
223
|
+
answer_text, support_text = _parse_answer(response_text)
|
|
224
|
+
|
|
225
|
+
expected_answer = sample["answer"]
|
|
226
|
+
answer_correct = int(normalise_answer(answer_text) == normalise_answer(expected_answer))
|
|
227
|
+
|
|
228
|
+
support_titles = sample["supporting_titles"]
|
|
229
|
+
support_hits = 0
|
|
230
|
+
if support_titles and support_text:
|
|
231
|
+
lower_support = support_text.lower()
|
|
232
|
+
support_hits = sum(1 for title in support_titles if title.lower() in lower_support)
|
|
233
|
+
support_coverage = (support_hits / len(support_titles)) if support_titles else 0.0
|
|
234
|
+
|
|
235
|
+
reward = 0.7 * answer_correct + 0.3 * support_coverage
|
|
236
|
+
|
|
237
|
+
info_payload = {
|
|
238
|
+
"expected_answer": expected_answer,
|
|
239
|
+
"predicted_answer": answer_text,
|
|
240
|
+
"support_text": support_text,
|
|
241
|
+
"answer_em": answer_correct,
|
|
242
|
+
"support_coverage": support_coverage,
|
|
243
|
+
"response_json": response_json,
|
|
244
|
+
**error_info,
|
|
245
|
+
}
|
|
246
|
+
|
|
247
|
+
with contextlib.suppress(Exception):
|
|
248
|
+
print(
|
|
249
|
+
f"[HOTPOTQA_ROLLOUT] run_id={request.run_id} split={sample['split']} "
|
|
250
|
+
f"index={sample['index']} answer_em={answer_correct} "
|
|
251
|
+
f"support={support_hits}/{len(support_titles) or 1} reward={reward:.3f}",
|
|
252
|
+
flush=True,
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
step = RolloutStep(
|
|
256
|
+
obs=observation,
|
|
257
|
+
tool_calls=tool_calls,
|
|
258
|
+
reward=reward,
|
|
259
|
+
done=True,
|
|
260
|
+
info=info_payload,
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
inference_url = (request.policy.config or {}).get("inference_url")
|
|
264
|
+
|
|
265
|
+
trajectory = RolloutTrajectory(
|
|
266
|
+
env_id=f"hotpotqa::{sample['split']}::{sample['index']}",
|
|
267
|
+
policy_id=request.policy.policy_id or request.policy.policy_name or "policy",
|
|
268
|
+
steps=[step],
|
|
269
|
+
final={"observation": observation, "reward": reward},
|
|
270
|
+
length=1,
|
|
271
|
+
inference_url=str(inference_url or ""),
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
metrics = RolloutMetrics(
|
|
275
|
+
episode_returns=[reward],
|
|
276
|
+
mean_return=reward,
|
|
277
|
+
num_steps=1,
|
|
278
|
+
num_episodes=1,
|
|
279
|
+
outcome_score=reward,
|
|
280
|
+
events_score=reward,
|
|
281
|
+
details={
|
|
282
|
+
"answer_correct": bool(answer_correct),
|
|
283
|
+
"support_coverage": support_coverage,
|
|
284
|
+
},
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
trace_payload = None
|
|
288
|
+
include_trace = bool(
|
|
289
|
+
(request.record and getattr(request.record, "return_trace", False))
|
|
290
|
+
or os.getenv("TASKAPP_TRACING_ENABLED")
|
|
291
|
+
)
|
|
292
|
+
if include_trace:
|
|
293
|
+
trace_payload = {
|
|
294
|
+
"session_id": str(uuid.uuid4()),
|
|
295
|
+
"events_count": 1,
|
|
296
|
+
"decision_rewards": [reward],
|
|
297
|
+
"metadata": {
|
|
298
|
+
"env": "hotpotqa",
|
|
299
|
+
"split": sample["split"],
|
|
300
|
+
"index": sample["index"],
|
|
301
|
+
"answer_em": answer_correct,
|
|
302
|
+
"support_coverage": support_coverage,
|
|
303
|
+
},
|
|
304
|
+
}
|
|
305
|
+
|
|
306
|
+
return RolloutResponse(
|
|
307
|
+
run_id=request.run_id,
|
|
308
|
+
trajectories=[trajectory],
|
|
309
|
+
branches={},
|
|
310
|
+
metrics=metrics,
|
|
311
|
+
aborted=False,
|
|
312
|
+
ops_executed=2,
|
|
313
|
+
trace=trace_payload,
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
def build_dataset() -> tuple[TaskDatasetRegistry, HotpotQADataset]:
|
|
318
|
+
registry = TaskDatasetRegistry()
|
|
319
|
+
dataset = HotpotQADataset()
|
|
320
|
+
dataset.ensure_ready([DEFAULT_SPLIT])
|
|
321
|
+
registry.register(HOTPOTQA_DATASET_SPEC, lambda _spec: dataset, cache=True)
|
|
322
|
+
return registry, dataset
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
def _base_task_info() -> TaskInfo:
|
|
326
|
+
return TaskInfo(
|
|
327
|
+
task={
|
|
328
|
+
"id": "hotpotqa",
|
|
329
|
+
"name": "HotpotQA Multi-Hop QA",
|
|
330
|
+
"version": "1.0.0",
|
|
331
|
+
"action_space": {
|
|
332
|
+
"type": "free_text",
|
|
333
|
+
"description": "Respond with an answer and supporting justification.",
|
|
334
|
+
},
|
|
335
|
+
},
|
|
336
|
+
environment="hotpotqa",
|
|
337
|
+
dataset={
|
|
338
|
+
**HOTPOTQA_DATASET_SPEC.model_dump(),
|
|
339
|
+
"hf_dataset": HOTPOTQA_DATASET,
|
|
340
|
+
"hf_config": HOTPOTQA_CONFIG,
|
|
341
|
+
},
|
|
342
|
+
rubric={
|
|
343
|
+
"version": "1",
|
|
344
|
+
"criteria_count": 2,
|
|
345
|
+
"source": "inline",
|
|
346
|
+
},
|
|
347
|
+
inference={
|
|
348
|
+
"supports_proxy": True,
|
|
349
|
+
"tool": None,
|
|
350
|
+
},
|
|
351
|
+
limits={"max_turns": 1},
|
|
352
|
+
task_metadata={
|
|
353
|
+
"format": "Answer: ... / Support: ...",
|
|
354
|
+
"support_titles": True,
|
|
355
|
+
},
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
def describe_taskset(dataset: HotpotQADataset) -> Mapping[str, Any]:
|
|
360
|
+
return {
|
|
361
|
+
**HOTPOTQA_DATASET_SPEC.model_dump(),
|
|
362
|
+
"hf_dataset": HOTPOTQA_DATASET,
|
|
363
|
+
"hf_config": HOTPOTQA_CONFIG,
|
|
364
|
+
"sizes": {split: dataset.size(split) for split in AVAILABLE_SPLITS},
|
|
365
|
+
}
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
def provide_task_instances(dataset: HotpotQADataset, seeds: Sequence[int]) -> Iterable[TaskInfo]:
|
|
369
|
+
base_info = _base_task_info()
|
|
370
|
+
for seed in seeds:
|
|
371
|
+
sample = dataset.sample(split=DEFAULT_SPLIT, index=seed)
|
|
372
|
+
yield TaskInfo(
|
|
373
|
+
task=base_info.task,
|
|
374
|
+
environment=base_info.environment,
|
|
375
|
+
dataset={
|
|
376
|
+
**base_info.dataset,
|
|
377
|
+
"split": sample["split"],
|
|
378
|
+
"index": sample["index"],
|
|
379
|
+
},
|
|
380
|
+
rubric=base_info.rubric,
|
|
381
|
+
inference=base_info.inference,
|
|
382
|
+
limits=base_info.limits,
|
|
383
|
+
task_metadata={
|
|
384
|
+
**base_info.task_metadata,
|
|
385
|
+
"question": sample["question"],
|
|
386
|
+
},
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
OUTCOME_RUBRIC: Rubric = cast(
|
|
391
|
+
Rubric,
|
|
392
|
+
load_rubric(
|
|
393
|
+
{
|
|
394
|
+
"version": "1",
|
|
395
|
+
"goal_text": "Answer HotpotQA questions accurately with supporting justification.",
|
|
396
|
+
"aggregation": "weighted_sum",
|
|
397
|
+
"criteria": [
|
|
398
|
+
{
|
|
399
|
+
"id": "answer_accuracy",
|
|
400
|
+
"description": "Final answer matches the gold answer.",
|
|
401
|
+
"weight": 0.7,
|
|
402
|
+
},
|
|
403
|
+
{
|
|
404
|
+
"id": "supporting_evidence",
|
|
405
|
+
"description": "Support references the correct passages.",
|
|
406
|
+
"weight": 0.3,
|
|
407
|
+
},
|
|
408
|
+
],
|
|
409
|
+
}
|
|
410
|
+
),
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
EVENTS_RUBRIC: Rubric = cast(
|
|
414
|
+
Rubric,
|
|
415
|
+
load_rubric(
|
|
416
|
+
{
|
|
417
|
+
"version": "1",
|
|
418
|
+
"goal_text": "Encourage concise responses with the requested format.",
|
|
419
|
+
"aggregation": "weighted_sum",
|
|
420
|
+
"criteria": [
|
|
421
|
+
{
|
|
422
|
+
"id": "format_compliance",
|
|
423
|
+
"description": "Respond using 'Answer:' and 'Support:' sections.",
|
|
424
|
+
"weight": 1.0,
|
|
425
|
+
}
|
|
426
|
+
],
|
|
427
|
+
}
|
|
428
|
+
),
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
|
|
432
|
+
def build_config() -> TaskAppConfig:
|
|
433
|
+
registry, dataset = build_dataset()
|
|
434
|
+
base_info = _base_task_info()
|
|
435
|
+
|
|
436
|
+
proxy_keys = normalize_vendor_keys()
|
|
437
|
+
proxy_config = ProxyConfig(
|
|
438
|
+
enable_openai=proxy_keys.get("OPENAI_API_KEY") is not None,
|
|
439
|
+
enable_groq=proxy_keys.get("GROQ_API_KEY") is not None,
|
|
440
|
+
system_hint="Provide an answer followed by supporting justification.",
|
|
441
|
+
)
|
|
442
|
+
|
|
443
|
+
config = TaskAppConfig(
|
|
444
|
+
app_id="hotpotqa",
|
|
445
|
+
name="HotpotQA Multi-Hop QA Task",
|
|
446
|
+
description="HotpotQA environment for evaluating prompt optimisers.",
|
|
447
|
+
base_task_info=base_info,
|
|
448
|
+
describe_taskset=lambda: describe_taskset(dataset),
|
|
449
|
+
provide_task_instances=lambda seeds: provide_task_instances(dataset, seeds),
|
|
450
|
+
rollout=rollout_executor,
|
|
451
|
+
dataset_registry=registry,
|
|
452
|
+
rubrics=RubricBundle(outcome=OUTCOME_RUBRIC, events=EVENTS_RUBRIC),
|
|
453
|
+
proxy=proxy_config,
|
|
454
|
+
routers=(hotpotqa_router,),
|
|
455
|
+
app_state={"hotpotqa_dataset": dataset},
|
|
456
|
+
cors_origins=["*"],
|
|
457
|
+
)
|
|
458
|
+
return config
|
|
459
|
+
|
|
460
|
+
|
|
461
|
+
register_task_app(
|
|
462
|
+
entry=TaskAppEntry(
|
|
463
|
+
app_id="hotpotqa",
|
|
464
|
+
description="HotpotQA multi-hop QA task app using the distractor split.",
|
|
465
|
+
config_factory=build_config,
|
|
466
|
+
aliases=("hotpotqa-multihop",),
|
|
467
|
+
modal=ModalDeploymentConfig(
|
|
468
|
+
app_name="synth-hotpotqa",
|
|
469
|
+
pip_packages=(
|
|
470
|
+
"datasets>=2.14.0",
|
|
471
|
+
"fastapi>=0.115.0",
|
|
472
|
+
"pydantic>=2.0.0",
|
|
473
|
+
"httpx>=0.26.0",
|
|
474
|
+
),
|
|
475
|
+
extra_local_dirs=((str(REPO_ROOT / "synth_ai"), "/opt/synth_ai_repo/synth_ai"),),
|
|
476
|
+
),
|
|
477
|
+
)
|
|
478
|
+
)
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
if __name__ == "__main__": # pragma: no cover - manual local run helper
|
|
482
|
+
import argparse
|
|
483
|
+
from synth_ai.task.server import run_task_app
|
|
484
|
+
|
|
485
|
+
parser = argparse.ArgumentParser(description="Run the HotpotQA task app locally")
|
|
486
|
+
parser.add_argument("--host", default="0.0.0.0")
|
|
487
|
+
parser.add_argument("--port", type=int, default=8110)
|
|
488
|
+
parser.add_argument("--reload", action="store_true", help="Enable uvicorn autoreload")
|
|
489
|
+
parser.add_argument(
|
|
490
|
+
"--env-file",
|
|
491
|
+
action="append",
|
|
492
|
+
default=[],
|
|
493
|
+
help="Additional .env files to load before startup",
|
|
494
|
+
)
|
|
495
|
+
args = parser.parse_args()
|
|
496
|
+
|
|
497
|
+
default_env = Path(__file__).resolve().parents[2] / ".env"
|
|
498
|
+
env_files = [str(default_env)] if default_env.exists() else []
|
|
499
|
+
env_files.extend(args.env_file or [])
|
|
500
|
+
|
|
501
|
+
run_task_app(
|
|
502
|
+
build_config,
|
|
503
|
+
host=args.host,
|
|
504
|
+
port=args.port,
|
|
505
|
+
reload=args.reload,
|
|
506
|
+
env_files=env_files,
|
|
507
|
+
)
|