themis-eval 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- themis/cli/__init__.py +5 -0
- themis/cli/__main__.py +6 -0
- themis/cli/commands/__init__.py +19 -0
- themis/cli/commands/benchmarks.py +221 -0
- themis/cli/commands/comparison.py +394 -0
- themis/cli/commands/config_commands.py +244 -0
- themis/cli/commands/cost.py +214 -0
- themis/cli/commands/demo.py +68 -0
- themis/cli/commands/info.py +90 -0
- themis/cli/commands/leaderboard.py +362 -0
- themis/cli/commands/math_benchmarks.py +318 -0
- themis/cli/commands/mcq_benchmarks.py +207 -0
- themis/cli/commands/sample_run.py +244 -0
- themis/cli/commands/visualize.py +299 -0
- themis/cli/main.py +93 -0
- themis/cli/new_project.py +33 -0
- themis/cli/utils.py +51 -0
- themis/config/__init__.py +19 -0
- themis/config/loader.py +27 -0
- themis/config/registry.py +34 -0
- themis/config/runtime.py +214 -0
- themis/config/schema.py +112 -0
- themis/core/__init__.py +5 -0
- themis/core/conversation.py +354 -0
- themis/core/entities.py +164 -0
- themis/core/serialization.py +231 -0
- themis/core/tools.py +393 -0
- themis/core/types.py +141 -0
- themis/datasets/__init__.py +273 -0
- themis/datasets/base.py +264 -0
- themis/datasets/commonsense_qa.py +174 -0
- themis/datasets/competition_math.py +265 -0
- themis/datasets/coqa.py +133 -0
- themis/datasets/gpqa.py +190 -0
- themis/datasets/gsm8k.py +123 -0
- themis/datasets/gsm_symbolic.py +124 -0
- themis/datasets/math500.py +122 -0
- themis/datasets/med_qa.py +179 -0
- themis/datasets/medmcqa.py +169 -0
- themis/datasets/mmlu_pro.py +262 -0
- themis/datasets/piqa.py +146 -0
- themis/datasets/registry.py +201 -0
- themis/datasets/schema.py +245 -0
- themis/datasets/sciq.py +150 -0
- themis/datasets/social_i_qa.py +151 -0
- themis/datasets/super_gpqa.py +263 -0
- themis/evaluation/__init__.py +1 -0
- themis/evaluation/conditional.py +410 -0
- themis/evaluation/extractors/__init__.py +19 -0
- themis/evaluation/extractors/error_taxonomy_extractor.py +80 -0
- themis/evaluation/extractors/exceptions.py +7 -0
- themis/evaluation/extractors/identity_extractor.py +29 -0
- themis/evaluation/extractors/json_field_extractor.py +45 -0
- themis/evaluation/extractors/math_verify_extractor.py +37 -0
- themis/evaluation/extractors/regex_extractor.py +43 -0
- themis/evaluation/math_verify_utils.py +87 -0
- themis/evaluation/metrics/__init__.py +21 -0
- themis/evaluation/metrics/composite_metric.py +47 -0
- themis/evaluation/metrics/consistency_metric.py +80 -0
- themis/evaluation/metrics/exact_match.py +51 -0
- themis/evaluation/metrics/length_difference_tolerance.py +33 -0
- themis/evaluation/metrics/math_verify_accuracy.py +40 -0
- themis/evaluation/metrics/pairwise_judge_metric.py +141 -0
- themis/evaluation/metrics/response_length.py +33 -0
- themis/evaluation/metrics/rubric_judge_metric.py +134 -0
- themis/evaluation/pipeline.py +49 -0
- themis/evaluation/pipelines/__init__.py +15 -0
- themis/evaluation/pipelines/composable_pipeline.py +357 -0
- themis/evaluation/pipelines/standard_pipeline.py +288 -0
- themis/evaluation/reports.py +293 -0
- themis/evaluation/statistics/__init__.py +53 -0
- themis/evaluation/statistics/bootstrap.py +79 -0
- themis/evaluation/statistics/confidence_intervals.py +121 -0
- themis/evaluation/statistics/distributions.py +207 -0
- themis/evaluation/statistics/effect_sizes.py +124 -0
- themis/evaluation/statistics/hypothesis_tests.py +305 -0
- themis/evaluation/statistics/types.py +139 -0
- themis/evaluation/strategies/__init__.py +13 -0
- themis/evaluation/strategies/attempt_aware_evaluation_strategy.py +51 -0
- themis/evaluation/strategies/default_evaluation_strategy.py +25 -0
- themis/evaluation/strategies/evaluation_strategy.py +24 -0
- themis/evaluation/strategies/judge_evaluation_strategy.py +64 -0
- themis/experiment/__init__.py +5 -0
- themis/experiment/builder.py +151 -0
- themis/experiment/cache_manager.py +129 -0
- themis/experiment/comparison.py +631 -0
- themis/experiment/cost.py +310 -0
- themis/experiment/definitions.py +62 -0
- themis/experiment/export.py +690 -0
- themis/experiment/export_csv.py +159 -0
- themis/experiment/integration_manager.py +104 -0
- themis/experiment/math.py +192 -0
- themis/experiment/mcq.py +169 -0
- themis/experiment/orchestrator.py +373 -0
- themis/experiment/pricing.py +317 -0
- themis/experiment/storage.py +255 -0
- themis/experiment/visualization.py +588 -0
- themis/generation/__init__.py +1 -0
- themis/generation/agentic_runner.py +420 -0
- themis/generation/batching.py +254 -0
- themis/generation/clients.py +143 -0
- themis/generation/conversation_runner.py +236 -0
- themis/generation/plan.py +456 -0
- themis/generation/providers/litellm_provider.py +221 -0
- themis/generation/providers/vllm_provider.py +135 -0
- themis/generation/router.py +34 -0
- themis/generation/runner.py +207 -0
- themis/generation/strategies.py +98 -0
- themis/generation/templates.py +71 -0
- themis/generation/turn_strategies.py +393 -0
- themis/generation/types.py +9 -0
- themis/integrations/__init__.py +0 -0
- themis/integrations/huggingface.py +61 -0
- themis/integrations/wandb.py +65 -0
- themis/interfaces/__init__.py +83 -0
- themis/project/__init__.py +20 -0
- themis/project/definitions.py +98 -0
- themis/project/patterns.py +230 -0
- themis/providers/__init__.py +5 -0
- themis/providers/registry.py +39 -0
- themis/utils/api_generator.py +379 -0
- themis/utils/cost_tracking.py +376 -0
- themis/utils/dashboard.py +452 -0
- themis/utils/logging_utils.py +41 -0
- themis/utils/progress.py +58 -0
- themis/utils/tracing.py +320 -0
- {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/METADATA +1 -1
- themis_eval-0.1.1.dist-info/RECORD +134 -0
- themis_eval-0.1.0.dist-info/RECORD +0 -8
- {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/WHEEL +0 -0
- {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/licenses/LICENSE +0 -0
- {themis_eval-0.1.0.dist-info → themis_eval-0.1.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,588 @@
|
|
|
1
|
+
"""Interactive visualizations for experiments using Plotly."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
try:
|
|
8
|
+
import plotly.graph_objects as go
|
|
9
|
+
import plotly.express as px
|
|
10
|
+
from plotly.subplots import make_subplots
|
|
11
|
+
|
|
12
|
+
PLOTLY_AVAILABLE = True
|
|
13
|
+
except ImportError:
|
|
14
|
+
PLOTLY_AVAILABLE = False
|
|
15
|
+
go = None # type: ignore
|
|
16
|
+
px = None # type: ignore
|
|
17
|
+
make_subplots = None # type: ignore
|
|
18
|
+
|
|
19
|
+
from themis.experiment.comparison import MultiExperimentComparison
|
|
20
|
+
from themis.experiment.cost import CostBreakdown
|
|
21
|
+
from themis.evaluation.reports import EvaluationReport
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _check_plotly():
|
|
25
|
+
"""Check if plotly is available."""
|
|
26
|
+
if not PLOTLY_AVAILABLE:
|
|
27
|
+
raise ImportError(
|
|
28
|
+
"Plotly is required for interactive visualizations. "
|
|
29
|
+
"Install with: pip install plotly"
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class InteractiveVisualizer:
|
|
34
|
+
"""Create interactive visualizations for experiments using Plotly.
|
|
35
|
+
|
|
36
|
+
Example:
|
|
37
|
+
>>> visualizer = InteractiveVisualizer()
|
|
38
|
+
>>> fig = visualizer.plot_metric_comparison(comparison, "accuracy")
|
|
39
|
+
>>> fig.write_html("comparison.html")
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(self):
|
|
43
|
+
"""Initialize visualizer."""
|
|
44
|
+
_check_plotly()
|
|
45
|
+
|
|
46
|
+
def plot_metric_comparison(
|
|
47
|
+
self,
|
|
48
|
+
comparison: MultiExperimentComparison,
|
|
49
|
+
metric: str,
|
|
50
|
+
title: str | None = None,
|
|
51
|
+
show_values: bool = True,
|
|
52
|
+
) -> go.Figure:
|
|
53
|
+
"""Create bar chart comparing metric across experiments.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
comparison: Multi-experiment comparison
|
|
57
|
+
metric: Metric name to visualize
|
|
58
|
+
title: Chart title (default: "{metric} Comparison")
|
|
59
|
+
show_values: Show values on bars
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
Plotly Figure object
|
|
63
|
+
|
|
64
|
+
Example:
|
|
65
|
+
>>> fig = visualizer.plot_metric_comparison(comparison, "accuracy")
|
|
66
|
+
>>> fig.show()
|
|
67
|
+
"""
|
|
68
|
+
if metric not in comparison.metrics and metric not in ("cost", "total_cost"):
|
|
69
|
+
raise ValueError(
|
|
70
|
+
f"Metric '{metric}' not found. Available: {comparison.metrics}"
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
# Extract data
|
|
74
|
+
run_ids = [exp.run_id for exp in comparison.experiments]
|
|
75
|
+
values = [exp.get_metric(metric) or 0.0 for exp in comparison.experiments]
|
|
76
|
+
|
|
77
|
+
# Create bar chart
|
|
78
|
+
fig = go.Figure(
|
|
79
|
+
data=[
|
|
80
|
+
go.Bar(
|
|
81
|
+
x=run_ids,
|
|
82
|
+
y=values,
|
|
83
|
+
text=[f"{v:.4f}" for v in values] if show_values else None,
|
|
84
|
+
textposition="auto",
|
|
85
|
+
hovertemplate=f"<b>%{{x}}</b><br>{metric}: %{{y:.4f}}<br>"
|
|
86
|
+
"<extra></extra>",
|
|
87
|
+
)
|
|
88
|
+
]
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
fig.update_layout(
|
|
92
|
+
title=title or f"{metric} Comparison",
|
|
93
|
+
xaxis_title="Run ID",
|
|
94
|
+
yaxis_title=metric,
|
|
95
|
+
hovermode="x unified",
|
|
96
|
+
template="plotly_white",
|
|
97
|
+
font=dict(size=12),
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
return fig
|
|
101
|
+
|
|
102
|
+
def plot_pareto_frontier(
|
|
103
|
+
self,
|
|
104
|
+
comparison: MultiExperimentComparison,
|
|
105
|
+
metric1: str,
|
|
106
|
+
metric2: str,
|
|
107
|
+
pareto_ids: list[str],
|
|
108
|
+
maximize1: bool = True,
|
|
109
|
+
maximize2: bool = True,
|
|
110
|
+
title: str | None = None,
|
|
111
|
+
) -> go.Figure:
|
|
112
|
+
"""Create scatter plot with Pareto frontier highlighted.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
comparison: Multi-experiment comparison
|
|
116
|
+
metric1: First metric (x-axis)
|
|
117
|
+
metric2: Second metric (y-axis)
|
|
118
|
+
pareto_ids: Run IDs on Pareto frontier
|
|
119
|
+
maximize1: Whether metric1 should be maximized
|
|
120
|
+
maximize2: Whether metric2 should be maximized
|
|
121
|
+
title: Chart title
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
Plotly Figure object
|
|
125
|
+
|
|
126
|
+
Example:
|
|
127
|
+
>>> pareto = comparison.pareto_frontier(["accuracy", "cost"], [True, False])
|
|
128
|
+
>>> fig = visualizer.plot_pareto_frontier(
|
|
129
|
+
... comparison, "accuracy", "cost", pareto, True, False
|
|
130
|
+
... )
|
|
131
|
+
"""
|
|
132
|
+
# Extract data
|
|
133
|
+
x_values = []
|
|
134
|
+
y_values = []
|
|
135
|
+
run_ids = []
|
|
136
|
+
is_pareto = []
|
|
137
|
+
|
|
138
|
+
for exp in comparison.experiments:
|
|
139
|
+
x_val = exp.get_metric(metric1)
|
|
140
|
+
y_val = exp.get_metric(metric2)
|
|
141
|
+
|
|
142
|
+
if x_val is not None and y_val is not None:
|
|
143
|
+
x_values.append(x_val)
|
|
144
|
+
y_values.append(y_val)
|
|
145
|
+
run_ids.append(exp.run_id)
|
|
146
|
+
is_pareto.append(exp.run_id in pareto_ids)
|
|
147
|
+
|
|
148
|
+
# Create scatter plot
|
|
149
|
+
colors = ["red" if p else "blue" for p in is_pareto]
|
|
150
|
+
sizes = [12 if p else 8 for p in is_pareto]
|
|
151
|
+
|
|
152
|
+
fig = go.Figure(
|
|
153
|
+
data=[
|
|
154
|
+
go.Scatter(
|
|
155
|
+
x=x_values,
|
|
156
|
+
y=y_values,
|
|
157
|
+
mode="markers+text",
|
|
158
|
+
text=run_ids,
|
|
159
|
+
textposition="top center",
|
|
160
|
+
marker=dict(
|
|
161
|
+
color=colors,
|
|
162
|
+
size=sizes,
|
|
163
|
+
line=dict(width=1, color="white"),
|
|
164
|
+
),
|
|
165
|
+
hovertemplate="<b>%{text}</b><br>"
|
|
166
|
+
+ f"{metric1}: %{{x:.4f}}<br>"
|
|
167
|
+
+ f"{metric2}: %{{y:.4f}}<br>"
|
|
168
|
+
+ "<extra></extra>",
|
|
169
|
+
)
|
|
170
|
+
]
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
fig.update_layout(
|
|
174
|
+
title=title or f"Pareto Frontier: {metric1} vs {metric2}",
|
|
175
|
+
xaxis_title=f"{metric1} ({'maximize' if maximize1 else 'minimize'})",
|
|
176
|
+
yaxis_title=f"{metric2} ({'maximize' if maximize2 else 'minimize'})",
|
|
177
|
+
template="plotly_white",
|
|
178
|
+
font=dict(size=12),
|
|
179
|
+
showlegend=False,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
# Add legend for colors
|
|
183
|
+
fig.add_annotation(
|
|
184
|
+
text="<b style='color:red'>●</b> Pareto optimal<br>"
|
|
185
|
+
"<b style='color:blue'>●</b> Dominated",
|
|
186
|
+
xref="paper",
|
|
187
|
+
yref="paper",
|
|
188
|
+
x=1.0,
|
|
189
|
+
y=1.0,
|
|
190
|
+
xanchor="left",
|
|
191
|
+
yanchor="top",
|
|
192
|
+
showarrow=False,
|
|
193
|
+
bgcolor="white",
|
|
194
|
+
bordercolor="black",
|
|
195
|
+
borderwidth=1,
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
return fig
|
|
199
|
+
|
|
200
|
+
def plot_metric_distribution(
|
|
201
|
+
self,
|
|
202
|
+
report: EvaluationReport,
|
|
203
|
+
metric: str,
|
|
204
|
+
plot_type: str = "histogram",
|
|
205
|
+
title: str | None = None,
|
|
206
|
+
) -> go.Figure:
|
|
207
|
+
"""Create histogram or violin plot of metric distribution.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
report: Evaluation report
|
|
211
|
+
metric: Metric name
|
|
212
|
+
plot_type: "histogram", "box", or "violin"
|
|
213
|
+
title: Chart title
|
|
214
|
+
|
|
215
|
+
Returns:
|
|
216
|
+
Plotly Figure object
|
|
217
|
+
|
|
218
|
+
Example:
|
|
219
|
+
>>> fig = visualizer.plot_metric_distribution(report, "accuracy")
|
|
220
|
+
>>> fig = visualizer.plot_metric_distribution(report, "accuracy", "violin")
|
|
221
|
+
"""
|
|
222
|
+
if metric not in report.metrics:
|
|
223
|
+
raise ValueError(
|
|
224
|
+
f"Metric '{metric}' not found. Available: {list(report.metrics.keys())}"
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
# Extract metric values per sample
|
|
228
|
+
values = []
|
|
229
|
+
for record in report.records:
|
|
230
|
+
for score in record.scores:
|
|
231
|
+
if score.metric_name == metric:
|
|
232
|
+
values.append(score.value)
|
|
233
|
+
|
|
234
|
+
if not values:
|
|
235
|
+
raise ValueError(f"No values found for metric '{metric}'")
|
|
236
|
+
|
|
237
|
+
# Create plot based on type
|
|
238
|
+
if plot_type == "histogram":
|
|
239
|
+
fig = go.Figure(
|
|
240
|
+
data=[
|
|
241
|
+
go.Histogram(
|
|
242
|
+
x=values,
|
|
243
|
+
nbinsx=30,
|
|
244
|
+
hovertemplate="Value: %{x:.4f}<br>Count: %{y}<extra></extra>",
|
|
245
|
+
)
|
|
246
|
+
]
|
|
247
|
+
)
|
|
248
|
+
fig.update_layout(
|
|
249
|
+
xaxis_title=metric,
|
|
250
|
+
yaxis_title="Count",
|
|
251
|
+
)
|
|
252
|
+
elif plot_type == "box":
|
|
253
|
+
fig = go.Figure(
|
|
254
|
+
data=[
|
|
255
|
+
go.Box(
|
|
256
|
+
y=values,
|
|
257
|
+
name=metric,
|
|
258
|
+
boxmean="sd",
|
|
259
|
+
hovertemplate="Value: %{y:.4f}<extra></extra>",
|
|
260
|
+
)
|
|
261
|
+
]
|
|
262
|
+
)
|
|
263
|
+
fig.update_layout(yaxis_title=metric)
|
|
264
|
+
elif plot_type == "violin":
|
|
265
|
+
fig = go.Figure(
|
|
266
|
+
data=[
|
|
267
|
+
go.Violin(
|
|
268
|
+
y=values,
|
|
269
|
+
name=metric,
|
|
270
|
+
box_visible=True,
|
|
271
|
+
meanline_visible=True,
|
|
272
|
+
hovertemplate="Value: %{y:.4f}<extra></extra>",
|
|
273
|
+
)
|
|
274
|
+
]
|
|
275
|
+
)
|
|
276
|
+
fig.update_layout(yaxis_title=metric)
|
|
277
|
+
else:
|
|
278
|
+
raise ValueError(
|
|
279
|
+
f"Unknown plot_type '{plot_type}'. Use 'histogram', 'box', or 'violin'"
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
fig.update_layout(
|
|
283
|
+
title=title or f"{metric} Distribution ({len(values)} samples)",
|
|
284
|
+
template="plotly_white",
|
|
285
|
+
font=dict(size=12),
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
return fig
|
|
289
|
+
|
|
290
|
+
def plot_cost_breakdown(
|
|
291
|
+
self,
|
|
292
|
+
cost_breakdown: CostBreakdown,
|
|
293
|
+
title: str | None = None,
|
|
294
|
+
) -> go.Figure:
|
|
295
|
+
"""Create pie chart of cost breakdown.
|
|
296
|
+
|
|
297
|
+
Args:
|
|
298
|
+
cost_breakdown: Cost breakdown data
|
|
299
|
+
title: Chart title
|
|
300
|
+
|
|
301
|
+
Returns:
|
|
302
|
+
Plotly Figure object
|
|
303
|
+
|
|
304
|
+
Example:
|
|
305
|
+
>>> breakdown = tracker.get_breakdown()
|
|
306
|
+
>>> fig = visualizer.plot_cost_breakdown(breakdown)
|
|
307
|
+
"""
|
|
308
|
+
# Build data for pie chart
|
|
309
|
+
labels = []
|
|
310
|
+
values = []
|
|
311
|
+
|
|
312
|
+
# Generation vs Evaluation
|
|
313
|
+
if cost_breakdown.generation_cost > 0:
|
|
314
|
+
labels.append("Generation")
|
|
315
|
+
values.append(cost_breakdown.generation_cost)
|
|
316
|
+
|
|
317
|
+
if cost_breakdown.evaluation_cost > 0:
|
|
318
|
+
labels.append("Evaluation")
|
|
319
|
+
values.append(cost_breakdown.evaluation_cost)
|
|
320
|
+
|
|
321
|
+
# If we have per-model breakdown, create a second pie
|
|
322
|
+
if cost_breakdown.per_model_costs:
|
|
323
|
+
# Create subplots for overall and per-model
|
|
324
|
+
fig = make_subplots(
|
|
325
|
+
rows=1,
|
|
326
|
+
cols=2,
|
|
327
|
+
subplot_titles=("Cost by Phase", "Cost by Model"),
|
|
328
|
+
specs=[[{"type": "pie"}, {"type": "pie"}]],
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
# Overall breakdown
|
|
332
|
+
fig.add_trace(
|
|
333
|
+
go.Pie(
|
|
334
|
+
labels=labels,
|
|
335
|
+
values=values,
|
|
336
|
+
textinfo="label+percent+value",
|
|
337
|
+
hovertemplate="<b>%{label}</b><br>"
|
|
338
|
+
"Cost: $%{value:.4f}<br>"
|
|
339
|
+
"Percentage: %{percent}<br>"
|
|
340
|
+
"<extra></extra>",
|
|
341
|
+
),
|
|
342
|
+
row=1,
|
|
343
|
+
col=1,
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
# Per-model breakdown
|
|
347
|
+
model_labels = list(cost_breakdown.per_model_costs.keys())
|
|
348
|
+
model_values = list(cost_breakdown.per_model_costs.values())
|
|
349
|
+
|
|
350
|
+
fig.add_trace(
|
|
351
|
+
go.Pie(
|
|
352
|
+
labels=model_labels,
|
|
353
|
+
values=model_values,
|
|
354
|
+
textinfo="label+percent+value",
|
|
355
|
+
hovertemplate="<b>%{label}</b><br>"
|
|
356
|
+
"Cost: $%{value:.4f}<br>"
|
|
357
|
+
"Percentage: %{percent}<br>"
|
|
358
|
+
"<extra></extra>",
|
|
359
|
+
),
|
|
360
|
+
row=1,
|
|
361
|
+
col=2,
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
default_title = f"Cost Breakdown (Total: ${cost_breakdown.total_cost:.4f})"
|
|
365
|
+
fig.update_layout(
|
|
366
|
+
title_text=title or default_title,
|
|
367
|
+
template="plotly_white",
|
|
368
|
+
font=dict(size=12),
|
|
369
|
+
)
|
|
370
|
+
else:
|
|
371
|
+
# Single pie chart
|
|
372
|
+
fig = go.Figure(
|
|
373
|
+
data=[
|
|
374
|
+
go.Pie(
|
|
375
|
+
labels=labels,
|
|
376
|
+
values=values,
|
|
377
|
+
textinfo="label+percent+value",
|
|
378
|
+
hovertemplate="<b>%{label}</b><br>"
|
|
379
|
+
"Cost: $%{value:.4f}<br>"
|
|
380
|
+
"Percentage: %{percent}<br>"
|
|
381
|
+
"<extra></extra>",
|
|
382
|
+
)
|
|
383
|
+
]
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
default_title = f"Cost Breakdown (Total: ${cost_breakdown.total_cost:.4f})"
|
|
387
|
+
fig.update_layout(
|
|
388
|
+
title=title or default_title,
|
|
389
|
+
template="plotly_white",
|
|
390
|
+
font=dict(size=12),
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
return fig
|
|
394
|
+
|
|
395
|
+
def plot_metric_evolution(
|
|
396
|
+
self,
|
|
397
|
+
comparison: MultiExperimentComparison,
|
|
398
|
+
metric: str,
|
|
399
|
+
title: str | None = None,
|
|
400
|
+
) -> go.Figure:
|
|
401
|
+
"""Create line plot showing metric evolution across runs.
|
|
402
|
+
|
|
403
|
+
Experiments are ordered by timestamp if available.
|
|
404
|
+
|
|
405
|
+
Args:
|
|
406
|
+
comparison: Multi-experiment comparison
|
|
407
|
+
metric: Metric name
|
|
408
|
+
title: Chart title
|
|
409
|
+
|
|
410
|
+
Returns:
|
|
411
|
+
Plotly Figure object
|
|
412
|
+
|
|
413
|
+
Example:
|
|
414
|
+
>>> fig = visualizer.plot_metric_evolution(comparison, "accuracy")
|
|
415
|
+
"""
|
|
416
|
+
if metric not in comparison.metrics and metric not in ("cost", "total_cost"):
|
|
417
|
+
raise ValueError(
|
|
418
|
+
f"Metric '{metric}' not found. Available: {comparison.metrics}"
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
# Sort experiments by timestamp if available
|
|
422
|
+
sorted_exps = sorted(
|
|
423
|
+
comparison.experiments,
|
|
424
|
+
key=lambda e: e.timestamp or "",
|
|
425
|
+
)
|
|
426
|
+
|
|
427
|
+
# Extract data
|
|
428
|
+
x_labels = [exp.run_id for exp in sorted_exps]
|
|
429
|
+
y_values = [exp.get_metric(metric) or 0.0 for exp in sorted_exps]
|
|
430
|
+
|
|
431
|
+
# Create line chart
|
|
432
|
+
fig = go.Figure(
|
|
433
|
+
data=[
|
|
434
|
+
go.Scatter(
|
|
435
|
+
x=x_labels,
|
|
436
|
+
y=y_values,
|
|
437
|
+
mode="lines+markers",
|
|
438
|
+
line=dict(width=2),
|
|
439
|
+
marker=dict(size=8),
|
|
440
|
+
hovertemplate="<b>%{x}</b><br>"
|
|
441
|
+
+ f"{metric}: %{{y:.4f}}<br>"
|
|
442
|
+
+ "<extra></extra>",
|
|
443
|
+
)
|
|
444
|
+
]
|
|
445
|
+
)
|
|
446
|
+
|
|
447
|
+
fig.update_layout(
|
|
448
|
+
title=title or f"{metric} Evolution Over Time",
|
|
449
|
+
xaxis_title="Run ID (chronological)",
|
|
450
|
+
yaxis_title=metric,
|
|
451
|
+
template="plotly_white",
|
|
452
|
+
font=dict(size=12),
|
|
453
|
+
hovermode="x unified",
|
|
454
|
+
)
|
|
455
|
+
|
|
456
|
+
return fig
|
|
457
|
+
|
|
458
|
+
def create_dashboard(
|
|
459
|
+
self,
|
|
460
|
+
comparison: MultiExperimentComparison,
|
|
461
|
+
metrics: list[str] | None = None,
|
|
462
|
+
include_cost: bool = True,
|
|
463
|
+
) -> go.Figure:
|
|
464
|
+
"""Create comprehensive dashboard with multiple charts.
|
|
465
|
+
|
|
466
|
+
Args:
|
|
467
|
+
comparison: Multi-experiment comparison
|
|
468
|
+
metrics: Metrics to visualize (default: all)
|
|
469
|
+
include_cost: Include cost visualization if available
|
|
470
|
+
|
|
471
|
+
Returns:
|
|
472
|
+
Plotly Figure with subplots
|
|
473
|
+
|
|
474
|
+
Example:
|
|
475
|
+
>>> fig = visualizer.create_dashboard(comparison)
|
|
476
|
+
>>> fig.write_html("dashboard.html")
|
|
477
|
+
"""
|
|
478
|
+
metrics_to_plot = metrics or comparison.metrics[:4] # Limit to 4 for layout
|
|
479
|
+
|
|
480
|
+
# Check if cost data is available
|
|
481
|
+
has_cost = include_cost and any(
|
|
482
|
+
exp.get_cost() is not None for exp in comparison.experiments
|
|
483
|
+
)
|
|
484
|
+
|
|
485
|
+
# Determine subplot layout
|
|
486
|
+
n_metrics = len(metrics_to_plot)
|
|
487
|
+
n_plots = n_metrics + (1 if has_cost else 0)
|
|
488
|
+
|
|
489
|
+
rows = (n_plots + 1) // 2 # 2 columns
|
|
490
|
+
cols = 2 if n_plots > 1 else 1
|
|
491
|
+
|
|
492
|
+
# Create subplots
|
|
493
|
+
subplot_titles = [f"{m} Comparison" for m in metrics_to_plot]
|
|
494
|
+
if has_cost:
|
|
495
|
+
subplot_titles.append("Cost Comparison")
|
|
496
|
+
|
|
497
|
+
fig = make_subplots(
|
|
498
|
+
rows=rows,
|
|
499
|
+
cols=cols,
|
|
500
|
+
subplot_titles=subplot_titles,
|
|
501
|
+
)
|
|
502
|
+
|
|
503
|
+
# Add metric comparisons
|
|
504
|
+
for idx, metric in enumerate(metrics_to_plot):
|
|
505
|
+
row = (idx // 2) + 1
|
|
506
|
+
col = (idx % 2) + 1
|
|
507
|
+
|
|
508
|
+
run_ids = [exp.run_id for exp in comparison.experiments]
|
|
509
|
+
values = [exp.get_metric(metric) or 0.0 for exp in comparison.experiments]
|
|
510
|
+
|
|
511
|
+
fig.add_trace(
|
|
512
|
+
go.Bar(
|
|
513
|
+
x=run_ids,
|
|
514
|
+
y=values,
|
|
515
|
+
name=metric,
|
|
516
|
+
text=[f"{v:.4f}" for v in values],
|
|
517
|
+
textposition="auto",
|
|
518
|
+
hovertemplate=(
|
|
519
|
+
f"<b>%{{x}}</b><br>{metric}: %{{y:.4f}}<extra></extra>"
|
|
520
|
+
),
|
|
521
|
+
),
|
|
522
|
+
row=row,
|
|
523
|
+
col=col,
|
|
524
|
+
)
|
|
525
|
+
|
|
526
|
+
# Add cost comparison if available
|
|
527
|
+
if has_cost:
|
|
528
|
+
idx = len(metrics_to_plot)
|
|
529
|
+
row = (idx // 2) + 1
|
|
530
|
+
col = (idx % 2) + 1
|
|
531
|
+
|
|
532
|
+
run_ids = [exp.run_id for exp in comparison.experiments]
|
|
533
|
+
costs = [exp.get_cost() or 0.0 for exp in comparison.experiments]
|
|
534
|
+
|
|
535
|
+
fig.add_trace(
|
|
536
|
+
go.Bar(
|
|
537
|
+
x=run_ids,
|
|
538
|
+
y=costs,
|
|
539
|
+
name="Cost",
|
|
540
|
+
text=[f"${v:.4f}" for v in costs],
|
|
541
|
+
textposition="auto",
|
|
542
|
+
marker_color="green",
|
|
543
|
+
hovertemplate="<b>%{x}</b><br>Cost: $%{y:.4f}<extra></extra>",
|
|
544
|
+
),
|
|
545
|
+
row=row,
|
|
546
|
+
col=col,
|
|
547
|
+
)
|
|
548
|
+
|
|
549
|
+
fig.update_layout(
|
|
550
|
+
title_text="Experiment Dashboard",
|
|
551
|
+
template="plotly_white",
|
|
552
|
+
font=dict(size=12),
|
|
553
|
+
showlegend=False,
|
|
554
|
+
height=400 * rows,
|
|
555
|
+
)
|
|
556
|
+
|
|
557
|
+
return fig
|
|
558
|
+
|
|
559
|
+
|
|
560
|
+
def export_interactive_html(
|
|
561
|
+
fig: go.Figure,
|
|
562
|
+
output_path: Path | str,
|
|
563
|
+
include_plotlyjs: str = "cdn",
|
|
564
|
+
) -> None:
|
|
565
|
+
"""Export Plotly figure to standalone HTML file.
|
|
566
|
+
|
|
567
|
+
Args:
|
|
568
|
+
fig: Plotly Figure object
|
|
569
|
+
output_path: Where to save HTML file
|
|
570
|
+
include_plotlyjs: How to include Plotly.js
|
|
571
|
+
- "cdn": Link to CDN (smaller file, requires internet)
|
|
572
|
+
- True: Embed full library (larger file, works offline)
|
|
573
|
+
- False: Don't include (for embedding in existing HTML)
|
|
574
|
+
|
|
575
|
+
Example:
|
|
576
|
+
>>> fig = visualizer.plot_metric_comparison(comparison, "accuracy")
|
|
577
|
+
>>> export_interactive_html(fig, "comparison.html")
|
|
578
|
+
"""
|
|
579
|
+
_check_plotly()
|
|
580
|
+
output_path = Path(output_path)
|
|
581
|
+
fig.write_html(str(output_path), include_plotlyjs=include_plotlyjs)
|
|
582
|
+
|
|
583
|
+
|
|
584
|
+
__all__ = [
|
|
585
|
+
"InteractiveVisualizer",
|
|
586
|
+
"export_interactive_html",
|
|
587
|
+
"PLOTLY_AVAILABLE",
|
|
588
|
+
]
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Generation domain primitives."""
|