arbiter-cli 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
arbiter/cli/display.py ADDED
@@ -0,0 +1,381 @@
1
+ """Rich terminal display for Arbiter CLI output."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Optional
6
+
7
+ from rich.columns import Columns
8
+ from rich.console import Console
9
+ from rich.layout import Layout
10
+ from rich.live import Live
11
+ from rich.panel import Panel
12
+ from rich.progress import BarColumn, Progress, SpinnerColumn, TextColumn, TimeElapsedColumn
13
+ from rich.table import Table
14
+ from rich.text import Text
15
+
16
+ from arbiter.core.discover import DiscoveredModel
17
+ from arbiter.core.leaderboard import Leaderboard
18
+ from arbiter.core.metrics import ComparisonResult, ModelMetrics
19
+
20
+
21
+ console = Console()
22
+
23
+ # Color palette for models
24
+ MODEL_COLORS = [
25
+ "cyan",
26
+ "magenta",
27
+ "green",
28
+ "yellow",
29
+ "blue",
30
+ "red",
31
+ "bright_cyan",
32
+ "bright_magenta",
33
+ ]
34
+
35
+
36
+ def get_model_color(index: int) -> str:
37
+ return MODEL_COLORS[index % len(MODEL_COLORS)]
38
+
39
+
40
+ def print_header() -> None:
41
+ """Print the Arbiter header."""
42
+ header = Text()
43
+ header.append(" ARBITER ", style="bold white on blue")
44
+ header.append(" The final word on your local models.", style="dim")
45
+ console.print()
46
+ console.print(header)
47
+ console.print()
48
+
49
+
50
+ def print_comparing(model_specs: list[str], prompt: str) -> None:
51
+ """Print what we're about to compare."""
52
+ console.print(f"[dim]Prompt:[/dim] {prompt[:120]}{'...' if len(prompt) > 120 else ''}")
53
+ models_text = " vs ".join(
54
+ f"[{get_model_color(i)}]{spec}[/{get_model_color(i)}]"
55
+ for i, spec in enumerate(model_specs)
56
+ )
57
+ console.print(f"[dim]Models:[/dim] {models_text}")
58
+ console.print()
59
+
60
+
61
+ def create_progress(model_specs: list[str]) -> tuple[Progress, dict[str, int]]:
62
+ """Create a progress display for streaming generation."""
63
+ progress = Progress(
64
+ SpinnerColumn(),
65
+ TextColumn("[bold]{task.description}"),
66
+ BarColumn(bar_width=30),
67
+ TextColumn("{task.fields[tokens]} tokens"),
68
+ TextColumn("{task.fields[tps]} tok/s"),
69
+ TimeElapsedColumn(),
70
+ console=console,
71
+ )
72
+ task_ids = {}
73
+ for i, spec in enumerate(model_specs):
74
+ color = get_model_color(i)
75
+ task_id = progress.add_task(
76
+ f"[{color}]{spec}",
77
+ total=None,
78
+ tokens=0,
79
+ tps="--",
80
+ )
81
+ task_ids[spec] = task_id
82
+ return progress, task_ids
83
+
84
+
85
+ def update_progress(
86
+ progress: Progress,
87
+ task_ids: dict[str, int],
88
+ model: str,
89
+ metrics: ModelMetrics,
90
+ ) -> None:
91
+ """Update progress display with latest metrics."""
92
+ # Match by model name in task_ids (may need to check original spec too)
93
+ task_id = None
94
+ for spec, tid in task_ids.items():
95
+ if model in spec or spec in model:
96
+ task_id = tid
97
+ break
98
+
99
+ if task_id is None:
100
+ return
101
+
102
+ tps = f"{metrics.tokens_per_sec:.1f}" if metrics.tokens_per_sec else "--"
103
+ progress.update(
104
+ task_id,
105
+ advance=1,
106
+ tokens=metrics._token_count,
107
+ tps=tps,
108
+ )
109
+
110
+
111
+ def print_results(result: ComparisonResult) -> None:
112
+ """Print the full comparison results."""
113
+ console.print()
114
+
115
+ # Summary table
116
+ table = Table(
117
+ title="Results",
118
+ show_header=True,
119
+ header_style="bold",
120
+ border_style="dim",
121
+ )
122
+ table.add_column("Model", style="bold")
123
+ table.add_column("Tokens/sec", justify="right")
124
+ table.add_column("TTFT", justify="right")
125
+ table.add_column("Total Time", justify="right")
126
+ table.add_column("Tokens", justify="right")
127
+ table.add_column("Memory", justify="right")
128
+ table.add_column("Quality", justify="right")
129
+ table.add_column("", justify="center")
130
+
131
+ for i, m in enumerate(result.models):
132
+ color = get_model_color(i)
133
+ is_winner = m.model == result.winner
134
+
135
+ tps = f"{m.tokens_per_sec:.1f}" if m.tokens_per_sec else "--"
136
+ ttft = f"{m.ttft_ms:.0f}ms" if m.ttft_ms else "--"
137
+ total = f"{m.total_time_s:.1f}s" if m.total_time_s else "--"
138
+ tokens = str(m.total_tokens) if m.total_tokens else "--"
139
+ memory = (
140
+ f"{m.peak_memory_delta_mb:+.1f}MB"
141
+ if m.peak_memory_delta_mb is not None
142
+ else "--"
143
+ )
144
+ quality = f"{m.overall_score:.1f}/10" if m.overall_score else "--"
145
+ badge = "[bold green]WINNER[/bold green]" if is_winner else ""
146
+
147
+ style = f"bold {color}" if is_winner else color
148
+ table.add_row(
149
+ f"[{style}]{m.model}[/{style}]",
150
+ tps,
151
+ ttft,
152
+ total,
153
+ tokens,
154
+ memory,
155
+ quality,
156
+ badge,
157
+ )
158
+
159
+ console.print(table)
160
+
161
+ # Quality breakdown if available
162
+ scored_models = [m for m in result.models if m.quality_scores]
163
+ if scored_models:
164
+ console.print()
165
+ qtable = Table(
166
+ title="Quality Breakdown",
167
+ show_header=True,
168
+ header_style="bold",
169
+ border_style="dim",
170
+ )
171
+ qtable.add_column("Model", style="bold")
172
+ qtable.add_column("Correctness", justify="center")
173
+ qtable.add_column("Completeness", justify="center")
174
+ qtable.add_column("Clarity", justify="center")
175
+ qtable.add_column("Code Quality", justify="center")
176
+
177
+ for i, m in enumerate(scored_models):
178
+ color = get_model_color(i)
179
+ s = m.quality_scores
180
+ qtable.add_row(
181
+ f"[{color}]{m.model}[/{color}]",
182
+ _score_cell(s.get("correctness", 0)),
183
+ _score_cell(s.get("completeness", 0)),
184
+ _score_cell(s.get("clarity", 0)),
185
+ _score_cell(s.get("code_quality", 0)),
186
+ )
187
+ console.print(qtable)
188
+
189
+ # Scoring breakdown (the real explanation)
190
+ if result.scoring:
191
+ print_scoring_breakdown(result.scoring)
192
+
193
+ if result.judge_model:
194
+ console.print(f"[dim]Judged by: {result.judge_model}[/dim]")
195
+ console.print()
196
+
197
+
198
+ def _score_cell(score: int | float) -> str:
199
+ """Format a score with color coding."""
200
+ score = float(score)
201
+ if score >= 8:
202
+ return f"[green]{score:.0f}[/green]"
203
+ elif score >= 6:
204
+ return f"[yellow]{score:.0f}[/yellow]"
205
+ else:
206
+ return f"[red]{score:.0f}[/red]"
207
+
208
+
209
+ def print_scoring_breakdown(scoring) -> None:
210
+ """Print the composite scoring breakdown -- shows exactly WHY the winner won."""
211
+ console.print()
212
+
213
+ # Formula
214
+ console.print(f"[dim]Scoring: {scoring.formula}[/dim]")
215
+ console.print()
216
+
217
+ # Scoring table
218
+ table = Table(
219
+ title="Composite Scoring",
220
+ show_header=True,
221
+ header_style="bold",
222
+ border_style="dim",
223
+ )
224
+ table.add_column("Model", style="bold")
225
+ table.add_column("Composite", justify="right", style="bold")
226
+
227
+ # Get component names from first model
228
+ if scoring.model_scores:
229
+ for comp in scoring.model_scores[0].components:
230
+ table.add_column(comp.metric_name, justify="right")
231
+
232
+ for i, ms in enumerate(scoring.model_scores):
233
+ color = get_model_color(i)
234
+ is_winner = ms.model == scoring.winner
235
+
236
+ row = [
237
+ f"[{'bold ' + color if is_winner else color}]{ms.model}[/]",
238
+ f"[{'bold green' if is_winner else 'white'}]{ms.composite:.2f}[/]",
239
+ ]
240
+ for comp in ms.components:
241
+ rank_badge = " [green]#1[/green]" if comp.rank == 1 and len(scoring.model_scores) > 1 else ""
242
+ row.append(f"{comp.raw_value:.1f} {comp.raw_unit}{rank_badge}")
243
+ table.add_row(*row)
244
+
245
+ console.print(table)
246
+
247
+ # Winner explanation
248
+ if scoring.winner and scoring.winner_reason:
249
+ console.print()
250
+ console.print(
251
+ Panel(
252
+ f"[bold green]{scoring.winner}[/bold green]\n\n{scoring.winner_reason}",
253
+ title="[bold]Winner[/bold]",
254
+ border_style="green",
255
+ padding=(1, 2),
256
+ )
257
+ )
258
+ console.print()
259
+
260
+
261
+ def print_model_output(model: str, output: str, index: int = 0) -> None:
262
+ """Print a model's full output in a panel."""
263
+ color = get_model_color(index)
264
+ console.print(
265
+ Panel(
266
+ output,
267
+ title=f"[bold {color}]{model}[/bold {color}]",
268
+ border_style=color,
269
+ padding=(1, 2),
270
+ )
271
+ )
272
+
273
+
274
+ def print_discover(models: list[DiscoveredModel]) -> None:
275
+ """Print discovered models."""
276
+ if not models:
277
+ console.print("[yellow]No models found. Is Ollama running?[/yellow]")
278
+ return
279
+
280
+ # Group by provider
281
+ by_provider: dict[str, list[DiscoveredModel]] = {}
282
+ for m in models:
283
+ by_provider.setdefault(m.provider, []).append(m)
284
+
285
+ for provider, provider_models in by_provider.items():
286
+ table = Table(
287
+ title=f"{provider.upper()} Models",
288
+ show_header=True,
289
+ header_style="bold",
290
+ border_style="dim",
291
+ )
292
+ table.add_column("Model", style="bold cyan")
293
+ table.add_column("Size", justify="right")
294
+ table.add_column("Parameters", justify="right")
295
+ table.add_column("Quantization")
296
+ table.add_column("Family")
297
+ table.add_column("Vision", justify="center")
298
+ table.add_column("Spec", style="dim")
299
+
300
+ for m in provider_models:
301
+ size = f"{m.size_gb}GB" if m.size_gb else "--"
302
+ params = m.parameter_size or "--"
303
+ quant = m.quantization or "--"
304
+ family = m.family or "--"
305
+ vision = "[green]Yes[/green]" if m.multimodal else "[dim]--[/dim]"
306
+ table.add_row(m.name, size, params, quant, family, vision, m.spec)
307
+
308
+ console.print(table)
309
+ console.print()
310
+
311
+ console.print(f"[dim]Total: {len(models)} models across {len(by_provider)} providers[/dim]")
312
+
313
+
314
+ def print_leaderboard(leaderboard: Leaderboard) -> None:
315
+ """Print the persistent leaderboard."""
316
+ ratings = leaderboard.sorted_ratings()
317
+
318
+ if not ratings:
319
+ console.print("[yellow]No leaderboard data yet. Run some comparisons first![/yellow]")
320
+ return
321
+
322
+ table = Table(
323
+ title="Arbiter Leaderboard",
324
+ show_header=True,
325
+ header_style="bold",
326
+ border_style="dim",
327
+ )
328
+ table.add_column("#", style="dim", justify="right")
329
+ table.add_column("Model", style="bold")
330
+ table.add_column("ELO", justify="right", style="bold")
331
+ table.add_column("W/L/D", justify="center")
332
+ table.add_column("Win Rate", justify="right")
333
+ table.add_column("Avg tok/s", justify="right")
334
+ table.add_column("Avg Quality", justify="right")
335
+ table.add_column("Comparisons", justify="right")
336
+
337
+ for i, r in enumerate(ratings):
338
+ rank = str(i + 1)
339
+ if i == 0:
340
+ rank = "[gold1]1[/gold1]"
341
+ elif i == 1:
342
+ rank = "[grey70]2[/grey70]"
343
+ elif i == 2:
344
+ rank = "[dark_orange3]3[/dark_orange3]"
345
+
346
+ elo_str = f"{r.elo:.0f}"
347
+ if r.elo >= 1600:
348
+ elo_str = f"[green]{elo_str}[/green]"
349
+ elif r.elo < 1400:
350
+ elo_str = f"[red]{elo_str}[/red]"
351
+
352
+ wld = f"{r.wins}/{r.losses}/{r.draws}"
353
+ win_rate = f"{r.win_rate * 100:.0f}%"
354
+ avg_tps = f"{r.avg_tokens_sec:.1f}" if r.avg_tokens_sec else "--"
355
+ avg_q = f"{r.avg_quality:.1f}/10" if r.avg_quality else "--"
356
+
357
+ color = get_model_color(i)
358
+ table.add_row(
359
+ rank,
360
+ f"[{color}]{r.name}[/{color}]",
361
+ elo_str,
362
+ wld,
363
+ win_rate,
364
+ avg_tps,
365
+ avg_q,
366
+ str(r.total_comparisons),
367
+ )
368
+
369
+ console.print(table)
370
+ console.print()
371
+
372
+
373
+ def print_error(message: str) -> None:
374
+ """Print an error message."""
375
+ console.print(f"[bold red]Error:[/bold red] {message}")
376
+
377
+
378
+ def print_json_output(result: ComparisonResult) -> None:
379
+ """Print results as JSON for scripting."""
380
+ import json
381
+ console.print_json(json.dumps(result.to_dict(), indent=2))
File without changes