local-deep-research 0.3.12__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- local_deep_research/__init__.py +1 -0
- local_deep_research/__version__.py +1 -1
- local_deep_research/advanced_search_system/filters/base_filter.py +2 -3
- local_deep_research/advanced_search_system/filters/cross_engine_filter.py +4 -5
- local_deep_research/advanced_search_system/filters/journal_reputation_filter.py +298 -0
- local_deep_research/advanced_search_system/findings/repository.py +0 -3
- local_deep_research/advanced_search_system/strategies/base_strategy.py +1 -2
- local_deep_research/advanced_search_system/strategies/iterdrag_strategy.py +14 -18
- local_deep_research/advanced_search_system/strategies/parallel_search_strategy.py +4 -8
- local_deep_research/advanced_search_system/strategies/rapid_search_strategy.py +5 -6
- local_deep_research/advanced_search_system/strategies/source_based_strategy.py +2 -2
- local_deep_research/advanced_search_system/strategies/standard_strategy.py +9 -7
- local_deep_research/api/benchmark_functions.py +288 -0
- local_deep_research/api/research_functions.py +8 -4
- local_deep_research/benchmarks/README.md +162 -0
- local_deep_research/benchmarks/__init__.py +51 -0
- local_deep_research/benchmarks/benchmark_functions.py +353 -0
- local_deep_research/benchmarks/cli/__init__.py +16 -0
- local_deep_research/benchmarks/cli/benchmark_commands.py +338 -0
- local_deep_research/benchmarks/cli.py +347 -0
- local_deep_research/benchmarks/comparison/__init__.py +12 -0
- local_deep_research/benchmarks/comparison/evaluator.py +768 -0
- local_deep_research/benchmarks/datasets/__init__.py +53 -0
- local_deep_research/benchmarks/datasets/base.py +295 -0
- local_deep_research/benchmarks/datasets/browsecomp.py +116 -0
- local_deep_research/benchmarks/datasets/custom_dataset_template.py +98 -0
- local_deep_research/benchmarks/datasets/simpleqa.py +74 -0
- local_deep_research/benchmarks/datasets/utils.py +116 -0
- local_deep_research/benchmarks/datasets.py +31 -0
- local_deep_research/benchmarks/efficiency/__init__.py +14 -0
- local_deep_research/benchmarks/efficiency/resource_monitor.py +367 -0
- local_deep_research/benchmarks/efficiency/speed_profiler.py +214 -0
- local_deep_research/benchmarks/evaluators/__init__.py +18 -0
- local_deep_research/benchmarks/evaluators/base.py +74 -0
- local_deep_research/benchmarks/evaluators/browsecomp.py +83 -0
- local_deep_research/benchmarks/evaluators/composite.py +121 -0
- local_deep_research/benchmarks/evaluators/simpleqa.py +271 -0
- local_deep_research/benchmarks/graders.py +410 -0
- local_deep_research/benchmarks/metrics/README.md +80 -0
- local_deep_research/benchmarks/metrics/__init__.py +24 -0
- local_deep_research/benchmarks/metrics/calculation.py +385 -0
- local_deep_research/benchmarks/metrics/reporting.py +155 -0
- local_deep_research/benchmarks/metrics/visualization.py +205 -0
- local_deep_research/benchmarks/metrics.py +11 -0
- local_deep_research/benchmarks/optimization/__init__.py +32 -0
- local_deep_research/benchmarks/optimization/api.py +274 -0
- local_deep_research/benchmarks/optimization/metrics.py +20 -0
- local_deep_research/benchmarks/optimization/optuna_optimizer.py +1163 -0
- local_deep_research/benchmarks/runners.py +434 -0
- local_deep_research/benchmarks/templates.py +65 -0
- local_deep_research/config/llm_config.py +26 -23
- local_deep_research/config/search_config.py +1 -5
- local_deep_research/defaults/default_settings.json +108 -7
- local_deep_research/search_system.py +16 -8
- local_deep_research/utilities/db_utils.py +3 -6
- local_deep_research/utilities/es_utils.py +441 -0
- local_deep_research/utilities/log_utils.py +36 -0
- local_deep_research/utilities/search_utilities.py +8 -9
- local_deep_research/web/app.py +15 -10
- local_deep_research/web/app_factory.py +9 -12
- local_deep_research/web/database/migrations.py +8 -5
- local_deep_research/web/database/models.py +20 -0
- local_deep_research/web/database/schema_upgrade.py +5 -8
- local_deep_research/web/models/database.py +15 -18
- local_deep_research/web/routes/benchmark_routes.py +427 -0
- local_deep_research/web/routes/research_routes.py +13 -17
- local_deep_research/web/routes/settings_routes.py +264 -67
- local_deep_research/web/services/research_service.py +58 -73
- local_deep_research/web/services/settings_manager.py +1 -4
- local_deep_research/web/services/settings_service.py +4 -6
- local_deep_research/web/static/css/styles.css +12 -0
- local_deep_research/web/static/js/components/logpanel.js +164 -155
- local_deep_research/web/static/js/components/research.js +44 -3
- local_deep_research/web/static/js/components/settings.js +27 -0
- local_deep_research/web/static/js/services/socket.js +47 -0
- local_deep_research/web_search_engines/default_search_engines.py +38 -0
- local_deep_research/web_search_engines/engines/meta_search_engine.py +100 -33
- local_deep_research/web_search_engines/engines/search_engine_arxiv.py +31 -17
- local_deep_research/web_search_engines/engines/search_engine_brave.py +8 -3
- local_deep_research/web_search_engines/engines/search_engine_elasticsearch.py +343 -0
- local_deep_research/web_search_engines/engines/search_engine_google_pse.py +14 -6
- local_deep_research/web_search_engines/engines/search_engine_local.py +19 -23
- local_deep_research/web_search_engines/engines/search_engine_local_all.py +9 -12
- local_deep_research/web_search_engines/engines/search_engine_searxng.py +12 -17
- local_deep_research/web_search_engines/engines/search_engine_serpapi.py +8 -4
- local_deep_research/web_search_engines/search_engine_base.py +22 -5
- local_deep_research/web_search_engines/search_engine_factory.py +30 -11
- local_deep_research/web_search_engines/search_engines_config.py +14 -1
- {local_deep_research-0.3.12.dist-info → local_deep_research-0.4.1.dist-info}/METADATA +10 -2
- {local_deep_research-0.3.12.dist-info → local_deep_research-0.4.1.dist-info}/RECORD +93 -51
- local_deep_research/app.py +0 -8
- {local_deep_research-0.3.12.dist-info → local_deep_research-0.4.1.dist-info}/WHEEL +0 -0
- {local_deep_research-0.3.12.dist-info → local_deep_research-0.4.1.dist-info}/entry_points.txt +0 -0
- {local_deep_research-0.3.12.dist-info → local_deep_research-0.4.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,271 @@
|
|
1
|
+
"""
|
2
|
+
SimpleQA benchmark evaluator.
|
3
|
+
|
4
|
+
This module provides a benchmark evaluator implementation for the SimpleQA
|
5
|
+
benchmark, which tests simple question-answering capabilities.
|
6
|
+
"""
|
7
|
+
|
8
|
+
import json
|
9
|
+
import logging
|
10
|
+
import os
|
11
|
+
import time
|
12
|
+
from typing import Any, Dict, List, Optional
|
13
|
+
|
14
|
+
from local_deep_research.api import quick_summary
|
15
|
+
from ..datasets.base import DatasetRegistry
|
16
|
+
from ..metrics import calculate_metrics, generate_report
|
17
|
+
from ..runners import run_simpleqa_benchmark # Keep for backward compatibility
|
18
|
+
from .base import BaseBenchmarkEvaluator
|
19
|
+
|
20
|
+
logger = logging.getLogger(__name__)
|
21
|
+
|
22
|
+
|
23
|
+
class SimpleQAEvaluator(BaseBenchmarkEvaluator):
|
24
|
+
"""
|
25
|
+
Evaluator for the SimpleQA benchmark.
|
26
|
+
|
27
|
+
This evaluator runs the SimpleQA benchmark, which tests a system's ability
|
28
|
+
to accurately answer straightforward factual questions.
|
29
|
+
"""
|
30
|
+
|
31
|
+
def __init__(self):
|
32
|
+
"""Initialize the SimpleQA evaluator."""
|
33
|
+
super().__init__("simpleqa")
|
34
|
+
|
35
|
+
def evaluate(
|
36
|
+
self,
|
37
|
+
system_config: Dict[str, Any],
|
38
|
+
num_examples: int,
|
39
|
+
output_dir: str,
|
40
|
+
use_direct_dataset: bool = True,
|
41
|
+
) -> Dict[str, Any]:
|
42
|
+
"""
|
43
|
+
Run SimpleQA benchmark and return metrics.
|
44
|
+
|
45
|
+
Args:
|
46
|
+
system_config: Search and LLM configuration parameters
|
47
|
+
num_examples: Number of benchmark examples to run
|
48
|
+
output_dir: Directory to save evaluation results
|
49
|
+
use_direct_dataset: Whether to use dataset classes directly (recommended)
|
50
|
+
or fall back to runner functions
|
51
|
+
|
52
|
+
Returns:
|
53
|
+
Dictionary with metrics including quality_score based on accuracy
|
54
|
+
"""
|
55
|
+
# Create benchmark-specific directory
|
56
|
+
benchmark_dir = self._create_subdirectory(output_dir)
|
57
|
+
|
58
|
+
# Log benchmark execution
|
59
|
+
logger.info(f"Running SimpleQA benchmark with {num_examples} examples")
|
60
|
+
|
61
|
+
try:
|
62
|
+
if use_direct_dataset:
|
63
|
+
# Use dataset classes directly (new approach)
|
64
|
+
results = self._run_with_dataset_class(
|
65
|
+
system_config=system_config,
|
66
|
+
num_examples=num_examples,
|
67
|
+
output_dir=benchmark_dir,
|
68
|
+
)
|
69
|
+
else:
|
70
|
+
# Fall back to legacy runner function
|
71
|
+
results = run_simpleqa_benchmark(
|
72
|
+
num_examples=num_examples,
|
73
|
+
output_dir=benchmark_dir,
|
74
|
+
search_config=system_config,
|
75
|
+
run_evaluation=True,
|
76
|
+
)
|
77
|
+
|
78
|
+
# Extract metrics
|
79
|
+
metrics = results.get("metrics", {})
|
80
|
+
accuracy = metrics.get("accuracy", 0.0)
|
81
|
+
|
82
|
+
# Return evaluation results with quality score
|
83
|
+
return {
|
84
|
+
"benchmark_type": self.name,
|
85
|
+
"accuracy": accuracy,
|
86
|
+
"quality_score": accuracy, # Map accuracy directly to quality score
|
87
|
+
"raw_results": results,
|
88
|
+
"report_path": results.get("report_path"),
|
89
|
+
}
|
90
|
+
|
91
|
+
except Exception as e:
|
92
|
+
logger.error(f"Error in SimpleQA evaluation: {str(e)}")
|
93
|
+
|
94
|
+
# Return error information
|
95
|
+
return {
|
96
|
+
"benchmark_type": self.name,
|
97
|
+
"error": str(e),
|
98
|
+
"quality_score": 0.0,
|
99
|
+
"accuracy": 0.0,
|
100
|
+
}
|
101
|
+
|
102
|
+
def _run_with_dataset_class(
|
103
|
+
self,
|
104
|
+
system_config: Dict[str, Any],
|
105
|
+
num_examples: int,
|
106
|
+
output_dir: str,
|
107
|
+
) -> Dict[str, Any]:
|
108
|
+
"""
|
109
|
+
Run SimpleQA benchmark using dataset classes directly.
|
110
|
+
|
111
|
+
This implementation directly uses the dataset classes rather than
|
112
|
+
going through the runner functions, allowing for more flexibility
|
113
|
+
and better integration with the object-oriented architecture.
|
114
|
+
|
115
|
+
Args:
|
116
|
+
system_config: Search and LLM configuration parameters
|
117
|
+
num_examples: Number of benchmark examples to run
|
118
|
+
output_dir: Directory to save evaluation results
|
119
|
+
|
120
|
+
Returns:
|
121
|
+
Dictionary with benchmark results
|
122
|
+
"""
|
123
|
+
# Create a dataset instance using the registry
|
124
|
+
try:
|
125
|
+
dataset_instance = DatasetRegistry.create_dataset(
|
126
|
+
dataset_id="simpleqa",
|
127
|
+
num_examples=num_examples,
|
128
|
+
seed=system_config.get("seed", 42),
|
129
|
+
)
|
130
|
+
|
131
|
+
# Load dataset examples
|
132
|
+
examples = dataset_instance.load()
|
133
|
+
logger.info(f"Loaded {len(examples)} SimpleQA examples")
|
134
|
+
|
135
|
+
# Set up output files
|
136
|
+
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
137
|
+
results_file = os.path.join(output_dir, f"simpleqa_{timestamp}_results.jsonl")
|
138
|
+
evaluation_file = os.path.join(output_dir, f"simpleqa_{timestamp}_evaluation.jsonl")
|
139
|
+
report_file = os.path.join(output_dir, f"simpleqa_{timestamp}_report.md")
|
140
|
+
|
141
|
+
# Process each example
|
142
|
+
results = []
|
143
|
+
|
144
|
+
for i, example in enumerate(examples):
|
145
|
+
# Extract question and answer using dataset methods
|
146
|
+
question = dataset_instance.get_question(example)
|
147
|
+
correct_answer = dataset_instance.get_answer(example)
|
148
|
+
|
149
|
+
logger.info(f"Processing {i + 1}/{len(examples)}: {question[:50]}...")
|
150
|
+
|
151
|
+
try:
|
152
|
+
# Format query based on dataset type
|
153
|
+
formatted_query = question # Simple format for SimpleQA
|
154
|
+
|
155
|
+
# Time the search
|
156
|
+
start_time = time.time()
|
157
|
+
|
158
|
+
# Create search config from system_config
|
159
|
+
search_params = {
|
160
|
+
"iterations": system_config.get("iterations", 3),
|
161
|
+
"questions_per_iteration": system_config.get("questions_per_iteration", 3),
|
162
|
+
"search_tool": system_config.get("search_tool", "searxng"),
|
163
|
+
# Note: search_strategy is stored in the config but not passed to quick_summary
|
164
|
+
# as it's not supported by the underlying API
|
165
|
+
}
|
166
|
+
|
167
|
+
# Get response from LDR
|
168
|
+
from local_deep_research.api import quick_summary
|
169
|
+
search_result = quick_summary(
|
170
|
+
query=formatted_query,
|
171
|
+
iterations=search_params.get("iterations"),
|
172
|
+
questions_per_iteration=search_params.get("questions_per_iteration"),
|
173
|
+
search_tool=search_params.get("search_tool"),
|
174
|
+
)
|
175
|
+
|
176
|
+
end_time = time.time()
|
177
|
+
processing_time = end_time - start_time
|
178
|
+
|
179
|
+
# Extract response
|
180
|
+
response = search_result.get("summary", "")
|
181
|
+
|
182
|
+
# Extract structured answer
|
183
|
+
from ..graders import extract_answer_from_response
|
184
|
+
extracted = extract_answer_from_response(response, "simpleqa")
|
185
|
+
|
186
|
+
# Format result
|
187
|
+
result = {
|
188
|
+
"id": example.get("id", f"example_{i}"),
|
189
|
+
"problem": question,
|
190
|
+
"correct_answer": correct_answer,
|
191
|
+
"response": response,
|
192
|
+
"extracted_answer": extracted["extracted_answer"],
|
193
|
+
"confidence": extracted["confidence"],
|
194
|
+
"processing_time": processing_time,
|
195
|
+
"sources": search_result.get("sources", []),
|
196
|
+
"search_config": search_params,
|
197
|
+
}
|
198
|
+
|
199
|
+
# Add to results list
|
200
|
+
results.append(result)
|
201
|
+
|
202
|
+
# Write result to file
|
203
|
+
with open(results_file, "a") as f:
|
204
|
+
f.write(json.dumps(result) + "\n")
|
205
|
+
|
206
|
+
except Exception as e:
|
207
|
+
logger.error(f"Error processing example {i + 1}: {str(e)}")
|
208
|
+
|
209
|
+
# Create error result
|
210
|
+
error_result = {
|
211
|
+
"id": example.get("id", f"example_{i}"),
|
212
|
+
"problem": question,
|
213
|
+
"correct_answer": correct_answer,
|
214
|
+
"error": str(e),
|
215
|
+
"processing_time": 0,
|
216
|
+
}
|
217
|
+
|
218
|
+
# Add to results list
|
219
|
+
results.append(error_result)
|
220
|
+
|
221
|
+
# Write error result to file
|
222
|
+
with open(results_file, "a") as f:
|
223
|
+
f.write(json.dumps(error_result) + "\n")
|
224
|
+
|
225
|
+
# Grade results
|
226
|
+
from ..graders import grade_results
|
227
|
+
evaluation_results = grade_results(
|
228
|
+
results_file=results_file,
|
229
|
+
output_file=evaluation_file,
|
230
|
+
dataset_type="simpleqa",
|
231
|
+
)
|
232
|
+
|
233
|
+
# Calculate metrics
|
234
|
+
metrics = calculate_metrics(evaluation_file)
|
235
|
+
|
236
|
+
# Generate report
|
237
|
+
dataset_name = "SimpleQA"
|
238
|
+
report_path = generate_report(
|
239
|
+
metrics=metrics,
|
240
|
+
results_file=evaluation_file,
|
241
|
+
output_file=report_file,
|
242
|
+
dataset_name=dataset_name,
|
243
|
+
config_info={
|
244
|
+
"Dataset": "SimpleQA",
|
245
|
+
"Examples": len(examples),
|
246
|
+
"Iterations": search_params.get("iterations", 3),
|
247
|
+
"Questions per iteration": search_params.get("questions_per_iteration", 3),
|
248
|
+
"Search tool": search_params.get("search_tool", "searxng"),
|
249
|
+
"Search strategy": search_params.get("search_strategy", "source_based"),
|
250
|
+
},
|
251
|
+
)
|
252
|
+
|
253
|
+
# Return results
|
254
|
+
return {
|
255
|
+
"status": "complete",
|
256
|
+
"dataset_type": "simpleqa",
|
257
|
+
"results_path": results_file,
|
258
|
+
"evaluation_path": evaluation_file,
|
259
|
+
"report_path": report_path,
|
260
|
+
"metrics": metrics,
|
261
|
+
"total_examples": len(examples),
|
262
|
+
"accuracy": metrics.get("accuracy", 0),
|
263
|
+
}
|
264
|
+
|
265
|
+
except Exception as e:
|
266
|
+
logger.error(f"Error in direct dataset evaluation: {str(e)}")
|
267
|
+
return {
|
268
|
+
"status": "error",
|
269
|
+
"dataset_type": "simpleqa",
|
270
|
+
"error": str(e),
|
271
|
+
}
|
@@ -0,0 +1,410 @@
|
|
1
|
+
"""
|
2
|
+
Evaluation and grading functionality.
|
3
|
+
|
4
|
+
This module provides tools for evaluating model outputs against reference answers.
|
5
|
+
"""
|
6
|
+
|
7
|
+
import json
|
8
|
+
import logging
|
9
|
+
import os
|
10
|
+
import re
|
11
|
+
from typing import Any, Callable, Dict, List, Optional
|
12
|
+
|
13
|
+
from langchain.schema import HumanMessage
|
14
|
+
|
15
|
+
from ..config.llm_config import get_llm
|
16
|
+
from .templates import BROWSECOMP_GRADER_TEMPLATE, SIMPLEQA_GRADER_TEMPLATE
|
17
|
+
|
18
|
+
logger = logging.getLogger(__name__)
|
19
|
+
|
20
|
+
# Default evaluation configuration using Claude 3.7 Sonnet via OpenRouter
|
21
|
+
DEFAULT_EVALUATION_CONFIG = {
|
22
|
+
"model_name": "anthropic/claude-3.7-sonnet", # Correct model ID for OpenRouter
|
23
|
+
"provider": "openai_endpoint", # Use OpenRouter
|
24
|
+
"openai_endpoint_url": "https://openrouter.ai/api/v1", # OpenRouter URL
|
25
|
+
"temperature": 0, # Zero temp for consistent evaluation
|
26
|
+
# Note: max_tokens removed as it's not supported by LDR's get_llm()
|
27
|
+
}
|
28
|
+
|
29
|
+
|
30
|
+
def get_evaluation_llm(custom_config: Optional[Dict[str, Any]] = None):
|
31
|
+
"""
|
32
|
+
Get an LLM for evaluation purposes using Claude 3.7 Sonnet via OpenRouter
|
33
|
+
by default, which can be overridden with custom settings.
|
34
|
+
|
35
|
+
Args:
|
36
|
+
custom_config: Optional custom configuration that overrides defaults
|
37
|
+
|
38
|
+
Returns:
|
39
|
+
An LLM instance for evaluation
|
40
|
+
"""
|
41
|
+
# Start with default config (Claude 3.7 Sonnet via OpenRouter)
|
42
|
+
config = DEFAULT_EVALUATION_CONFIG.copy()
|
43
|
+
|
44
|
+
# Override with any custom settings
|
45
|
+
if custom_config:
|
46
|
+
config.update(custom_config)
|
47
|
+
|
48
|
+
logger.info(
|
49
|
+
f"Getting evaluation LLM with provider={config['provider']}, model={config['model_name']}"
|
50
|
+
)
|
51
|
+
|
52
|
+
# Remove any parameters that LDR's get_llm doesn't support
|
53
|
+
# This ensures compatibility with LDR's implementation
|
54
|
+
ldr_supported_params = {
|
55
|
+
"model_name",
|
56
|
+
"temperature",
|
57
|
+
"provider",
|
58
|
+
"openai_endpoint_url",
|
59
|
+
"api_key",
|
60
|
+
}
|
61
|
+
|
62
|
+
filtered_config = {k: v for k, v in config.items() if k in ldr_supported_params}
|
63
|
+
|
64
|
+
# Check if we're using openai_endpoint but don't have an API key configured
|
65
|
+
if filtered_config.get("provider") == "openai_endpoint":
|
66
|
+
# Try to get API key from environment or config
|
67
|
+
import os
|
68
|
+
|
69
|
+
api_key = os.getenv("OPENAI_ENDPOINT_API_KEY")
|
70
|
+
if not api_key:
|
71
|
+
logger.warning(
|
72
|
+
"Using openai_endpoint provider but no API key found. "
|
73
|
+
"Set the OPENAI_ENDPOINT_API_KEY environment variable or "
|
74
|
+
"specify api_key in the evaluation_config."
|
75
|
+
)
|
76
|
+
# Try to fall back to LDR's config if API key not explicitly provided
|
77
|
+
# The get_llm function will handle this case
|
78
|
+
|
79
|
+
# Get the LLM using LDR's existing function
|
80
|
+
return get_llm(**filtered_config)
|
81
|
+
|
82
|
+
|
83
|
+
def extract_answer_from_response(
|
84
|
+
response: str, dataset_type: str = "simpleqa"
|
85
|
+
) -> Dict[str, str]:
|
86
|
+
"""
|
87
|
+
Extract structured information from LDR's response.
|
88
|
+
|
89
|
+
Args:
|
90
|
+
response: Response from LDR
|
91
|
+
dataset_type: Type of dataset
|
92
|
+
|
93
|
+
Returns:
|
94
|
+
Dictionary with extracted answer and confidence
|
95
|
+
"""
|
96
|
+
# Clean up citations
|
97
|
+
response = re.sub(r"\[\d+\]", "", response)
|
98
|
+
|
99
|
+
# Extract differently based on dataset type
|
100
|
+
if dataset_type.lower() == "browsecomp":
|
101
|
+
# Extract the final answer from structured response
|
102
|
+
answer_match = re.search(r"Exact Answer:\s*(.*?)(?:\n|$)", response)
|
103
|
+
exact_answer = answer_match.group(1).strip() if answer_match else "None"
|
104
|
+
|
105
|
+
# Extract confidence
|
106
|
+
confidence_match = re.search(r"Confidence:\s*(\d+)%", response)
|
107
|
+
confidence = confidence_match.group(1) if confidence_match else "100"
|
108
|
+
|
109
|
+
return {"extracted_answer": exact_answer, "confidence": confidence}
|
110
|
+
|
111
|
+
# For SimpleQA, return the whole response as the answer
|
112
|
+
return {
|
113
|
+
"extracted_answer": response,
|
114
|
+
"confidence": "100", # SimpleQA doesn't have confidence scores
|
115
|
+
}
|
116
|
+
|
117
|
+
|
118
|
+
def grade_results(
|
119
|
+
results_file: str,
|
120
|
+
output_file: str,
|
121
|
+
dataset_type: str = "simpleqa",
|
122
|
+
evaluation_config: Optional[Dict[str, Any]] = None,
|
123
|
+
progress_callback: Optional[Callable[[int, int, Dict], None]] = None,
|
124
|
+
) -> List[Dict[str, Any]]:
|
125
|
+
"""
|
126
|
+
Grade benchmark results using LLM.
|
127
|
+
|
128
|
+
Args:
|
129
|
+
results_file: Path to results file
|
130
|
+
output_file: Path to save graded results
|
131
|
+
dataset_type: Type of dataset
|
132
|
+
evaluation_config: Optional custom config for evaluation LLM
|
133
|
+
progress_callback: Optional callback for progress updates
|
134
|
+
|
135
|
+
Returns:
|
136
|
+
List of graded results
|
137
|
+
"""
|
138
|
+
# Get evaluation LLM
|
139
|
+
evaluation_llm = get_evaluation_llm(evaluation_config)
|
140
|
+
|
141
|
+
# Select appropriate template
|
142
|
+
template = (
|
143
|
+
BROWSECOMP_GRADER_TEMPLATE
|
144
|
+
if dataset_type.lower() == "browsecomp"
|
145
|
+
else SIMPLEQA_GRADER_TEMPLATE
|
146
|
+
)
|
147
|
+
|
148
|
+
# Load results
|
149
|
+
results = []
|
150
|
+
with open(results_file, "r") as f:
|
151
|
+
for line in f:
|
152
|
+
if line.strip():
|
153
|
+
results.append(json.loads(line))
|
154
|
+
|
155
|
+
# Remove output file if it exists
|
156
|
+
if os.path.exists(output_file):
|
157
|
+
os.remove(output_file)
|
158
|
+
|
159
|
+
graded_results = []
|
160
|
+
correct_count = 0
|
161
|
+
|
162
|
+
# Process each result
|
163
|
+
for idx, result in enumerate(results):
|
164
|
+
question = result.get("problem", "")
|
165
|
+
correct_answer = result.get("correct_answer", "")
|
166
|
+
response = result.get("response", "")
|
167
|
+
|
168
|
+
# Call progress callback if provided
|
169
|
+
if progress_callback:
|
170
|
+
progress_callback(
|
171
|
+
idx,
|
172
|
+
len(results),
|
173
|
+
{"status": "grading", "index": idx, "total": len(results)},
|
174
|
+
)
|
175
|
+
|
176
|
+
logger.info(f"Grading {idx + 1}/{len(results)}: {question[:50]}...")
|
177
|
+
|
178
|
+
# Format grading prompt
|
179
|
+
grading_prompt = template.format(
|
180
|
+
question=question, correct_answer=correct_answer, response=response
|
181
|
+
)
|
182
|
+
|
183
|
+
try:
|
184
|
+
# Grade using LLM
|
185
|
+
if hasattr(evaluation_llm, "invoke") and callable(evaluation_llm.invoke):
|
186
|
+
if hasattr(evaluation_llm, "chat_messages"):
|
187
|
+
# Handle ChatOpenAI and similar models that use messages
|
188
|
+
grading_response = evaluation_llm.invoke(
|
189
|
+
[HumanMessage(content=grading_prompt)]
|
190
|
+
).content
|
191
|
+
else:
|
192
|
+
# Handle other LLM types
|
193
|
+
grading_response = evaluation_llm.invoke(grading_prompt)
|
194
|
+
if hasattr(grading_response, "content"):
|
195
|
+
grading_response = grading_response.content
|
196
|
+
else:
|
197
|
+
# Fallback for other LLM interfaces
|
198
|
+
grading_response = str(evaluation_llm(grading_prompt))
|
199
|
+
|
200
|
+
# Extract grading information using regex
|
201
|
+
if dataset_type.lower() == "browsecomp":
|
202
|
+
# BrowseComp-specific extraction
|
203
|
+
extracted_answer_match = re.search(
|
204
|
+
r"extracted_final_answer:\s*(.*?)(?:\n|$)", grading_response
|
205
|
+
)
|
206
|
+
extracted_answer = (
|
207
|
+
extracted_answer_match.group(1).strip()
|
208
|
+
if extracted_answer_match
|
209
|
+
else "None"
|
210
|
+
)
|
211
|
+
|
212
|
+
reasoning_match = re.search(
|
213
|
+
r"reasoning:\s*(.*?)(?:\n\n|\ncorrect:|\Z)",
|
214
|
+
grading_response,
|
215
|
+
re.DOTALL,
|
216
|
+
)
|
217
|
+
reasoning = reasoning_match.group(1).strip() if reasoning_match else ""
|
218
|
+
|
219
|
+
correct_match = re.search(
|
220
|
+
r"correct:\s*(yes|no)", grading_response, re.IGNORECASE
|
221
|
+
)
|
222
|
+
is_correct = (
|
223
|
+
(correct_match.group(1).lower() == "yes")
|
224
|
+
if correct_match
|
225
|
+
else False
|
226
|
+
)
|
227
|
+
|
228
|
+
confidence_match = re.search(r"confidence:\s*(\d+)", grading_response)
|
229
|
+
confidence = confidence_match.group(1) if confidence_match else "100"
|
230
|
+
else:
|
231
|
+
# SimpleQA extraction
|
232
|
+
extracted_answer_match = re.search(
|
233
|
+
r"Extracted Answer:\s*(.*?)(?:\n|$)", grading_response
|
234
|
+
)
|
235
|
+
extracted_answer = (
|
236
|
+
extracted_answer_match.group(1).strip()
|
237
|
+
if extracted_answer_match
|
238
|
+
else "None"
|
239
|
+
)
|
240
|
+
|
241
|
+
reasoning_match = re.search(
|
242
|
+
r"Reasoning:\s*(.*?)(?:\nCorrect:|\Z)", grading_response, re.DOTALL
|
243
|
+
)
|
244
|
+
reasoning = reasoning_match.group(1).strip() if reasoning_match else ""
|
245
|
+
|
246
|
+
correct_match = re.search(
|
247
|
+
r"Correct:\s*(yes|no)", grading_response, re.IGNORECASE
|
248
|
+
)
|
249
|
+
is_correct = (
|
250
|
+
(correct_match.group(1).lower() == "yes")
|
251
|
+
if correct_match
|
252
|
+
else False
|
253
|
+
)
|
254
|
+
|
255
|
+
confidence = "100" # SimpleQA doesn't have confidence
|
256
|
+
|
257
|
+
if is_correct:
|
258
|
+
correct_count += 1
|
259
|
+
|
260
|
+
# Format graded result
|
261
|
+
graded_result = result.copy()
|
262
|
+
graded_result.update(
|
263
|
+
{
|
264
|
+
"extracted_by_grader": extracted_answer,
|
265
|
+
"reasoning": reasoning,
|
266
|
+
"is_correct": is_correct,
|
267
|
+
"graded_confidence": confidence,
|
268
|
+
"grader_response": grading_response,
|
269
|
+
}
|
270
|
+
)
|
271
|
+
|
272
|
+
graded_results.append(graded_result)
|
273
|
+
|
274
|
+
# Write to output file
|
275
|
+
with open(output_file, "a") as f:
|
276
|
+
f.write(json.dumps(graded_result) + "\n")
|
277
|
+
|
278
|
+
# Call progress callback if provided
|
279
|
+
if progress_callback:
|
280
|
+
progress_callback(
|
281
|
+
idx,
|
282
|
+
len(results),
|
283
|
+
{
|
284
|
+
"status": "graded",
|
285
|
+
"is_correct": is_correct,
|
286
|
+
"result": graded_result,
|
287
|
+
},
|
288
|
+
)
|
289
|
+
|
290
|
+
except Exception as e:
|
291
|
+
logger.error(f"Error grading result {idx + 1}: {str(e)}")
|
292
|
+
|
293
|
+
# Handle error
|
294
|
+
error_result = result.copy()
|
295
|
+
error_result["grading_error"] = str(e)
|
296
|
+
|
297
|
+
with open(output_file, "a") as f:
|
298
|
+
f.write(json.dumps(error_result) + "\n")
|
299
|
+
|
300
|
+
graded_results.append(error_result)
|
301
|
+
|
302
|
+
# Call progress callback if provided
|
303
|
+
if progress_callback:
|
304
|
+
progress_callback(
|
305
|
+
idx,
|
306
|
+
len(results),
|
307
|
+
{"status": "error", "error": str(e), "result": error_result},
|
308
|
+
)
|
309
|
+
|
310
|
+
accuracy = correct_count / len(results) if results else 0
|
311
|
+
logger.info(f"Grading complete. Accuracy: {accuracy:.3f}")
|
312
|
+
logger.info(f"Correct: {correct_count}/{len(results)}")
|
313
|
+
|
314
|
+
return graded_results
|
315
|
+
|
316
|
+
|
317
|
+
def human_evaluation(
|
318
|
+
results_file: str, output_file: str, interactive: bool = True
|
319
|
+
) -> List[Dict[str, Any]]:
|
320
|
+
"""
|
321
|
+
Allow for human evaluation of results.
|
322
|
+
|
323
|
+
Args:
|
324
|
+
results_file: Path to results file
|
325
|
+
output_file: Path to save human-graded results
|
326
|
+
interactive: Whether to run in interactive console mode
|
327
|
+
|
328
|
+
Returns:
|
329
|
+
List of human-graded results
|
330
|
+
"""
|
331
|
+
# Load results
|
332
|
+
results = []
|
333
|
+
with open(results_file, "r") as f:
|
334
|
+
for line in f:
|
335
|
+
if line.strip():
|
336
|
+
results.append(json.loads(line))
|
337
|
+
|
338
|
+
# Remove output file if it exists
|
339
|
+
if os.path.exists(output_file):
|
340
|
+
os.remove(output_file)
|
341
|
+
|
342
|
+
human_graded_results = []
|
343
|
+
correct_count = 0
|
344
|
+
|
345
|
+
if interactive:
|
346
|
+
logger.info(f"Human evaluation: {len(results)} examples to grade")
|
347
|
+
print(f"Human evaluation: {len(results)} examples to grade")
|
348
|
+
print(
|
349
|
+
"For each example, you'll see the question, correct answer, and model's response."
|
350
|
+
)
|
351
|
+
print("You'll be asked to judge if the model's answer is correct.")
|
352
|
+
|
353
|
+
for idx, result in enumerate(results):
|
354
|
+
question = result.get("problem", "")
|
355
|
+
correct_answer = result.get("correct_answer", "")
|
356
|
+
response = result.get("response", "")
|
357
|
+
extracted_answer = result.get("extracted_answer", "")
|
358
|
+
|
359
|
+
if interactive:
|
360
|
+
print(f"\n\n===== Example {idx + 1}/{len(results)} =====")
|
361
|
+
print(f"Question: {question}")
|
362
|
+
print(f"\nCorrect Answer: {correct_answer}")
|
363
|
+
print(f"\nModel Response: {response}")
|
364
|
+
print(f"\nExtracted Answer: {extracted_answer}")
|
365
|
+
|
366
|
+
# Get human judgment
|
367
|
+
while True:
|
368
|
+
judgment = (
|
369
|
+
input("\nIs the model's answer correct? (y/n): ").strip().lower()
|
370
|
+
)
|
371
|
+
if judgment in ["y", "n"]:
|
372
|
+
break
|
373
|
+
print("Please enter 'y' or 'n'")
|
374
|
+
|
375
|
+
is_correct = judgment == "y"
|
376
|
+
|
377
|
+
# Get reasoning
|
378
|
+
reasoning = input("Please provide reasoning for your judgment: ").strip()
|
379
|
+
else:
|
380
|
+
# Non-interactive mode - placeholder for API/UI implementation
|
381
|
+
# In a real implementation, this would be filled by UI actions
|
382
|
+
is_correct = False
|
383
|
+
reasoning = "Non-interactive evaluation"
|
384
|
+
|
385
|
+
if is_correct:
|
386
|
+
correct_count += 1
|
387
|
+
|
388
|
+
# Update result with human judgment
|
389
|
+
human_result = result.copy()
|
390
|
+
human_result.update(
|
391
|
+
{
|
392
|
+
"is_correct": is_correct,
|
393
|
+
"reasoning": reasoning,
|
394
|
+
"human_evaluation": True,
|
395
|
+
}
|
396
|
+
)
|
397
|
+
|
398
|
+
human_graded_results.append(human_result)
|
399
|
+
|
400
|
+
# Write to output file
|
401
|
+
with open(output_file, "a") as f:
|
402
|
+
f.write(json.dumps(human_result) + "\n")
|
403
|
+
|
404
|
+
accuracy = correct_count / len(results) if results else 0
|
405
|
+
logger.info(f"Human evaluation complete. Accuracy: {accuracy:.3f}")
|
406
|
+
if interactive:
|
407
|
+
print(f"\nHuman evaluation complete. Accuracy: {accuracy:.3f}")
|
408
|
+
print(f"Correct: {correct_count}/{len(results)}")
|
409
|
+
|
410
|
+
return human_graded_results
|