sandboxy 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sandboxy/__init__.py +3 -0
- sandboxy/agents/__init__.py +21 -0
- sandboxy/agents/base.py +66 -0
- sandboxy/agents/llm_prompt.py +308 -0
- sandboxy/agents/loader.py +222 -0
- sandboxy/api/__init__.py +5 -0
- sandboxy/api/app.py +76 -0
- sandboxy/api/routes/__init__.py +1 -0
- sandboxy/api/routes/agents.py +92 -0
- sandboxy/api/routes/local.py +1388 -0
- sandboxy/api/routes/tools.py +106 -0
- sandboxy/cli/__init__.py +1 -0
- sandboxy/cli/main.py +1196 -0
- sandboxy/cli/type_detector.py +48 -0
- sandboxy/config.py +49 -0
- sandboxy/core/__init__.py +1 -0
- sandboxy/core/async_runner.py +824 -0
- sandboxy/core/mdl_parser.py +441 -0
- sandboxy/core/runner.py +599 -0
- sandboxy/core/safe_eval.py +165 -0
- sandboxy/core/state.py +234 -0
- sandboxy/datasets/__init__.py +20 -0
- sandboxy/datasets/loader.py +193 -0
- sandboxy/datasets/runner.py +442 -0
- sandboxy/errors.py +166 -0
- sandboxy/local/context.py +235 -0
- sandboxy/local/results.py +173 -0
- sandboxy/logging.py +31 -0
- sandboxy/mcp/__init__.py +25 -0
- sandboxy/mcp/client.py +360 -0
- sandboxy/mcp/wrapper.py +99 -0
- sandboxy/providers/__init__.py +34 -0
- sandboxy/providers/anthropic_provider.py +271 -0
- sandboxy/providers/base.py +123 -0
- sandboxy/providers/http_client.py +101 -0
- sandboxy/providers/openai_provider.py +282 -0
- sandboxy/providers/openrouter.py +958 -0
- sandboxy/providers/registry.py +199 -0
- sandboxy/scenarios/__init__.py +11 -0
- sandboxy/scenarios/comparison.py +491 -0
- sandboxy/scenarios/loader.py +262 -0
- sandboxy/scenarios/runner.py +468 -0
- sandboxy/scenarios/unified.py +1434 -0
- sandboxy/session/__init__.py +21 -0
- sandboxy/session/manager.py +278 -0
- sandboxy/tools/__init__.py +34 -0
- sandboxy/tools/base.py +127 -0
- sandboxy/tools/loader.py +270 -0
- sandboxy/tools/yaml_tools.py +708 -0
- sandboxy/ui/__init__.py +27 -0
- sandboxy/ui/dist/assets/index-CgAkYWrJ.css +1 -0
- sandboxy/ui/dist/assets/index-D4zoGFcr.js +347 -0
- sandboxy/ui/dist/index.html +14 -0
- sandboxy/utils/__init__.py +3 -0
- sandboxy/utils/time.py +20 -0
- sandboxy-0.0.1.dist-info/METADATA +241 -0
- sandboxy-0.0.1.dist-info/RECORD +60 -0
- sandboxy-0.0.1.dist-info/WHEEL +4 -0
- sandboxy-0.0.1.dist-info/entry_points.txt +3 -0
- sandboxy-0.0.1.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,199 @@
|
|
|
1
|
+
"""Provider registry for managing multiple LLM providers."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
|
|
6
|
+
from sandboxy.providers.base import BaseProvider, ModelInfo, ProviderError
|
|
7
|
+
|
|
8
|
+
logger = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ProviderRegistry:
|
|
12
|
+
"""Registry of available LLM providers.
|
|
13
|
+
|
|
14
|
+
Automatically detects available providers based on environment variables
|
|
15
|
+
and provides unified access to models across all providers.
|
|
16
|
+
|
|
17
|
+
Priority order:
|
|
18
|
+
1. Direct providers (OpenAI, Anthropic) - lower latency
|
|
19
|
+
2. OpenRouter - unified access to all models
|
|
20
|
+
|
|
21
|
+
Example:
|
|
22
|
+
registry = ProviderRegistry()
|
|
23
|
+
provider = registry.get_provider_for_model("openai/gpt-4o")
|
|
24
|
+
response = await provider.complete("openai/gpt-4o", messages)
|
|
25
|
+
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(self):
|
|
29
|
+
"""Initialize registry and detect available providers."""
|
|
30
|
+
self.providers: dict[str, BaseProvider] = {}
|
|
31
|
+
self._init_providers()
|
|
32
|
+
|
|
33
|
+
def _init_providers(self) -> None:
|
|
34
|
+
"""Initialize providers based on available API keys."""
|
|
35
|
+
# OpenRouter - unified provider (lower priority but covers all)
|
|
36
|
+
if os.getenv("OPENROUTER_API_KEY"):
|
|
37
|
+
try:
|
|
38
|
+
from sandboxy.providers.openrouter import OpenRouterProvider
|
|
39
|
+
|
|
40
|
+
self.providers["openrouter"] = OpenRouterProvider()
|
|
41
|
+
logger.info("OpenRouter provider initialized")
|
|
42
|
+
except ProviderError as e:
|
|
43
|
+
logger.warning(f"Failed to init OpenRouter: {e}")
|
|
44
|
+
|
|
45
|
+
# Direct OpenAI (higher priority for OpenAI models)
|
|
46
|
+
if os.getenv("OPENAI_API_KEY"):
|
|
47
|
+
try:
|
|
48
|
+
from sandboxy.providers.openai_provider import OpenAIProvider
|
|
49
|
+
|
|
50
|
+
self.providers["openai"] = OpenAIProvider()
|
|
51
|
+
logger.info("OpenAI provider initialized")
|
|
52
|
+
except ProviderError as e:
|
|
53
|
+
logger.warning(f"Failed to init OpenAI: {e}")
|
|
54
|
+
|
|
55
|
+
# Direct Anthropic (higher priority for Claude models)
|
|
56
|
+
if os.getenv("ANTHROPIC_API_KEY"):
|
|
57
|
+
try:
|
|
58
|
+
from sandboxy.providers.anthropic_provider import AnthropicProvider
|
|
59
|
+
|
|
60
|
+
self.providers["anthropic"] = AnthropicProvider()
|
|
61
|
+
logger.info("Anthropic provider initialized")
|
|
62
|
+
except ProviderError as e:
|
|
63
|
+
logger.warning(f"Failed to init Anthropic: {e}")
|
|
64
|
+
|
|
65
|
+
if not self.providers:
|
|
66
|
+
logger.warning(
|
|
67
|
+
"No providers available. Set at least one API key: "
|
|
68
|
+
"OPENROUTER_API_KEY, OPENAI_API_KEY, or ANTHROPIC_API_KEY"
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
def get_provider_for_model(self, model_id: str) -> BaseProvider:
|
|
72
|
+
"""Get the best provider for a given model.
|
|
73
|
+
|
|
74
|
+
Model ID formats:
|
|
75
|
+
- "provider/model" (e.g., "openai/gpt-4o") - OpenRouter format, use OpenRouter
|
|
76
|
+
- "model" (e.g., "gpt-4o") - direct provider format, auto-select
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
model_id: Model identifier
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
Provider instance that can handle the model
|
|
83
|
+
|
|
84
|
+
Raises:
|
|
85
|
+
ProviderError: If no provider available for the model
|
|
86
|
+
|
|
87
|
+
"""
|
|
88
|
+
if not self.providers:
|
|
89
|
+
raise ProviderError(
|
|
90
|
+
"No providers configured. Set API key environment variables.",
|
|
91
|
+
provider="registry",
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
# If model has a prefix (openai/gpt-4o format), use OpenRouter
|
|
95
|
+
# This is OpenRouter's convention - direct APIs don't use prefixes
|
|
96
|
+
if "/" in model_id:
|
|
97
|
+
if "openrouter" in self.providers:
|
|
98
|
+
return self.providers["openrouter"]
|
|
99
|
+
# If no OpenRouter, try to extract and use direct provider
|
|
100
|
+
provider_name, model_name = model_id.split("/", 1)
|
|
101
|
+
if provider_name == "openai" and "openai" in self.providers:
|
|
102
|
+
# Note: caller should strip prefix when calling direct provider
|
|
103
|
+
return self.providers["openai"]
|
|
104
|
+
if provider_name == "anthropic" and "anthropic" in self.providers:
|
|
105
|
+
return self.providers["anthropic"]
|
|
106
|
+
|
|
107
|
+
# No prefix - use direct providers
|
|
108
|
+
model_lower = model_id.lower()
|
|
109
|
+
|
|
110
|
+
# OpenAI models (direct format: gpt-4o, not openai/gpt-4o)
|
|
111
|
+
if any(m in model_lower for m in ["gpt-4", "gpt-5", "o1", "o3"]):
|
|
112
|
+
if "openai" in self.providers:
|
|
113
|
+
return self.providers["openai"]
|
|
114
|
+
if "openrouter" in self.providers:
|
|
115
|
+
return self.providers["openrouter"]
|
|
116
|
+
|
|
117
|
+
# Anthropic models (direct format: claude-3-opus, not anthropic/claude-3-opus)
|
|
118
|
+
if "claude" in model_lower:
|
|
119
|
+
if "anthropic" in self.providers:
|
|
120
|
+
return self.providers["anthropic"]
|
|
121
|
+
if "openrouter" in self.providers:
|
|
122
|
+
return self.providers["openrouter"]
|
|
123
|
+
|
|
124
|
+
# Default to OpenRouter if available (covers most models)
|
|
125
|
+
if "openrouter" in self.providers:
|
|
126
|
+
return self.providers["openrouter"]
|
|
127
|
+
|
|
128
|
+
# Last resort - return first available provider
|
|
129
|
+
return next(iter(self.providers.values()))
|
|
130
|
+
|
|
131
|
+
def list_all_models(self) -> list[ModelInfo]:
|
|
132
|
+
"""List all models from all providers.
|
|
133
|
+
|
|
134
|
+
Returns deduplicated list with direct providers preferred
|
|
135
|
+
over OpenRouter for overlapping models.
|
|
136
|
+
"""
|
|
137
|
+
seen_ids: set[str] = set()
|
|
138
|
+
models: list[ModelInfo] = []
|
|
139
|
+
|
|
140
|
+
# Add direct provider models first (preferred)
|
|
141
|
+
for name, provider in self.providers.items():
|
|
142
|
+
if name == "openrouter":
|
|
143
|
+
continue # Add last
|
|
144
|
+
|
|
145
|
+
for model in provider.list_models():
|
|
146
|
+
if model.id not in seen_ids:
|
|
147
|
+
seen_ids.add(model.id)
|
|
148
|
+
models.append(model)
|
|
149
|
+
|
|
150
|
+
# Add OpenRouter models (for ones not covered by direct)
|
|
151
|
+
if "openrouter" in self.providers:
|
|
152
|
+
for model in self.providers["openrouter"].list_models():
|
|
153
|
+
if model.id not in seen_ids:
|
|
154
|
+
seen_ids.add(model.id)
|
|
155
|
+
models.append(model)
|
|
156
|
+
|
|
157
|
+
return models
|
|
158
|
+
|
|
159
|
+
def get_provider(self, provider_name: str) -> BaseProvider | None:
|
|
160
|
+
"""Get a specific provider by name.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
provider_name: Provider name (openai, anthropic, openrouter)
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
Provider instance or None if not available
|
|
167
|
+
|
|
168
|
+
"""
|
|
169
|
+
return self.providers.get(provider_name)
|
|
170
|
+
|
|
171
|
+
@property
|
|
172
|
+
def available_providers(self) -> list[str]:
|
|
173
|
+
"""List names of available providers."""
|
|
174
|
+
return list(self.providers.keys())
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
# Global registry instance
|
|
178
|
+
_registry: ProviderRegistry | None = None
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def get_registry() -> ProviderRegistry:
|
|
182
|
+
"""Get the global provider registry."""
|
|
183
|
+
global _registry
|
|
184
|
+
if _registry is None:
|
|
185
|
+
_registry = ProviderRegistry()
|
|
186
|
+
return _registry
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def get_provider(model_id: str) -> BaseProvider:
|
|
190
|
+
"""Get a provider for a model (convenience function).
|
|
191
|
+
|
|
192
|
+
Args:
|
|
193
|
+
model_id: Model identifier
|
|
194
|
+
|
|
195
|
+
Returns:
|
|
196
|
+
Provider that can handle the model
|
|
197
|
+
|
|
198
|
+
"""
|
|
199
|
+
return get_registry().get_provider_for_model(model_id)
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
"""Scenario module - load and run scenarios with YAML-defined tools."""
|
|
2
|
+
|
|
3
|
+
from sandboxy.scenarios.loader import ScenarioSpec, load_scenario
|
|
4
|
+
from sandboxy.scenarios.runner import ScenarioResult, ScenarioRunner
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
"ScenarioSpec",
|
|
8
|
+
"load_scenario",
|
|
9
|
+
"ScenarioRunner",
|
|
10
|
+
"ScenarioResult",
|
|
11
|
+
]
|
|
@@ -0,0 +1,491 @@
|
|
|
1
|
+
"""Comparison utilities for multi-model scenario runs.
|
|
2
|
+
|
|
3
|
+
This module provides:
|
|
4
|
+
- ComparisonResult: Aggregates results from multiple models
|
|
5
|
+
- Statistical summaries (avg, min, max, std)
|
|
6
|
+
- Tabular and JSON output formatting
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import logging
|
|
12
|
+
import statistics
|
|
13
|
+
from dataclasses import dataclass, field
|
|
14
|
+
from datetime import datetime
|
|
15
|
+
from typing import Any
|
|
16
|
+
|
|
17
|
+
from sandboxy.scenarios.unified import RunResult, UnifiedScenarioSpec
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class ModelStats:
|
|
24
|
+
"""Statistics for a single model across multiple runs."""
|
|
25
|
+
|
|
26
|
+
model: str
|
|
27
|
+
runs: int = 0
|
|
28
|
+
avg_score: float = 0.0
|
|
29
|
+
min_score: float = 0.0
|
|
30
|
+
max_score: float = 0.0
|
|
31
|
+
std_score: float = 0.0
|
|
32
|
+
avg_latency_ms: int = 0
|
|
33
|
+
total_input_tokens: int = 0
|
|
34
|
+
total_output_tokens: int = 0
|
|
35
|
+
total_cost_usd: float | None = None
|
|
36
|
+
avg_cost_usd: float | None = None
|
|
37
|
+
errors: int = 0
|
|
38
|
+
|
|
39
|
+
# Message/turn counts
|
|
40
|
+
total_messages: int = 0
|
|
41
|
+
avg_messages: float = 0.0
|
|
42
|
+
total_tool_calls: int = 0
|
|
43
|
+
avg_tool_calls: float = 0.0
|
|
44
|
+
|
|
45
|
+
# Goal achievements (percentage of runs that achieved each goal)
|
|
46
|
+
goal_rates: dict[str, float] = field(default_factory=dict)
|
|
47
|
+
|
|
48
|
+
# Judge scores (if applicable)
|
|
49
|
+
avg_judge_score: float | None = None
|
|
50
|
+
|
|
51
|
+
def to_dict(self) -> dict[str, Any]:
|
|
52
|
+
"""Convert to dictionary."""
|
|
53
|
+
return {
|
|
54
|
+
"model": self.model,
|
|
55
|
+
"runs": self.runs,
|
|
56
|
+
"avg_score": self.avg_score,
|
|
57
|
+
"min_score": self.min_score,
|
|
58
|
+
"max_score": self.max_score,
|
|
59
|
+
"std_score": self.std_score,
|
|
60
|
+
"avg_latency_ms": self.avg_latency_ms,
|
|
61
|
+
"total_input_tokens": self.total_input_tokens,
|
|
62
|
+
"total_output_tokens": self.total_output_tokens,
|
|
63
|
+
"total_cost_usd": self.total_cost_usd,
|
|
64
|
+
"avg_cost_usd": self.avg_cost_usd,
|
|
65
|
+
"total_messages": self.total_messages,
|
|
66
|
+
"avg_messages": self.avg_messages,
|
|
67
|
+
"total_tool_calls": self.total_tool_calls,
|
|
68
|
+
"avg_tool_calls": self.avg_tool_calls,
|
|
69
|
+
"errors": self.errors,
|
|
70
|
+
"goal_rates": self.goal_rates,
|
|
71
|
+
"avg_judge_score": self.avg_judge_score,
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@dataclass
|
|
76
|
+
class ComparisonResult:
|
|
77
|
+
"""Result of running a scenario with multiple models.
|
|
78
|
+
|
|
79
|
+
Aggregates results from all runs and provides statistical summaries.
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
scenario_id: str
|
|
83
|
+
scenario_name: str
|
|
84
|
+
models: list[str]
|
|
85
|
+
runs_per_model: int
|
|
86
|
+
results: list[RunResult] = field(default_factory=list)
|
|
87
|
+
stats: dict[str, ModelStats] = field(default_factory=dict)
|
|
88
|
+
created_at: datetime = field(default_factory=datetime.now)
|
|
89
|
+
|
|
90
|
+
# Cached ranking
|
|
91
|
+
_ranking: list[str] | None = field(default=None, repr=False)
|
|
92
|
+
|
|
93
|
+
def add_result(self, result: RunResult) -> None:
|
|
94
|
+
"""Add a result and update statistics."""
|
|
95
|
+
self.results.append(result)
|
|
96
|
+
self._ranking = None # Invalidate cache
|
|
97
|
+
|
|
98
|
+
def compute_stats(self) -> None:
|
|
99
|
+
"""Compute statistics for all models."""
|
|
100
|
+
# Import pricing calculator
|
|
101
|
+
try:
|
|
102
|
+
from sandboxy.api.routes.local import calculate_cost
|
|
103
|
+
except ImportError:
|
|
104
|
+
calculate_cost = None # type: ignore
|
|
105
|
+
|
|
106
|
+
# Group results by model
|
|
107
|
+
by_model: dict[str, list[RunResult]] = {}
|
|
108
|
+
for result in self.results:
|
|
109
|
+
if result.model not in by_model:
|
|
110
|
+
by_model[result.model] = []
|
|
111
|
+
by_model[result.model].append(result)
|
|
112
|
+
|
|
113
|
+
# Compute stats for each model
|
|
114
|
+
for model, model_results in by_model.items():
|
|
115
|
+
scores = []
|
|
116
|
+
latencies = []
|
|
117
|
+
input_tokens = 0
|
|
118
|
+
output_tokens = 0
|
|
119
|
+
costs: list[float] = []
|
|
120
|
+
errors = 0
|
|
121
|
+
goal_achievements: dict[str, list[bool]] = {}
|
|
122
|
+
judge_scores: list[float] = []
|
|
123
|
+
message_counts: list[int] = []
|
|
124
|
+
tool_call_counts: list[int] = []
|
|
125
|
+
|
|
126
|
+
for result in model_results:
|
|
127
|
+
if result.error:
|
|
128
|
+
errors += 1
|
|
129
|
+
continue
|
|
130
|
+
|
|
131
|
+
# Collect scores
|
|
132
|
+
if result.evaluation:
|
|
133
|
+
scores.append(result.evaluation.total_score)
|
|
134
|
+
|
|
135
|
+
# Collect goal achievements
|
|
136
|
+
for goal in result.evaluation.goals:
|
|
137
|
+
if goal.id not in goal_achievements:
|
|
138
|
+
goal_achievements[goal.id] = []
|
|
139
|
+
goal_achievements[goal.id].append(goal.achieved)
|
|
140
|
+
|
|
141
|
+
# Collect judge scores
|
|
142
|
+
if result.evaluation.judge:
|
|
143
|
+
judge_scores.append(result.evaluation.judge.score)
|
|
144
|
+
|
|
145
|
+
latencies.append(result.latency_ms)
|
|
146
|
+
input_tokens += result.input_tokens
|
|
147
|
+
output_tokens += result.output_tokens
|
|
148
|
+
|
|
149
|
+
# Collect cost (from API or calculate from tokens)
|
|
150
|
+
if result.cost_usd is not None:
|
|
151
|
+
costs.append(result.cost_usd)
|
|
152
|
+
elif calculate_cost and result.input_tokens and result.output_tokens:
|
|
153
|
+
calculated = calculate_cost(model, result.input_tokens, result.output_tokens)
|
|
154
|
+
if calculated is not None:
|
|
155
|
+
costs.append(calculated)
|
|
156
|
+
|
|
157
|
+
# Count messages (history length)
|
|
158
|
+
if result.history:
|
|
159
|
+
message_counts.append(len(result.history))
|
|
160
|
+
|
|
161
|
+
# Count tool calls
|
|
162
|
+
if result.tool_calls:
|
|
163
|
+
tool_call_counts.append(len(result.tool_calls))
|
|
164
|
+
else:
|
|
165
|
+
tool_call_counts.append(0)
|
|
166
|
+
|
|
167
|
+
# Calculate statistics
|
|
168
|
+
stats = ModelStats(
|
|
169
|
+
model=model,
|
|
170
|
+
runs=len(model_results),
|
|
171
|
+
errors=errors,
|
|
172
|
+
total_input_tokens=input_tokens,
|
|
173
|
+
total_output_tokens=output_tokens,
|
|
174
|
+
total_cost_usd=sum(costs) if costs else None,
|
|
175
|
+
avg_cost_usd=statistics.mean(costs) if costs else None,
|
|
176
|
+
total_messages=sum(message_counts),
|
|
177
|
+
avg_messages=statistics.mean(message_counts) if message_counts else 0.0,
|
|
178
|
+
total_tool_calls=sum(tool_call_counts),
|
|
179
|
+
avg_tool_calls=statistics.mean(tool_call_counts) if tool_call_counts else 0.0,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
if scores:
|
|
183
|
+
stats.avg_score = statistics.mean(scores)
|
|
184
|
+
stats.min_score = min(scores)
|
|
185
|
+
stats.max_score = max(scores)
|
|
186
|
+
stats.std_score = statistics.stdev(scores) if len(scores) > 1 else 0.0
|
|
187
|
+
|
|
188
|
+
if latencies:
|
|
189
|
+
stats.avg_latency_ms = int(statistics.mean(latencies))
|
|
190
|
+
|
|
191
|
+
if judge_scores:
|
|
192
|
+
stats.avg_judge_score = statistics.mean(judge_scores)
|
|
193
|
+
|
|
194
|
+
# Calculate goal rates
|
|
195
|
+
for goal_id, achievements in goal_achievements.items():
|
|
196
|
+
stats.goal_rates[goal_id] = sum(achievements) / len(achievements) * 100
|
|
197
|
+
|
|
198
|
+
self.stats[model] = stats
|
|
199
|
+
|
|
200
|
+
def get_ranking(self) -> list[str]:
|
|
201
|
+
"""Get models ranked by average score (highest first)."""
|
|
202
|
+
if self._ranking is not None:
|
|
203
|
+
return self._ranking
|
|
204
|
+
|
|
205
|
+
if not self.stats:
|
|
206
|
+
self.compute_stats()
|
|
207
|
+
|
|
208
|
+
self._ranking = sorted(
|
|
209
|
+
self.stats.keys(),
|
|
210
|
+
key=lambda m: self.stats[m].avg_score,
|
|
211
|
+
reverse=True,
|
|
212
|
+
)
|
|
213
|
+
return self._ranking
|
|
214
|
+
|
|
215
|
+
def get_winner(self) -> str | None:
|
|
216
|
+
"""Get the model with highest average score."""
|
|
217
|
+
ranking = self.get_ranking()
|
|
218
|
+
return ranking[0] if ranking else None
|
|
219
|
+
|
|
220
|
+
def to_dict(self) -> dict[str, Any]:
|
|
221
|
+
"""Convert to dictionary."""
|
|
222
|
+
if not self.stats:
|
|
223
|
+
self.compute_stats()
|
|
224
|
+
|
|
225
|
+
return {
|
|
226
|
+
"scenario_id": self.scenario_id,
|
|
227
|
+
"scenario_name": self.scenario_name,
|
|
228
|
+
"models": self.models,
|
|
229
|
+
"runs_per_model": self.runs_per_model,
|
|
230
|
+
"results": [r.to_dict() for r in self.results],
|
|
231
|
+
"stats": {k: v.to_dict() for k, v in self.stats.items()},
|
|
232
|
+
"ranking": self.get_ranking(),
|
|
233
|
+
"winner": self.get_winner(),
|
|
234
|
+
"created_at": self.created_at.isoformat(),
|
|
235
|
+
}
|
|
236
|
+
|
|
237
|
+
def to_json(self, indent: int | None = 2) -> str:
|
|
238
|
+
"""Serialize to JSON string."""
|
|
239
|
+
import json
|
|
240
|
+
|
|
241
|
+
return json.dumps(self.to_dict(), indent=indent)
|
|
242
|
+
|
|
243
|
+
def to_table(self, max_goals: int = 5) -> str:
|
|
244
|
+
"""Format as ASCII table.
|
|
245
|
+
|
|
246
|
+
Args:
|
|
247
|
+
max_goals: Maximum number of goals to show in table
|
|
248
|
+
|
|
249
|
+
Returns:
|
|
250
|
+
Formatted table string
|
|
251
|
+
|
|
252
|
+
"""
|
|
253
|
+
if not self.stats:
|
|
254
|
+
self.compute_stats()
|
|
255
|
+
|
|
256
|
+
lines = [
|
|
257
|
+
f"Model Comparison: {self.scenario_name}",
|
|
258
|
+
f"Runs per model: {self.runs_per_model}",
|
|
259
|
+
"",
|
|
260
|
+
]
|
|
261
|
+
|
|
262
|
+
# Collect all goal IDs
|
|
263
|
+
all_goals: set[str] = set()
|
|
264
|
+
for stats in self.stats.values():
|
|
265
|
+
all_goals.update(stats.goal_rates.keys())
|
|
266
|
+
goal_ids = sorted(all_goals)[:max_goals]
|
|
267
|
+
|
|
268
|
+
# Build header
|
|
269
|
+
headers = ["Model", "Avg Score", "Latency", "Messages", "Tools", "Cost"]
|
|
270
|
+
if goal_ids:
|
|
271
|
+
headers.extend(goal_ids)
|
|
272
|
+
if any(s.avg_judge_score is not None for s in self.stats.values()):
|
|
273
|
+
headers.append("Judge")
|
|
274
|
+
|
|
275
|
+
# Calculate column widths
|
|
276
|
+
rows: list[list[str]] = []
|
|
277
|
+
for model in self.get_ranking():
|
|
278
|
+
stats = self.stats[model]
|
|
279
|
+
# Format cost
|
|
280
|
+
if stats.avg_cost_usd is not None:
|
|
281
|
+
if stats.avg_cost_usd < 0.01:
|
|
282
|
+
cost_str = f"${stats.avg_cost_usd:.4f}"
|
|
283
|
+
else:
|
|
284
|
+
cost_str = f"${stats.avg_cost_usd:.3f}"
|
|
285
|
+
else:
|
|
286
|
+
cost_str = "-"
|
|
287
|
+
|
|
288
|
+
row = [
|
|
289
|
+
model[:30], # Truncate long model names
|
|
290
|
+
f"{stats.avg_score:.1f}",
|
|
291
|
+
f"{stats.avg_latency_ms}ms",
|
|
292
|
+
f"{stats.avg_messages:.1f}",
|
|
293
|
+
f"{stats.avg_tool_calls:.1f}",
|
|
294
|
+
cost_str,
|
|
295
|
+
]
|
|
296
|
+
|
|
297
|
+
# Add goal rates
|
|
298
|
+
for goal_id in goal_ids:
|
|
299
|
+
rate = stats.goal_rates.get(goal_id, 0)
|
|
300
|
+
row.append(f"{rate:.0f}%")
|
|
301
|
+
|
|
302
|
+
# Add judge score
|
|
303
|
+
if any(s.avg_judge_score is not None for s in self.stats.values()):
|
|
304
|
+
if stats.avg_judge_score is not None:
|
|
305
|
+
row.append(f"{stats.avg_judge_score:.2f}")
|
|
306
|
+
else:
|
|
307
|
+
row.append("-")
|
|
308
|
+
|
|
309
|
+
rows.append(row)
|
|
310
|
+
|
|
311
|
+
# Calculate column widths
|
|
312
|
+
widths = [max(len(h), max(len(r[i]) for r in rows)) for i, h in enumerate(headers)]
|
|
313
|
+
|
|
314
|
+
# Build table
|
|
315
|
+
sep = "+" + "+".join("-" * (w + 2) for w in widths) + "+"
|
|
316
|
+
|
|
317
|
+
lines.append(sep)
|
|
318
|
+
lines.append("|" + "|".join(f" {h.ljust(widths[i])} " for i, h in enumerate(headers)) + "|")
|
|
319
|
+
lines.append(sep)
|
|
320
|
+
|
|
321
|
+
for row in rows:
|
|
322
|
+
lines.append("|" + "|".join(f" {c.ljust(widths[i])} " for i, c in enumerate(row)) + "|")
|
|
323
|
+
|
|
324
|
+
lines.append(sep)
|
|
325
|
+
|
|
326
|
+
# Add summary
|
|
327
|
+
if self.stats:
|
|
328
|
+
winner = self.get_winner()
|
|
329
|
+
if winner:
|
|
330
|
+
lines.append("")
|
|
331
|
+
lines.append(f"Winner: {winner}")
|
|
332
|
+
|
|
333
|
+
return "\n".join(lines)
|
|
334
|
+
|
|
335
|
+
def to_markdown(self, max_goals: int = 5) -> str:
|
|
336
|
+
"""Format as Markdown table.
|
|
337
|
+
|
|
338
|
+
Args:
|
|
339
|
+
max_goals: Maximum number of goals to show
|
|
340
|
+
|
|
341
|
+
Returns:
|
|
342
|
+
Markdown formatted string
|
|
343
|
+
|
|
344
|
+
"""
|
|
345
|
+
if not self.stats:
|
|
346
|
+
self.compute_stats()
|
|
347
|
+
|
|
348
|
+
lines = [
|
|
349
|
+
f"## Model Comparison: {self.scenario_name}",
|
|
350
|
+
"",
|
|
351
|
+
f"Runs per model: {self.runs_per_model}",
|
|
352
|
+
"",
|
|
353
|
+
]
|
|
354
|
+
|
|
355
|
+
# Collect all goal IDs
|
|
356
|
+
all_goals: set[str] = set()
|
|
357
|
+
for stats in self.stats.values():
|
|
358
|
+
all_goals.update(stats.goal_rates.keys())
|
|
359
|
+
goal_ids = sorted(all_goals)[:max_goals]
|
|
360
|
+
|
|
361
|
+
# Build header
|
|
362
|
+
headers = ["Model", "Avg Score", "Latency", "Msgs", "Tools", "Cost"]
|
|
363
|
+
if goal_ids:
|
|
364
|
+
headers.extend(goal_ids)
|
|
365
|
+
if any(s.avg_judge_score is not None for s in self.stats.values()):
|
|
366
|
+
headers.append("Judge")
|
|
367
|
+
|
|
368
|
+
lines.append("| " + " | ".join(headers) + " |")
|
|
369
|
+
lines.append("| " + " | ".join("-" * len(h) for h in headers) + " |")
|
|
370
|
+
|
|
371
|
+
for model in self.get_ranking():
|
|
372
|
+
stats = self.stats[model]
|
|
373
|
+
# Format cost
|
|
374
|
+
if stats.avg_cost_usd is not None:
|
|
375
|
+
if stats.avg_cost_usd < 0.01:
|
|
376
|
+
cost_str = f"${stats.avg_cost_usd:.4f}"
|
|
377
|
+
else:
|
|
378
|
+
cost_str = f"${stats.avg_cost_usd:.3f}"
|
|
379
|
+
else:
|
|
380
|
+
cost_str = "-"
|
|
381
|
+
|
|
382
|
+
row = [
|
|
383
|
+
f"`{model}`",
|
|
384
|
+
f"{stats.avg_score:.1f}",
|
|
385
|
+
f"{stats.avg_latency_ms}ms",
|
|
386
|
+
f"{stats.avg_messages:.1f}",
|
|
387
|
+
f"{stats.avg_tool_calls:.1f}",
|
|
388
|
+
cost_str,
|
|
389
|
+
]
|
|
390
|
+
|
|
391
|
+
for goal_id in goal_ids:
|
|
392
|
+
rate = stats.goal_rates.get(goal_id, 0)
|
|
393
|
+
row.append(f"{rate:.0f}%")
|
|
394
|
+
|
|
395
|
+
if any(s.avg_judge_score is not None for s in self.stats.values()):
|
|
396
|
+
if stats.avg_judge_score is not None:
|
|
397
|
+
row.append(f"{stats.avg_judge_score:.2f}")
|
|
398
|
+
else:
|
|
399
|
+
row.append("-")
|
|
400
|
+
|
|
401
|
+
lines.append("| " + " | ".join(row) + " |")
|
|
402
|
+
|
|
403
|
+
# Add winner
|
|
404
|
+
winner = self.get_winner()
|
|
405
|
+
if winner:
|
|
406
|
+
lines.append("")
|
|
407
|
+
lines.append(f"**Winner:** `{winner}`")
|
|
408
|
+
|
|
409
|
+
return "\n".join(lines)
|
|
410
|
+
|
|
411
|
+
def pretty(self) -> str:
|
|
412
|
+
"""Format for human-readable display (alias for to_table)."""
|
|
413
|
+
return self.to_table()
|
|
414
|
+
|
|
415
|
+
|
|
416
|
+
async def run_comparison(
|
|
417
|
+
scenario: UnifiedScenarioSpec,
|
|
418
|
+
models: list[str],
|
|
419
|
+
runs_per_model: int = 1,
|
|
420
|
+
variables: dict[str, Any] | None = None,
|
|
421
|
+
max_turns: int = 20,
|
|
422
|
+
max_tokens: int = 1024,
|
|
423
|
+
temperature: float = 0.7,
|
|
424
|
+
parallel: bool = True,
|
|
425
|
+
) -> ComparisonResult:
|
|
426
|
+
"""Run a scenario with multiple models and compare results.
|
|
427
|
+
|
|
428
|
+
Args:
|
|
429
|
+
scenario: The scenario specification
|
|
430
|
+
models: List of model IDs to test
|
|
431
|
+
runs_per_model: Number of runs per model (for statistical significance)
|
|
432
|
+
variables: Variable substitutions
|
|
433
|
+
max_turns: Maximum conversation turns
|
|
434
|
+
max_tokens: Maximum tokens per response
|
|
435
|
+
temperature: Sampling temperature
|
|
436
|
+
parallel: Run models in parallel (True) or sequentially (False)
|
|
437
|
+
|
|
438
|
+
Returns:
|
|
439
|
+
ComparisonResult with all runs and statistics
|
|
440
|
+
|
|
441
|
+
"""
|
|
442
|
+
import asyncio
|
|
443
|
+
|
|
444
|
+
from sandboxy.scenarios.unified import UnifiedRunner
|
|
445
|
+
|
|
446
|
+
runner = UnifiedRunner()
|
|
447
|
+
comparison = ComparisonResult(
|
|
448
|
+
scenario_id=scenario.id,
|
|
449
|
+
scenario_name=scenario.name or scenario.id,
|
|
450
|
+
models=models,
|
|
451
|
+
runs_per_model=runs_per_model,
|
|
452
|
+
)
|
|
453
|
+
|
|
454
|
+
async def run_single(model: str) -> RunResult:
|
|
455
|
+
"""Run a single model iteration."""
|
|
456
|
+
return await runner.run(
|
|
457
|
+
scenario=scenario,
|
|
458
|
+
model=model,
|
|
459
|
+
variables=variables,
|
|
460
|
+
max_turns=max_turns,
|
|
461
|
+
max_tokens=max_tokens,
|
|
462
|
+
temperature=temperature,
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
if parallel:
|
|
466
|
+
# Run ALL iterations for ALL models in parallel
|
|
467
|
+
# Creates [(model, task), ...] for every model × runs_per_model
|
|
468
|
+
tasks = [run_single(model) for model in models for _ in range(runs_per_model)]
|
|
469
|
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
470
|
+
for result in results:
|
|
471
|
+
if isinstance(result, Exception):
|
|
472
|
+
# Log error but continue with other results
|
|
473
|
+
logger.error(f"Run failed: {result}")
|
|
474
|
+
else:
|
|
475
|
+
comparison.add_result(result)
|
|
476
|
+
else:
|
|
477
|
+
# Run sequentially
|
|
478
|
+
for model in models:
|
|
479
|
+
for _ in range(runs_per_model):
|
|
480
|
+
result = await runner.run(
|
|
481
|
+
scenario=scenario,
|
|
482
|
+
model=model,
|
|
483
|
+
variables=variables,
|
|
484
|
+
max_turns=max_turns,
|
|
485
|
+
max_tokens=max_tokens,
|
|
486
|
+
temperature=temperature,
|
|
487
|
+
)
|
|
488
|
+
comparison.add_result(result)
|
|
489
|
+
|
|
490
|
+
comparison.compute_stats()
|
|
491
|
+
return comparison
|