evaldata 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evaldata/__init__.py +35 -0
- evaldata/cli.py +191 -0
- evaldata/core/__init__.py +5 -0
- evaldata/core/runner.py +60 -0
- evaldata/equivalence/__init__.py +10 -0
- evaldata/equivalence/columns.py +49 -0
- evaldata/equivalence/compare.py +78 -0
- evaldata/loaders/__init__.py +5 -0
- evaldata/loaders/python.py +69 -0
- evaldata/platforms/__init__.py +46 -0
- evaldata/platforms/base.py +145 -0
- evaldata/platforms/databricks.py +152 -0
- evaldata/platforms/duckdb.py +78 -0
- evaldata/platforms/postgres.py +86 -0
- evaldata/platforms/registry.py +159 -0
- evaldata/py.typed +0 -0
- evaldata/pytest_plugin/__init__.py +1 -0
- evaldata/pytest_plugin/plugin.py +89 -0
- evaldata/reporting/__init__.py +5 -0
- evaldata/reporting/collector.py +64 -0
- evaldata/reporting/terminal.py +177 -0
- evaldata/scorers/__init__.py +9 -0
- evaldata/scorers/base.py +17 -0
- evaldata/scorers/context.py +16 -0
- evaldata/scorers/expectation_suite.py +235 -0
- evaldata/scorers/query.py +134 -0
- evaldata/scorers/result_set_equivalence.py +479 -0
- evaldata/scorers/sql.py +764 -0
- evaldata/solvers/__init__.py +27 -0
- evaldata/solvers/base.py +14 -0
- evaldata/solvers/callable.py +24 -0
- evaldata/solvers/prompt.py +178 -0
- evaldata/types.py +510 -0
- evaldata-0.1.0.dist-info/METADATA +113 -0
- evaldata-0.1.0.dist-info/RECORD +38 -0
- evaldata-0.1.0.dist-info/WHEEL +4 -0
- evaldata-0.1.0.dist-info/entry_points.txt +6 -0
- evaldata-0.1.0.dist-info/licenses/LICENSE +201 -0
evaldata/__init__.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
"""evaldata — AI evals framework for data and analytics engineering teams."""
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING, Any
|
|
4
|
+
|
|
5
|
+
from evaldata.core import assert_eval
|
|
6
|
+
from evaldata.loaders import eval_case
|
|
7
|
+
from evaldata.scorers import ExpectationSuiteScorer, ResultSetEquivalence
|
|
8
|
+
from evaldata.solvers import CallableSolver
|
|
9
|
+
from evaldata.types import EvalCase, PlatformRef
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from evaldata.solvers import PromptSolver as PromptSolver
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"CallableSolver",
|
|
16
|
+
"EvalCase",
|
|
17
|
+
"ExpectationSuiteScorer",
|
|
18
|
+
"PlatformRef",
|
|
19
|
+
"ResultSetEquivalence",
|
|
20
|
+
"assert_eval",
|
|
21
|
+
"eval_case",
|
|
22
|
+
]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def __getattr__(name: str) -> Any:
|
|
26
|
+
if name == "PromptSolver":
|
|
27
|
+
from evaldata.solvers import PromptSolver
|
|
28
|
+
|
|
29
|
+
return PromptSolver
|
|
30
|
+
msg = f"module {__name__!r} has no attribute {name!r}"
|
|
31
|
+
raise AttributeError(msg)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def __dir__() -> list[str]:
|
|
35
|
+
return sorted([*globals(), "PromptSolver"])
|
evaldata/cli.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
"""The `evaldata` command-line interface."""
|
|
2
|
+
|
|
3
|
+
import subprocess
|
|
4
|
+
import sys
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
import typer
|
|
8
|
+
from rich.console import Console
|
|
9
|
+
from rich.table import Table
|
|
10
|
+
from rich.text import Text
|
|
11
|
+
|
|
12
|
+
from evaldata.platforms.registry import (
|
|
13
|
+
close_all,
|
|
14
|
+
databricks_platform,
|
|
15
|
+
duckdb_platform,
|
|
16
|
+
postgres_platform,
|
|
17
|
+
resolve,
|
|
18
|
+
)
|
|
19
|
+
from evaldata.types import PlatformRef
|
|
20
|
+
|
|
21
|
+
app = typer.Typer(help="AI evals for data & analytics engineering teams.", no_args_is_help=True)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@app.command(context_settings={"allow_extra_args": True, "ignore_unknown_options": True})
|
|
25
|
+
def run(
|
|
26
|
+
ctx: typer.Context,
|
|
27
|
+
path: str | None = typer.Argument(None, help="Path or test id to run; omit to use pytest's testpaths."),
|
|
28
|
+
json_path: Path | None = typer.Option(
|
|
29
|
+
None,
|
|
30
|
+
"--json",
|
|
31
|
+
metavar="PATH",
|
|
32
|
+
help="Also write the structured evaldata results JSON to PATH (off by default).",
|
|
33
|
+
),
|
|
34
|
+
) -> None:
|
|
35
|
+
"""Run the eval suite via pytest, forwarding any extra pytest arguments verbatim.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
ctx: The Typer context; its extra args are forwarded straight to pytest.
|
|
39
|
+
path: A path or test id to run; omit to use pytest's `testpaths`.
|
|
40
|
+
json_path: If given, also write the structured results JSON to this path.
|
|
41
|
+
|
|
42
|
+
Raises:
|
|
43
|
+
Exit: Always, carrying pytest's return code as the process exit code.
|
|
44
|
+
"""
|
|
45
|
+
cmd = [sys.executable, "-m", "pytest"]
|
|
46
|
+
if path is not None:
|
|
47
|
+
cmd.append(path)
|
|
48
|
+
if json_path is not None:
|
|
49
|
+
cmd.append(f"--evaldata-json={json_path}")
|
|
50
|
+
cmd.extend(ctx.args)
|
|
51
|
+
completed = subprocess.run(cmd) # noqa: PLW1510 - exit code is forwarded, not raised on
|
|
52
|
+
raise typer.Exit(completed.returncode)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _build_refs(
|
|
56
|
+
*,
|
|
57
|
+
duckdb: str | None,
|
|
58
|
+
postgres: str | None,
|
|
59
|
+
databricks_server_hostname: str | None = None,
|
|
60
|
+
databricks_http_path: str | None = None,
|
|
61
|
+
) -> list[PlatformRef]:
|
|
62
|
+
"""Build a `PlatformRef` for each platform flag that was provided.
|
|
63
|
+
|
|
64
|
+
Each branch routes through the typed registry builder, so a flag can only ever name a
|
|
65
|
+
real `PlatformKind`. The Databricks ref is built only when both its server hostname and
|
|
66
|
+
HTTP path are given (it has no single-value form).
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
duckdb: A DuckDB database path, or `None` if the flag was not given.
|
|
70
|
+
postgres: A PostgreSQL conninfo, or `None` if the flag was not given.
|
|
71
|
+
databricks_server_hostname: A Databricks workspace hostname, or `None`.
|
|
72
|
+
databricks_http_path: A Databricks SQL Warehouse HTTP path, or `None`.
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
One `PlatformRef` per platform whose flag(s) were provided, in flag order.
|
|
76
|
+
"""
|
|
77
|
+
refs: list[PlatformRef] = []
|
|
78
|
+
if duckdb is not None:
|
|
79
|
+
refs.append(duckdb_platform(name="duckdb", path=duckdb))
|
|
80
|
+
if postgres is not None:
|
|
81
|
+
refs.append(postgres_platform(name="postgres", conninfo=postgres))
|
|
82
|
+
if databricks_server_hostname is not None and databricks_http_path is not None:
|
|
83
|
+
refs.append(
|
|
84
|
+
databricks_platform(
|
|
85
|
+
name="databricks",
|
|
86
|
+
server_hostname=databricks_server_hostname,
|
|
87
|
+
http_path=databricks_http_path,
|
|
88
|
+
)
|
|
89
|
+
)
|
|
90
|
+
return refs
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _probe(ref: PlatformRef) -> tuple[bool, str]:
|
|
94
|
+
"""Resolve `ref` to a live adapter and run `SELECT 1`.
|
|
95
|
+
|
|
96
|
+
Catches broadly on purpose: adapter construction can raise (e.g. psycopg fails to
|
|
97
|
+
connect, or an optional driver is missing), and `doctor` must report that as a FAIL
|
|
98
|
+
rather than crash. A query that fails as a value (`ExecutionResult.error`) is a FAIL
|
|
99
|
+
too.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
ref: The platform reference to probe.
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
A tuple `(ok, detail)`: `ok` is whether the probe succeeded, and `detail` is a
|
|
106
|
+
human-readable status or error message.
|
|
107
|
+
"""
|
|
108
|
+
try:
|
|
109
|
+
result = resolve(ref).execute("SELECT 1")
|
|
110
|
+
except Exception as e: # noqa: BLE001 - diagnostics: any failure is a reported FAIL
|
|
111
|
+
return False, str(e)
|
|
112
|
+
if result.error is not None:
|
|
113
|
+
return False, result.error
|
|
114
|
+
return True, "connected"
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
@app.command()
|
|
118
|
+
def doctor(
|
|
119
|
+
duckdb: str | None = typer.Option(
|
|
120
|
+
None, "--duckdb", metavar="PATH", envvar="EVALDATA_DUCKDB_PATH", help="DuckDB database path to check."
|
|
121
|
+
),
|
|
122
|
+
postgres: str | None = typer.Option(
|
|
123
|
+
None,
|
|
124
|
+
"--postgres",
|
|
125
|
+
metavar="CONNINFO",
|
|
126
|
+
envvar="EVALDATA_POSTGRES_CONNINFO",
|
|
127
|
+
help='PostgreSQL libpq conninfo to check (empty "" uses PG* env vars / libpq defaults).',
|
|
128
|
+
),
|
|
129
|
+
databricks_server_hostname: str | None = typer.Option(
|
|
130
|
+
None,
|
|
131
|
+
"--databricks-server-hostname",
|
|
132
|
+
metavar="HOST",
|
|
133
|
+
envvar="DATABRICKS_SERVER_HOSTNAME",
|
|
134
|
+
help="Databricks workspace hostname to check (paired with --databricks-http-path).",
|
|
135
|
+
),
|
|
136
|
+
databricks_http_path: str | None = typer.Option(
|
|
137
|
+
None,
|
|
138
|
+
"--databricks-http-path",
|
|
139
|
+
metavar="PATH",
|
|
140
|
+
envvar="DATABRICKS_HTTP_PATH",
|
|
141
|
+
help="Databricks SQL Warehouse HTTP path to check (paired with --databricks-server-hostname).",
|
|
142
|
+
),
|
|
143
|
+
) -> None:
|
|
144
|
+
"""Check that the given platform connections work (one --<kind> flag per platform).
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
duckdb: A DuckDB database path to check (also read from `EVALDATA_DUCKDB_PATH`).
|
|
148
|
+
postgres: A PostgreSQL conninfo to check (also read from
|
|
149
|
+
`EVALDATA_POSTGRES_CONNINFO`).
|
|
150
|
+
databricks_server_hostname: A Databricks workspace hostname to check (also read from
|
|
151
|
+
`DATABRICKS_SERVER_HOSTNAME`); required together with `databricks_http_path`.
|
|
152
|
+
databricks_http_path: A Databricks SQL Warehouse HTTP path to check (also read from
|
|
153
|
+
`DATABRICKS_HTTP_PATH`); required together with `databricks_server_hostname`.
|
|
154
|
+
|
|
155
|
+
Raises:
|
|
156
|
+
BadParameter: If no platform flag is provided, or only one of the two Databricks flags is.
|
|
157
|
+
Exit: With code 1 if any platform connection fails.
|
|
158
|
+
"""
|
|
159
|
+
if (databricks_server_hostname is None) != (databricks_http_path is None):
|
|
160
|
+
msg = "--databricks-server-hostname and --databricks-http-path must be given together"
|
|
161
|
+
raise typer.BadParameter(msg)
|
|
162
|
+
refs = _build_refs(
|
|
163
|
+
duckdb=duckdb,
|
|
164
|
+
postgres=postgres,
|
|
165
|
+
databricks_server_hostname=databricks_server_hostname,
|
|
166
|
+
databricks_http_path=databricks_http_path,
|
|
167
|
+
)
|
|
168
|
+
if not refs:
|
|
169
|
+
msg = "specify at least one platform, e.g. --duckdb PATH or --postgres CONNINFO"
|
|
170
|
+
raise typer.BadParameter(msg)
|
|
171
|
+
|
|
172
|
+
console = Console()
|
|
173
|
+
table = Table(title="evaldata doctor", title_justify="left")
|
|
174
|
+
table.add_column("platform")
|
|
175
|
+
table.add_column("kind")
|
|
176
|
+
table.add_column("status")
|
|
177
|
+
|
|
178
|
+
all_ok = True
|
|
179
|
+
try:
|
|
180
|
+
for ref in refs:
|
|
181
|
+
ok, detail = _probe(ref)
|
|
182
|
+
all_ok = all_ok and ok
|
|
183
|
+
mark = "OK" if ok else "FAIL"
|
|
184
|
+
# Text (not markup) so bracketed driver messages render verbatim.
|
|
185
|
+
table.add_row(ref.name, ref.kind, Text(f"{mark} {detail}", style="green" if ok else "red"))
|
|
186
|
+
finally:
|
|
187
|
+
close_all() # this CLI invocation owns the adapters it resolved
|
|
188
|
+
|
|
189
|
+
console.print(table)
|
|
190
|
+
if not all_ok:
|
|
191
|
+
raise typer.Exit(1)
|
evaldata/core/runner.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
"""Eval orchestration and the pytest-facing `assert_eval`."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
|
|
5
|
+
from evaldata.platforms.base import PlatformAdapter
|
|
6
|
+
from evaldata.platforms.registry import resolve
|
|
7
|
+
from evaldata.reporting.collector import CaseReport, record
|
|
8
|
+
from evaldata.reporting.terminal import render_failure, render_solver_error
|
|
9
|
+
from evaldata.scorers.base import Scorer
|
|
10
|
+
from evaldata.scorers.context import ScoreContext
|
|
11
|
+
from evaldata.scorers.query import QueryRunner
|
|
12
|
+
from evaldata.solvers.base import Solver
|
|
13
|
+
from evaldata.types import EvalCase
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def assert_eval(
|
|
17
|
+
case: EvalCase,
|
|
18
|
+
solver: Solver,
|
|
19
|
+
*,
|
|
20
|
+
scorers: Sequence[Scorer],
|
|
21
|
+
adapter: PlatformAdapter | None = None,
|
|
22
|
+
) -> None:
|
|
23
|
+
"""Run `case` through `solver` + a platform adapter + `scorers`; raise on any failure.
|
|
24
|
+
|
|
25
|
+
Solves the case, executes the produced SQL, and scores the result with each scorer.
|
|
26
|
+
The adapter is the explicitly passed `adapter` if given, otherwise resolved (and
|
|
27
|
+
session-cached) from `case.platform`. Execution is bounded by `case.cost_budget`'s
|
|
28
|
+
`max_seconds`: an overrunning query is cancelled and scored as an execution failure.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
case: The eval case to run.
|
|
32
|
+
solver: The solver that produces SQL for the case.
|
|
33
|
+
scorers: Scorers applied to the execution result; all must pass.
|
|
34
|
+
adapter: A platform adapter to execute against. If omitted, one is resolved and
|
|
35
|
+
session-cached from `case.platform`.
|
|
36
|
+
|
|
37
|
+
Raises:
|
|
38
|
+
AssertionError: If the solver fails or any scorer fails, carrying a composed
|
|
39
|
+
diagnostic. Raising is pytest's failure protocol.
|
|
40
|
+
"""
|
|
41
|
+
output = solver.solve(case)
|
|
42
|
+
if output.error is not None:
|
|
43
|
+
error = output.error
|
|
44
|
+
record(CaseReport(id=case.id, input=case.input, passed=False, error=f"solver error [{error.kind}]"))
|
|
45
|
+
raise AssertionError(render_solver_error(case, error))
|
|
46
|
+
sql = output.output
|
|
47
|
+
if sql is None: # pragma: no cover - unreachable: SolverOutput's validator guarantees output XOR error
|
|
48
|
+
msg = f"evaldata case {case.id!r}: solver returned neither output nor error"
|
|
49
|
+
raise AssertionError(msg)
|
|
50
|
+
live = adapter if adapter is not None else resolve(case.platform)
|
|
51
|
+
max_seconds = case.cost_budget.max_seconds if case.cost_budget is not None else None
|
|
52
|
+
dialect = case.platform.dialect or case.platform.kind
|
|
53
|
+
queries = QueryRunner(live, sql, dialect, max_seconds)
|
|
54
|
+
result = queries.run(sql)
|
|
55
|
+
context = ScoreContext(queries=queries)
|
|
56
|
+
scores = [scorer.score(case, output, result, context=context) for scorer in scorers]
|
|
57
|
+
failures = [s for s in scores if not s.passed]
|
|
58
|
+
record(CaseReport(id=case.id, input=case.input, passed=not failures, scores=list(scores)))
|
|
59
|
+
if failures:
|
|
60
|
+
raise AssertionError(render_failure(case, output, result, failures))
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
"""Result-set equivalence engine: column reconciliation plus the pure `build_result_set_diff` assembly seam."""
|
|
2
|
+
|
|
3
|
+
from evaldata.equivalence.columns import ColumnReconciliation, reconcile_columns
|
|
4
|
+
from evaldata.equivalence.compare import build_result_set_diff
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
"ColumnReconciliation",
|
|
8
|
+
"build_result_set_diff",
|
|
9
|
+
"reconcile_columns",
|
|
10
|
+
]
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
"""Column reconciliation between actual and expected schemas."""
|
|
2
|
+
|
|
3
|
+
from typing import Literal, NamedTuple
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ColumnReconciliation(NamedTuple):
|
|
7
|
+
"""The outcome of reconciling actual against expected column names.
|
|
8
|
+
|
|
9
|
+
Attributes:
|
|
10
|
+
in_both: Columns present in both, in expected order; the columns compared on.
|
|
11
|
+
missing: Columns expected but absent from actual, in expected order.
|
|
12
|
+
unexpected: Columns present in actual but not expected, in actual order.
|
|
13
|
+
order_mismatch: `True` only when `column_order == "strict"` and the sequences
|
|
14
|
+
differ positionally.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
in_both: list[str]
|
|
18
|
+
missing: list[str]
|
|
19
|
+
unexpected: list[str]
|
|
20
|
+
order_mismatch: bool
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def reconcile_columns(
|
|
24
|
+
actual: list[str],
|
|
25
|
+
expected: list[str],
|
|
26
|
+
column_order: Literal["ignore", "strict"],
|
|
27
|
+
) -> ColumnReconciliation:
|
|
28
|
+
"""Reconcile actual against expected column-name sequences.
|
|
29
|
+
|
|
30
|
+
Row comparison is always keyed by name (rows are dicts), so the order signal is a
|
|
31
|
+
separate assertion rather than a constraint on row matching.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
actual: Column names from the actual result set.
|
|
35
|
+
expected: Column names from the expected result set.
|
|
36
|
+
column_order: `"strict"` to flag a positional order difference, `"ignore"` to
|
|
37
|
+
disregard ordering.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
A `ColumnReconciliation`. The `in_both`/`missing`/`unexpected` lists preserve
|
|
41
|
+
source order by construction (the sets are membership lookups only).
|
|
42
|
+
"""
|
|
43
|
+
actual_set = set(actual)
|
|
44
|
+
expected_set = set(expected)
|
|
45
|
+
in_both = [c for c in expected if c in actual_set]
|
|
46
|
+
missing = [c for c in expected if c not in actual_set]
|
|
47
|
+
unexpected = [c for c in actual if c not in expected_set]
|
|
48
|
+
order_mismatch = column_order == "strict" and actual != expected
|
|
49
|
+
return ColumnReconciliation(in_both, missing, unexpected, order_mismatch)
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
"""Pure assembly of a `ResultSetDiff` from precomputed counts, samples, and column signals."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from evaldata.equivalence.columns import ColumnReconciliation
|
|
6
|
+
from evaldata.types import ColumnMismatch, ResultSetDiff, TypeMismatch
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def build_result_set_diff(
|
|
10
|
+
*,
|
|
11
|
+
expected_row_count: int,
|
|
12
|
+
actual_row_count: int,
|
|
13
|
+
missing_row_count: int,
|
|
14
|
+
extra_row_count: int,
|
|
15
|
+
sample_missing_rows: list[dict[str, Any]],
|
|
16
|
+
sample_extra_rows: list[dict[str, Any]],
|
|
17
|
+
columns: ColumnReconciliation,
|
|
18
|
+
type_mismatches: list[TypeMismatch],
|
|
19
|
+
column_mismatches: list[ColumnMismatch],
|
|
20
|
+
) -> ResultSetDiff | None:
|
|
21
|
+
"""Assemble a `ResultSetDiff` from already-computed diff signals.
|
|
22
|
+
|
|
23
|
+
Warehouse-free: the row counts/samples are computed by the engine and the column/type
|
|
24
|
+
signals in Python, then passed here. `column_mismatches` is populated only by the keyed
|
|
25
|
+
`FULL OUTER JOIN` path (empty for the keyless `EXCEPT ALL` path).
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
expected_row_count: The number of expected rows.
|
|
29
|
+
actual_row_count: The number of actual rows.
|
|
30
|
+
missing_row_count: Rows present in expected but absent from actual.
|
|
31
|
+
extra_row_count: Rows present in actual but absent from expected.
|
|
32
|
+
sample_missing_rows: A bounded sample of the missing rows.
|
|
33
|
+
sample_extra_rows: A bounded sample of the extra rows.
|
|
34
|
+
columns: The reconciliation of actual against expected column names.
|
|
35
|
+
type_mismatches: Per-column type differences over the shared columns.
|
|
36
|
+
column_mismatches: Per-column counts of key-matched rows whose value differs;
|
|
37
|
+
empty for the keyless path.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
`None` if the assembled diff records no differences (the result sets are equal),
|
|
41
|
+
else the populated `ResultSetDiff`.
|
|
42
|
+
"""
|
|
43
|
+
diff = ResultSetDiff(
|
|
44
|
+
expected_row_count=expected_row_count,
|
|
45
|
+
actual_row_count=actual_row_count,
|
|
46
|
+
missing_row_count=missing_row_count,
|
|
47
|
+
extra_row_count=extra_row_count,
|
|
48
|
+
sample_missing_rows=sample_missing_rows,
|
|
49
|
+
sample_extra_rows=sample_extra_rows,
|
|
50
|
+
missing_columns=columns.missing,
|
|
51
|
+
unexpected_columns=columns.unexpected,
|
|
52
|
+
type_mismatches=type_mismatches,
|
|
53
|
+
column_mismatches=column_mismatches,
|
|
54
|
+
column_order_mismatch=columns.order_mismatch,
|
|
55
|
+
)
|
|
56
|
+
if _is_equal(diff):
|
|
57
|
+
return None
|
|
58
|
+
return diff
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _is_equal(d: ResultSetDiff) -> bool:
|
|
62
|
+
"""Whether the diff records no differences — i.e. the result sets are equal.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
d: The diff to inspect.
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
`True` if there are no row, column, type, or ordering differences.
|
|
69
|
+
"""
|
|
70
|
+
return (
|
|
71
|
+
d.missing_row_count == 0
|
|
72
|
+
and d.extra_row_count == 0
|
|
73
|
+
and not d.missing_columns
|
|
74
|
+
and not d.unexpected_columns
|
|
75
|
+
and not d.type_mismatches
|
|
76
|
+
and not d.column_mismatches
|
|
77
|
+
and not d.column_order_mismatch
|
|
78
|
+
)
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
"""`@eval_case`: the Python authoring decorator for test cases."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from typing import Any, TypeVar
|
|
5
|
+
from weakref import WeakKeyDictionary
|
|
6
|
+
|
|
7
|
+
from pydantic import TypeAdapter
|
|
8
|
+
|
|
9
|
+
from evaldata.types import ComparisonConfig, CostBudget, EvalCase, Expected, PlatformRef
|
|
10
|
+
|
|
11
|
+
_TestFn = TypeVar("_TestFn", bound=Callable[..., Any])
|
|
12
|
+
|
|
13
|
+
# Built once; validates a dict into the discriminated `Expected` union.
|
|
14
|
+
_EXPECTED_ADAPTER: TypeAdapter[Expected] = TypeAdapter(Expected)
|
|
15
|
+
|
|
16
|
+
# Weak keys so a collected test function that goes away releases its entry; identity
|
|
17
|
+
# lookup matches what pytest passes as `request.function`.
|
|
18
|
+
_CASES: WeakKeyDictionary[Callable[..., Any], EvalCase] = WeakKeyDictionary()
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def eval_case(
|
|
22
|
+
*,
|
|
23
|
+
input: str,
|
|
24
|
+
expected: dict[str, Any] | Expected,
|
|
25
|
+
platform: PlatformRef,
|
|
26
|
+
id: str | None = None,
|
|
27
|
+
metadata: dict[str, Any] | None = None,
|
|
28
|
+
comparison: ComparisonConfig | None = None,
|
|
29
|
+
cost_budget: CostBudget | None = None,
|
|
30
|
+
) -> Callable[[_TestFn], _TestFn]:
|
|
31
|
+
"""Attach an `EvalCase` to a test function for the `case` fixture to inject.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
input: The natural-language question / instruction under test.
|
|
35
|
+
expected: The expected outcome — a typed `Expected` or a dict coerced to one.
|
|
36
|
+
platform: A `PlatformRef` (build one with `duckdb_platform` / `postgres_platform`).
|
|
37
|
+
id: Case identifier; defaults to the decorated function's name.
|
|
38
|
+
metadata: Optional free-form tags/owner/source metadata.
|
|
39
|
+
comparison: Optional result-set comparison rules; defaults to `ComparisonConfig()`.
|
|
40
|
+
cost_budget: Optional ceiling on platform resource consumption for the case.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
A decorator that records the case and returns the function unchanged.
|
|
44
|
+
"""
|
|
45
|
+
coerced: Expected = _EXPECTED_ADAPTER.validate_python(expected) if isinstance(expected, dict) else expected
|
|
46
|
+
|
|
47
|
+
def decorator(func: _TestFn) -> _TestFn:
|
|
48
|
+
extra: dict[str, Any] = {}
|
|
49
|
+
if metadata is not None:
|
|
50
|
+
extra["metadata"] = metadata
|
|
51
|
+
if comparison is not None:
|
|
52
|
+
extra["comparison"] = comparison
|
|
53
|
+
if cost_budget is not None:
|
|
54
|
+
extra["cost_budget"] = cost_budget
|
|
55
|
+
_CASES[func] = EvalCase(
|
|
56
|
+
id=id or getattr(func, "__name__", ""),
|
|
57
|
+
input=input,
|
|
58
|
+
expected=coerced,
|
|
59
|
+
platform=platform,
|
|
60
|
+
**extra,
|
|
61
|
+
)
|
|
62
|
+
return func
|
|
63
|
+
|
|
64
|
+
return decorator
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def read_eval_case(func: Callable[..., Any]) -> EvalCase | None:
|
|
68
|
+
"""Return the `EvalCase` attached to `func` by `@eval_case`, or `None`."""
|
|
69
|
+
return _CASES.get(func)
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
"""Platform adapters: per-platform integrations that execute SQL against a data platform."""
|
|
2
|
+
|
|
3
|
+
import importlib
|
|
4
|
+
from typing import TYPE_CHECKING, Any
|
|
5
|
+
|
|
6
|
+
from evaldata.platforms.base import PlatformAdapter
|
|
7
|
+
from evaldata.platforms.duckdb import DuckDBAdapter
|
|
8
|
+
from evaldata.platforms.registry import databricks_platform, duckdb_platform, postgres_platform, resolve
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from evaldata.platforms.databricks import DatabricksAdapter
|
|
12
|
+
from evaldata.platforms.postgres import PostgresAdapter
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"DatabricksAdapter",
|
|
16
|
+
"DuckDBAdapter",
|
|
17
|
+
"PlatformAdapter",
|
|
18
|
+
"PostgresAdapter",
|
|
19
|
+
"databricks_platform",
|
|
20
|
+
"duckdb_platform",
|
|
21
|
+
"postgres_platform",
|
|
22
|
+
"resolve",
|
|
23
|
+
]
|
|
24
|
+
|
|
25
|
+
_LAZY_ADAPTERS = {
|
|
26
|
+
"PostgresAdapter": ("evaldata.platforms.postgres", "postgres"),
|
|
27
|
+
"DatabricksAdapter": ("evaldata.platforms.databricks", "databricks"),
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def __getattr__(name: str) -> Any:
|
|
32
|
+
lazy = _LAZY_ADAPTERS.get(name)
|
|
33
|
+
if lazy is not None:
|
|
34
|
+
module_path, extra = lazy
|
|
35
|
+
try:
|
|
36
|
+
module = importlib.import_module(module_path)
|
|
37
|
+
except ImportError as e:
|
|
38
|
+
msg = f"{name} requires the {extra!r} extra: install evaldata[{extra}]"
|
|
39
|
+
raise ImportError(msg) from e
|
|
40
|
+
return getattr(module, name)
|
|
41
|
+
msg = f"module {__name__!r} has no attribute {name!r}"
|
|
42
|
+
raise AttributeError(msg)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def __dir__() -> list[str]:
|
|
46
|
+
return sorted([*globals(), *_LAZY_ADAPTERS])
|