promptum 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
promptum/__init__.py ADDED
@@ -0,0 +1,44 @@
1
+ from promptum.benchmark import Benchmark, Report
2
+ from promptum.core import Metrics, RetryConfig, RetryStrategy, TestCase, TestResult
3
+ from promptum.execution import Runner
4
+ from promptum.providers import LLMProvider, OpenRouterClient
5
+ from promptum.serialization import (
6
+ HTMLSerializer,
7
+ JSONSerializer,
8
+ Serializer,
9
+ YAMLSerializer,
10
+ )
11
+ from promptum.storage import FileStorage, ResultStorage
12
+ from promptum.validation import (
13
+ Contains,
14
+ ExactMatch,
15
+ JsonSchema,
16
+ Regex,
17
+ Validator,
18
+ )
19
+
20
+ __version__ = "0.1.0"
21
+
22
+ __all__ = [
23
+ "TestCase",
24
+ "TestResult",
25
+ "Metrics",
26
+ "RetryConfig",
27
+ "RetryStrategy",
28
+ "Validator",
29
+ "ExactMatch",
30
+ "Contains",
31
+ "Regex",
32
+ "JsonSchema",
33
+ "LLMProvider",
34
+ "OpenRouterClient",
35
+ "Runner",
36
+ "Benchmark",
37
+ "Report",
38
+ "Serializer",
39
+ "JSONSerializer",
40
+ "YAMLSerializer",
41
+ "HTMLSerializer",
42
+ "ResultStorage",
43
+ "FileStorage",
44
+ ]
@@ -0,0 +1,4 @@
1
+ from promptum.benchmark.benchmark import Benchmark
2
+ from promptum.benchmark.report import Report
3
+
4
+ __all__ = ["Benchmark", "Report"]
@@ -0,0 +1,50 @@
1
+ import asyncio
2
+ from collections.abc import Callable, Sequence
3
+ from typing import Any
4
+
5
+ from promptum.benchmark.report import Report
6
+ from promptum.core.result import TestResult
7
+ from promptum.core.test_case import TestCase
8
+ from promptum.execution.runner import Runner
9
+ from promptum.providers.protocol import LLMProvider
10
+
11
+
12
+ class Benchmark:
13
+ def __init__(
14
+ self,
15
+ provider: LLMProvider,
16
+ name: str = "benchmark",
17
+ max_concurrent: int = 5,
18
+ progress_callback: Callable[[int, int, TestResult], None] | None = None,
19
+ ):
20
+ self.provider = provider
21
+ self.name = name
22
+ self.max_concurrent = max_concurrent
23
+ self.progress_callback = progress_callback
24
+ self._test_cases: list[TestCase] = []
25
+
26
+ def add_test(self, test_case: TestCase) -> None:
27
+ self._test_cases.append(test_case)
28
+
29
+ def add_tests(self, test_cases: Sequence[TestCase]) -> None:
30
+ self._test_cases.extend(test_cases)
31
+
32
+ def run(self, metadata: dict[str, Any] | None = None) -> Report:
33
+ return asyncio.run(self.run_async(metadata))
34
+
35
+ async def run_async(self, metadata: dict[str, Any] | None = None) -> Report:
36
+ if not self._test_cases:
37
+ return Report(results=[], metadata=metadata or {})
38
+
39
+ runner = Runner(
40
+ provider=self.provider,
41
+ max_concurrent=self.max_concurrent,
42
+ progress_callback=self.progress_callback,
43
+ )
44
+
45
+ results = await runner.run(self._test_cases)
46
+
47
+ return Report(
48
+ results=results,
49
+ metadata=metadata or {},
50
+ )
@@ -0,0 +1,75 @@
1
+ from collections.abc import Callable, Sequence
2
+ from dataclasses import dataclass
3
+ from typing import Any
4
+
5
+ from promptum.core.result import TestResult
6
+
7
+
8
+ @dataclass(frozen=True, slots=True)
9
+ class Report:
10
+ results: Sequence[TestResult]
11
+ metadata: dict[str, Any]
12
+
13
+ def get_summary(self) -> dict[str, Any]:
14
+ total = len(self.results)
15
+ passed = sum(1 for r in self.results if r.passed)
16
+
17
+ latencies = [r.metrics.latency_ms for r in self.results if r.metrics]
18
+ total_cost = sum(r.metrics.cost_usd or 0 for r in self.results if r.metrics)
19
+ total_tokens = sum(r.metrics.total_tokens or 0 for r in self.results if r.metrics)
20
+
21
+ return {
22
+ "total": total,
23
+ "passed": passed,
24
+ "failed": total - passed,
25
+ "pass_rate": passed / total if total > 0 else 0,
26
+ "avg_latency_ms": sum(latencies) / len(latencies) if latencies else 0,
27
+ "p50_latency_ms": self._percentile(latencies, 0.5) if latencies else 0,
28
+ "p95_latency_ms": self._percentile(latencies, 0.95) if latencies else 0,
29
+ "p99_latency_ms": self._percentile(latencies, 0.99) if latencies else 0,
30
+ "total_cost_usd": total_cost,
31
+ "total_tokens": total_tokens,
32
+ }
33
+
34
+ def filter(
35
+ self,
36
+ model: str | None = None,
37
+ tags: Sequence[str] | None = None,
38
+ passed: bool | None = None,
39
+ ) -> "Report":
40
+ filtered = list(self.results)
41
+
42
+ if model:
43
+ filtered = [r for r in filtered if r.test_case.model == model]
44
+
45
+ if tags:
46
+ tag_set = set(tags)
47
+ filtered = [r for r in filtered if tag_set.intersection(r.test_case.tags)]
48
+
49
+ if passed is not None:
50
+ filtered = [r for r in filtered if r.passed == passed]
51
+
52
+ return Report(results=filtered, metadata=self.metadata)
53
+
54
+ def group_by(self, key: Callable[[TestResult], str]) -> dict[str, "Report"]:
55
+ groups: dict[str, list[TestResult]] = {}
56
+
57
+ for result in self.results:
58
+ group_key = key(result)
59
+ if group_key not in groups:
60
+ groups[group_key] = []
61
+ groups[group_key].append(result)
62
+
63
+ return {k: Report(results=v, metadata=self.metadata) for k, v in groups.items()}
64
+
65
+ def compare_models(self) -> dict[str, dict[str, Any]]:
66
+ by_model = self.group_by(lambda r: r.test_case.model)
67
+ return {model: report.get_summary() for model, report in by_model.items()}
68
+
69
+ @staticmethod
70
+ def _percentile(values: list[float], p: float) -> float:
71
+ if not values:
72
+ return 0
73
+ sorted_values = sorted(values)
74
+ index = int(len(sorted_values) * p)
75
+ return sorted_values[min(index, len(sorted_values) - 1)]
@@ -0,0 +1,12 @@
1
+ from promptum.core.metrics import Metrics
2
+ from promptum.core.result import TestResult
3
+ from promptum.core.retry import RetryConfig, RetryStrategy
4
+ from promptum.core.test_case import TestCase
5
+
6
+ __all__ = [
7
+ "Metrics",
8
+ "RetryConfig",
9
+ "RetryStrategy",
10
+ "TestCase",
11
+ "TestResult",
12
+ ]
@@ -0,0 +1,16 @@
1
+ from collections.abc import Sequence
2
+ from dataclasses import dataclass
3
+
4
+
5
+ @dataclass(frozen=True, slots=True)
6
+ class Metrics:
7
+ latency_ms: float
8
+ prompt_tokens: int | None = None
9
+ completion_tokens: int | None = None
10
+ total_tokens: int | None = None
11
+ cost_usd: float | None = None
12
+ retry_delays: Sequence[float] = ()
13
+
14
+ @property
15
+ def total_attempts(self) -> int:
16
+ return len(self.retry_delays) + 1
@@ -0,0 +1,17 @@
1
+ from dataclasses import dataclass, field
2
+ from datetime import datetime
3
+ from typing import Any
4
+
5
+ from promptum.core.metrics import Metrics
6
+ from promptum.core.test_case import TestCase
7
+
8
+
9
+ @dataclass(frozen=True, slots=True)
10
+ class TestResult:
11
+ test_case: TestCase
12
+ response: str | None
13
+ passed: bool
14
+ metrics: Metrics | None
15
+ validation_details: dict[str, Any]
16
+ execution_error: str | None = None
17
+ timestamp: datetime = field(default_factory=datetime.now)
promptum/core/retry.py ADDED
@@ -0,0 +1,19 @@
1
+ from collections.abc import Sequence
2
+ from dataclasses import dataclass
3
+ from enum import Enum
4
+
5
+
6
+ class RetryStrategy(Enum):
7
+ EXPONENTIAL_BACKOFF = "exponential_backoff"
8
+ FIXED_DELAY = "fixed_delay"
9
+
10
+
11
+ @dataclass(frozen=True, slots=True)
12
+ class RetryConfig:
13
+ max_attempts: int = 3
14
+ strategy: RetryStrategy = RetryStrategy.EXPONENTIAL_BACKOFF
15
+ initial_delay: float = 1.0
16
+ max_delay: float = 60.0
17
+ exponential_base: float = 2.0
18
+ retryable_status_codes: Sequence[int] = (429, 500, 502, 503, 504)
19
+ timeout: float = 60.0
@@ -0,0 +1,22 @@
1
+ from collections.abc import Sequence
2
+ from dataclasses import dataclass, field
3
+ from typing import TYPE_CHECKING, Any
4
+
5
+ if TYPE_CHECKING:
6
+ from promptum.validation.protocol import Validator
7
+
8
+ from promptum.core.retry import RetryConfig
9
+
10
+
11
+ @dataclass(frozen=True, slots=True)
12
+ class TestCase:
13
+ name: str
14
+ prompt: str
15
+ model: str
16
+ validator: "Validator"
17
+ tags: Sequence[str] = ()
18
+ system_prompt: str | None = None
19
+ temperature: float = 1.0
20
+ max_tokens: int | None = None
21
+ retry_config: RetryConfig | None = None
22
+ metadata: dict[str, Any] = field(default_factory=dict)
@@ -0,0 +1,3 @@
1
+ from promptum.execution.runner import Runner
2
+
3
+ __all__ = ["Runner"]
@@ -0,0 +1,75 @@
1
+ import asyncio
2
+ from collections.abc import Callable, Sequence
3
+
4
+ import httpx
5
+
6
+ from promptum.core.result import TestResult
7
+ from promptum.core.test_case import TestCase
8
+ from promptum.providers.protocol import LLMProvider
9
+
10
+
11
+ class Runner:
12
+ def __init__(
13
+ self,
14
+ provider: LLMProvider,
15
+ max_concurrent: int = 5,
16
+ progress_callback: Callable[[int, int, TestResult], None] | None = None,
17
+ ):
18
+ self.provider = provider
19
+ self.max_concurrent = max_concurrent
20
+ self.progress_callback = progress_callback
21
+
22
+ async def run(self, test_cases: Sequence[TestCase]) -> list[TestResult]:
23
+ semaphore = asyncio.Semaphore(self.max_concurrent)
24
+ completed = 0
25
+ total = len(test_cases)
26
+
27
+ async def run_with_semaphore(test_case: TestCase) -> TestResult:
28
+ async with semaphore:
29
+ result = await self._run_single_test(test_case)
30
+
31
+ nonlocal completed
32
+ completed += 1
33
+ if self.progress_callback:
34
+ self.progress_callback(completed, total, result)
35
+
36
+ return result
37
+
38
+ results = await asyncio.gather(
39
+ *[run_with_semaphore(tc) for tc in test_cases],
40
+ return_exceptions=False,
41
+ )
42
+
43
+ return list(results)
44
+
45
+ async def _run_single_test(self, test_case: TestCase) -> TestResult:
46
+ try:
47
+ response, metrics = await self.provider.generate(
48
+ prompt=test_case.prompt,
49
+ model=test_case.model,
50
+ system_prompt=test_case.system_prompt,
51
+ temperature=test_case.temperature,
52
+ max_tokens=test_case.max_tokens,
53
+ retry_config=test_case.retry_config,
54
+ )
55
+
56
+ passed, validation_details = test_case.validator.validate(response)
57
+
58
+ return TestResult(
59
+ test_case=test_case,
60
+ response=response,
61
+ passed=passed,
62
+ metrics=metrics,
63
+ validation_details=validation_details,
64
+ execution_error=None,
65
+ )
66
+
67
+ except (RuntimeError, ValueError, TypeError, httpx.HTTPError) as e:
68
+ return TestResult(
69
+ test_case=test_case,
70
+ response=None,
71
+ passed=False,
72
+ metrics=None,
73
+ validation_details={},
74
+ execution_error=str(e),
75
+ )
@@ -0,0 +1,7 @@
1
+ from promptum.providers.openrouter import OpenRouterClient
2
+ from promptum.providers.protocol import LLMProvider
3
+
4
+ __all__ = [
5
+ "LLMProvider",
6
+ "OpenRouterClient",
7
+ ]
@@ -0,0 +1,123 @@
1
+ import asyncio
2
+ import time
3
+ from typing import Any
4
+
5
+ import httpx
6
+
7
+ from promptum.core.metrics import Metrics
8
+ from promptum.core.retry import RetryConfig, RetryStrategy
9
+
10
+
11
+ class OpenRouterClient:
12
+ def __init__(
13
+ self,
14
+ api_key: str,
15
+ base_url: str = "https://openrouter.ai/api/v1",
16
+ default_retry_config: RetryConfig | None = None,
17
+ ):
18
+ self.api_key = api_key
19
+ self.base_url = base_url
20
+ self.default_retry_config = default_retry_config or RetryConfig()
21
+ self._client: httpx.AsyncClient | None = None
22
+
23
+ async def __aenter__(self) -> "OpenRouterClient":
24
+ self._client = httpx.AsyncClient(
25
+ base_url=self.base_url,
26
+ headers={
27
+ "Authorization": f"Bearer {self.api_key}",
28
+ "Content-Type": "application/json",
29
+ },
30
+ timeout=self.default_retry_config.timeout,
31
+ )
32
+ return self
33
+
34
+ async def __aexit__(self, *args: Any) -> None:
35
+ if self._client:
36
+ await self._client.aclose()
37
+
38
+ async def generate(
39
+ self,
40
+ prompt: str,
41
+ model: str,
42
+ system_prompt: str | None = None,
43
+ temperature: float = 1.0,
44
+ max_tokens: int | None = None,
45
+ retry_config: RetryConfig | None = None,
46
+ **kwargs: Any,
47
+ ) -> tuple[str, Metrics]:
48
+ if not self._client:
49
+ raise RuntimeError("Client not initialized. Use async context manager.")
50
+
51
+ config = retry_config or self.default_retry_config
52
+ retry_delays: list[float] = []
53
+
54
+ messages = []
55
+ if system_prompt:
56
+ messages.append({"role": "system", "content": system_prompt})
57
+ messages.append({"role": "user", "content": prompt})
58
+
59
+ payload: dict[str, Any] = {
60
+ "model": model,
61
+ "messages": messages,
62
+ "temperature": temperature,
63
+ }
64
+ if max_tokens:
65
+ payload["max_tokens"] = max_tokens
66
+ payload.update(kwargs)
67
+
68
+ for attempt in range(config.max_attempts):
69
+ start_time = time.perf_counter()
70
+ try:
71
+ response = await self._client.post(
72
+ "/chat/completions",
73
+ json=payload,
74
+ timeout=config.timeout,
75
+ )
76
+
77
+ if response.status_code == 200:
78
+ latency_ms = (time.perf_counter() - start_time) * 1000
79
+ try:
80
+ data = response.json()
81
+ content = data["choices"][0]["message"]["content"]
82
+ except (KeyError, IndexError, TypeError) as e:
83
+ raise RuntimeError(f"Invalid API response structure: {e}") from e
84
+
85
+ usage = data.get("usage", {})
86
+ metrics = Metrics(
87
+ latency_ms=latency_ms,
88
+ prompt_tokens=usage.get("prompt_tokens"),
89
+ completion_tokens=usage.get("completion_tokens"),
90
+ total_tokens=usage.get("total_tokens"),
91
+ cost_usd=usage.get("cost") or usage.get("total_cost"),
92
+ retry_delays=tuple(retry_delays),
93
+ )
94
+
95
+ return content, metrics
96
+
97
+ if response.status_code not in config.retryable_status_codes:
98
+ response.raise_for_status()
99
+
100
+ if attempt < config.max_attempts - 1:
101
+ delay = self._calculate_delay(attempt, config)
102
+ retry_delays.append(delay)
103
+ await asyncio.sleep(delay)
104
+
105
+ except (httpx.TimeoutException, httpx.NetworkError) as e:
106
+ if attempt < config.max_attempts - 1:
107
+ delay = self._calculate_delay(attempt, config)
108
+ retry_delays.append(delay)
109
+ await asyncio.sleep(delay)
110
+ else:
111
+ raise RuntimeError(
112
+ f"Request failed after {config.max_attempts} attempts: {e}"
113
+ ) from e
114
+ except httpx.HTTPStatusError as e:
115
+ raise RuntimeError(f"HTTP error {e.response.status_code}: {e.response.text}") from e
116
+
117
+ raise RuntimeError(f"Request failed after {config.max_attempts} attempts")
118
+
119
+ def _calculate_delay(self, attempt: int, config: RetryConfig) -> float:
120
+ if config.strategy == RetryStrategy.EXPONENTIAL_BACKOFF:
121
+ delay = config.initial_delay * (config.exponential_base**attempt)
122
+ return min(delay, config.max_delay)
123
+ return config.initial_delay
@@ -0,0 +1,22 @@
1
+ from typing import Any, Protocol
2
+
3
+ from promptum.core.metrics import Metrics
4
+
5
+
6
+ class LLMProvider(Protocol):
7
+ async def generate(
8
+ self,
9
+ prompt: str,
10
+ model: str,
11
+ system_prompt: str | None = None,
12
+ temperature: float = 1.0,
13
+ max_tokens: int | None = None,
14
+ **kwargs: Any,
15
+ ) -> tuple[str, Metrics]:
16
+ """
17
+ Generates a response from the LLM.
18
+
19
+ Returns:
20
+ (response_text, metrics)
21
+ """
22
+ ...
promptum/py.typed ADDED
File without changes
@@ -0,0 +1,11 @@
1
+ from promptum.serialization.html import HTMLSerializer
2
+ from promptum.serialization.json import JSONSerializer
3
+ from promptum.serialization.protocol import Serializer
4
+ from promptum.serialization.yaml import YAMLSerializer
5
+
6
+ __all__ = [
7
+ "Serializer",
8
+ "JSONSerializer",
9
+ "YAMLSerializer",
10
+ "HTMLSerializer",
11
+ ]
@@ -0,0 +1,48 @@
1
+ """Base serializer with shared result serialization logic."""
2
+
3
+ from typing import Any
4
+
5
+ from promptum.core.result import TestResult
6
+
7
+
8
+ class BaseSerializer:
9
+ """
10
+ Base class for serializers with common result serialization logic.
11
+
12
+ Subclasses should implement:
13
+ - serialize(report: Report) -> str
14
+ - get_file_extension() -> str
15
+ """
16
+
17
+ @staticmethod
18
+ def _serialize_result(result: TestResult) -> dict[str, Any]:
19
+ """Convert TestResult to dictionary representation."""
20
+ return {
21
+ "test_case": {
22
+ "name": result.test_case.name,
23
+ "prompt": result.test_case.prompt,
24
+ "model": result.test_case.model,
25
+ "tags": list(result.test_case.tags),
26
+ "system_prompt": result.test_case.system_prompt,
27
+ "temperature": result.test_case.temperature,
28
+ "max_tokens": result.test_case.max_tokens,
29
+ "metadata": result.test_case.metadata,
30
+ "validator": result.test_case.validator.describe(),
31
+ },
32
+ "response": result.response,
33
+ "passed": result.passed,
34
+ "metrics": {
35
+ "latency_ms": result.metrics.latency_ms,
36
+ "prompt_tokens": result.metrics.prompt_tokens,
37
+ "completion_tokens": result.metrics.completion_tokens,
38
+ "total_tokens": result.metrics.total_tokens,
39
+ "cost_usd": result.metrics.cost_usd,
40
+ "retry_delays": list(result.metrics.retry_delays),
41
+ "total_attempts": result.metrics.total_attempts,
42
+ }
43
+ if result.metrics
44
+ else None,
45
+ "validation_details": result.validation_details,
46
+ "execution_error": result.execution_error,
47
+ "timestamp": result.timestamp.isoformat(),
48
+ }
@@ -0,0 +1,52 @@
1
+ import json
2
+ from pathlib import Path
3
+
4
+ from jinja2 import Template
5
+
6
+ from promptum.benchmark.report import Report
7
+
8
+
9
+ class HTMLSerializer:
10
+ def __init__(self) -> None:
11
+ template_path = Path(__file__).parent / "report_template.html"
12
+ self._template = Template(template_path.read_text())
13
+
14
+ def serialize(self, report: Report) -> str:
15
+ summary = report.get_summary()
16
+
17
+ results_data = []
18
+ for result in report.results:
19
+ results_data.append(
20
+ {
21
+ "test_case": {
22
+ "name": result.test_case.name,
23
+ "prompt": result.test_case.prompt,
24
+ "model": result.test_case.model,
25
+ "tags": list(result.test_case.tags),
26
+ "system_prompt": result.test_case.system_prompt,
27
+ "validator": result.test_case.validator.describe(),
28
+ },
29
+ "response": result.response,
30
+ "passed": result.passed,
31
+ "metrics": {
32
+ "latency_ms": result.metrics.latency_ms,
33
+ "prompt_tokens": result.metrics.prompt_tokens,
34
+ "completion_tokens": result.metrics.completion_tokens,
35
+ "total_tokens": result.metrics.total_tokens,
36
+ "cost_usd": result.metrics.cost_usd,
37
+ "total_attempts": result.metrics.total_attempts,
38
+ }
39
+ if result.metrics
40
+ else None,
41
+ "execution_error": result.execution_error,
42
+ }
43
+ )
44
+
45
+ return self._template.render(
46
+ summary=summary,
47
+ results=results_data,
48
+ results_json=json.dumps(results_data),
49
+ )
50
+
51
+ def get_file_extension(self) -> str:
52
+ return "html"
@@ -0,0 +1,28 @@
1
+ import json
2
+ from datetime import datetime
3
+ from typing import Any
4
+
5
+ from promptum.benchmark.report import Report
6
+ from promptum.serialization.base import BaseSerializer
7
+
8
+
9
+ class JSONSerializer(BaseSerializer):
10
+ def __init__(self, indent: int = 2):
11
+ self.indent = indent
12
+
13
+ def serialize(self, report: Report) -> str:
14
+ data = {
15
+ "metadata": report.metadata,
16
+ "summary": report.get_summary(),
17
+ "results": [self._serialize_result(r) for r in report.results],
18
+ }
19
+ return json.dumps(data, indent=self.indent, default=self._json_default)
20
+
21
+ def get_file_extension(self) -> str:
22
+ return "json"
23
+
24
+ @staticmethod
25
+ def _json_default(obj: Any) -> Any:
26
+ if isinstance(obj, datetime):
27
+ return obj.isoformat()
28
+ raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
@@ -0,0 +1,13 @@
1
+ from typing import Protocol
2
+
3
+ from promptum.benchmark.report import Report
4
+
5
+
6
+ class Serializer(Protocol):
7
+ def serialize(self, report: Report) -> str:
8
+ """Serializes a Report to a string format."""
9
+ ...
10
+
11
+ def get_file_extension(self) -> str:
12
+ """Returns the file extension for this format (e.g., 'json', 'html')."""
13
+ ...