evaldata 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
evaldata/__init__.py ADDED
@@ -0,0 +1,35 @@
1
+ """evaldata — AI evals framework for data and analytics engineering teams."""
2
+
3
+ from typing import TYPE_CHECKING, Any
4
+
5
+ from evaldata.core import assert_eval
6
+ from evaldata.loaders import eval_case
7
+ from evaldata.scorers import ExpectationSuiteScorer, ResultSetEquivalence
8
+ from evaldata.solvers import CallableSolver
9
+ from evaldata.types import EvalCase, PlatformRef
10
+
11
+ if TYPE_CHECKING:
12
+ from evaldata.solvers import PromptSolver as PromptSolver
13
+
14
+ __all__ = [
15
+ "CallableSolver",
16
+ "EvalCase",
17
+ "ExpectationSuiteScorer",
18
+ "PlatformRef",
19
+ "ResultSetEquivalence",
20
+ "assert_eval",
21
+ "eval_case",
22
+ ]
23
+
24
+
25
+ def __getattr__(name: str) -> Any:
26
+ if name == "PromptSolver":
27
+ from evaldata.solvers import PromptSolver
28
+
29
+ return PromptSolver
30
+ msg = f"module {__name__!r} has no attribute {name!r}"
31
+ raise AttributeError(msg)
32
+
33
+
34
+ def __dir__() -> list[str]:
35
+ return sorted([*globals(), "PromptSolver"])
evaldata/cli.py ADDED
@@ -0,0 +1,191 @@
1
+ """The `evaldata` command-line interface."""
2
+
3
+ import subprocess
4
+ import sys
5
+ from pathlib import Path
6
+
7
+ import typer
8
+ from rich.console import Console
9
+ from rich.table import Table
10
+ from rich.text import Text
11
+
12
+ from evaldata.platforms.registry import (
13
+ close_all,
14
+ databricks_platform,
15
+ duckdb_platform,
16
+ postgres_platform,
17
+ resolve,
18
+ )
19
+ from evaldata.types import PlatformRef
20
+
21
+ app = typer.Typer(help="AI evals for data & analytics engineering teams.", no_args_is_help=True)
22
+
23
+
24
+ @app.command(context_settings={"allow_extra_args": True, "ignore_unknown_options": True})
25
+ def run(
26
+ ctx: typer.Context,
27
+ path: str | None = typer.Argument(None, help="Path or test id to run; omit to use pytest's testpaths."),
28
+ json_path: Path | None = typer.Option(
29
+ None,
30
+ "--json",
31
+ metavar="PATH",
32
+ help="Also write the structured evaldata results JSON to PATH (off by default).",
33
+ ),
34
+ ) -> None:
35
+ """Run the eval suite via pytest, forwarding any extra pytest arguments verbatim.
36
+
37
+ Args:
38
+ ctx: The Typer context; its extra args are forwarded straight to pytest.
39
+ path: A path or test id to run; omit to use pytest's `testpaths`.
40
+ json_path: If given, also write the structured results JSON to this path.
41
+
42
+ Raises:
43
+ Exit: Always, carrying pytest's return code as the process exit code.
44
+ """
45
+ cmd = [sys.executable, "-m", "pytest"]
46
+ if path is not None:
47
+ cmd.append(path)
48
+ if json_path is not None:
49
+ cmd.append(f"--evaldata-json={json_path}")
50
+ cmd.extend(ctx.args)
51
+ completed = subprocess.run(cmd) # noqa: PLW1510 - exit code is forwarded, not raised on
52
+ raise typer.Exit(completed.returncode)
53
+
54
+
55
+ def _build_refs(
56
+ *,
57
+ duckdb: str | None,
58
+ postgres: str | None,
59
+ databricks_server_hostname: str | None = None,
60
+ databricks_http_path: str | None = None,
61
+ ) -> list[PlatformRef]:
62
+ """Build a `PlatformRef` for each platform flag that was provided.
63
+
64
+ Each branch routes through the typed registry builder, so a flag can only ever name a
65
+ real `PlatformKind`. The Databricks ref is built only when both its server hostname and
66
+ HTTP path are given (it has no single-value form).
67
+
68
+ Args:
69
+ duckdb: A DuckDB database path, or `None` if the flag was not given.
70
+ postgres: A PostgreSQL conninfo, or `None` if the flag was not given.
71
+ databricks_server_hostname: A Databricks workspace hostname, or `None`.
72
+ databricks_http_path: A Databricks SQL Warehouse HTTP path, or `None`.
73
+
74
+ Returns:
75
+ One `PlatformRef` per platform whose flag(s) were provided, in flag order.
76
+ """
77
+ refs: list[PlatformRef] = []
78
+ if duckdb is not None:
79
+ refs.append(duckdb_platform(name="duckdb", path=duckdb))
80
+ if postgres is not None:
81
+ refs.append(postgres_platform(name="postgres", conninfo=postgres))
82
+ if databricks_server_hostname is not None and databricks_http_path is not None:
83
+ refs.append(
84
+ databricks_platform(
85
+ name="databricks",
86
+ server_hostname=databricks_server_hostname,
87
+ http_path=databricks_http_path,
88
+ )
89
+ )
90
+ return refs
91
+
92
+
93
+ def _probe(ref: PlatformRef) -> tuple[bool, str]:
94
+ """Resolve `ref` to a live adapter and run `SELECT 1`.
95
+
96
+ Catches broadly on purpose: adapter construction can raise (e.g. psycopg fails to
97
+ connect, or an optional driver is missing), and `doctor` must report that as a FAIL
98
+ rather than crash. A query that fails as a value (`ExecutionResult.error`) is a FAIL
99
+ too.
100
+
101
+ Args:
102
+ ref: The platform reference to probe.
103
+
104
+ Returns:
105
+ A tuple `(ok, detail)`: `ok` is whether the probe succeeded, and `detail` is a
106
+ human-readable status or error message.
107
+ """
108
+ try:
109
+ result = resolve(ref).execute("SELECT 1")
110
+ except Exception as e: # noqa: BLE001 - diagnostics: any failure is a reported FAIL
111
+ return False, str(e)
112
+ if result.error is not None:
113
+ return False, result.error
114
+ return True, "connected"
115
+
116
+
117
+ @app.command()
118
+ def doctor(
119
+ duckdb: str | None = typer.Option(
120
+ None, "--duckdb", metavar="PATH", envvar="EVALDATA_DUCKDB_PATH", help="DuckDB database path to check."
121
+ ),
122
+ postgres: str | None = typer.Option(
123
+ None,
124
+ "--postgres",
125
+ metavar="CONNINFO",
126
+ envvar="EVALDATA_POSTGRES_CONNINFO",
127
+ help='PostgreSQL libpq conninfo to check (empty "" uses PG* env vars / libpq defaults).',
128
+ ),
129
+ databricks_server_hostname: str | None = typer.Option(
130
+ None,
131
+ "--databricks-server-hostname",
132
+ metavar="HOST",
133
+ envvar="DATABRICKS_SERVER_HOSTNAME",
134
+ help="Databricks workspace hostname to check (paired with --databricks-http-path).",
135
+ ),
136
+ databricks_http_path: str | None = typer.Option(
137
+ None,
138
+ "--databricks-http-path",
139
+ metavar="PATH",
140
+ envvar="DATABRICKS_HTTP_PATH",
141
+ help="Databricks SQL Warehouse HTTP path to check (paired with --databricks-server-hostname).",
142
+ ),
143
+ ) -> None:
144
+ """Check that the given platform connections work (one --<kind> flag per platform).
145
+
146
+ Args:
147
+ duckdb: A DuckDB database path to check (also read from `EVALDATA_DUCKDB_PATH`).
148
+ postgres: A PostgreSQL conninfo to check (also read from
149
+ `EVALDATA_POSTGRES_CONNINFO`).
150
+ databricks_server_hostname: A Databricks workspace hostname to check (also read from
151
+ `DATABRICKS_SERVER_HOSTNAME`); required together with `databricks_http_path`.
152
+ databricks_http_path: A Databricks SQL Warehouse HTTP path to check (also read from
153
+ `DATABRICKS_HTTP_PATH`); required together with `databricks_server_hostname`.
154
+
155
+ Raises:
156
+ BadParameter: If no platform flag is provided, or only one of the two Databricks flags is.
157
+ Exit: With code 1 if any platform connection fails.
158
+ """
159
+ if (databricks_server_hostname is None) != (databricks_http_path is None):
160
+ msg = "--databricks-server-hostname and --databricks-http-path must be given together"
161
+ raise typer.BadParameter(msg)
162
+ refs = _build_refs(
163
+ duckdb=duckdb,
164
+ postgres=postgres,
165
+ databricks_server_hostname=databricks_server_hostname,
166
+ databricks_http_path=databricks_http_path,
167
+ )
168
+ if not refs:
169
+ msg = "specify at least one platform, e.g. --duckdb PATH or --postgres CONNINFO"
170
+ raise typer.BadParameter(msg)
171
+
172
+ console = Console()
173
+ table = Table(title="evaldata doctor", title_justify="left")
174
+ table.add_column("platform")
175
+ table.add_column("kind")
176
+ table.add_column("status")
177
+
178
+ all_ok = True
179
+ try:
180
+ for ref in refs:
181
+ ok, detail = _probe(ref)
182
+ all_ok = all_ok and ok
183
+ mark = "OK" if ok else "FAIL"
184
+ # Text (not markup) so bracketed driver messages render verbatim.
185
+ table.add_row(ref.name, ref.kind, Text(f"{mark} {detail}", style="green" if ok else "red"))
186
+ finally:
187
+ close_all() # this CLI invocation owns the adapters it resolved
188
+
189
+ console.print(table)
190
+ if not all_ok:
191
+ raise typer.Exit(1)
@@ -0,0 +1,5 @@
1
+ """Core orchestration: the runner and the pytest-facing `assert_eval`."""
2
+
3
+ from evaldata.core.runner import assert_eval
4
+
5
+ __all__ = ["assert_eval"]
@@ -0,0 +1,60 @@
1
+ """Eval orchestration and the pytest-facing `assert_eval`."""
2
+
3
+ from collections.abc import Sequence
4
+
5
+ from evaldata.platforms.base import PlatformAdapter
6
+ from evaldata.platforms.registry import resolve
7
+ from evaldata.reporting.collector import CaseReport, record
8
+ from evaldata.reporting.terminal import render_failure, render_solver_error
9
+ from evaldata.scorers.base import Scorer
10
+ from evaldata.scorers.context import ScoreContext
11
+ from evaldata.scorers.query import QueryRunner
12
+ from evaldata.solvers.base import Solver
13
+ from evaldata.types import EvalCase
14
+
15
+
16
+ def assert_eval(
17
+ case: EvalCase,
18
+ solver: Solver,
19
+ *,
20
+ scorers: Sequence[Scorer],
21
+ adapter: PlatformAdapter | None = None,
22
+ ) -> None:
23
+ """Run `case` through `solver` + a platform adapter + `scorers`; raise on any failure.
24
+
25
+ Solves the case, executes the produced SQL, and scores the result with each scorer.
26
+ The adapter is the explicitly passed `adapter` if given, otherwise resolved (and
27
+ session-cached) from `case.platform`. Execution is bounded by `case.cost_budget`'s
28
+ `max_seconds`: an overrunning query is cancelled and scored as an execution failure.
29
+
30
+ Args:
31
+ case: The eval case to run.
32
+ solver: The solver that produces SQL for the case.
33
+ scorers: Scorers applied to the execution result; all must pass.
34
+ adapter: A platform adapter to execute against. If omitted, one is resolved and
35
+ session-cached from `case.platform`.
36
+
37
+ Raises:
38
+ AssertionError: If the solver fails or any scorer fails, carrying a composed
39
+ diagnostic. Raising is pytest's failure protocol.
40
+ """
41
+ output = solver.solve(case)
42
+ if output.error is not None:
43
+ error = output.error
44
+ record(CaseReport(id=case.id, input=case.input, passed=False, error=f"solver error [{error.kind}]"))
45
+ raise AssertionError(render_solver_error(case, error))
46
+ sql = output.output
47
+ if sql is None: # pragma: no cover - unreachable: SolverOutput's validator guarantees output XOR error
48
+ msg = f"evaldata case {case.id!r}: solver returned neither output nor error"
49
+ raise AssertionError(msg)
50
+ live = adapter if adapter is not None else resolve(case.platform)
51
+ max_seconds = case.cost_budget.max_seconds if case.cost_budget is not None else None
52
+ dialect = case.platform.dialect or case.platform.kind
53
+ queries = QueryRunner(live, sql, dialect, max_seconds)
54
+ result = queries.run(sql)
55
+ context = ScoreContext(queries=queries)
56
+ scores = [scorer.score(case, output, result, context=context) for scorer in scorers]
57
+ failures = [s for s in scores if not s.passed]
58
+ record(CaseReport(id=case.id, input=case.input, passed=not failures, scores=list(scores)))
59
+ if failures:
60
+ raise AssertionError(render_failure(case, output, result, failures))
@@ -0,0 +1,10 @@
1
+ """Result-set equivalence engine: column reconciliation plus the pure `build_result_set_diff` assembly seam."""
2
+
3
+ from evaldata.equivalence.columns import ColumnReconciliation, reconcile_columns
4
+ from evaldata.equivalence.compare import build_result_set_diff
5
+
6
+ __all__ = [
7
+ "ColumnReconciliation",
8
+ "build_result_set_diff",
9
+ "reconcile_columns",
10
+ ]
@@ -0,0 +1,49 @@
1
+ """Column reconciliation between actual and expected schemas."""
2
+
3
+ from typing import Literal, NamedTuple
4
+
5
+
6
+ class ColumnReconciliation(NamedTuple):
7
+ """The outcome of reconciling actual against expected column names.
8
+
9
+ Attributes:
10
+ in_both: Columns present in both, in expected order; the columns compared on.
11
+ missing: Columns expected but absent from actual, in expected order.
12
+ unexpected: Columns present in actual but not expected, in actual order.
13
+ order_mismatch: `True` only when `column_order == "strict"` and the sequences
14
+ differ positionally.
15
+ """
16
+
17
+ in_both: list[str]
18
+ missing: list[str]
19
+ unexpected: list[str]
20
+ order_mismatch: bool
21
+
22
+
23
+ def reconcile_columns(
24
+ actual: list[str],
25
+ expected: list[str],
26
+ column_order: Literal["ignore", "strict"],
27
+ ) -> ColumnReconciliation:
28
+ """Reconcile actual against expected column-name sequences.
29
+
30
+ Row comparison is always keyed by name (rows are dicts), so the order signal is a
31
+ separate assertion rather than a constraint on row matching.
32
+
33
+ Args:
34
+ actual: Column names from the actual result set.
35
+ expected: Column names from the expected result set.
36
+ column_order: `"strict"` to flag a positional order difference, `"ignore"` to
37
+ disregard ordering.
38
+
39
+ Returns:
40
+ A `ColumnReconciliation`. The `in_both`/`missing`/`unexpected` lists preserve
41
+ source order by construction (the sets are membership lookups only).
42
+ """
43
+ actual_set = set(actual)
44
+ expected_set = set(expected)
45
+ in_both = [c for c in expected if c in actual_set]
46
+ missing = [c for c in expected if c not in actual_set]
47
+ unexpected = [c for c in actual if c not in expected_set]
48
+ order_mismatch = column_order == "strict" and actual != expected
49
+ return ColumnReconciliation(in_both, missing, unexpected, order_mismatch)
@@ -0,0 +1,78 @@
1
+ """Pure assembly of a `ResultSetDiff` from precomputed counts, samples, and column signals."""
2
+
3
+ from typing import Any
4
+
5
+ from evaldata.equivalence.columns import ColumnReconciliation
6
+ from evaldata.types import ColumnMismatch, ResultSetDiff, TypeMismatch
7
+
8
+
9
+ def build_result_set_diff(
10
+ *,
11
+ expected_row_count: int,
12
+ actual_row_count: int,
13
+ missing_row_count: int,
14
+ extra_row_count: int,
15
+ sample_missing_rows: list[dict[str, Any]],
16
+ sample_extra_rows: list[dict[str, Any]],
17
+ columns: ColumnReconciliation,
18
+ type_mismatches: list[TypeMismatch],
19
+ column_mismatches: list[ColumnMismatch],
20
+ ) -> ResultSetDiff | None:
21
+ """Assemble a `ResultSetDiff` from already-computed diff signals.
22
+
23
+ Warehouse-free: the row counts/samples are computed by the engine and the column/type
24
+ signals in Python, then passed here. `column_mismatches` is populated only by the keyed
25
+ `FULL OUTER JOIN` path (empty for the keyless `EXCEPT ALL` path).
26
+
27
+ Args:
28
+ expected_row_count: The number of expected rows.
29
+ actual_row_count: The number of actual rows.
30
+ missing_row_count: Rows present in expected but absent from actual.
31
+ extra_row_count: Rows present in actual but absent from expected.
32
+ sample_missing_rows: A bounded sample of the missing rows.
33
+ sample_extra_rows: A bounded sample of the extra rows.
34
+ columns: The reconciliation of actual against expected column names.
35
+ type_mismatches: Per-column type differences over the shared columns.
36
+ column_mismatches: Per-column counts of key-matched rows whose value differs;
37
+ empty for the keyless path.
38
+
39
+ Returns:
40
+ `None` if the assembled diff records no differences (the result sets are equal),
41
+ else the populated `ResultSetDiff`.
42
+ """
43
+ diff = ResultSetDiff(
44
+ expected_row_count=expected_row_count,
45
+ actual_row_count=actual_row_count,
46
+ missing_row_count=missing_row_count,
47
+ extra_row_count=extra_row_count,
48
+ sample_missing_rows=sample_missing_rows,
49
+ sample_extra_rows=sample_extra_rows,
50
+ missing_columns=columns.missing,
51
+ unexpected_columns=columns.unexpected,
52
+ type_mismatches=type_mismatches,
53
+ column_mismatches=column_mismatches,
54
+ column_order_mismatch=columns.order_mismatch,
55
+ )
56
+ if _is_equal(diff):
57
+ return None
58
+ return diff
59
+
60
+
61
+ def _is_equal(d: ResultSetDiff) -> bool:
62
+ """Whether the diff records no differences — i.e. the result sets are equal.
63
+
64
+ Args:
65
+ d: The diff to inspect.
66
+
67
+ Returns:
68
+ `True` if there are no row, column, type, or ordering differences.
69
+ """
70
+ return (
71
+ d.missing_row_count == 0
72
+ and d.extra_row_count == 0
73
+ and not d.missing_columns
74
+ and not d.unexpected_columns
75
+ and not d.type_mismatches
76
+ and not d.column_mismatches
77
+ and not d.column_order_mismatch
78
+ )
@@ -0,0 +1,5 @@
1
+ """Loaders: build `EvalCase`s from authoring surfaces (Python decorator first; YAML in v1.x)."""
2
+
3
+ from evaldata.loaders.python import eval_case, read_eval_case
4
+
5
+ __all__ = ["eval_case", "read_eval_case"]
@@ -0,0 +1,69 @@
1
+ """`@eval_case`: the Python authoring decorator for test cases."""
2
+
3
+ from collections.abc import Callable
4
+ from typing import Any, TypeVar
5
+ from weakref import WeakKeyDictionary
6
+
7
+ from pydantic import TypeAdapter
8
+
9
+ from evaldata.types import ComparisonConfig, CostBudget, EvalCase, Expected, PlatformRef
10
+
11
+ _TestFn = TypeVar("_TestFn", bound=Callable[..., Any])
12
+
13
+ # Built once; validates a dict into the discriminated `Expected` union.
14
+ _EXPECTED_ADAPTER: TypeAdapter[Expected] = TypeAdapter(Expected)
15
+
16
+ # Weak keys so a collected test function that goes away releases its entry; identity
17
+ # lookup matches what pytest passes as `request.function`.
18
+ _CASES: WeakKeyDictionary[Callable[..., Any], EvalCase] = WeakKeyDictionary()
19
+
20
+
21
+ def eval_case(
22
+ *,
23
+ input: str,
24
+ expected: dict[str, Any] | Expected,
25
+ platform: PlatformRef,
26
+ id: str | None = None,
27
+ metadata: dict[str, Any] | None = None,
28
+ comparison: ComparisonConfig | None = None,
29
+ cost_budget: CostBudget | None = None,
30
+ ) -> Callable[[_TestFn], _TestFn]:
31
+ """Attach an `EvalCase` to a test function for the `case` fixture to inject.
32
+
33
+ Args:
34
+ input: The natural-language question / instruction under test.
35
+ expected: The expected outcome — a typed `Expected` or a dict coerced to one.
36
+ platform: A `PlatformRef` (build one with `duckdb_platform` / `postgres_platform`).
37
+ id: Case identifier; defaults to the decorated function's name.
38
+ metadata: Optional free-form tags/owner/source metadata.
39
+ comparison: Optional result-set comparison rules; defaults to `ComparisonConfig()`.
40
+ cost_budget: Optional ceiling on platform resource consumption for the case.
41
+
42
+ Returns:
43
+ A decorator that records the case and returns the function unchanged.
44
+ """
45
+ coerced: Expected = _EXPECTED_ADAPTER.validate_python(expected) if isinstance(expected, dict) else expected
46
+
47
+ def decorator(func: _TestFn) -> _TestFn:
48
+ extra: dict[str, Any] = {}
49
+ if metadata is not None:
50
+ extra["metadata"] = metadata
51
+ if comparison is not None:
52
+ extra["comparison"] = comparison
53
+ if cost_budget is not None:
54
+ extra["cost_budget"] = cost_budget
55
+ _CASES[func] = EvalCase(
56
+ id=id or getattr(func, "__name__", ""),
57
+ input=input,
58
+ expected=coerced,
59
+ platform=platform,
60
+ **extra,
61
+ )
62
+ return func
63
+
64
+ return decorator
65
+
66
+
67
+ def read_eval_case(func: Callable[..., Any]) -> EvalCase | None:
68
+ """Return the `EvalCase` attached to `func` by `@eval_case`, or `None`."""
69
+ return _CASES.get(func)
@@ -0,0 +1,46 @@
1
+ """Platform adapters: per-platform integrations that execute SQL against a data platform."""
2
+
3
+ import importlib
4
+ from typing import TYPE_CHECKING, Any
5
+
6
+ from evaldata.platforms.base import PlatformAdapter
7
+ from evaldata.platforms.duckdb import DuckDBAdapter
8
+ from evaldata.platforms.registry import databricks_platform, duckdb_platform, postgres_platform, resolve
9
+
10
+ if TYPE_CHECKING:
11
+ from evaldata.platforms.databricks import DatabricksAdapter
12
+ from evaldata.platforms.postgres import PostgresAdapter
13
+
14
+ __all__ = [
15
+ "DatabricksAdapter",
16
+ "DuckDBAdapter",
17
+ "PlatformAdapter",
18
+ "PostgresAdapter",
19
+ "databricks_platform",
20
+ "duckdb_platform",
21
+ "postgres_platform",
22
+ "resolve",
23
+ ]
24
+
25
+ _LAZY_ADAPTERS = {
26
+ "PostgresAdapter": ("evaldata.platforms.postgres", "postgres"),
27
+ "DatabricksAdapter": ("evaldata.platforms.databricks", "databricks"),
28
+ }
29
+
30
+
31
+ def __getattr__(name: str) -> Any:
32
+ lazy = _LAZY_ADAPTERS.get(name)
33
+ if lazy is not None:
34
+ module_path, extra = lazy
35
+ try:
36
+ module = importlib.import_module(module_path)
37
+ except ImportError as e:
38
+ msg = f"{name} requires the {extra!r} extra: install evaldata[{extra}]"
39
+ raise ImportError(msg) from e
40
+ return getattr(module, name)
41
+ msg = f"module {__name__!r} has no attribute {name!r}"
42
+ raise AttributeError(msg)
43
+
44
+
45
+ def __dir__() -> list[str]:
46
+ return sorted([*globals(), *_LAZY_ADAPTERS])