promptum 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- promptum/__init__.py +44 -0
- promptum/benchmark/__init__.py +4 -0
- promptum/benchmark/benchmark.py +50 -0
- promptum/benchmark/report.py +75 -0
- promptum/core/__init__.py +12 -0
- promptum/core/metrics.py +16 -0
- promptum/core/result.py +17 -0
- promptum/core/retry.py +19 -0
- promptum/core/test_case.py +22 -0
- promptum/execution/__init__.py +3 -0
- promptum/execution/runner.py +75 -0
- promptum/providers/__init__.py +7 -0
- promptum/providers/openrouter.py +123 -0
- promptum/providers/protocol.py +22 -0
- promptum/py.typed +0 -0
- promptum/serialization/__init__.py +11 -0
- promptum/serialization/base.py +48 -0
- promptum/serialization/html.py +52 -0
- promptum/serialization/json.py +28 -0
- promptum/serialization/protocol.py +13 -0
- promptum/serialization/report_template.html +293 -0
- promptum/serialization/yaml.py +17 -0
- promptum/storage/__init__.py +7 -0
- promptum/storage/file.py +157 -0
- promptum/storage/protocol.py +23 -0
- promptum/validation/__init__.py +15 -0
- promptum/validation/protocol.py +16 -0
- promptum/validation/validators.py +108 -0
- promptum-0.0.1.dist-info/METADATA +280 -0
- promptum-0.0.1.dist-info/RECORD +32 -0
- promptum-0.0.1.dist-info/WHEEL +4 -0
- promptum-0.0.1.dist-info/licenses/LICENSE +21 -0
promptum/__init__.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
from promptum.benchmark import Benchmark, Report
|
|
2
|
+
from promptum.core import Metrics, RetryConfig, RetryStrategy, TestCase, TestResult
|
|
3
|
+
from promptum.execution import Runner
|
|
4
|
+
from promptum.providers import LLMProvider, OpenRouterClient
|
|
5
|
+
from promptum.serialization import (
|
|
6
|
+
HTMLSerializer,
|
|
7
|
+
JSONSerializer,
|
|
8
|
+
Serializer,
|
|
9
|
+
YAMLSerializer,
|
|
10
|
+
)
|
|
11
|
+
from promptum.storage import FileStorage, ResultStorage
|
|
12
|
+
from promptum.validation import (
|
|
13
|
+
Contains,
|
|
14
|
+
ExactMatch,
|
|
15
|
+
JsonSchema,
|
|
16
|
+
Regex,
|
|
17
|
+
Validator,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
__version__ = "0.1.0"
|
|
21
|
+
|
|
22
|
+
__all__ = [
|
|
23
|
+
"TestCase",
|
|
24
|
+
"TestResult",
|
|
25
|
+
"Metrics",
|
|
26
|
+
"RetryConfig",
|
|
27
|
+
"RetryStrategy",
|
|
28
|
+
"Validator",
|
|
29
|
+
"ExactMatch",
|
|
30
|
+
"Contains",
|
|
31
|
+
"Regex",
|
|
32
|
+
"JsonSchema",
|
|
33
|
+
"LLMProvider",
|
|
34
|
+
"OpenRouterClient",
|
|
35
|
+
"Runner",
|
|
36
|
+
"Benchmark",
|
|
37
|
+
"Report",
|
|
38
|
+
"Serializer",
|
|
39
|
+
"JSONSerializer",
|
|
40
|
+
"YAMLSerializer",
|
|
41
|
+
"HTMLSerializer",
|
|
42
|
+
"ResultStorage",
|
|
43
|
+
"FileStorage",
|
|
44
|
+
]
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from collections.abc import Callable, Sequence
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from promptum.benchmark.report import Report
|
|
6
|
+
from promptum.core.result import TestResult
|
|
7
|
+
from promptum.core.test_case import TestCase
|
|
8
|
+
from promptum.execution.runner import Runner
|
|
9
|
+
from promptum.providers.protocol import LLMProvider
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Benchmark:
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
provider: LLMProvider,
|
|
16
|
+
name: str = "benchmark",
|
|
17
|
+
max_concurrent: int = 5,
|
|
18
|
+
progress_callback: Callable[[int, int, TestResult], None] | None = None,
|
|
19
|
+
):
|
|
20
|
+
self.provider = provider
|
|
21
|
+
self.name = name
|
|
22
|
+
self.max_concurrent = max_concurrent
|
|
23
|
+
self.progress_callback = progress_callback
|
|
24
|
+
self._test_cases: list[TestCase] = []
|
|
25
|
+
|
|
26
|
+
def add_test(self, test_case: TestCase) -> None:
|
|
27
|
+
self._test_cases.append(test_case)
|
|
28
|
+
|
|
29
|
+
def add_tests(self, test_cases: Sequence[TestCase]) -> None:
|
|
30
|
+
self._test_cases.extend(test_cases)
|
|
31
|
+
|
|
32
|
+
def run(self, metadata: dict[str, Any] | None = None) -> Report:
|
|
33
|
+
return asyncio.run(self.run_async(metadata))
|
|
34
|
+
|
|
35
|
+
async def run_async(self, metadata: dict[str, Any] | None = None) -> Report:
|
|
36
|
+
if not self._test_cases:
|
|
37
|
+
return Report(results=[], metadata=metadata or {})
|
|
38
|
+
|
|
39
|
+
runner = Runner(
|
|
40
|
+
provider=self.provider,
|
|
41
|
+
max_concurrent=self.max_concurrent,
|
|
42
|
+
progress_callback=self.progress_callback,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
results = await runner.run(self._test_cases)
|
|
46
|
+
|
|
47
|
+
return Report(
|
|
48
|
+
results=results,
|
|
49
|
+
metadata=metadata or {},
|
|
50
|
+
)
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
from collections.abc import Callable, Sequence
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from promptum.core.result import TestResult
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass(frozen=True, slots=True)
|
|
9
|
+
class Report:
|
|
10
|
+
results: Sequence[TestResult]
|
|
11
|
+
metadata: dict[str, Any]
|
|
12
|
+
|
|
13
|
+
def get_summary(self) -> dict[str, Any]:
|
|
14
|
+
total = len(self.results)
|
|
15
|
+
passed = sum(1 for r in self.results if r.passed)
|
|
16
|
+
|
|
17
|
+
latencies = [r.metrics.latency_ms for r in self.results if r.metrics]
|
|
18
|
+
total_cost = sum(r.metrics.cost_usd or 0 for r in self.results if r.metrics)
|
|
19
|
+
total_tokens = sum(r.metrics.total_tokens or 0 for r in self.results if r.metrics)
|
|
20
|
+
|
|
21
|
+
return {
|
|
22
|
+
"total": total,
|
|
23
|
+
"passed": passed,
|
|
24
|
+
"failed": total - passed,
|
|
25
|
+
"pass_rate": passed / total if total > 0 else 0,
|
|
26
|
+
"avg_latency_ms": sum(latencies) / len(latencies) if latencies else 0,
|
|
27
|
+
"p50_latency_ms": self._percentile(latencies, 0.5) if latencies else 0,
|
|
28
|
+
"p95_latency_ms": self._percentile(latencies, 0.95) if latencies else 0,
|
|
29
|
+
"p99_latency_ms": self._percentile(latencies, 0.99) if latencies else 0,
|
|
30
|
+
"total_cost_usd": total_cost,
|
|
31
|
+
"total_tokens": total_tokens,
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
def filter(
|
|
35
|
+
self,
|
|
36
|
+
model: str | None = None,
|
|
37
|
+
tags: Sequence[str] | None = None,
|
|
38
|
+
passed: bool | None = None,
|
|
39
|
+
) -> "Report":
|
|
40
|
+
filtered = list(self.results)
|
|
41
|
+
|
|
42
|
+
if model:
|
|
43
|
+
filtered = [r for r in filtered if r.test_case.model == model]
|
|
44
|
+
|
|
45
|
+
if tags:
|
|
46
|
+
tag_set = set(tags)
|
|
47
|
+
filtered = [r for r in filtered if tag_set.intersection(r.test_case.tags)]
|
|
48
|
+
|
|
49
|
+
if passed is not None:
|
|
50
|
+
filtered = [r for r in filtered if r.passed == passed]
|
|
51
|
+
|
|
52
|
+
return Report(results=filtered, metadata=self.metadata)
|
|
53
|
+
|
|
54
|
+
def group_by(self, key: Callable[[TestResult], str]) -> dict[str, "Report"]:
|
|
55
|
+
groups: dict[str, list[TestResult]] = {}
|
|
56
|
+
|
|
57
|
+
for result in self.results:
|
|
58
|
+
group_key = key(result)
|
|
59
|
+
if group_key not in groups:
|
|
60
|
+
groups[group_key] = []
|
|
61
|
+
groups[group_key].append(result)
|
|
62
|
+
|
|
63
|
+
return {k: Report(results=v, metadata=self.metadata) for k, v in groups.items()}
|
|
64
|
+
|
|
65
|
+
def compare_models(self) -> dict[str, dict[str, Any]]:
|
|
66
|
+
by_model = self.group_by(lambda r: r.test_case.model)
|
|
67
|
+
return {model: report.get_summary() for model, report in by_model.items()}
|
|
68
|
+
|
|
69
|
+
@staticmethod
|
|
70
|
+
def _percentile(values: list[float], p: float) -> float:
|
|
71
|
+
if not values:
|
|
72
|
+
return 0
|
|
73
|
+
sorted_values = sorted(values)
|
|
74
|
+
index = int(len(sorted_values) * p)
|
|
75
|
+
return sorted_values[min(index, len(sorted_values) - 1)]
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
from promptum.core.metrics import Metrics
|
|
2
|
+
from promptum.core.result import TestResult
|
|
3
|
+
from promptum.core.retry import RetryConfig, RetryStrategy
|
|
4
|
+
from promptum.core.test_case import TestCase
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
"Metrics",
|
|
8
|
+
"RetryConfig",
|
|
9
|
+
"RetryStrategy",
|
|
10
|
+
"TestCase",
|
|
11
|
+
"TestResult",
|
|
12
|
+
]
|
promptum/core/metrics.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
@dataclass(frozen=True, slots=True)
|
|
6
|
+
class Metrics:
|
|
7
|
+
latency_ms: float
|
|
8
|
+
prompt_tokens: int | None = None
|
|
9
|
+
completion_tokens: int | None = None
|
|
10
|
+
total_tokens: int | None = None
|
|
11
|
+
cost_usd: float | None = None
|
|
12
|
+
retry_delays: Sequence[float] = ()
|
|
13
|
+
|
|
14
|
+
@property
|
|
15
|
+
def total_attempts(self) -> int:
|
|
16
|
+
return len(self.retry_delays) + 1
|
promptum/core/result.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
from datetime import datetime
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from promptum.core.metrics import Metrics
|
|
6
|
+
from promptum.core.test_case import TestCase
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass(frozen=True, slots=True)
|
|
10
|
+
class TestResult:
|
|
11
|
+
test_case: TestCase
|
|
12
|
+
response: str | None
|
|
13
|
+
passed: bool
|
|
14
|
+
metrics: Metrics | None
|
|
15
|
+
validation_details: dict[str, Any]
|
|
16
|
+
execution_error: str | None = None
|
|
17
|
+
timestamp: datetime = field(default_factory=datetime.now)
|
promptum/core/retry.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from enum import Enum
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class RetryStrategy(Enum):
|
|
7
|
+
EXPONENTIAL_BACKOFF = "exponential_backoff"
|
|
8
|
+
FIXED_DELAY = "fixed_delay"
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass(frozen=True, slots=True)
|
|
12
|
+
class RetryConfig:
|
|
13
|
+
max_attempts: int = 3
|
|
14
|
+
strategy: RetryStrategy = RetryStrategy.EXPONENTIAL_BACKOFF
|
|
15
|
+
initial_delay: float = 1.0
|
|
16
|
+
max_delay: float = 60.0
|
|
17
|
+
exponential_base: float = 2.0
|
|
18
|
+
retryable_status_codes: Sequence[int] = (429, 500, 502, 503, 504)
|
|
19
|
+
timeout: float = 60.0
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from typing import TYPE_CHECKING, Any
|
|
4
|
+
|
|
5
|
+
if TYPE_CHECKING:
|
|
6
|
+
from promptum.validation.protocol import Validator
|
|
7
|
+
|
|
8
|
+
from promptum.core.retry import RetryConfig
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass(frozen=True, slots=True)
|
|
12
|
+
class TestCase:
|
|
13
|
+
name: str
|
|
14
|
+
prompt: str
|
|
15
|
+
model: str
|
|
16
|
+
validator: "Validator"
|
|
17
|
+
tags: Sequence[str] = ()
|
|
18
|
+
system_prompt: str | None = None
|
|
19
|
+
temperature: float = 1.0
|
|
20
|
+
max_tokens: int | None = None
|
|
21
|
+
retry_config: RetryConfig | None = None
|
|
22
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from collections.abc import Callable, Sequence
|
|
3
|
+
|
|
4
|
+
import httpx
|
|
5
|
+
|
|
6
|
+
from promptum.core.result import TestResult
|
|
7
|
+
from promptum.core.test_case import TestCase
|
|
8
|
+
from promptum.providers.protocol import LLMProvider
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Runner:
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
provider: LLMProvider,
|
|
15
|
+
max_concurrent: int = 5,
|
|
16
|
+
progress_callback: Callable[[int, int, TestResult], None] | None = None,
|
|
17
|
+
):
|
|
18
|
+
self.provider = provider
|
|
19
|
+
self.max_concurrent = max_concurrent
|
|
20
|
+
self.progress_callback = progress_callback
|
|
21
|
+
|
|
22
|
+
async def run(self, test_cases: Sequence[TestCase]) -> list[TestResult]:
|
|
23
|
+
semaphore = asyncio.Semaphore(self.max_concurrent)
|
|
24
|
+
completed = 0
|
|
25
|
+
total = len(test_cases)
|
|
26
|
+
|
|
27
|
+
async def run_with_semaphore(test_case: TestCase) -> TestResult:
|
|
28
|
+
async with semaphore:
|
|
29
|
+
result = await self._run_single_test(test_case)
|
|
30
|
+
|
|
31
|
+
nonlocal completed
|
|
32
|
+
completed += 1
|
|
33
|
+
if self.progress_callback:
|
|
34
|
+
self.progress_callback(completed, total, result)
|
|
35
|
+
|
|
36
|
+
return result
|
|
37
|
+
|
|
38
|
+
results = await asyncio.gather(
|
|
39
|
+
*[run_with_semaphore(tc) for tc in test_cases],
|
|
40
|
+
return_exceptions=False,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
return list(results)
|
|
44
|
+
|
|
45
|
+
async def _run_single_test(self, test_case: TestCase) -> TestResult:
|
|
46
|
+
try:
|
|
47
|
+
response, metrics = await self.provider.generate(
|
|
48
|
+
prompt=test_case.prompt,
|
|
49
|
+
model=test_case.model,
|
|
50
|
+
system_prompt=test_case.system_prompt,
|
|
51
|
+
temperature=test_case.temperature,
|
|
52
|
+
max_tokens=test_case.max_tokens,
|
|
53
|
+
retry_config=test_case.retry_config,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
passed, validation_details = test_case.validator.validate(response)
|
|
57
|
+
|
|
58
|
+
return TestResult(
|
|
59
|
+
test_case=test_case,
|
|
60
|
+
response=response,
|
|
61
|
+
passed=passed,
|
|
62
|
+
metrics=metrics,
|
|
63
|
+
validation_details=validation_details,
|
|
64
|
+
execution_error=None,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
except (RuntimeError, ValueError, TypeError, httpx.HTTPError) as e:
|
|
68
|
+
return TestResult(
|
|
69
|
+
test_case=test_case,
|
|
70
|
+
response=None,
|
|
71
|
+
passed=False,
|
|
72
|
+
metrics=None,
|
|
73
|
+
validation_details={},
|
|
74
|
+
execution_error=str(e),
|
|
75
|
+
)
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import time
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import httpx
|
|
6
|
+
|
|
7
|
+
from promptum.core.metrics import Metrics
|
|
8
|
+
from promptum.core.retry import RetryConfig, RetryStrategy
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class OpenRouterClient:
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
api_key: str,
|
|
15
|
+
base_url: str = "https://openrouter.ai/api/v1",
|
|
16
|
+
default_retry_config: RetryConfig | None = None,
|
|
17
|
+
):
|
|
18
|
+
self.api_key = api_key
|
|
19
|
+
self.base_url = base_url
|
|
20
|
+
self.default_retry_config = default_retry_config or RetryConfig()
|
|
21
|
+
self._client: httpx.AsyncClient | None = None
|
|
22
|
+
|
|
23
|
+
async def __aenter__(self) -> "OpenRouterClient":
|
|
24
|
+
self._client = httpx.AsyncClient(
|
|
25
|
+
base_url=self.base_url,
|
|
26
|
+
headers={
|
|
27
|
+
"Authorization": f"Bearer {self.api_key}",
|
|
28
|
+
"Content-Type": "application/json",
|
|
29
|
+
},
|
|
30
|
+
timeout=self.default_retry_config.timeout,
|
|
31
|
+
)
|
|
32
|
+
return self
|
|
33
|
+
|
|
34
|
+
async def __aexit__(self, *args: Any) -> None:
|
|
35
|
+
if self._client:
|
|
36
|
+
await self._client.aclose()
|
|
37
|
+
|
|
38
|
+
async def generate(
|
|
39
|
+
self,
|
|
40
|
+
prompt: str,
|
|
41
|
+
model: str,
|
|
42
|
+
system_prompt: str | None = None,
|
|
43
|
+
temperature: float = 1.0,
|
|
44
|
+
max_tokens: int | None = None,
|
|
45
|
+
retry_config: RetryConfig | None = None,
|
|
46
|
+
**kwargs: Any,
|
|
47
|
+
) -> tuple[str, Metrics]:
|
|
48
|
+
if not self._client:
|
|
49
|
+
raise RuntimeError("Client not initialized. Use async context manager.")
|
|
50
|
+
|
|
51
|
+
config = retry_config or self.default_retry_config
|
|
52
|
+
retry_delays: list[float] = []
|
|
53
|
+
|
|
54
|
+
messages = []
|
|
55
|
+
if system_prompt:
|
|
56
|
+
messages.append({"role": "system", "content": system_prompt})
|
|
57
|
+
messages.append({"role": "user", "content": prompt})
|
|
58
|
+
|
|
59
|
+
payload: dict[str, Any] = {
|
|
60
|
+
"model": model,
|
|
61
|
+
"messages": messages,
|
|
62
|
+
"temperature": temperature,
|
|
63
|
+
}
|
|
64
|
+
if max_tokens:
|
|
65
|
+
payload["max_tokens"] = max_tokens
|
|
66
|
+
payload.update(kwargs)
|
|
67
|
+
|
|
68
|
+
for attempt in range(config.max_attempts):
|
|
69
|
+
start_time = time.perf_counter()
|
|
70
|
+
try:
|
|
71
|
+
response = await self._client.post(
|
|
72
|
+
"/chat/completions",
|
|
73
|
+
json=payload,
|
|
74
|
+
timeout=config.timeout,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
if response.status_code == 200:
|
|
78
|
+
latency_ms = (time.perf_counter() - start_time) * 1000
|
|
79
|
+
try:
|
|
80
|
+
data = response.json()
|
|
81
|
+
content = data["choices"][0]["message"]["content"]
|
|
82
|
+
except (KeyError, IndexError, TypeError) as e:
|
|
83
|
+
raise RuntimeError(f"Invalid API response structure: {e}") from e
|
|
84
|
+
|
|
85
|
+
usage = data.get("usage", {})
|
|
86
|
+
metrics = Metrics(
|
|
87
|
+
latency_ms=latency_ms,
|
|
88
|
+
prompt_tokens=usage.get("prompt_tokens"),
|
|
89
|
+
completion_tokens=usage.get("completion_tokens"),
|
|
90
|
+
total_tokens=usage.get("total_tokens"),
|
|
91
|
+
cost_usd=usage.get("cost") or usage.get("total_cost"),
|
|
92
|
+
retry_delays=tuple(retry_delays),
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
return content, metrics
|
|
96
|
+
|
|
97
|
+
if response.status_code not in config.retryable_status_codes:
|
|
98
|
+
response.raise_for_status()
|
|
99
|
+
|
|
100
|
+
if attempt < config.max_attempts - 1:
|
|
101
|
+
delay = self._calculate_delay(attempt, config)
|
|
102
|
+
retry_delays.append(delay)
|
|
103
|
+
await asyncio.sleep(delay)
|
|
104
|
+
|
|
105
|
+
except (httpx.TimeoutException, httpx.NetworkError) as e:
|
|
106
|
+
if attempt < config.max_attempts - 1:
|
|
107
|
+
delay = self._calculate_delay(attempt, config)
|
|
108
|
+
retry_delays.append(delay)
|
|
109
|
+
await asyncio.sleep(delay)
|
|
110
|
+
else:
|
|
111
|
+
raise RuntimeError(
|
|
112
|
+
f"Request failed after {config.max_attempts} attempts: {e}"
|
|
113
|
+
) from e
|
|
114
|
+
except httpx.HTTPStatusError as e:
|
|
115
|
+
raise RuntimeError(f"HTTP error {e.response.status_code}: {e.response.text}") from e
|
|
116
|
+
|
|
117
|
+
raise RuntimeError(f"Request failed after {config.max_attempts} attempts")
|
|
118
|
+
|
|
119
|
+
def _calculate_delay(self, attempt: int, config: RetryConfig) -> float:
|
|
120
|
+
if config.strategy == RetryStrategy.EXPONENTIAL_BACKOFF:
|
|
121
|
+
delay = config.initial_delay * (config.exponential_base**attempt)
|
|
122
|
+
return min(delay, config.max_delay)
|
|
123
|
+
return config.initial_delay
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from typing import Any, Protocol
|
|
2
|
+
|
|
3
|
+
from promptum.core.metrics import Metrics
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class LLMProvider(Protocol):
|
|
7
|
+
async def generate(
|
|
8
|
+
self,
|
|
9
|
+
prompt: str,
|
|
10
|
+
model: str,
|
|
11
|
+
system_prompt: str | None = None,
|
|
12
|
+
temperature: float = 1.0,
|
|
13
|
+
max_tokens: int | None = None,
|
|
14
|
+
**kwargs: Any,
|
|
15
|
+
) -> tuple[str, Metrics]:
|
|
16
|
+
"""
|
|
17
|
+
Generates a response from the LLM.
|
|
18
|
+
|
|
19
|
+
Returns:
|
|
20
|
+
(response_text, metrics)
|
|
21
|
+
"""
|
|
22
|
+
...
|
promptum/py.typed
ADDED
|
File without changes
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from promptum.serialization.html import HTMLSerializer
|
|
2
|
+
from promptum.serialization.json import JSONSerializer
|
|
3
|
+
from promptum.serialization.protocol import Serializer
|
|
4
|
+
from promptum.serialization.yaml import YAMLSerializer
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
"Serializer",
|
|
8
|
+
"JSONSerializer",
|
|
9
|
+
"YAMLSerializer",
|
|
10
|
+
"HTMLSerializer",
|
|
11
|
+
]
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
"""Base serializer with shared result serialization logic."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from promptum.core.result import TestResult
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class BaseSerializer:
|
|
9
|
+
"""
|
|
10
|
+
Base class for serializers with common result serialization logic.
|
|
11
|
+
|
|
12
|
+
Subclasses should implement:
|
|
13
|
+
- serialize(report: Report) -> str
|
|
14
|
+
- get_file_extension() -> str
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
@staticmethod
|
|
18
|
+
def _serialize_result(result: TestResult) -> dict[str, Any]:
|
|
19
|
+
"""Convert TestResult to dictionary representation."""
|
|
20
|
+
return {
|
|
21
|
+
"test_case": {
|
|
22
|
+
"name": result.test_case.name,
|
|
23
|
+
"prompt": result.test_case.prompt,
|
|
24
|
+
"model": result.test_case.model,
|
|
25
|
+
"tags": list(result.test_case.tags),
|
|
26
|
+
"system_prompt": result.test_case.system_prompt,
|
|
27
|
+
"temperature": result.test_case.temperature,
|
|
28
|
+
"max_tokens": result.test_case.max_tokens,
|
|
29
|
+
"metadata": result.test_case.metadata,
|
|
30
|
+
"validator": result.test_case.validator.describe(),
|
|
31
|
+
},
|
|
32
|
+
"response": result.response,
|
|
33
|
+
"passed": result.passed,
|
|
34
|
+
"metrics": {
|
|
35
|
+
"latency_ms": result.metrics.latency_ms,
|
|
36
|
+
"prompt_tokens": result.metrics.prompt_tokens,
|
|
37
|
+
"completion_tokens": result.metrics.completion_tokens,
|
|
38
|
+
"total_tokens": result.metrics.total_tokens,
|
|
39
|
+
"cost_usd": result.metrics.cost_usd,
|
|
40
|
+
"retry_delays": list(result.metrics.retry_delays),
|
|
41
|
+
"total_attempts": result.metrics.total_attempts,
|
|
42
|
+
}
|
|
43
|
+
if result.metrics
|
|
44
|
+
else None,
|
|
45
|
+
"validation_details": result.validation_details,
|
|
46
|
+
"execution_error": result.execution_error,
|
|
47
|
+
"timestamp": result.timestamp.isoformat(),
|
|
48
|
+
}
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
from jinja2 import Template
|
|
5
|
+
|
|
6
|
+
from promptum.benchmark.report import Report
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class HTMLSerializer:
|
|
10
|
+
def __init__(self) -> None:
|
|
11
|
+
template_path = Path(__file__).parent / "report_template.html"
|
|
12
|
+
self._template = Template(template_path.read_text())
|
|
13
|
+
|
|
14
|
+
def serialize(self, report: Report) -> str:
|
|
15
|
+
summary = report.get_summary()
|
|
16
|
+
|
|
17
|
+
results_data = []
|
|
18
|
+
for result in report.results:
|
|
19
|
+
results_data.append(
|
|
20
|
+
{
|
|
21
|
+
"test_case": {
|
|
22
|
+
"name": result.test_case.name,
|
|
23
|
+
"prompt": result.test_case.prompt,
|
|
24
|
+
"model": result.test_case.model,
|
|
25
|
+
"tags": list(result.test_case.tags),
|
|
26
|
+
"system_prompt": result.test_case.system_prompt,
|
|
27
|
+
"validator": result.test_case.validator.describe(),
|
|
28
|
+
},
|
|
29
|
+
"response": result.response,
|
|
30
|
+
"passed": result.passed,
|
|
31
|
+
"metrics": {
|
|
32
|
+
"latency_ms": result.metrics.latency_ms,
|
|
33
|
+
"prompt_tokens": result.metrics.prompt_tokens,
|
|
34
|
+
"completion_tokens": result.metrics.completion_tokens,
|
|
35
|
+
"total_tokens": result.metrics.total_tokens,
|
|
36
|
+
"cost_usd": result.metrics.cost_usd,
|
|
37
|
+
"total_attempts": result.metrics.total_attempts,
|
|
38
|
+
}
|
|
39
|
+
if result.metrics
|
|
40
|
+
else None,
|
|
41
|
+
"execution_error": result.execution_error,
|
|
42
|
+
}
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
return self._template.render(
|
|
46
|
+
summary=summary,
|
|
47
|
+
results=results_data,
|
|
48
|
+
results_json=json.dumps(results_data),
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
def get_file_extension(self) -> str:
|
|
52
|
+
return "html"
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from datetime import datetime
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from promptum.benchmark.report import Report
|
|
6
|
+
from promptum.serialization.base import BaseSerializer
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class JSONSerializer(BaseSerializer):
|
|
10
|
+
def __init__(self, indent: int = 2):
|
|
11
|
+
self.indent = indent
|
|
12
|
+
|
|
13
|
+
def serialize(self, report: Report) -> str:
|
|
14
|
+
data = {
|
|
15
|
+
"metadata": report.metadata,
|
|
16
|
+
"summary": report.get_summary(),
|
|
17
|
+
"results": [self._serialize_result(r) for r in report.results],
|
|
18
|
+
}
|
|
19
|
+
return json.dumps(data, indent=self.indent, default=self._json_default)
|
|
20
|
+
|
|
21
|
+
def get_file_extension(self) -> str:
|
|
22
|
+
return "json"
|
|
23
|
+
|
|
24
|
+
@staticmethod
|
|
25
|
+
def _json_default(obj: Any) -> Any:
|
|
26
|
+
if isinstance(obj, datetime):
|
|
27
|
+
return obj.isoformat()
|
|
28
|
+
raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from typing import Protocol
|
|
2
|
+
|
|
3
|
+
from promptum.benchmark.report import Report
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class Serializer(Protocol):
|
|
7
|
+
def serialize(self, report: Report) -> str:
|
|
8
|
+
"""Serializes a Report to a string format."""
|
|
9
|
+
...
|
|
10
|
+
|
|
11
|
+
def get_file_extension(self) -> str:
|
|
12
|
+
"""Returns the file extension for this format (e.g., 'json', 'html')."""
|
|
13
|
+
...
|