local-deep-research 0.3.12__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- local_deep_research/__init__.py +1 -0
- local_deep_research/__version__.py +1 -1
- local_deep_research/advanced_search_system/filters/base_filter.py +2 -3
- local_deep_research/advanced_search_system/filters/cross_engine_filter.py +4 -5
- local_deep_research/advanced_search_system/filters/journal_reputation_filter.py +298 -0
- local_deep_research/advanced_search_system/findings/repository.py +0 -3
- local_deep_research/advanced_search_system/strategies/base_strategy.py +1 -2
- local_deep_research/advanced_search_system/strategies/iterdrag_strategy.py +14 -18
- local_deep_research/advanced_search_system/strategies/parallel_search_strategy.py +4 -8
- local_deep_research/advanced_search_system/strategies/rapid_search_strategy.py +5 -6
- local_deep_research/advanced_search_system/strategies/source_based_strategy.py +2 -2
- local_deep_research/advanced_search_system/strategies/standard_strategy.py +9 -7
- local_deep_research/api/benchmark_functions.py +288 -0
- local_deep_research/api/research_functions.py +8 -4
- local_deep_research/benchmarks/README.md +162 -0
- local_deep_research/benchmarks/__init__.py +51 -0
- local_deep_research/benchmarks/benchmark_functions.py +353 -0
- local_deep_research/benchmarks/cli/__init__.py +16 -0
- local_deep_research/benchmarks/cli/benchmark_commands.py +338 -0
- local_deep_research/benchmarks/cli.py +347 -0
- local_deep_research/benchmarks/comparison/__init__.py +12 -0
- local_deep_research/benchmarks/comparison/evaluator.py +768 -0
- local_deep_research/benchmarks/datasets/__init__.py +53 -0
- local_deep_research/benchmarks/datasets/base.py +295 -0
- local_deep_research/benchmarks/datasets/browsecomp.py +116 -0
- local_deep_research/benchmarks/datasets/custom_dataset_template.py +98 -0
- local_deep_research/benchmarks/datasets/simpleqa.py +74 -0
- local_deep_research/benchmarks/datasets/utils.py +116 -0
- local_deep_research/benchmarks/datasets.py +31 -0
- local_deep_research/benchmarks/efficiency/__init__.py +14 -0
- local_deep_research/benchmarks/efficiency/resource_monitor.py +367 -0
- local_deep_research/benchmarks/efficiency/speed_profiler.py +214 -0
- local_deep_research/benchmarks/evaluators/__init__.py +18 -0
- local_deep_research/benchmarks/evaluators/base.py +74 -0
- local_deep_research/benchmarks/evaluators/browsecomp.py +83 -0
- local_deep_research/benchmarks/evaluators/composite.py +121 -0
- local_deep_research/benchmarks/evaluators/simpleqa.py +271 -0
- local_deep_research/benchmarks/graders.py +410 -0
- local_deep_research/benchmarks/metrics/README.md +80 -0
- local_deep_research/benchmarks/metrics/__init__.py +24 -0
- local_deep_research/benchmarks/metrics/calculation.py +385 -0
- local_deep_research/benchmarks/metrics/reporting.py +155 -0
- local_deep_research/benchmarks/metrics/visualization.py +205 -0
- local_deep_research/benchmarks/metrics.py +11 -0
- local_deep_research/benchmarks/optimization/__init__.py +32 -0
- local_deep_research/benchmarks/optimization/api.py +274 -0
- local_deep_research/benchmarks/optimization/metrics.py +20 -0
- local_deep_research/benchmarks/optimization/optuna_optimizer.py +1163 -0
- local_deep_research/benchmarks/runners.py +434 -0
- local_deep_research/benchmarks/templates.py +65 -0
- local_deep_research/config/llm_config.py +26 -23
- local_deep_research/config/search_config.py +1 -5
- local_deep_research/defaults/default_settings.json +108 -7
- local_deep_research/search_system.py +16 -8
- local_deep_research/utilities/db_utils.py +3 -6
- local_deep_research/utilities/es_utils.py +441 -0
- local_deep_research/utilities/log_utils.py +36 -0
- local_deep_research/utilities/search_utilities.py +8 -9
- local_deep_research/web/app.py +15 -10
- local_deep_research/web/app_factory.py +9 -12
- local_deep_research/web/database/migrations.py +8 -5
- local_deep_research/web/database/models.py +20 -0
- local_deep_research/web/database/schema_upgrade.py +5 -8
- local_deep_research/web/models/database.py +15 -18
- local_deep_research/web/routes/benchmark_routes.py +427 -0
- local_deep_research/web/routes/research_routes.py +13 -17
- local_deep_research/web/routes/settings_routes.py +264 -67
- local_deep_research/web/services/research_service.py +58 -73
- local_deep_research/web/services/settings_manager.py +1 -4
- local_deep_research/web/services/settings_service.py +4 -6
- local_deep_research/web/static/css/styles.css +12 -0
- local_deep_research/web/static/js/components/logpanel.js +164 -155
- local_deep_research/web/static/js/components/research.js +44 -3
- local_deep_research/web/static/js/components/settings.js +27 -0
- local_deep_research/web/static/js/services/socket.js +47 -0
- local_deep_research/web_search_engines/default_search_engines.py +38 -0
- local_deep_research/web_search_engines/engines/meta_search_engine.py +100 -33
- local_deep_research/web_search_engines/engines/search_engine_arxiv.py +31 -17
- local_deep_research/web_search_engines/engines/search_engine_brave.py +8 -3
- local_deep_research/web_search_engines/engines/search_engine_elasticsearch.py +343 -0
- local_deep_research/web_search_engines/engines/search_engine_google_pse.py +14 -6
- local_deep_research/web_search_engines/engines/search_engine_local.py +19 -23
- local_deep_research/web_search_engines/engines/search_engine_local_all.py +9 -12
- local_deep_research/web_search_engines/engines/search_engine_searxng.py +12 -17
- local_deep_research/web_search_engines/engines/search_engine_serpapi.py +8 -4
- local_deep_research/web_search_engines/search_engine_base.py +22 -5
- local_deep_research/web_search_engines/search_engine_factory.py +30 -11
- local_deep_research/web_search_engines/search_engines_config.py +14 -1
- {local_deep_research-0.3.12.dist-info → local_deep_research-0.4.1.dist-info}/METADATA +10 -2
- {local_deep_research-0.3.12.dist-info → local_deep_research-0.4.1.dist-info}/RECORD +93 -51
- local_deep_research/app.py +0 -8
- {local_deep_research-0.3.12.dist-info → local_deep_research-0.4.1.dist-info}/WHEEL +0 -0
- {local_deep_research-0.3.12.dist-info → local_deep_research-0.4.1.dist-info}/entry_points.txt +0 -0
- {local_deep_research-0.3.12.dist-info → local_deep_research-0.4.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,1163 @@
|
|
1
|
+
"""
|
2
|
+
Optuna-based parameter optimizer for Local Deep Research.
|
3
|
+
|
4
|
+
This module provides the core optimization functionality using Optuna
|
5
|
+
to find optimal parameters for the research system, balancing quality
|
6
|
+
and performance metrics.
|
7
|
+
"""
|
8
|
+
|
9
|
+
import json
|
10
|
+
import logging
|
11
|
+
import os
|
12
|
+
import time
|
13
|
+
from datetime import datetime
|
14
|
+
from functools import partial
|
15
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple
|
16
|
+
|
17
|
+
import joblib
|
18
|
+
import numpy as np
|
19
|
+
import optuna
|
20
|
+
from optuna.visualization import (
|
21
|
+
plot_contour,
|
22
|
+
plot_optimization_history,
|
23
|
+
plot_param_importances,
|
24
|
+
plot_slice,
|
25
|
+
)
|
26
|
+
|
27
|
+
from local_deep_research.benchmarks.efficiency.speed_profiler import SpeedProfiler
|
28
|
+
from local_deep_research.benchmarks.evaluators import CompositeBenchmarkEvaluator
|
29
|
+
|
30
|
+
# Import benchmark evaluator components
|
31
|
+
|
32
|
+
logger = logging.getLogger(__name__)
|
33
|
+
|
34
|
+
# Try to import visualization libraries, but don't fail if not available
|
35
|
+
try:
|
36
|
+
import matplotlib.pyplot as plt
|
37
|
+
from matplotlib.lines import Line2D
|
38
|
+
|
39
|
+
# We'll use matplotlib for plotting visualization results
|
40
|
+
|
41
|
+
PLOTTING_AVAILABLE = True
|
42
|
+
except ImportError:
|
43
|
+
PLOTTING_AVAILABLE = False
|
44
|
+
logger.warning("Matplotlib not available, visualization will be limited")
|
45
|
+
|
46
|
+
|
47
|
+
class OptunaOptimizer:
|
48
|
+
"""
|
49
|
+
Optimize parameters for Local Deep Research using Optuna.
|
50
|
+
|
51
|
+
This class provides functionality to:
|
52
|
+
1. Define search spaces for parameter optimization
|
53
|
+
2. Evaluate parameter combinations using objective functions
|
54
|
+
3. Find optimal parameters via Optuna
|
55
|
+
4. Visualize and analyze optimization results
|
56
|
+
"""
|
57
|
+
|
58
|
+
def __init__(
|
59
|
+
self,
|
60
|
+
base_query: str,
|
61
|
+
output_dir: str = "optimization_results",
|
62
|
+
model_name: Optional[str] = None,
|
63
|
+
provider: Optional[str] = None,
|
64
|
+
search_tool: Optional[str] = None,
|
65
|
+
temperature: float = 0.7,
|
66
|
+
n_trials: int = 30,
|
67
|
+
timeout: Optional[int] = None,
|
68
|
+
n_jobs: int = 1,
|
69
|
+
study_name: Optional[str] = None,
|
70
|
+
optimization_metrics: Optional[List[str]] = None,
|
71
|
+
metric_weights: Optional[Dict[str, float]] = None,
|
72
|
+
progress_callback: Optional[Callable[[int, int, Dict], None]] = None,
|
73
|
+
benchmark_weights: Optional[Dict[str, float]] = None,
|
74
|
+
):
|
75
|
+
"""
|
76
|
+
Initialize the optimizer.
|
77
|
+
|
78
|
+
Args:
|
79
|
+
base_query: The research query to use for all experiments
|
80
|
+
output_dir: Directory to save optimization results
|
81
|
+
model_name: Name of the LLM model to use
|
82
|
+
provider: LLM provider
|
83
|
+
search_tool: Search engine to use
|
84
|
+
temperature: LLM temperature
|
85
|
+
n_trials: Number of parameter combinations to try
|
86
|
+
timeout: Maximum seconds to run optimization (None for no limit)
|
87
|
+
n_jobs: Number of parallel jobs for optimization
|
88
|
+
study_name: Name of the Optuna study
|
89
|
+
optimization_metrics: List of metrics to optimize (default: ["quality", "speed"])
|
90
|
+
metric_weights: Dictionary of weights for each metric (e.g., {"quality": 0.6, "speed": 0.4})
|
91
|
+
progress_callback: Optional callback for progress updates
|
92
|
+
benchmark_weights: Dictionary mapping benchmark types to weights
|
93
|
+
(e.g., {"simpleqa": 0.6, "browsecomp": 0.4})
|
94
|
+
If None, only SimpleQA is used with weight 1.0
|
95
|
+
"""
|
96
|
+
self.base_query = base_query
|
97
|
+
self.output_dir = output_dir
|
98
|
+
self.model_name = model_name
|
99
|
+
self.provider = provider
|
100
|
+
self.search_tool = search_tool
|
101
|
+
self.temperature = temperature
|
102
|
+
self.n_trials = n_trials
|
103
|
+
self.timeout = timeout
|
104
|
+
self.n_jobs = n_jobs
|
105
|
+
self.optimization_metrics = optimization_metrics or ["quality", "speed"]
|
106
|
+
self.metric_weights = metric_weights or {"quality": 0.6, "speed": 0.4}
|
107
|
+
self.progress_callback = progress_callback
|
108
|
+
|
109
|
+
# Initialize benchmark evaluator with weights
|
110
|
+
self.benchmark_weights = benchmark_weights or {"simpleqa": 1.0}
|
111
|
+
self.benchmark_evaluator = CompositeBenchmarkEvaluator(self.benchmark_weights)
|
112
|
+
|
113
|
+
# Normalize weights to sum to 1.0
|
114
|
+
total_weight = sum(self.metric_weights.values())
|
115
|
+
if total_weight > 0:
|
116
|
+
self.metric_weights = {
|
117
|
+
k: v / total_weight for k, v in self.metric_weights.items()
|
118
|
+
}
|
119
|
+
|
120
|
+
# Generate a unique study name if not provided
|
121
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
122
|
+
self.study_name = study_name or f"ldr_opt_{timestamp}"
|
123
|
+
|
124
|
+
# Create output directory
|
125
|
+
os.makedirs(output_dir, exist_ok=True)
|
126
|
+
|
127
|
+
# Store the trial history for analysis
|
128
|
+
self.trials_history = []
|
129
|
+
|
130
|
+
# Storage for the best parameters and study
|
131
|
+
self.best_params = None
|
132
|
+
self.study = None
|
133
|
+
|
134
|
+
def optimize(
|
135
|
+
self, param_space: Optional[Dict[str, Any]] = None
|
136
|
+
) -> Tuple[Dict[str, Any], float]:
|
137
|
+
"""
|
138
|
+
Run the optimization process using Optuna.
|
139
|
+
|
140
|
+
Args:
|
141
|
+
param_space: Dictionary defining parameter search spaces
|
142
|
+
(if None, use default spaces)
|
143
|
+
|
144
|
+
Returns:
|
145
|
+
Tuple containing (best_parameters, best_score)
|
146
|
+
"""
|
147
|
+
param_space = param_space or self._get_default_param_space()
|
148
|
+
|
149
|
+
# Create a study object
|
150
|
+
storage_name = f"sqlite:///{self.output_dir}/{self.study_name}.db"
|
151
|
+
self.study = optuna.create_study(
|
152
|
+
study_name=self.study_name,
|
153
|
+
storage=storage_name,
|
154
|
+
load_if_exists=True,
|
155
|
+
direction="maximize",
|
156
|
+
sampler=optuna.samplers.TPESampler(seed=42),
|
157
|
+
)
|
158
|
+
|
159
|
+
# Create partial function with param_space
|
160
|
+
objective = partial(self._objective, param_space=param_space)
|
161
|
+
|
162
|
+
# Log optimization start
|
163
|
+
logger.info(
|
164
|
+
f"Starting optimization with {self.n_trials} trials, {self.n_jobs} parallel jobs"
|
165
|
+
)
|
166
|
+
logger.info(f"Parameter space: {param_space}")
|
167
|
+
logger.info(f"Metric weights: {self.metric_weights}")
|
168
|
+
logger.info(f"Benchmark weights: {self.benchmark_weights}")
|
169
|
+
|
170
|
+
# Initialize progress tracking
|
171
|
+
if self.progress_callback:
|
172
|
+
self.progress_callback(
|
173
|
+
0,
|
174
|
+
self.n_trials,
|
175
|
+
{
|
176
|
+
"status": "starting",
|
177
|
+
"stage": "initialization",
|
178
|
+
"trials_completed": 0,
|
179
|
+
"total_trials": self.n_trials,
|
180
|
+
},
|
181
|
+
)
|
182
|
+
|
183
|
+
try:
|
184
|
+
# Run optimization
|
185
|
+
self.study.optimize(
|
186
|
+
objective,
|
187
|
+
n_trials=self.n_trials,
|
188
|
+
timeout=self.timeout,
|
189
|
+
n_jobs=self.n_jobs,
|
190
|
+
callbacks=[self._optimization_callback],
|
191
|
+
show_progress_bar=True,
|
192
|
+
)
|
193
|
+
|
194
|
+
# Store best parameters
|
195
|
+
self.best_params = self.study.best_params
|
196
|
+
|
197
|
+
# Save the results
|
198
|
+
self._save_results()
|
199
|
+
|
200
|
+
# Create visualizations
|
201
|
+
self._create_visualizations()
|
202
|
+
|
203
|
+
logger.info(f"Optimization complete. Best parameters: {self.best_params}")
|
204
|
+
logger.info(f"Best value: {self.study.best_value}")
|
205
|
+
|
206
|
+
# Report completion
|
207
|
+
if self.progress_callback:
|
208
|
+
self.progress_callback(
|
209
|
+
self.n_trials,
|
210
|
+
self.n_trials,
|
211
|
+
{
|
212
|
+
"status": "completed",
|
213
|
+
"stage": "finished",
|
214
|
+
"trials_completed": len(self.study.trials),
|
215
|
+
"total_trials": self.n_trials,
|
216
|
+
"best_params": self.best_params,
|
217
|
+
"best_value": self.study.best_value,
|
218
|
+
},
|
219
|
+
)
|
220
|
+
|
221
|
+
return self.best_params, self.study.best_value
|
222
|
+
|
223
|
+
except KeyboardInterrupt:
|
224
|
+
logger.info("Optimization interrupted by user")
|
225
|
+
# Still save what we have
|
226
|
+
self._save_results()
|
227
|
+
self._create_visualizations()
|
228
|
+
|
229
|
+
# Report interruption
|
230
|
+
if self.progress_callback:
|
231
|
+
self.progress_callback(
|
232
|
+
len(self.study.trials),
|
233
|
+
self.n_trials,
|
234
|
+
{
|
235
|
+
"status": "interrupted",
|
236
|
+
"stage": "interrupted",
|
237
|
+
"trials_completed": len(self.study.trials),
|
238
|
+
"total_trials": self.n_trials,
|
239
|
+
"best_params": self.study.best_params,
|
240
|
+
"best_value": self.study.best_value,
|
241
|
+
},
|
242
|
+
)
|
243
|
+
|
244
|
+
return self.study.best_params, self.study.best_value
|
245
|
+
|
246
|
+
def _get_default_param_space(self) -> Dict[str, Any]:
|
247
|
+
"""
|
248
|
+
Get default parameter search space.
|
249
|
+
|
250
|
+
Returns:
|
251
|
+
Dictionary defining the default parameter search spaces
|
252
|
+
"""
|
253
|
+
return {
|
254
|
+
"iterations": {
|
255
|
+
"type": "int",
|
256
|
+
"low": 1,
|
257
|
+
"high": 5,
|
258
|
+
"step": 1,
|
259
|
+
},
|
260
|
+
"questions_per_iteration": {
|
261
|
+
"type": "int",
|
262
|
+
"low": 1,
|
263
|
+
"high": 5,
|
264
|
+
"step": 1,
|
265
|
+
},
|
266
|
+
"search_strategy": {
|
267
|
+
"type": "categorical",
|
268
|
+
"choices": [
|
269
|
+
"iterdrag",
|
270
|
+
"standard",
|
271
|
+
"rapid",
|
272
|
+
"parallel",
|
273
|
+
"source_based",
|
274
|
+
],
|
275
|
+
},
|
276
|
+
"max_results": {
|
277
|
+
"type": "int",
|
278
|
+
"low": 10,
|
279
|
+
"high": 100,
|
280
|
+
"step": 10,
|
281
|
+
},
|
282
|
+
}
|
283
|
+
|
284
|
+
def _objective(self, trial: optuna.Trial, param_space: Dict[str, Any]) -> float:
|
285
|
+
"""
|
286
|
+
Objective function for Optuna optimization.
|
287
|
+
|
288
|
+
Args:
|
289
|
+
trial: Optuna trial object
|
290
|
+
param_space: Dictionary defining parameter search spaces
|
291
|
+
|
292
|
+
Returns:
|
293
|
+
Score to maximize
|
294
|
+
"""
|
295
|
+
# Generate parameters for this trial
|
296
|
+
params = {}
|
297
|
+
for param_name, param_config in param_space.items():
|
298
|
+
param_type = param_config["type"]
|
299
|
+
|
300
|
+
if param_type == "int":
|
301
|
+
params[param_name] = trial.suggest_int(
|
302
|
+
param_name,
|
303
|
+
param_config["low"],
|
304
|
+
param_config["high"],
|
305
|
+
step=param_config.get("step", 1),
|
306
|
+
)
|
307
|
+
elif param_type == "float":
|
308
|
+
params[param_name] = trial.suggest_float(
|
309
|
+
param_name,
|
310
|
+
param_config["low"],
|
311
|
+
param_config["high"],
|
312
|
+
step=param_config.get("step"),
|
313
|
+
log=param_config.get("log", False),
|
314
|
+
)
|
315
|
+
elif param_type == "categorical":
|
316
|
+
params[param_name] = trial.suggest_categorical(
|
317
|
+
param_name, param_config["choices"]
|
318
|
+
)
|
319
|
+
|
320
|
+
# Log the trial parameters
|
321
|
+
logger.info(f"Trial {trial.number}: {params}")
|
322
|
+
|
323
|
+
# Update progress callback if available
|
324
|
+
if self.progress_callback:
|
325
|
+
self.progress_callback(
|
326
|
+
trial.number,
|
327
|
+
self.n_trials,
|
328
|
+
{
|
329
|
+
"status": "running",
|
330
|
+
"stage": "trial_started",
|
331
|
+
"trial_number": trial.number,
|
332
|
+
"params": params,
|
333
|
+
"trials_completed": trial.number,
|
334
|
+
"total_trials": self.n_trials,
|
335
|
+
},
|
336
|
+
)
|
337
|
+
|
338
|
+
# Run an experiment with these parameters
|
339
|
+
try:
|
340
|
+
start_time = time.time()
|
341
|
+
result = self._run_experiment(params)
|
342
|
+
duration = time.time() - start_time
|
343
|
+
|
344
|
+
# Store details about the trial
|
345
|
+
trial_info = {
|
346
|
+
"trial_number": trial.number,
|
347
|
+
"params": params,
|
348
|
+
"result": result,
|
349
|
+
"score": result.get("score", 0),
|
350
|
+
"duration": duration,
|
351
|
+
"timestamp": datetime.now().isoformat(),
|
352
|
+
}
|
353
|
+
self.trials_history.append(trial_info)
|
354
|
+
|
355
|
+
# Update callback with results
|
356
|
+
if self.progress_callback:
|
357
|
+
self.progress_callback(
|
358
|
+
trial.number,
|
359
|
+
self.n_trials,
|
360
|
+
{
|
361
|
+
"status": "completed",
|
362
|
+
"stage": "trial_completed",
|
363
|
+
"trial_number": trial.number,
|
364
|
+
"params": params,
|
365
|
+
"score": result.get("score", 0),
|
366
|
+
"trials_completed": trial.number + 1,
|
367
|
+
"total_trials": self.n_trials,
|
368
|
+
},
|
369
|
+
)
|
370
|
+
|
371
|
+
logger.info(
|
372
|
+
f"Trial {trial.number} completed: {params}, score: {result['score']:.4f}"
|
373
|
+
)
|
374
|
+
|
375
|
+
return result["score"]
|
376
|
+
except Exception as e:
|
377
|
+
logger.error(f"Error in trial {trial.number}: {str(e)}")
|
378
|
+
|
379
|
+
# Update callback with error
|
380
|
+
if self.progress_callback:
|
381
|
+
self.progress_callback(
|
382
|
+
trial.number,
|
383
|
+
self.n_trials,
|
384
|
+
{
|
385
|
+
"status": "error",
|
386
|
+
"stage": "trial_error",
|
387
|
+
"trial_number": trial.number,
|
388
|
+
"params": params,
|
389
|
+
"error": str(e),
|
390
|
+
"trials_completed": trial.number,
|
391
|
+
"total_trials": self.n_trials,
|
392
|
+
},
|
393
|
+
)
|
394
|
+
|
395
|
+
return float("-inf") # Return a very low score for failed trials
|
396
|
+
|
397
|
+
def _run_experiment(self, params: Dict[str, Any]) -> Dict[str, Any]:
|
398
|
+
"""
|
399
|
+
Run a single experiment with the given parameters.
|
400
|
+
|
401
|
+
Args:
|
402
|
+
params: Dictionary of parameters to test
|
403
|
+
|
404
|
+
Returns:
|
405
|
+
Results dictionary with metrics and score
|
406
|
+
"""
|
407
|
+
# Extract parameters
|
408
|
+
iterations = params.get("iterations", 2)
|
409
|
+
questions_per_iteration = params.get("questions_per_iteration", 2)
|
410
|
+
search_strategy = params.get("search_strategy", "iterdrag")
|
411
|
+
max_results = params.get("max_results", 50)
|
412
|
+
|
413
|
+
# Initialize profiling tools
|
414
|
+
speed_profiler = SpeedProfiler()
|
415
|
+
|
416
|
+
# Start profiling
|
417
|
+
speed_profiler.start()
|
418
|
+
|
419
|
+
try:
|
420
|
+
# Create system configuration
|
421
|
+
system_config = {
|
422
|
+
"iterations": iterations,
|
423
|
+
"questions_per_iteration": questions_per_iteration,
|
424
|
+
"search_strategy": search_strategy,
|
425
|
+
"search_tool": self.search_tool,
|
426
|
+
"max_results": max_results,
|
427
|
+
"model_name": self.model_name,
|
428
|
+
"provider": self.provider,
|
429
|
+
}
|
430
|
+
|
431
|
+
# Evaluate quality using composite benchmark evaluator
|
432
|
+
# Use a small number of examples for efficiency
|
433
|
+
benchmark_dir = os.path.join(self.output_dir, "benchmark_temp")
|
434
|
+
quality_results = self.benchmark_evaluator.evaluate(
|
435
|
+
system_config=system_config,
|
436
|
+
num_examples=5, # Small number for optimization efficiency
|
437
|
+
output_dir=benchmark_dir,
|
438
|
+
)
|
439
|
+
|
440
|
+
# Stop timing
|
441
|
+
speed_profiler.stop()
|
442
|
+
timing_results = speed_profiler.get_summary()
|
443
|
+
|
444
|
+
# Extract key metrics
|
445
|
+
quality_score = quality_results.get("quality_score", 0.0)
|
446
|
+
benchmark_results = quality_results.get("benchmark_results", {})
|
447
|
+
|
448
|
+
# Speed score: convert duration to a 0-1 score where faster is better
|
449
|
+
# Using a reasonable threshold (e.g., 180 seconds for 5 examples)
|
450
|
+
# Below this threshold: high score, above it: declining score
|
451
|
+
total_duration = timing_results.get("total_duration", 180)
|
452
|
+
speed_score = max(0.0, min(1.0, 1.0 - (total_duration - 60) / 180))
|
453
|
+
|
454
|
+
# Calculate combined score based on weights
|
455
|
+
combined_score = (
|
456
|
+
self.metric_weights.get("quality", 0.6) * quality_score
|
457
|
+
+ self.metric_weights.get("speed", 0.4) * speed_score
|
458
|
+
)
|
459
|
+
|
460
|
+
# Return streamlined results
|
461
|
+
return {
|
462
|
+
"quality_score": quality_score,
|
463
|
+
"benchmark_results": benchmark_results,
|
464
|
+
"speed_score": speed_score,
|
465
|
+
"total_duration": total_duration,
|
466
|
+
"score": combined_score,
|
467
|
+
"success": True,
|
468
|
+
}
|
469
|
+
|
470
|
+
except Exception as e:
|
471
|
+
# Stop profiling on error
|
472
|
+
speed_profiler.stop()
|
473
|
+
|
474
|
+
# Log error
|
475
|
+
logger.error(f"Error in experiment: {str(e)}")
|
476
|
+
|
477
|
+
# Return error information
|
478
|
+
return {"error": str(e), "score": 0.0, "success": False}
|
479
|
+
|
480
|
+
def _optimization_callback(self, study: optuna.Study, trial: optuna.Trial):
|
481
|
+
"""
|
482
|
+
Callback for the Optuna optimization process.
|
483
|
+
|
484
|
+
Args:
|
485
|
+
study: Optuna study object
|
486
|
+
trial: Current trial
|
487
|
+
"""
|
488
|
+
# Save intermediate results periodically
|
489
|
+
if trial.number % 10 == 0 and trial.number > 0:
|
490
|
+
self._save_results()
|
491
|
+
self._create_quick_visualizations()
|
492
|
+
|
493
|
+
def _save_results(self):
|
494
|
+
"""Save the optimization results to disk."""
|
495
|
+
# Create a timestamp for filenames
|
496
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
497
|
+
|
498
|
+
# Save trial history
|
499
|
+
history_file = os.path.join(self.output_dir, f"{self.study_name}_history.json")
|
500
|
+
with open(history_file, "w") as f:
|
501
|
+
# Convert numpy values to native Python types for JSON serialization
|
502
|
+
clean_history = []
|
503
|
+
for trial in self.trials_history:
|
504
|
+
clean_trial = {}
|
505
|
+
for k, v in trial.items():
|
506
|
+
if isinstance(v, dict):
|
507
|
+
clean_trial[k] = {
|
508
|
+
dk: (float(dv) if isinstance(dv, np.number) else dv)
|
509
|
+
for dk, dv in v.items()
|
510
|
+
}
|
511
|
+
elif isinstance(v, np.number):
|
512
|
+
clean_trial[k] = float(v)
|
513
|
+
else:
|
514
|
+
clean_trial[k] = v
|
515
|
+
clean_history.append(clean_trial)
|
516
|
+
|
517
|
+
json.dump(clean_history, f, indent=2)
|
518
|
+
|
519
|
+
# Save current best parameters
|
520
|
+
if self.study and hasattr(self.study, "best_params") and self.study.best_params:
|
521
|
+
best_params_file = os.path.join(
|
522
|
+
self.output_dir, f"{self.study_name}_best_params.json"
|
523
|
+
)
|
524
|
+
with open(best_params_file, "w") as f:
|
525
|
+
json.dump(
|
526
|
+
{
|
527
|
+
"best_params": self.study.best_params,
|
528
|
+
"best_value": float(self.study.best_value),
|
529
|
+
"n_trials": len(self.study.trials),
|
530
|
+
"timestamp": timestamp,
|
531
|
+
"base_query": self.base_query,
|
532
|
+
"model_name": self.model_name,
|
533
|
+
"provider": self.provider,
|
534
|
+
"search_tool": self.search_tool,
|
535
|
+
"metric_weights": self.metric_weights,
|
536
|
+
"benchmark_weights": self.benchmark_weights,
|
537
|
+
},
|
538
|
+
f,
|
539
|
+
indent=2,
|
540
|
+
)
|
541
|
+
|
542
|
+
# Save the Optuna study
|
543
|
+
if self.study:
|
544
|
+
study_file = os.path.join(self.output_dir, f"{self.study_name}_study.pkl")
|
545
|
+
joblib.dump(self.study, study_file)
|
546
|
+
|
547
|
+
logger.info(f"Results saved to {self.output_dir}")
|
548
|
+
|
549
|
+
def _create_visualizations(self):
|
550
|
+
"""Create and save comprehensive visualizations of the optimization results."""
|
551
|
+
if not PLOTTING_AVAILABLE:
|
552
|
+
logger.warning("Matplotlib not available, skipping visualization creation")
|
553
|
+
return
|
554
|
+
|
555
|
+
if not self.study or len(self.study.trials) < 2:
|
556
|
+
logger.warning("Not enough trials to create visualizations")
|
557
|
+
return
|
558
|
+
|
559
|
+
# Create directory for visualizations
|
560
|
+
viz_dir = os.path.join(self.output_dir, "visualizations")
|
561
|
+
os.makedirs(viz_dir, exist_ok=True)
|
562
|
+
|
563
|
+
# Create Optuna visualizations
|
564
|
+
self._create_optuna_visualizations(viz_dir)
|
565
|
+
|
566
|
+
# Create custom visualizations
|
567
|
+
self._create_custom_visualizations(viz_dir)
|
568
|
+
|
569
|
+
logger.info(f"Visualizations saved to {viz_dir}")
|
570
|
+
|
571
|
+
def _create_quick_visualizations(self):
|
572
|
+
"""Create a smaller set of visualizations for intermediate progress."""
|
573
|
+
if not PLOTTING_AVAILABLE or not self.study or len(self.study.trials) < 2:
|
574
|
+
return
|
575
|
+
|
576
|
+
# Create directory for visualizations
|
577
|
+
viz_dir = os.path.join(self.output_dir, "visualizations")
|
578
|
+
os.makedirs(viz_dir, exist_ok=True)
|
579
|
+
|
580
|
+
# Create optimization history only (faster than full visualization)
|
581
|
+
try:
|
582
|
+
fig = plot_optimization_history(self.study)
|
583
|
+
fig.write_image(
|
584
|
+
os.path.join(
|
585
|
+
viz_dir, f"{self.study_name}_optimization_history_current.png"
|
586
|
+
)
|
587
|
+
)
|
588
|
+
except Exception as e:
|
589
|
+
logger.error(f"Error creating optimization history plot: {str(e)}")
|
590
|
+
|
591
|
+
def _create_optuna_visualizations(self, viz_dir: str):
|
592
|
+
"""
|
593
|
+
Create and save Optuna's built-in visualizations.
|
594
|
+
|
595
|
+
Args:
|
596
|
+
viz_dir: Directory to save visualizations
|
597
|
+
"""
|
598
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
599
|
+
|
600
|
+
# 1. Optimization history
|
601
|
+
try:
|
602
|
+
fig = plot_optimization_history(self.study)
|
603
|
+
fig.write_image(
|
604
|
+
os.path.join(
|
605
|
+
viz_dir, f"{self.study_name}_optimization_history_{timestamp}.png"
|
606
|
+
)
|
607
|
+
)
|
608
|
+
except Exception as e:
|
609
|
+
logger.error(f"Error creating optimization history plot: {str(e)}")
|
610
|
+
|
611
|
+
# 2. Parameter importances
|
612
|
+
try:
|
613
|
+
fig = plot_param_importances(self.study)
|
614
|
+
fig.write_image(
|
615
|
+
os.path.join(
|
616
|
+
viz_dir, f"{self.study_name}_param_importances_{timestamp}.png"
|
617
|
+
)
|
618
|
+
)
|
619
|
+
except Exception as e:
|
620
|
+
logger.error(f"Error creating parameter importances plot: {str(e)}")
|
621
|
+
|
622
|
+
# 3. Slice plot for each parameter
|
623
|
+
try:
|
624
|
+
for param_name in self.study.best_params.keys():
|
625
|
+
fig = plot_slice(self.study, [param_name])
|
626
|
+
fig.write_image(
|
627
|
+
os.path.join(
|
628
|
+
viz_dir, f"{self.study_name}_slice_{param_name}_{timestamp}.png"
|
629
|
+
)
|
630
|
+
)
|
631
|
+
except Exception as e:
|
632
|
+
logger.error(f"Error creating slice plots: {str(e)}")
|
633
|
+
|
634
|
+
# 4. Contour plots for important parameter pairs
|
635
|
+
try:
|
636
|
+
# Get all parameter names
|
637
|
+
param_names = list(self.study.best_params.keys())
|
638
|
+
|
639
|
+
# Create contour plots for each pair
|
640
|
+
for i in range(len(param_names)):
|
641
|
+
for j in range(i + 1, len(param_names)):
|
642
|
+
try:
|
643
|
+
fig = plot_contour(
|
644
|
+
self.study, params=[param_names[i], param_names[j]]
|
645
|
+
)
|
646
|
+
fig.write_image(
|
647
|
+
os.path.join(
|
648
|
+
viz_dir,
|
649
|
+
f"{self.study_name}_contour_{param_names[i]}_{param_names[j]}_{timestamp}.png",
|
650
|
+
)
|
651
|
+
)
|
652
|
+
except Exception as e:
|
653
|
+
logger.warning(
|
654
|
+
f"Error creating contour plot for {param_names[i]} vs {param_names[j]}: {str(e)}"
|
655
|
+
)
|
656
|
+
except Exception as e:
|
657
|
+
logger.error(f"Error creating contour plots: {str(e)}")
|
658
|
+
|
659
|
+
def _create_custom_visualizations(self, viz_dir: str):
|
660
|
+
"""
|
661
|
+
Create custom visualizations based on trial history.
|
662
|
+
|
663
|
+
Args:
|
664
|
+
viz_dir: Directory to save visualizations
|
665
|
+
"""
|
666
|
+
if not self.trials_history:
|
667
|
+
return
|
668
|
+
|
669
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
670
|
+
|
671
|
+
# Create quality vs speed plot
|
672
|
+
self._create_quality_vs_speed_plot(viz_dir, timestamp)
|
673
|
+
|
674
|
+
# Create parameter evolution plots
|
675
|
+
self._create_parameter_evolution_plots(viz_dir, timestamp)
|
676
|
+
|
677
|
+
# Create trial duration vs score plot
|
678
|
+
self._create_duration_vs_score_plot(viz_dir, timestamp)
|
679
|
+
|
680
|
+
def _create_quality_vs_speed_plot(self, viz_dir: str, timestamp: str):
|
681
|
+
"""Create a plot showing quality vs. speed trade-off."""
|
682
|
+
if not self.trials_history:
|
683
|
+
return
|
684
|
+
|
685
|
+
# Extract data from successful trials
|
686
|
+
successful_trials = [
|
687
|
+
t for t in self.trials_history if t.get("result", {}).get("success", False)
|
688
|
+
]
|
689
|
+
|
690
|
+
if not successful_trials:
|
691
|
+
logger.warning("No successful trials for visualization")
|
692
|
+
return
|
693
|
+
|
694
|
+
try:
|
695
|
+
plt.figure(figsize=(10, 8))
|
696
|
+
|
697
|
+
# Extract metrics
|
698
|
+
quality_scores = []
|
699
|
+
speed_scores = []
|
700
|
+
labels = []
|
701
|
+
iterations_values = []
|
702
|
+
questions_values = []
|
703
|
+
|
704
|
+
for trial in successful_trials:
|
705
|
+
result = trial["result"]
|
706
|
+
quality = result.get("quality_score", 0)
|
707
|
+
speed = result.get("speed_score", 0)
|
708
|
+
iterations = trial["params"].get("iterations", 0)
|
709
|
+
questions = trial["params"].get("questions_per_iteration", 0)
|
710
|
+
|
711
|
+
quality_scores.append(quality)
|
712
|
+
speed_scores.append(speed)
|
713
|
+
labels.append(f"Trial {trial['trial_number']}")
|
714
|
+
iterations_values.append(iterations)
|
715
|
+
questions_values.append(questions)
|
716
|
+
|
717
|
+
# Create scatter plot with size based on iterations*questions
|
718
|
+
sizes = [i * q * 5 for i, q in zip(iterations_values, questions_values)]
|
719
|
+
scatter = plt.scatter(
|
720
|
+
quality_scores,
|
721
|
+
speed_scores,
|
722
|
+
s=sizes,
|
723
|
+
alpha=0.7,
|
724
|
+
c=range(len(quality_scores)),
|
725
|
+
cmap="viridis",
|
726
|
+
)
|
727
|
+
|
728
|
+
# Highlight best trial
|
729
|
+
best_trial = max(
|
730
|
+
successful_trials, key=lambda x: x.get("result", {}).get("score", 0)
|
731
|
+
)
|
732
|
+
best_quality = best_trial["result"].get("quality_score", 0)
|
733
|
+
best_speed = best_trial["result"].get("speed_score", 0)
|
734
|
+
best_iter = best_trial["params"].get("iterations", 0)
|
735
|
+
best_questions = best_trial["params"].get("questions_per_iteration", 0)
|
736
|
+
|
737
|
+
plt.scatter(
|
738
|
+
[best_quality],
|
739
|
+
[best_speed],
|
740
|
+
s=200,
|
741
|
+
facecolors="none",
|
742
|
+
edgecolors="red",
|
743
|
+
linewidth=2,
|
744
|
+
label=f"Best: {best_iter}×{best_questions}",
|
745
|
+
)
|
746
|
+
|
747
|
+
# Add annotations for key points
|
748
|
+
for i, (q, s, l) in enumerate(zip(quality_scores, speed_scores, labels)):
|
749
|
+
if i % max(1, len(quality_scores) // 5) == 0: # Label ~5 points
|
750
|
+
plt.annotate(
|
751
|
+
f"{iterations_values[i]}×{questions_values[i]}",
|
752
|
+
(q, s),
|
753
|
+
xytext=(5, 5),
|
754
|
+
textcoords="offset points",
|
755
|
+
)
|
756
|
+
|
757
|
+
# Add colorbar and labels
|
758
|
+
cbar = plt.colorbar(scatter)
|
759
|
+
cbar.set_label("Trial Progression")
|
760
|
+
|
761
|
+
# Add benchmark weight information
|
762
|
+
weights_str = ", ".join(
|
763
|
+
[f"{k}:{v:.1f}" for k, v in self.benchmark_weights.items()]
|
764
|
+
)
|
765
|
+
plt.title(f"Quality vs. Speed Trade-off\nBenchmark Weights: {weights_str}")
|
766
|
+
plt.xlabel("Quality Score (Benchmark Accuracy)")
|
767
|
+
plt.ylabel("Speed Score")
|
768
|
+
plt.grid(True, linestyle="--", alpha=0.7)
|
769
|
+
|
770
|
+
# Add legend explaining size
|
771
|
+
legend_elements = [
|
772
|
+
Line2D(
|
773
|
+
[0],
|
774
|
+
[0],
|
775
|
+
marker="o",
|
776
|
+
color="w",
|
777
|
+
markerfacecolor="gray",
|
778
|
+
markersize=np.sqrt(n * 5 / np.pi),
|
779
|
+
label=f"{n} Total Questions",
|
780
|
+
)
|
781
|
+
for n in [5, 10, 15, 20, 25]
|
782
|
+
]
|
783
|
+
plt.legend(handles=legend_elements, title="Workload")
|
784
|
+
|
785
|
+
# Save the figure
|
786
|
+
plt.tight_layout()
|
787
|
+
plt.savefig(
|
788
|
+
os.path.join(
|
789
|
+
viz_dir, f"{self.study_name}_quality_vs_speed_{timestamp}.png"
|
790
|
+
)
|
791
|
+
)
|
792
|
+
plt.close()
|
793
|
+
except Exception as e:
|
794
|
+
logger.error(f"Error creating quality vs speed plot: {str(e)}")
|
795
|
+
|
796
|
+
def _create_parameter_evolution_plots(self, viz_dir: str, timestamp: str):
|
797
|
+
"""Create plots showing how parameter values evolve over trials."""
|
798
|
+
try:
|
799
|
+
successful_trials = [
|
800
|
+
t
|
801
|
+
for t in self.trials_history
|
802
|
+
if t.get("result", {}).get("success", False)
|
803
|
+
]
|
804
|
+
|
805
|
+
if not successful_trials or len(successful_trials) < 5:
|
806
|
+
return
|
807
|
+
|
808
|
+
# Get key parameters
|
809
|
+
main_params = list(successful_trials[0]["params"].keys())
|
810
|
+
|
811
|
+
# For each parameter, plot its values over trials
|
812
|
+
for param_name in main_params:
|
813
|
+
plt.figure(figsize=(12, 6))
|
814
|
+
|
815
|
+
trial_numbers = []
|
816
|
+
param_values = []
|
817
|
+
scores = []
|
818
|
+
|
819
|
+
for trial in self.trials_history:
|
820
|
+
if "params" in trial and param_name in trial["params"]:
|
821
|
+
trial_numbers.append(trial["trial_number"])
|
822
|
+
param_values.append(trial["params"][param_name])
|
823
|
+
scores.append(trial.get("score", 0))
|
824
|
+
|
825
|
+
# Create evolution plot
|
826
|
+
scatter = plt.scatter(
|
827
|
+
trial_numbers,
|
828
|
+
param_values,
|
829
|
+
c=scores,
|
830
|
+
cmap="plasma",
|
831
|
+
alpha=0.8,
|
832
|
+
s=80,
|
833
|
+
)
|
834
|
+
|
835
|
+
# Add best trial marker
|
836
|
+
best_trial_idx = scores.index(max(scores))
|
837
|
+
plt.scatter(
|
838
|
+
[trial_numbers[best_trial_idx]],
|
839
|
+
[param_values[best_trial_idx]],
|
840
|
+
s=150,
|
841
|
+
facecolors="none",
|
842
|
+
edgecolors="red",
|
843
|
+
linewidth=2,
|
844
|
+
label=f"Best Value: {param_values[best_trial_idx]}",
|
845
|
+
)
|
846
|
+
|
847
|
+
# Add colorbar
|
848
|
+
cbar = plt.colorbar(scatter)
|
849
|
+
cbar.set_label("Score")
|
850
|
+
|
851
|
+
# Set chart properties
|
852
|
+
plt.title(f"Evolution of {param_name} Values")
|
853
|
+
plt.xlabel("Trial Number")
|
854
|
+
plt.ylabel(param_name)
|
855
|
+
plt.grid(True, linestyle="--", alpha=0.7)
|
856
|
+
plt.legend()
|
857
|
+
|
858
|
+
# For categorical parameters, adjust y-axis
|
859
|
+
if isinstance(param_values[0], str):
|
860
|
+
unique_values = sorted(set(param_values))
|
861
|
+
plt.yticks(range(len(unique_values)), unique_values)
|
862
|
+
|
863
|
+
# Save the figure
|
864
|
+
plt.tight_layout()
|
865
|
+
plt.savefig(
|
866
|
+
os.path.join(
|
867
|
+
viz_dir,
|
868
|
+
f"{self.study_name}_param_evolution_{param_name}_{timestamp}.png",
|
869
|
+
)
|
870
|
+
)
|
871
|
+
plt.close()
|
872
|
+
except Exception as e:
|
873
|
+
logger.error(f"Error creating parameter evolution plots: {str(e)}")
|
874
|
+
|
875
|
+
def _create_duration_vs_score_plot(self, viz_dir: str, timestamp: str):
|
876
|
+
"""Create a plot showing trial duration vs score."""
|
877
|
+
try:
|
878
|
+
plt.figure(figsize=(10, 6))
|
879
|
+
|
880
|
+
successful_trials = [
|
881
|
+
t
|
882
|
+
for t in self.trials_history
|
883
|
+
if t.get("result", {}).get("success", False)
|
884
|
+
]
|
885
|
+
|
886
|
+
if not successful_trials:
|
887
|
+
return
|
888
|
+
|
889
|
+
trial_durations = []
|
890
|
+
trial_scores = []
|
891
|
+
trial_iterations = []
|
892
|
+
trial_questions = []
|
893
|
+
|
894
|
+
for trial in successful_trials:
|
895
|
+
duration = trial.get("duration", 0)
|
896
|
+
score = trial.get("score", 0)
|
897
|
+
iterations = trial.get("params", {}).get("iterations", 1)
|
898
|
+
questions = trial.get("params", {}).get("questions_per_iteration", 1)
|
899
|
+
|
900
|
+
trial_durations.append(duration)
|
901
|
+
trial_scores.append(score)
|
902
|
+
trial_iterations.append(iterations)
|
903
|
+
trial_questions.append(questions)
|
904
|
+
|
905
|
+
# Total questions per trial
|
906
|
+
total_questions = [i * q for i, q in zip(trial_iterations, trial_questions)]
|
907
|
+
|
908
|
+
# Create scatter plot with size based on total questions
|
909
|
+
plt.scatter(
|
910
|
+
trial_durations,
|
911
|
+
trial_scores,
|
912
|
+
s=[q * 5 for q in total_questions], # Size based on total questions
|
913
|
+
alpha=0.7,
|
914
|
+
c=range(len(trial_durations)),
|
915
|
+
cmap="viridis",
|
916
|
+
)
|
917
|
+
|
918
|
+
# Add labels
|
919
|
+
plt.xlabel("Trial Duration (seconds)")
|
920
|
+
plt.ylabel("Score")
|
921
|
+
plt.title("Trial Duration vs. Score")
|
922
|
+
plt.grid(True, linestyle="--", alpha=0.7)
|
923
|
+
|
924
|
+
# Add trial number annotations for selected points
|
925
|
+
for i, (d, s) in enumerate(zip(trial_durations, trial_scores)):
|
926
|
+
if i % max(1, len(trial_durations) // 5) == 0: # Annotate ~5 points
|
927
|
+
plt.annotate(
|
928
|
+
f"{trial_iterations[i]}×{trial_questions[i]}",
|
929
|
+
(d, s),
|
930
|
+
xytext=(5, 5),
|
931
|
+
textcoords="offset points",
|
932
|
+
)
|
933
|
+
|
934
|
+
# Save the figure
|
935
|
+
plt.tight_layout()
|
936
|
+
plt.savefig(
|
937
|
+
os.path.join(
|
938
|
+
viz_dir, f"{self.study_name}_duration_vs_score_{timestamp}.png"
|
939
|
+
)
|
940
|
+
)
|
941
|
+
plt.close()
|
942
|
+
except Exception as e:
|
943
|
+
logger.error(f"Error creating duration vs score plot: {str(e)}")
|
944
|
+
|
945
|
+
|
946
|
+
def optimize_parameters(
|
947
|
+
query: str,
|
948
|
+
param_space: Optional[Dict[str, Any]] = None,
|
949
|
+
output_dir: str = os.path.join("data", "optimization_results"),
|
950
|
+
model_name: Optional[str] = None,
|
951
|
+
provider: Optional[str] = None,
|
952
|
+
search_tool: Optional[str] = None,
|
953
|
+
temperature: float = 0.7,
|
954
|
+
n_trials: int = 30,
|
955
|
+
timeout: Optional[int] = None,
|
956
|
+
n_jobs: int = 1,
|
957
|
+
study_name: Optional[str] = None,
|
958
|
+
optimization_metrics: Optional[List[str]] = None,
|
959
|
+
metric_weights: Optional[Dict[str, float]] = None,
|
960
|
+
progress_callback: Optional[Callable[[int, int, Dict], None]] = None,
|
961
|
+
benchmark_weights: Optional[Dict[str, float]] = None,
|
962
|
+
) -> Tuple[Dict[str, Any], float]:
|
963
|
+
"""
|
964
|
+
Optimize parameters for Local Deep Research.
|
965
|
+
|
966
|
+
Args:
|
967
|
+
query: The research query to use for all experiments
|
968
|
+
param_space: Dictionary defining parameter search spaces (optional)
|
969
|
+
output_dir: Directory to save optimization results
|
970
|
+
model_name: Name of the LLM model to use
|
971
|
+
provider: LLM provider
|
972
|
+
search_tool: Search engine to use
|
973
|
+
temperature: LLM temperature
|
974
|
+
n_trials: Number of parameter combinations to try
|
975
|
+
timeout: Maximum seconds to run optimization (None for no limit)
|
976
|
+
n_jobs: Number of parallel jobs for optimization
|
977
|
+
study_name: Name of the Optuna study
|
978
|
+
optimization_metrics: List of metrics to optimize (default: ["quality", "speed"])
|
979
|
+
metric_weights: Dictionary of weights for each metric (e.g., {"quality": 0.6, "speed": 0.4})
|
980
|
+
progress_callback: Optional callback for progress updates
|
981
|
+
benchmark_weights: Dictionary mapping benchmark types to weights
|
982
|
+
(e.g., {"simpleqa": 0.6, "browsecomp": 0.4})
|
983
|
+
If None, only SimpleQA is used with weight 1.0
|
984
|
+
|
985
|
+
Returns:
|
986
|
+
Tuple of (best_parameters, best_score)
|
987
|
+
"""
|
988
|
+
# Create optimizer
|
989
|
+
optimizer = OptunaOptimizer(
|
990
|
+
base_query=query,
|
991
|
+
output_dir=output_dir,
|
992
|
+
model_name=model_name,
|
993
|
+
provider=provider,
|
994
|
+
search_tool=search_tool,
|
995
|
+
temperature=temperature,
|
996
|
+
n_trials=n_trials,
|
997
|
+
timeout=timeout,
|
998
|
+
n_jobs=n_jobs,
|
999
|
+
study_name=study_name,
|
1000
|
+
optimization_metrics=optimization_metrics,
|
1001
|
+
metric_weights=metric_weights,
|
1002
|
+
progress_callback=progress_callback,
|
1003
|
+
benchmark_weights=benchmark_weights,
|
1004
|
+
)
|
1005
|
+
|
1006
|
+
# Run optimization
|
1007
|
+
return optimizer.optimize(param_space)
|
1008
|
+
|
1009
|
+
|
1010
|
+
def optimize_for_speed(
|
1011
|
+
query: str,
|
1012
|
+
n_trials: int = 20,
|
1013
|
+
output_dir: str = os.path.join("data", "optimization_results"),
|
1014
|
+
model_name: Optional[str] = None,
|
1015
|
+
provider: Optional[str] = None,
|
1016
|
+
search_tool: Optional[str] = None,
|
1017
|
+
progress_callback: Optional[Callable[[int, int, Dict], None]] = None,
|
1018
|
+
benchmark_weights: Optional[Dict[str, float]] = None,
|
1019
|
+
) -> Tuple[Dict[str, Any], float]:
|
1020
|
+
"""
|
1021
|
+
Optimize parameters with a focus on speed performance.
|
1022
|
+
|
1023
|
+
Args:
|
1024
|
+
query: The research query to use for all experiments
|
1025
|
+
n_trials: Number of parameter combinations to try
|
1026
|
+
output_dir: Directory to save optimization results
|
1027
|
+
model_name: Name of the LLM model to use
|
1028
|
+
provider: LLM provider
|
1029
|
+
search_tool: Search engine to use
|
1030
|
+
progress_callback: Optional callback for progress updates
|
1031
|
+
benchmark_weights: Dictionary mapping benchmark types to weights
|
1032
|
+
(e.g., {"simpleqa": 0.6, "browsecomp": 0.4})
|
1033
|
+
If None, only SimpleQA is used with weight 1.0
|
1034
|
+
|
1035
|
+
Returns:
|
1036
|
+
Tuple of (best_parameters, best_score)
|
1037
|
+
"""
|
1038
|
+
# Focus on speed with reduced parameter space
|
1039
|
+
param_space = {
|
1040
|
+
"iterations": {
|
1041
|
+
"type": "int",
|
1042
|
+
"low": 1,
|
1043
|
+
"high": 3,
|
1044
|
+
"step": 1,
|
1045
|
+
},
|
1046
|
+
"questions_per_iteration": {
|
1047
|
+
"type": "int",
|
1048
|
+
"low": 1,
|
1049
|
+
"high": 3,
|
1050
|
+
"step": 1,
|
1051
|
+
},
|
1052
|
+
"search_strategy": {
|
1053
|
+
"type": "categorical",
|
1054
|
+
"choices": ["rapid", "parallel", "source_based"],
|
1055
|
+
},
|
1056
|
+
}
|
1057
|
+
|
1058
|
+
# Speed-focused weights
|
1059
|
+
metric_weights = {"speed": 0.8, "quality": 0.2}
|
1060
|
+
|
1061
|
+
return optimize_parameters(
|
1062
|
+
query=query,
|
1063
|
+
param_space=param_space,
|
1064
|
+
output_dir=output_dir,
|
1065
|
+
model_name=model_name,
|
1066
|
+
provider=provider,
|
1067
|
+
search_tool=search_tool,
|
1068
|
+
n_trials=n_trials,
|
1069
|
+
metric_weights=metric_weights,
|
1070
|
+
optimization_metrics=["speed", "quality"],
|
1071
|
+
progress_callback=progress_callback,
|
1072
|
+
benchmark_weights=benchmark_weights,
|
1073
|
+
)
|
1074
|
+
|
1075
|
+
|
1076
|
+
def optimize_for_quality(
|
1077
|
+
query: str,
|
1078
|
+
n_trials: int = 30,
|
1079
|
+
output_dir: str = os.path.join("data", "optimization_results"),
|
1080
|
+
model_name: Optional[str] = None,
|
1081
|
+
provider: Optional[str] = None,
|
1082
|
+
search_tool: Optional[str] = None,
|
1083
|
+
progress_callback: Optional[Callable[[int, int, Dict], None]] = None,
|
1084
|
+
benchmark_weights: Optional[Dict[str, float]] = None,
|
1085
|
+
) -> Tuple[Dict[str, Any], float]:
|
1086
|
+
"""
|
1087
|
+
Optimize parameters with a focus on result quality.
|
1088
|
+
|
1089
|
+
Args:
|
1090
|
+
query: The research query to use for all experiments
|
1091
|
+
n_trials: Number of parameter combinations to try
|
1092
|
+
output_dir: Directory to save optimization results
|
1093
|
+
model_name: Name of the LLM model to use
|
1094
|
+
provider: LLM provider
|
1095
|
+
search_tool: Search engine to use
|
1096
|
+
progress_callback: Optional callback for progress updates
|
1097
|
+
benchmark_weights: Dictionary mapping benchmark types to weights
|
1098
|
+
(e.g., {"simpleqa": 0.6, "browsecomp": 0.4})
|
1099
|
+
If None, only SimpleQA is used with weight 1.0
|
1100
|
+
|
1101
|
+
Returns:
|
1102
|
+
Tuple of (best_parameters, best_score)
|
1103
|
+
"""
|
1104
|
+
# Quality-focused weights
|
1105
|
+
metric_weights = {"quality": 0.9, "speed": 0.1}
|
1106
|
+
|
1107
|
+
return optimize_parameters(
|
1108
|
+
query=query,
|
1109
|
+
output_dir=output_dir,
|
1110
|
+
model_name=model_name,
|
1111
|
+
provider=provider,
|
1112
|
+
search_tool=search_tool,
|
1113
|
+
n_trials=n_trials,
|
1114
|
+
metric_weights=metric_weights,
|
1115
|
+
optimization_metrics=["quality", "speed"],
|
1116
|
+
progress_callback=progress_callback,
|
1117
|
+
benchmark_weights=benchmark_weights,
|
1118
|
+
)
|
1119
|
+
|
1120
|
+
|
1121
|
+
def optimize_for_efficiency(
|
1122
|
+
query: str,
|
1123
|
+
n_trials: int = 25,
|
1124
|
+
output_dir: str = os.path.join("data", "optimization_results"),
|
1125
|
+
model_name: Optional[str] = None,
|
1126
|
+
provider: Optional[str] = None,
|
1127
|
+
search_tool: Optional[str] = None,
|
1128
|
+
progress_callback: Optional[Callable[[int, int, Dict], None]] = None,
|
1129
|
+
benchmark_weights: Optional[Dict[str, float]] = None,
|
1130
|
+
) -> Tuple[Dict[str, Any], float]:
|
1131
|
+
"""
|
1132
|
+
Optimize parameters with a focus on resource efficiency.
|
1133
|
+
|
1134
|
+
Args:
|
1135
|
+
query: The research query to use for all experiments
|
1136
|
+
n_trials: Number of parameter combinations to try
|
1137
|
+
output_dir: Directory to save optimization results
|
1138
|
+
model_name: Name of the LLM model to use
|
1139
|
+
provider: LLM provider
|
1140
|
+
search_tool: Search engine to use
|
1141
|
+
progress_callback: Optional callback for progress updates
|
1142
|
+
benchmark_weights: Dictionary mapping benchmark types to weights
|
1143
|
+
(e.g., {"simpleqa": 0.6, "browsecomp": 0.4})
|
1144
|
+
If None, only SimpleQA is used with weight 1.0
|
1145
|
+
|
1146
|
+
Returns:
|
1147
|
+
Tuple of (best_parameters, best_score)
|
1148
|
+
"""
|
1149
|
+
# Balance of quality, speed and resource usage
|
1150
|
+
metric_weights = {"quality": 0.4, "speed": 0.3, "resource": 0.3}
|
1151
|
+
|
1152
|
+
return optimize_parameters(
|
1153
|
+
query=query,
|
1154
|
+
output_dir=output_dir,
|
1155
|
+
model_name=model_name,
|
1156
|
+
provider=provider,
|
1157
|
+
search_tool=search_tool,
|
1158
|
+
n_trials=n_trials,
|
1159
|
+
metric_weights=metric_weights,
|
1160
|
+
optimization_metrics=["quality", "speed", "resource"],
|
1161
|
+
progress_callback=progress_callback,
|
1162
|
+
benchmark_weights=benchmark_weights,
|
1163
|
+
)
|