aetherdialect 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aetherdialect-0.1.0.dist-info/METADATA +197 -0
- aetherdialect-0.1.0.dist-info/RECORD +34 -0
- aetherdialect-0.1.0.dist-info/WHEEL +5 -0
- aetherdialect-0.1.0.dist-info/licenses/LICENSE +7 -0
- aetherdialect-0.1.0.dist-info/top_level.txt +1 -0
- text2sql/__init__.py +7 -0
- text2sql/config.py +1063 -0
- text2sql/contracts_base.py +952 -0
- text2sql/contracts_core.py +1890 -0
- text2sql/core_utils.py +834 -0
- text2sql/dialect.py +1134 -0
- text2sql/expansion_ops.py +1218 -0
- text2sql/expansion_rules.py +496 -0
- text2sql/intent_expr.py +1759 -0
- text2sql/intent_process.py +2133 -0
- text2sql/intent_repair.py +1733 -0
- text2sql/intent_resolve.py +1292 -0
- text2sql/live_testing.py +1117 -0
- text2sql/main_execution.py +799 -0
- text2sql/pipeline.py +1662 -0
- text2sql/qsim_ops.py +1286 -0
- text2sql/qsim_sample.py +609 -0
- text2sql/qsim_struct.py +569 -0
- text2sql/schema.py +973 -0
- text2sql/schema_profiling.py +2075 -0
- text2sql/simulator.py +970 -0
- text2sql/sql_gen.py +1537 -0
- text2sql/templates.py +1037 -0
- text2sql/text2sql.py +726 -0
- text2sql/utils.py +973 -0
- text2sql/validation_agg.py +1033 -0
- text2sql/validation_execute.py +1092 -0
- text2sql/validation_schema.py +1847 -0
- text2sql/validation_semantic.py +2122 -0
text2sql/live_testing.py
ADDED
|
@@ -0,0 +1,1117 @@
|
|
|
1
|
+
"""Reusable live-testing framework for the text2sql pipeline.
|
|
2
|
+
|
|
3
|
+
Provides data models (scenarios, expected outcomes, step results), a soft-assertion collector, a pipeline-capture context manager that monkey-patches interactive I/O, and a LiveTestRunner that orchestrates end-to-end execution of a scenario against a real LLM and live database.
|
|
4
|
+
|
|
5
|
+
All classes and helpers are database-agnostic; the only database-specific pieces live in the scenario definitions and conftest fixtures provided by the caller.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import os
|
|
11
|
+
import time
|
|
12
|
+
import traceback
|
|
13
|
+
from collections.abc import Callable
|
|
14
|
+
from contextlib import contextmanager
|
|
15
|
+
from dataclasses import dataclass, field, replace as dc_replace
|
|
16
|
+
from typing import Any
|
|
17
|
+
from unittest.mock import patch
|
|
18
|
+
|
|
19
|
+
from .config import PolicyConfig
|
|
20
|
+
from .contracts_core import RuntimeIntent
|
|
21
|
+
from .core_utils import debug, normalize_question, substitute_params
|
|
22
|
+
from .intent_process import match_template_for_union
|
|
23
|
+
from .pipeline import (
|
|
24
|
+
check_and_handle_hard_block,
|
|
25
|
+
check_template_reuse,
|
|
26
|
+
compute_final_metrics,
|
|
27
|
+
confirm_intent_with_user,
|
|
28
|
+
display_final_results,
|
|
29
|
+
generate_and_validate_sql,
|
|
30
|
+
generate_join_candidates,
|
|
31
|
+
handle_direct_sql_reuse,
|
|
32
|
+
handle_user_feedback,
|
|
33
|
+
load_pipeline_resources,
|
|
34
|
+
parse_intent_via_llm,
|
|
35
|
+
save_result_csv,
|
|
36
|
+
)
|
|
37
|
+
from .templates import save_template_store
|
|
38
|
+
from .utils import flatten_param_values, intent_key, validate_question
|
|
39
|
+
from .validation_execute import execute_sql, get_spark_sql_for_execution
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclass
|
|
43
|
+
class Expected:
|
|
44
|
+
"""Declarative expectations for a single pipeline run.
|
|
45
|
+
|
|
46
|
+
Every field is optional. When a field is ``None`` or the default, the corresponding assertion is skipped by the runner.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
|
|
50
|
+
tables: Expected tables referenced by the generated intent.
|
|
51
|
+
|
|
52
|
+
min_rows: Minimum row count (inclusive) for the result set.
|
|
53
|
+
|
|
54
|
+
max_rows: Maximum row count (inclusive) for the result set.
|
|
55
|
+
|
|
56
|
+
min_confidence: Lower-bound confidence score (inclusive).
|
|
57
|
+
|
|
58
|
+
reuse_type: One of ``"direct_reuse"``, ``"intent_direct_reuse"``, or ``"none"``; when set, the runner asserts the template-match type.
|
|
59
|
+
|
|
60
|
+
contains_join: When ``True`` the SQL must include a JOIN clause.
|
|
61
|
+
|
|
62
|
+
contains_group_by: When ``True`` the SQL must include GROUP BY.
|
|
63
|
+
|
|
64
|
+
contains_cte: When ``True`` the SQL must include a CTE (``WITH``).
|
|
65
|
+
|
|
66
|
+
sql_contains: List of substrings that must all appear in the final SQL.
|
|
67
|
+
|
|
68
|
+
sql_excludes: List of substrings that must not appear in the final SQL.
|
|
69
|
+
|
|
70
|
+
grain: Expected intent grain (``"row_level"`` or ``"grouped"``).
|
|
71
|
+
|
|
72
|
+
should_hard_block: When ``True`` the pipeline should trigger a hard-block.
|
|
73
|
+
|
|
74
|
+
should_fail_validation: When ``True`` schema or result validation should fail.
|
|
75
|
+
|
|
76
|
+
column_names_one_of: Allowed column header lists; result must match one exactly.
|
|
77
|
+
|
|
78
|
+
row_value_check: Optional callable ``(rows) -> bool`` for custom value checks.
|
|
79
|
+
|
|
80
|
+
min_semantic_warnings: Minimum number of semantic warnings expected.
|
|
81
|
+
|
|
82
|
+
status: Expected pipeline exit status when the default ``"ok"`` is not appropriate (for example ``"restricted"`` or ``"hard_blocked"``).
|
|
83
|
+
|
|
84
|
+
status_in: When set, status must be one of the given values.
|
|
85
|
+
|
|
86
|
+
tables_one_of: When set, tables must equal one of the given lists (order-independent).
|
|
87
|
+
|
|
88
|
+
grain_in: When set, grain must be one of the given values.
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
tables: list[str] | None = None
|
|
92
|
+
tables_one_of: list[list[str]] | None = None
|
|
93
|
+
grain_in: tuple[str, ...] | None = None
|
|
94
|
+
min_rows: int | None = None
|
|
95
|
+
max_rows: int | None = None
|
|
96
|
+
min_confidence: float | None = None
|
|
97
|
+
reuse_type: str | tuple[str, ...] | None = None
|
|
98
|
+
contains_join: bool | None = None
|
|
99
|
+
contains_group_by: bool | None = None
|
|
100
|
+
contains_cte: bool | None = None
|
|
101
|
+
sql_contains: list[str] | None = None
|
|
102
|
+
sql_excludes: list[str] | None = None
|
|
103
|
+
grain: str | tuple[str, ...] | None = None
|
|
104
|
+
should_hard_block: bool = False
|
|
105
|
+
should_fail_validation: bool = False
|
|
106
|
+
column_names_one_of: list[list[str]] | None = None
|
|
107
|
+
row_value_check: Callable[[list[tuple]], bool] | None = None
|
|
108
|
+
min_semantic_warnings: int | None = None
|
|
109
|
+
status: str | None = None
|
|
110
|
+
status_in: tuple[str, ...] | None = None
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
@dataclass
|
|
114
|
+
class Scenario:
|
|
115
|
+
"""A single question-and-expectation pair sent to the live pipeline.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
|
|
119
|
+
id: Unique scenario identifier (for example ``"ST-001"``).
|
|
120
|
+
|
|
121
|
+
question: Natural-language question fed to the pipeline.
|
|
122
|
+
|
|
123
|
+
expected: Declarative expectations to assert against the result.
|
|
124
|
+
|
|
125
|
+
category: Logical grouping label used for pytest parametrization.
|
|
126
|
+
|
|
127
|
+
auto_responses: Pre-programmed ``y``/``n`` answers consumed in FIFO order by every ``ask_user_choice`` call during the run; defaults to ``["y", "y", "y"]`` (accept everything).
|
|
128
|
+
|
|
129
|
+
feedback: Final ``y``/``n`` feedback value for the "is this correct?" prompt; defaults to ``"y"``.
|
|
130
|
+
|
|
131
|
+
reject_reason: Pre-canned rejection reason text supplied when *feedback* is ``"n"`` and the pipeline prompts for a reason.
|
|
132
|
+
|
|
133
|
+
sequence_id: When this scenario is run as part of a sequence, the sequence id (for result storage so failures show full logs).
|
|
134
|
+
"""
|
|
135
|
+
|
|
136
|
+
id: str
|
|
137
|
+
question: str
|
|
138
|
+
expected: Expected
|
|
139
|
+
category: str = ""
|
|
140
|
+
auto_responses: list[str] | None = None
|
|
141
|
+
feedback: str = "y"
|
|
142
|
+
reject_reason: str = "incorrect results"
|
|
143
|
+
sequence_id: str | None = None
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
@dataclass
|
|
147
|
+
class SequenceScenario:
|
|
148
|
+
"""An ordered list of scenarios executed in series to test statefulness.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
|
|
152
|
+
id: Unique sequence identifier.
|
|
153
|
+
|
|
154
|
+
steps: Ordered ``Scenario`` objects run one after the other against a shared template store.
|
|
155
|
+
|
|
156
|
+
category: Logical grouping label.
|
|
157
|
+
"""
|
|
158
|
+
|
|
159
|
+
id: str
|
|
160
|
+
steps: list[Scenario]
|
|
161
|
+
category: str = ""
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
@dataclass
|
|
165
|
+
class SoftFailure:
|
|
166
|
+
"""A single soft-assertion failure recorded during evaluation.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
|
|
170
|
+
field: The name of the checked property (for example ``"min_rows"``).
|
|
171
|
+
|
|
172
|
+
expected: The expected value.
|
|
173
|
+
|
|
174
|
+
actual: The observed value.
|
|
175
|
+
|
|
176
|
+
message: Human-readable description of the mismatch.
|
|
177
|
+
"""
|
|
178
|
+
|
|
179
|
+
field: str
|
|
180
|
+
expected: Any
|
|
181
|
+
actual: Any
|
|
182
|
+
message: str
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
@dataclass
|
|
186
|
+
class StepResult:
|
|
187
|
+
"""Captured output from a single pipeline execution.
|
|
188
|
+
|
|
189
|
+
Args:
|
|
190
|
+
|
|
191
|
+
scenario_id: The ``Scenario.id`` that produced this result.
|
|
192
|
+
|
|
193
|
+
question: The raw question string.
|
|
194
|
+
|
|
195
|
+
status: Pipeline exit status (for example ``"ok"``, ``"restricted"``, ``"empty_input"``, ``"validation_failed"``, ``"hard_blocked"``, or ``"error"``).
|
|
196
|
+
|
|
197
|
+
intent: Parsed ``RuntimeIntent``, if the pipeline got far enough.
|
|
198
|
+
|
|
199
|
+
sql: Generated SQL string, if available.
|
|
200
|
+
|
|
201
|
+
rows: Result-set rows, if execution succeeded.
|
|
202
|
+
|
|
203
|
+
confidence: Final confidence score, if computed.
|
|
204
|
+
|
|
205
|
+
reuse_type: Template-match reuse type string.
|
|
206
|
+
|
|
207
|
+
template_id: Matched or created template id.
|
|
208
|
+
|
|
209
|
+
hard_blocked: Whether a hard-block fired.
|
|
210
|
+
|
|
211
|
+
validation_failed: Whether schema or result validation failed.
|
|
212
|
+
|
|
213
|
+
feedback: The feedback value (``"y"`` or ``"n"``) that was applied.
|
|
214
|
+
|
|
215
|
+
error: Stringified traceback when the run raised an exception.
|
|
216
|
+
|
|
217
|
+
duration_seconds: Wall-clock time for the run.
|
|
218
|
+
|
|
219
|
+
captured_logs: All ``[LOG]`` and ``[DEBUG]`` messages emitted.
|
|
220
|
+
|
|
221
|
+
semantic_warnings: Semantic warning messages collected during parsing.
|
|
222
|
+
|
|
223
|
+
soft_warnings: Informational result-validation messages (non-blocking).
|
|
224
|
+
|
|
225
|
+
hard_remaining: Unresolved hard result-validation failures.
|
|
226
|
+
|
|
227
|
+
rejection_classifications: Keyword categories for rejection reasons.
|
|
228
|
+
|
|
229
|
+
llm_calls: Total number of LLM API calls during intent parsing.
|
|
230
|
+
|
|
231
|
+
reject_reason_actual: When feedback is ``"n"``, the raw reason text (e.g. from scenario).
|
|
232
|
+
|
|
233
|
+
classified_category: When feedback is ``"n"``, the LLM-classified category.
|
|
234
|
+
|
|
235
|
+
classified_reason: When feedback is ``"n"``, the normalized reason summary.
|
|
236
|
+
|
|
237
|
+
generation_path: Which SQL generation branch was taken — ``"direct_reuse"``, ``"union_match"``, or ``"fresh"``. ``None`` when the pipeline exited before reaching SQL generation.
|
|
238
|
+
"""
|
|
239
|
+
|
|
240
|
+
scenario_id: str
|
|
241
|
+
question: str
|
|
242
|
+
status: str = "unknown"
|
|
243
|
+
intent: RuntimeIntent | None = None
|
|
244
|
+
sql: str | None = None
|
|
245
|
+
rows: list[tuple] | None = None
|
|
246
|
+
confidence: float | None = None
|
|
247
|
+
reuse_type: str | None = None
|
|
248
|
+
template_id: str | None = None
|
|
249
|
+
hard_blocked: bool = False
|
|
250
|
+
validation_failed: bool = False
|
|
251
|
+
feedback: str | None = None
|
|
252
|
+
error: str | None = None
|
|
253
|
+
duration_seconds: float = 0.0
|
|
254
|
+
captured_logs: list[str] = field(default_factory=list)
|
|
255
|
+
semantic_warnings: list[str] = field(default_factory=list)
|
|
256
|
+
soft_warnings: list[str] = field(default_factory=list)
|
|
257
|
+
hard_remaining: list[str] = field(default_factory=list)
|
|
258
|
+
rejection_classifications: list[str] = field(default_factory=list)
|
|
259
|
+
llm_calls: int = 0
|
|
260
|
+
reject_reason_actual: str | None = None
|
|
261
|
+
classified_category: str | None = None
|
|
262
|
+
classified_reason: str | None = None
|
|
263
|
+
generation_path: str | None = None
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
class SoftAssert:
|
|
267
|
+
"""Accumulates assertion failures instead of raising immediately.
|
|
268
|
+
|
|
269
|
+
Call :meth:`check` for each condition. After all checks, call :meth:`report` to raise a single ``AssertionError`` that lists every failure, or do nothing when all checks passed.
|
|
270
|
+
"""
|
|
271
|
+
|
|
272
|
+
def __init__(self) -> None:
|
|
273
|
+
self.failures: list[SoftFailure] = []
|
|
274
|
+
|
|
275
|
+
def check(
|
|
276
|
+
self,
|
|
277
|
+
condition: bool,
|
|
278
|
+
field_name: str,
|
|
279
|
+
expected: Any,
|
|
280
|
+
actual: Any,
|
|
281
|
+
message: str = "",
|
|
282
|
+
) -> None:
|
|
283
|
+
"""Record a failure when *condition* is ``False``.
|
|
284
|
+
|
|
285
|
+
Args:
|
|
286
|
+
|
|
287
|
+
condition: The boolean result of the assertion expression.
|
|
288
|
+
|
|
289
|
+
field_name: Symbolic name of the property being checked.
|
|
290
|
+
|
|
291
|
+
expected: The expected value (for reporting).
|
|
292
|
+
|
|
293
|
+
actual: The observed value (for reporting).
|
|
294
|
+
|
|
295
|
+
message: Optional human-readable explanation.
|
|
296
|
+
"""
|
|
297
|
+
if not condition:
|
|
298
|
+
msg = message or f"{field_name}: expected {expected!r}, got {actual!r}"
|
|
299
|
+
self.failures.append(SoftFailure(field=field_name, expected=expected, actual=actual, message=msg))
|
|
300
|
+
|
|
301
|
+
@property
|
|
302
|
+
def passed(self) -> bool:
|
|
303
|
+
"""Return ``True`` when no failures have been recorded."""
|
|
304
|
+
return len(self.failures) == 0
|
|
305
|
+
|
|
306
|
+
def report(self, header: str = "") -> None:
|
|
307
|
+
"""Raise ``AssertionError`` listing all accumulated failures.
|
|
308
|
+
|
|
309
|
+
Does nothing when :attr:`passed` is ``True``.
|
|
310
|
+
|
|
311
|
+
Args:
|
|
312
|
+
|
|
313
|
+
header: Optional leading line for the error message.
|
|
314
|
+
"""
|
|
315
|
+
if self.passed:
|
|
316
|
+
return
|
|
317
|
+
lines = [header] if header else []
|
|
318
|
+
for f in self.failures:
|
|
319
|
+
lines.append(f" [{f.field}] {f.message}")
|
|
320
|
+
raise AssertionError("\n".join(lines))
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
def _make_auto_responder(responses: list[str]) -> Callable:
|
|
324
|
+
"""Build a replacement for ``ask_user_choice`` that drains a response queue.
|
|
325
|
+
|
|
326
|
+
Args:
|
|
327
|
+
|
|
328
|
+
responses: FIFO list of ``"y"``/``"n"`` strings.
|
|
329
|
+
|
|
330
|
+
Returns:
|
|
331
|
+
|
|
332
|
+
A callable with the same signature as ``ask_user_choice``.
|
|
333
|
+
|
|
334
|
+
"""
|
|
335
|
+
queue = list(responses)
|
|
336
|
+
|
|
337
|
+
def _responder(prompt: str, options: list[str], silent_no: bool = False) -> str | None:
|
|
338
|
+
if queue:
|
|
339
|
+
return queue.pop(0)
|
|
340
|
+
return "y"
|
|
341
|
+
|
|
342
|
+
return _responder
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
def _make_input_responder(reject_reason: str = "incorrect results") -> Callable:
|
|
346
|
+
"""Build a replacement for ``builtins.input`` that supplies canned text.
|
|
347
|
+
|
|
348
|
+
The first call returns the *reject_reason*. Subsequent calls return ``"n"`` to bail out of any further interactive prompts.
|
|
349
|
+
|
|
350
|
+
Args:
|
|
351
|
+
|
|
352
|
+
reject_reason: Text returned on the first ``input()`` call.
|
|
353
|
+
|
|
354
|
+
Returns:
|
|
355
|
+
|
|
356
|
+
A callable replacing the built-in ``input``.
|
|
357
|
+
|
|
358
|
+
"""
|
|
359
|
+
call_count = {"n": 0}
|
|
360
|
+
|
|
361
|
+
def _fake_input(prompt: str = "") -> str:
|
|
362
|
+
call_count["n"] += 1
|
|
363
|
+
if call_count["n"] == 1:
|
|
364
|
+
return reject_reason
|
|
365
|
+
return "n"
|
|
366
|
+
|
|
367
|
+
return _fake_input
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
@contextmanager
|
|
371
|
+
def _pipeline_capture(
|
|
372
|
+
auto_responses: list[str],
|
|
373
|
+
reject_reason: str = "incorrect results",
|
|
374
|
+
csv_dir: str = "",
|
|
375
|
+
):
|
|
376
|
+
"""Context manager that patches interactive I/O for programmatic pipeline runs.
|
|
377
|
+
|
|
378
|
+
Replaces ``ask_user_choice`` with a FIFO auto-responder and ``builtins.input`` with a canned-text responder so that the pipeline runs without blocking on stdin. When *csv_dir* is set, ``save_result_csv`` is redirected so that the CSV file lands in that directory instead of the current working directory.
|
|
379
|
+
|
|
380
|
+
Args:
|
|
381
|
+
|
|
382
|
+
auto_responses: FIFO list of ``"y"``/``"n"`` strings for ``ask_user_choice``.
|
|
383
|
+
|
|
384
|
+
reject_reason: Canned rejection reason for ``input()`` prompts.
|
|
385
|
+
|
|
386
|
+
csv_dir: If non-empty, redirect ``results.csv`` writes into this directory.
|
|
387
|
+
|
|
388
|
+
Yields:
|
|
389
|
+
|
|
390
|
+
A dict ``{"logs": []}`` populated with captured log lines during the run.
|
|
391
|
+
|
|
392
|
+
"""
|
|
393
|
+
capture: dict[str, Any] = {"logs": []}
|
|
394
|
+
responder = _make_auto_responder(auto_responses)
|
|
395
|
+
input_responder = _make_input_responder(reject_reason)
|
|
396
|
+
|
|
397
|
+
import text2sql.core_utils as _cu
|
|
398
|
+
import text2sql.dialect as _di
|
|
399
|
+
import text2sql.expansion_ops as _eo
|
|
400
|
+
import text2sql.intent_expr as _ie
|
|
401
|
+
import text2sql.intent_process as _ip
|
|
402
|
+
import text2sql.intent_repair as _ir
|
|
403
|
+
import text2sql.intent_resolve as _irs
|
|
404
|
+
import text2sql.main_execution as _me
|
|
405
|
+
import text2sql.pipeline as _pl
|
|
406
|
+
import text2sql.qsim_ops as _qo
|
|
407
|
+
import text2sql.qsim_sample as _qs
|
|
408
|
+
import text2sql.qsim_struct as _qst
|
|
409
|
+
import text2sql.schema as _sc
|
|
410
|
+
import text2sql.schema_profiling as _sp
|
|
411
|
+
import text2sql.simulator as _sim
|
|
412
|
+
import text2sql.sql_gen as _sg
|
|
413
|
+
import text2sql.templates as _tp
|
|
414
|
+
import text2sql.utils as _ut
|
|
415
|
+
import text2sql.validation_agg as _va
|
|
416
|
+
import text2sql.validation_execute as _ve
|
|
417
|
+
import text2sql.validation_schema as _vs
|
|
418
|
+
import text2sql.validation_semantic as _vsm
|
|
419
|
+
|
|
420
|
+
original_log = _cu.log
|
|
421
|
+
original_debug = _cu.debug
|
|
422
|
+
|
|
423
|
+
def _capturing_log(msg: str) -> None:
|
|
424
|
+
capture["logs"].append(f"[LOG] {msg}")
|
|
425
|
+
original_log(msg)
|
|
426
|
+
|
|
427
|
+
def _capturing_debug(msg: str) -> None:
|
|
428
|
+
capture["logs"].append(f"[DEBUG] {msg}")
|
|
429
|
+
original_debug(msg)
|
|
430
|
+
|
|
431
|
+
_debug_modules = [
|
|
432
|
+
_cu,
|
|
433
|
+
_pl,
|
|
434
|
+
_sg,
|
|
435
|
+
_va,
|
|
436
|
+
_ve,
|
|
437
|
+
_vs,
|
|
438
|
+
_vsm,
|
|
439
|
+
_ie,
|
|
440
|
+
_ip,
|
|
441
|
+
_ir,
|
|
442
|
+
_irs,
|
|
443
|
+
_di,
|
|
444
|
+
_eo,
|
|
445
|
+
_ut,
|
|
446
|
+
_tp,
|
|
447
|
+
_sc,
|
|
448
|
+
_sp,
|
|
449
|
+
_qs,
|
|
450
|
+
_qst,
|
|
451
|
+
_qo,
|
|
452
|
+
_me,
|
|
453
|
+
_sim,
|
|
454
|
+
]
|
|
455
|
+
_log_modules = [_cu, _pl, _me, _eo, _sim]
|
|
456
|
+
|
|
457
|
+
extra_patches: list[Any] = []
|
|
458
|
+
for mod in _debug_modules:
|
|
459
|
+
if hasattr(mod, "debug"):
|
|
460
|
+
extra_patches.append(patch.object(mod, "debug", _capturing_debug))
|
|
461
|
+
for mod in _log_modules:
|
|
462
|
+
if hasattr(mod, "log"):
|
|
463
|
+
extra_patches.append(patch.object(mod, "log", _capturing_log))
|
|
464
|
+
|
|
465
|
+
if csv_dir:
|
|
466
|
+
_original_save = _pl.save_result_csv
|
|
467
|
+
|
|
468
|
+
def _redirected_save(rows, intent, sql):
|
|
469
|
+
orig_cwd = os.getcwd()
|
|
470
|
+
try:
|
|
471
|
+
os.chdir(csv_dir)
|
|
472
|
+
_original_save(rows, intent, sql)
|
|
473
|
+
finally:
|
|
474
|
+
os.chdir(orig_cwd)
|
|
475
|
+
|
|
476
|
+
import text2sql.live_testing as _lt
|
|
477
|
+
|
|
478
|
+
extra_patches.append(patch.object(_pl, "save_result_csv", _redirected_save))
|
|
479
|
+
extra_patches.append(patch.object(_lt, "save_result_csv", _redirected_save))
|
|
480
|
+
|
|
481
|
+
with (
|
|
482
|
+
patch.object(_cu, "ask_user_choice", responder),
|
|
483
|
+
patch.object(_pl, "ask_user_choice", responder),
|
|
484
|
+
patch.object(_me, "ask_user_choice", responder),
|
|
485
|
+
patch("builtins.input", input_responder),
|
|
486
|
+
):
|
|
487
|
+
for p in extra_patches:
|
|
488
|
+
p.start()
|
|
489
|
+
try:
|
|
490
|
+
yield capture
|
|
491
|
+
finally:
|
|
492
|
+
for p in extra_patches:
|
|
493
|
+
p.stop()
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
def _extract_reuse_sql(tmpl: Any, q_norm: str) -> str:
|
|
497
|
+
"""Reconstruct the final SQL that ``handle_direct_sql_reuse`` would produce."""
|
|
498
|
+
vh = tmpl.value_history
|
|
499
|
+
matched_params: dict[str, str] = {}
|
|
500
|
+
for i, hq in enumerate(vh.questions):
|
|
501
|
+
if hq and q_norm == hq:
|
|
502
|
+
matched_params = dict(vh.param_values[i])
|
|
503
|
+
break
|
|
504
|
+
if matched_params:
|
|
505
|
+
return substitute_params(tmpl.sql_param, matched_params)
|
|
506
|
+
return tmpl.sql_param
|
|
507
|
+
|
|
508
|
+
|
|
509
|
+
def _build_reuse_intent(tmpl: Any) -> RuntimeIntent:
|
|
510
|
+
"""Build a lightweight ``RuntimeIntent`` from a template's intent signature."""
|
|
511
|
+
sig = tmpl.intent_signature
|
|
512
|
+
return RuntimeIntent(
|
|
513
|
+
tables=sig.tables or [],
|
|
514
|
+
grain=sig.grain or "row_level",
|
|
515
|
+
select_cols=sig.select_cols or [],
|
|
516
|
+
group_by_cols=sig.group_by_cols or [],
|
|
517
|
+
order_by_cols=sig.order_by_cols or [],
|
|
518
|
+
filters_param=sig.filters_param or [],
|
|
519
|
+
having_param=getattr(sig, "having_param", None) or [],
|
|
520
|
+
column_map=getattr(sig, "column_map", None) or {},
|
|
521
|
+
natural_language="",
|
|
522
|
+
)
|
|
523
|
+
|
|
524
|
+
|
|
525
|
+
def _run_pipeline_core(
|
|
526
|
+
question: str,
|
|
527
|
+
schema: Any,
|
|
528
|
+
store: dict[str, Any],
|
|
529
|
+
templates: dict,
|
|
530
|
+
rejected: dict,
|
|
531
|
+
schema_terms: set[str],
|
|
532
|
+
feedback: str,
|
|
533
|
+
captured_logs: list[str],
|
|
534
|
+
reject_reason: str = "",
|
|
535
|
+
) -> StepResult:
|
|
536
|
+
"""Execute the pipeline steps for a single question and return captured state.
|
|
537
|
+
|
|
538
|
+
This mirrors the control flow of ``interactive_run_once`` but accepts all arguments programmatically and returns a ``StepResult`` instead of printing.
|
|
539
|
+
|
|
540
|
+
Args:
|
|
541
|
+
|
|
542
|
+
question: Natural-language question string.
|
|
543
|
+
|
|
544
|
+
schema: Loaded ``SchemaGraph``.
|
|
545
|
+
|
|
546
|
+
store: Mutable template store dict.
|
|
547
|
+
|
|
548
|
+
templates: Accepted templates dict.
|
|
549
|
+
|
|
550
|
+
rejected: Rejected templates dict.
|
|
551
|
+
|
|
552
|
+
schema_terms: Set of schema term tokens.
|
|
553
|
+
|
|
554
|
+
feedback: Pre-determined ``"y"``/``"n"`` feedback value.
|
|
555
|
+
|
|
556
|
+
captured_logs: Mutable list to which log lines are appended.
|
|
557
|
+
|
|
558
|
+
reject_reason: When feedback is ``"n"``, the canned reason supplied to the pipeline (for recording on ``StepResult``).
|
|
559
|
+
|
|
560
|
+
Returns:
|
|
561
|
+
|
|
562
|
+
Populated ``StepResult``.
|
|
563
|
+
|
|
564
|
+
"""
|
|
565
|
+
result = StepResult(scenario_id="", question=question, captured_logs=captured_logs)
|
|
566
|
+
|
|
567
|
+
valid, query_type, corrected = validate_question(question)
|
|
568
|
+
if not valid:
|
|
569
|
+
result.status = "restricted" if query_type == "restricted" else "invalid_question"
|
|
570
|
+
return result
|
|
571
|
+
|
|
572
|
+
if corrected != question:
|
|
573
|
+
question = corrected
|
|
574
|
+
|
|
575
|
+
q_norm = normalize_question(question)
|
|
576
|
+
|
|
577
|
+
dialect, engine, schema, store, templates, rejected, schema_terms = load_pipeline_resources(
|
|
578
|
+
schema, store, templates, rejected, schema_terms
|
|
579
|
+
)
|
|
580
|
+
|
|
581
|
+
tmpl_match = check_template_reuse(q_norm, templates)
|
|
582
|
+
result.reuse_type = tmpl_match.reuse_type
|
|
583
|
+
|
|
584
|
+
if tmpl_match.reuse_type == "direct_reuse":
|
|
585
|
+
result.reuse_type = "direct_reuse"
|
|
586
|
+
result.template_id = tmpl_match.best_template.id if tmpl_match.best_template else None
|
|
587
|
+
result.generation_path = "direct_reuse"
|
|
588
|
+
handled = handle_direct_sql_reuse(
|
|
589
|
+
q_norm,
|
|
590
|
+
tmpl_match.best_template,
|
|
591
|
+
dialect,
|
|
592
|
+
store,
|
|
593
|
+
templates,
|
|
594
|
+
schema,
|
|
595
|
+
engine=engine,
|
|
596
|
+
existing_nl=None,
|
|
597
|
+
)
|
|
598
|
+
if handled:
|
|
599
|
+
result.status = "ok"
|
|
600
|
+
result.sql = _extract_reuse_sql(tmpl_match.best_template, q_norm)
|
|
601
|
+
result.intent = _build_reuse_intent(tmpl_match.best_template)
|
|
602
|
+
return result
|
|
603
|
+
|
|
604
|
+
intent = tmpl_match.intent
|
|
605
|
+
semantic_warnings: list[dict[str, Any]] = []
|
|
606
|
+
|
|
607
|
+
if intent is None:
|
|
608
|
+
parsed = parse_intent_via_llm(q_norm, schema, templates, store)
|
|
609
|
+
if parsed is None:
|
|
610
|
+
result.status = "intent_parse_failed"
|
|
611
|
+
return result
|
|
612
|
+
intent, semantic_warnings, llm_calls = parsed
|
|
613
|
+
result.llm_calls = llm_calls
|
|
614
|
+
debug(f"[live_testing] intent parsed: tables={intent.tables} grain={intent.grain} llm_calls={llm_calls}")
|
|
615
|
+
if llm_calls > 2:
|
|
616
|
+
debug(f"[live_testing] WARNING: intent parse required {llm_calls} LLM calls for: {q_norm}")
|
|
617
|
+
|
|
618
|
+
result.intent = intent
|
|
619
|
+
|
|
620
|
+
ikey = intent_key(intent)
|
|
621
|
+
result.semantic_warnings = [w.get("message", "") if isinstance(w, dict) else str(w) for w in semantic_warnings]
|
|
622
|
+
|
|
623
|
+
union_result = match_template_for_union(intent, templates)
|
|
624
|
+
matched_template = None
|
|
625
|
+
union_select_cols = None
|
|
626
|
+
cols_changed = False
|
|
627
|
+
has_union_match = union_result is not None
|
|
628
|
+
if union_result is not None:
|
|
629
|
+
matched_template, union_select_cols, cols_changed = union_result
|
|
630
|
+
result.reuse_type = "intent_reuse" if cols_changed else "intent_direct_reuse"
|
|
631
|
+
debug(f"[live_testing] union match: template={matched_template.id if matched_template else None} cols_changed={cols_changed}")
|
|
632
|
+
|
|
633
|
+
result.generation_path = "union_match" if has_union_match else "fresh"
|
|
634
|
+
|
|
635
|
+
if not confirm_intent_with_user(
|
|
636
|
+
intent, store, semantic_warnings, has_union_match=has_union_match,
|
|
637
|
+
):
|
|
638
|
+
result.status = "intent_rejected"
|
|
639
|
+
return result
|
|
640
|
+
|
|
641
|
+
join_candidates, cmap, cte_join_hints = generate_join_candidates(intent, schema)
|
|
642
|
+
if join_candidates is None:
|
|
643
|
+
save_template_store(store)
|
|
644
|
+
result.status = "join_failed"
|
|
645
|
+
return result
|
|
646
|
+
|
|
647
|
+
(
|
|
648
|
+
hard_block_override,
|
|
649
|
+
hard_block_rejected_template,
|
|
650
|
+
matched_rejected_template,
|
|
651
|
+
proceed,
|
|
652
|
+
) = check_and_handle_hard_block(rejected, ikey, intent)
|
|
653
|
+
result.hard_blocked = not proceed
|
|
654
|
+
if not proceed:
|
|
655
|
+
save_template_store(store)
|
|
656
|
+
result.status = "hard_blocked"
|
|
657
|
+
return result
|
|
658
|
+
|
|
659
|
+
sql, ok = generate_and_validate_sql(
|
|
660
|
+
q_norm,
|
|
661
|
+
intent,
|
|
662
|
+
schema,
|
|
663
|
+
join_candidates,
|
|
664
|
+
cmap,
|
|
665
|
+
dialect,
|
|
666
|
+
store,
|
|
667
|
+
engine=engine,
|
|
668
|
+
cte_join_hints=cte_join_hints,
|
|
669
|
+
matched_template=matched_template,
|
|
670
|
+
union_select_cols=union_select_cols,
|
|
671
|
+
cols_changed=cols_changed,
|
|
672
|
+
)
|
|
673
|
+
result.sql = sql
|
|
674
|
+
if not ok:
|
|
675
|
+
debug(f"[live_testing] SQL validation failed for: {q_norm}")
|
|
676
|
+
result.status = "validation_failed"
|
|
677
|
+
result.validation_failed = True
|
|
678
|
+
return result
|
|
679
|
+
|
|
680
|
+
spark_sql = get_spark_sql_for_execution(
|
|
681
|
+
intent.sql_param or "",
|
|
682
|
+
dict(flatten_param_values(intent)),
|
|
683
|
+
schema,
|
|
684
|
+
intent,
|
|
685
|
+
dialect,
|
|
686
|
+
)
|
|
687
|
+
rows = execute_sql(
|
|
688
|
+
dialect,
|
|
689
|
+
sql,
|
|
690
|
+
spark_sql_for_execution=spark_sql if spark_sql else None,
|
|
691
|
+
)
|
|
692
|
+
result.sql = sql
|
|
693
|
+
result.rows = rows
|
|
694
|
+
debug(f"[live_testing] SQL generated ({result.generation_path}): rows={len(rows) if rows else 0}")
|
|
695
|
+
|
|
696
|
+
conf = compute_final_metrics(sql, intent, schema, templates, join_candidates, store)
|
|
697
|
+
result.confidence = conf
|
|
698
|
+
|
|
699
|
+
ux_summary = display_final_results(q_norm, intent, sql, rows)
|
|
700
|
+
|
|
701
|
+
if conf >= PolicyConfig.FINAL_SQL_AUTO_ACCEPT_THRESHOLD:
|
|
702
|
+
effective_feedback = "y"
|
|
703
|
+
else:
|
|
704
|
+
effective_feedback = feedback
|
|
705
|
+
|
|
706
|
+
result.feedback = effective_feedback
|
|
707
|
+
|
|
708
|
+
if effective_feedback == "y":
|
|
709
|
+
save_result_csv(rows, intent, sql)
|
|
710
|
+
|
|
711
|
+
reject_info = handle_user_feedback(
|
|
712
|
+
effective_feedback,
|
|
713
|
+
intent,
|
|
714
|
+
sql,
|
|
715
|
+
schema,
|
|
716
|
+
store,
|
|
717
|
+
templates,
|
|
718
|
+
rejected,
|
|
719
|
+
q_norm,
|
|
720
|
+
hard_block_override,
|
|
721
|
+
hard_block_rejected_template,
|
|
722
|
+
matched_rejected_template,
|
|
723
|
+
ux_summary,
|
|
724
|
+
dialect=dialect,
|
|
725
|
+
)
|
|
726
|
+
if effective_feedback == "n" and reject_info:
|
|
727
|
+
result.reject_reason_actual = reject_reason or reject_info.get("reject_reason")
|
|
728
|
+
result.classified_category = reject_info.get("category")
|
|
729
|
+
result.classified_reason = reject_info.get("normalized_reason")
|
|
730
|
+
|
|
731
|
+
result.status = "ok"
|
|
732
|
+
return result
|
|
733
|
+
|
|
734
|
+
|
|
735
|
+
class LiveTestRunner:
|
|
736
|
+
"""Orchestrates single-scenario and sequence-scenario execution.
|
|
737
|
+
|
|
738
|
+
Holds pre-loaded pipeline resources and exposes :meth:`run` and :meth:`run_sequence` which wrap the pipeline in a capture context, execute the question or questions, and return ``StepResult`` objects ready for assertion.
|
|
739
|
+
|
|
740
|
+
Args:
|
|
741
|
+
|
|
742
|
+
schema: Profiled ``SchemaGraph``.
|
|
743
|
+
|
|
744
|
+
store: Mutable template store dict.
|
|
745
|
+
|
|
746
|
+
templates: Accepted templates dict.
|
|
747
|
+
|
|
748
|
+
rejected: Rejected templates dict.
|
|
749
|
+
|
|
750
|
+
schema_terms: Set of schema term tokens.
|
|
751
|
+
|
|
752
|
+
csv_dir: Directory for ``results.csv`` output (empty means current working directory).
|
|
753
|
+
"""
|
|
754
|
+
|
|
755
|
+
def __init__(
|
|
756
|
+
self,
|
|
757
|
+
schema: Any,
|
|
758
|
+
store: dict[str, Any],
|
|
759
|
+
templates: dict,
|
|
760
|
+
rejected: dict,
|
|
761
|
+
schema_terms: set[str],
|
|
762
|
+
csv_dir: str = "",
|
|
763
|
+
) -> None:
|
|
764
|
+
self.schema = schema
|
|
765
|
+
self.store = store
|
|
766
|
+
self.templates = templates
|
|
767
|
+
self.rejected = rejected
|
|
768
|
+
self.schema_terms = schema_terms
|
|
769
|
+
self.csv_dir = csv_dir
|
|
770
|
+
|
|
771
|
+
def run(self, scenario: Scenario, retries: int = 0) -> StepResult:
|
|
772
|
+
"""Execute a single scenario against the live pipeline.
|
|
773
|
+
|
|
774
|
+
Args:
|
|
775
|
+
|
|
776
|
+
scenario: The ``Scenario`` to execute.
|
|
777
|
+
|
|
778
|
+
retries: Number of additional attempts on failure (0 means a single try).
|
|
779
|
+
|
|
780
|
+
Returns:
|
|
781
|
+
|
|
782
|
+
``StepResult`` from the last attempt.
|
|
783
|
+
|
|
784
|
+
"""
|
|
785
|
+
auto = scenario.auto_responses if scenario.auto_responses is not None else ["y", "y", "y"]
|
|
786
|
+
last_result: StepResult | None = None
|
|
787
|
+
|
|
788
|
+
for _ in range(1 + retries):
|
|
789
|
+
t0 = time.monotonic()
|
|
790
|
+
try:
|
|
791
|
+
with _pipeline_capture(list(auto), scenario.reject_reason, csv_dir=self.csv_dir) as cap:
|
|
792
|
+
step = _run_pipeline_core(
|
|
793
|
+
question=scenario.question,
|
|
794
|
+
schema=self.schema,
|
|
795
|
+
store=self.store,
|
|
796
|
+
templates=self.templates,
|
|
797
|
+
rejected=self.rejected,
|
|
798
|
+
schema_terms=self.schema_terms,
|
|
799
|
+
feedback=scenario.feedback,
|
|
800
|
+
captured_logs=cap["logs"],
|
|
801
|
+
reject_reason=scenario.reject_reason,
|
|
802
|
+
)
|
|
803
|
+
except Exception:
|
|
804
|
+
step = StepResult(
|
|
805
|
+
scenario_id=scenario.id,
|
|
806
|
+
question=scenario.question,
|
|
807
|
+
status="error",
|
|
808
|
+
error=traceback.format_exc(),
|
|
809
|
+
captured_logs=cap.get("logs", []) if "cap" in dir() else [],
|
|
810
|
+
)
|
|
811
|
+
|
|
812
|
+
step.scenario_id = scenario.id
|
|
813
|
+
step.duration_seconds = time.monotonic() - t0
|
|
814
|
+
last_result = step
|
|
815
|
+
|
|
816
|
+
if step.status == "ok":
|
|
817
|
+
break
|
|
818
|
+
|
|
819
|
+
return last_result
|
|
820
|
+
|
|
821
|
+
def run_sequence(self, seq: SequenceScenario, retries: int = 0) -> list[StepResult]:
|
|
822
|
+
"""Execute an ordered list of scenarios sharing template state.
|
|
823
|
+
|
|
824
|
+
Args:
|
|
825
|
+
|
|
826
|
+
seq: The ``SequenceScenario`` whose steps are run in order.
|
|
827
|
+
|
|
828
|
+
retries: Per-step retry count passed to :meth:`run`.
|
|
829
|
+
|
|
830
|
+
Returns:
|
|
831
|
+
|
|
832
|
+
List of ``StepResult`` objects, one per step.
|
|
833
|
+
|
|
834
|
+
"""
|
|
835
|
+
results: list[StepResult] = []
|
|
836
|
+
for idx, step_scenario in enumerate(seq.steps):
|
|
837
|
+
debug(
|
|
838
|
+
f"[LiveTestRunner.run_sequence] step {idx}/{len(seq.steps)} "
|
|
839
|
+
f"id={step_scenario.id} rejected_keys={sorted(self.rejected.keys())}"
|
|
840
|
+
)
|
|
841
|
+
step_with_seq = dc_replace(step_scenario, sequence_id=seq.id)
|
|
842
|
+
r = self.run(step_with_seq, retries=retries)
|
|
843
|
+
debug(f"[LiveTestRunner.run_sequence] step {idx} result: status={r.status} hard_blocked={r.hard_blocked}")
|
|
844
|
+
results.append(r)
|
|
845
|
+
return results
|
|
846
|
+
|
|
847
|
+
|
|
848
|
+
def _assert_scenario(result: StepResult, expected: Expected, soft: SoftAssert | None = None) -> SoftAssert:
|
|
849
|
+
"""Evaluate a ``StepResult`` against an ``Expected`` specification.
|
|
850
|
+
|
|
851
|
+
When *soft* is ``None`` a new ``SoftAssert`` is created. All applicable assertions are run and failures accumulated on the returned ``SoftAssert``.
|
|
852
|
+
|
|
853
|
+
Args:
|
|
854
|
+
|
|
855
|
+
result: The ``StepResult`` from a pipeline run.
|
|
856
|
+
|
|
857
|
+
expected: The ``Expected`` to assert against.
|
|
858
|
+
|
|
859
|
+
soft: Optional existing ``SoftAssert`` to append to.
|
|
860
|
+
|
|
861
|
+
Returns:
|
|
862
|
+
|
|
863
|
+
The ``SoftAssert`` instance (same object when one was passed in).
|
|
864
|
+
|
|
865
|
+
"""
|
|
866
|
+
if soft is None:
|
|
867
|
+
soft = SoftAssert()
|
|
868
|
+
|
|
869
|
+
if expected.status is not None:
|
|
870
|
+
soft.check(
|
|
871
|
+
result.status == expected.status,
|
|
872
|
+
"status",
|
|
873
|
+
expected.status,
|
|
874
|
+
result.status,
|
|
875
|
+
)
|
|
876
|
+
elif expected.status_in is not None:
|
|
877
|
+
soft.check(
|
|
878
|
+
result.status in expected.status_in,
|
|
879
|
+
"status_in",
|
|
880
|
+
expected.status_in,
|
|
881
|
+
result.status,
|
|
882
|
+
)
|
|
883
|
+
|
|
884
|
+
if result.intent is not None:
|
|
885
|
+
actual_tables = sorted(result.intent.tables or [])
|
|
886
|
+
if expected.tables_one_of is not None:
|
|
887
|
+
allowed = [sorted(t) for t in expected.tables_one_of]
|
|
888
|
+
soft.check(
|
|
889
|
+
actual_tables in allowed,
|
|
890
|
+
"tables",
|
|
891
|
+
expected.tables_one_of,
|
|
892
|
+
actual_tables,
|
|
893
|
+
)
|
|
894
|
+
elif expected.tables is not None:
|
|
895
|
+
expected_tables = sorted(expected.tables)
|
|
896
|
+
soft.check(
|
|
897
|
+
actual_tables == expected_tables,
|
|
898
|
+
"tables",
|
|
899
|
+
expected_tables,
|
|
900
|
+
actual_tables,
|
|
901
|
+
)
|
|
902
|
+
|
|
903
|
+
if expected.grain is not None and result.intent is not None:
|
|
904
|
+
if isinstance(expected.grain, tuple):
|
|
905
|
+
soft.check(
|
|
906
|
+
result.intent.grain in expected.grain,
|
|
907
|
+
"grain",
|
|
908
|
+
expected.grain,
|
|
909
|
+
result.intent.grain,
|
|
910
|
+
)
|
|
911
|
+
else:
|
|
912
|
+
soft.check(
|
|
913
|
+
result.intent.grain == expected.grain,
|
|
914
|
+
"grain",
|
|
915
|
+
expected.grain,
|
|
916
|
+
result.intent.grain,
|
|
917
|
+
)
|
|
918
|
+
elif expected.grain_in is not None and result.intent is not None:
|
|
919
|
+
soft.check(
|
|
920
|
+
result.intent.grain in expected.grain_in,
|
|
921
|
+
"grain",
|
|
922
|
+
expected.grain_in,
|
|
923
|
+
result.intent.grain,
|
|
924
|
+
)
|
|
925
|
+
|
|
926
|
+
if expected.reuse_type is not None:
|
|
927
|
+
if isinstance(expected.reuse_type, tuple):
|
|
928
|
+
soft.check(
|
|
929
|
+
result.reuse_type in expected.reuse_type,
|
|
930
|
+
"reuse_type",
|
|
931
|
+
expected.reuse_type,
|
|
932
|
+
result.reuse_type,
|
|
933
|
+
)
|
|
934
|
+
else:
|
|
935
|
+
soft.check(
|
|
936
|
+
result.reuse_type == expected.reuse_type,
|
|
937
|
+
"reuse_type",
|
|
938
|
+
expected.reuse_type,
|
|
939
|
+
result.reuse_type,
|
|
940
|
+
)
|
|
941
|
+
|
|
942
|
+
sql_upper = (result.sql or "").upper()
|
|
943
|
+
|
|
944
|
+
if expected.contains_join is not None:
|
|
945
|
+
has_join = "JOIN" in sql_upper
|
|
946
|
+
soft.check(
|
|
947
|
+
has_join == expected.contains_join,
|
|
948
|
+
"contains_join",
|
|
949
|
+
expected.contains_join,
|
|
950
|
+
has_join,
|
|
951
|
+
)
|
|
952
|
+
|
|
953
|
+
if expected.contains_group_by is not None:
|
|
954
|
+
has_gb = "GROUP BY" in sql_upper
|
|
955
|
+
soft.check(
|
|
956
|
+
has_gb == expected.contains_group_by,
|
|
957
|
+
"contains_group_by",
|
|
958
|
+
expected.contains_group_by,
|
|
959
|
+
has_gb,
|
|
960
|
+
)
|
|
961
|
+
|
|
962
|
+
if expected.contains_cte is not None:
|
|
963
|
+
has_cte = sql_upper.lstrip().startswith("WITH ")
|
|
964
|
+
soft.check(
|
|
965
|
+
has_cte == expected.contains_cte,
|
|
966
|
+
"contains_cte",
|
|
967
|
+
expected.contains_cte,
|
|
968
|
+
has_cte,
|
|
969
|
+
)
|
|
970
|
+
|
|
971
|
+
if expected.sql_contains is not None and result.sql is not None:
|
|
972
|
+
for substr in expected.sql_contains:
|
|
973
|
+
found = substr.upper() in sql_upper
|
|
974
|
+
soft.check(found, "sql_contains", substr, f"not found in: {result.sql[:120]}")
|
|
975
|
+
|
|
976
|
+
if expected.sql_excludes is not None and result.sql is not None:
|
|
977
|
+
for substr in expected.sql_excludes:
|
|
978
|
+
found = substr.upper() in sql_upper
|
|
979
|
+
soft.check(
|
|
980
|
+
not found,
|
|
981
|
+
"sql_excludes",
|
|
982
|
+
f"absent: {substr}",
|
|
983
|
+
f"found in: {result.sql[:120]}",
|
|
984
|
+
)
|
|
985
|
+
|
|
986
|
+
if expected.min_rows is not None and result.rows is not None:
|
|
987
|
+
soft.check(
|
|
988
|
+
len(result.rows) >= expected.min_rows,
|
|
989
|
+
"min_rows",
|
|
990
|
+
expected.min_rows,
|
|
991
|
+
len(result.rows),
|
|
992
|
+
)
|
|
993
|
+
|
|
994
|
+
if expected.max_rows is not None and result.rows is not None:
|
|
995
|
+
soft.check(
|
|
996
|
+
len(result.rows) <= expected.max_rows,
|
|
997
|
+
"max_rows",
|
|
998
|
+
expected.max_rows,
|
|
999
|
+
len(result.rows),
|
|
1000
|
+
)
|
|
1001
|
+
|
|
1002
|
+
if expected.min_confidence is not None and result.confidence is not None:
|
|
1003
|
+
soft.check(
|
|
1004
|
+
result.confidence >= expected.min_confidence,
|
|
1005
|
+
"min_confidence",
|
|
1006
|
+
expected.min_confidence,
|
|
1007
|
+
result.confidence,
|
|
1008
|
+
)
|
|
1009
|
+
|
|
1010
|
+
if expected.column_names_one_of is not None and result.rows is not None and result.intent is not None:
|
|
1011
|
+
actual_cols = []
|
|
1012
|
+
for c in result.intent.select_cols or []:
|
|
1013
|
+
name = getattr(c, "alias", None) or c.expr.primary_term
|
|
1014
|
+
actual_cols.append(name.split(".")[-1] if name and "." in name else (name or ""))
|
|
1015
|
+
allowed = [sorted(cols) for cols in expected.column_names_one_of]
|
|
1016
|
+
soft.check(
|
|
1017
|
+
sorted(actual_cols) in allowed,
|
|
1018
|
+
"column_names",
|
|
1019
|
+
expected.column_names_one_of,
|
|
1020
|
+
actual_cols,
|
|
1021
|
+
)
|
|
1022
|
+
|
|
1023
|
+
if expected.row_value_check is not None and result.rows is not None:
|
|
1024
|
+
check_ok = expected.row_value_check(result.rows)
|
|
1025
|
+
soft.check(check_ok, "row_value_check", "True", check_ok)
|
|
1026
|
+
|
|
1027
|
+
if expected.min_semantic_warnings is not None:
|
|
1028
|
+
soft.check(
|
|
1029
|
+
len(result.semantic_warnings) >= expected.min_semantic_warnings,
|
|
1030
|
+
"min_semantic_warnings",
|
|
1031
|
+
expected.min_semantic_warnings,
|
|
1032
|
+
len(result.semantic_warnings),
|
|
1033
|
+
)
|
|
1034
|
+
|
|
1035
|
+
if expected.should_hard_block:
|
|
1036
|
+
soft.check(
|
|
1037
|
+
result.hard_blocked,
|
|
1038
|
+
"should_hard_block",
|
|
1039
|
+
True,
|
|
1040
|
+
result.hard_blocked,
|
|
1041
|
+
)
|
|
1042
|
+
|
|
1043
|
+
if expected.should_fail_validation:
|
|
1044
|
+
soft.check(
|
|
1045
|
+
result.validation_failed,
|
|
1046
|
+
"should_fail_validation",
|
|
1047
|
+
True,
|
|
1048
|
+
result.validation_failed,
|
|
1049
|
+
)
|
|
1050
|
+
|
|
1051
|
+
return soft
|
|
1052
|
+
|
|
1053
|
+
|
|
1054
|
+
def run_and_assert(
|
|
1055
|
+
runner: LiveTestRunner,
|
|
1056
|
+
scenario: Scenario,
|
|
1057
|
+
header: str,
|
|
1058
|
+
max_attempts: int = 2,
|
|
1059
|
+
retries: int = 1,
|
|
1060
|
+
) -> None:
|
|
1061
|
+
"""Run a scenario and assert expectations, retrying from scratch on failure.
|
|
1062
|
+
|
|
1063
|
+
On the first attempt the pipeline is run and assertions are
|
|
1064
|
+
checked. When any assertion fails and *max_attempts* > 1 the
|
|
1065
|
+
entire pipeline is re-executed from scratch and assertions are
|
|
1066
|
+
re-evaluated on the fresh result.
|
|
1067
|
+
|
|
1068
|
+
Args:
|
|
1069
|
+
|
|
1070
|
+
runner: ``LiveTestRunner`` instance configured for the target
|
|
1071
|
+
database.
|
|
1072
|
+
scenario: The ``Scenario`` to execute.
|
|
1073
|
+
header: Label used in the ``AssertionError`` message.
|
|
1074
|
+
max_attempts: Total attempts including the initial run.
|
|
1075
|
+
retries: Per-attempt pipeline retry count passed to
|
|
1076
|
+
``runner.run``.
|
|
1077
|
+
"""
|
|
1078
|
+
last_soft: SoftAssert | None = None
|
|
1079
|
+
for _ in range(max_attempts):
|
|
1080
|
+
result = runner.run(scenario, retries=retries)
|
|
1081
|
+
last_soft = _assert_scenario(result, scenario.expected)
|
|
1082
|
+
if last_soft.passed:
|
|
1083
|
+
return
|
|
1084
|
+
if last_soft is not None:
|
|
1085
|
+
last_soft.report(header=header)
|
|
1086
|
+
|
|
1087
|
+
|
|
1088
|
+
def run_sequence_and_assert(
|
|
1089
|
+
runner: LiveTestRunner,
|
|
1090
|
+
seq: SequenceScenario,
|
|
1091
|
+
max_attempts: int = 2,
|
|
1092
|
+
retries: int = 1,
|
|
1093
|
+
) -> None:
|
|
1094
|
+
"""Run a sequence of scenarios and assert each step, retrying on failure.
|
|
1095
|
+
|
|
1096
|
+
When any step's assertions fail and *max_attempts* > 1, the
|
|
1097
|
+
entire sequence is re-executed from scratch.
|
|
1098
|
+
|
|
1099
|
+
Args:
|
|
1100
|
+
|
|
1101
|
+
runner: ``LiveTestRunner`` instance configured for the target
|
|
1102
|
+
database.
|
|
1103
|
+
seq: The ``SequenceScenario`` whose steps are run in order.
|
|
1104
|
+
max_attempts: Total attempts including the initial run.
|
|
1105
|
+
retries: Per-step pipeline retry count passed to
|
|
1106
|
+
``runner.run``.
|
|
1107
|
+
"""
|
|
1108
|
+
last_soft: SoftAssert | None = None
|
|
1109
|
+
for _ in range(max_attempts):
|
|
1110
|
+
results = runner.run_sequence(seq, retries=retries)
|
|
1111
|
+
last_soft = SoftAssert()
|
|
1112
|
+
for step_scenario, result in zip(seq.steps, results):
|
|
1113
|
+
_assert_scenario(result, step_scenario.expected, soft=last_soft)
|
|
1114
|
+
if last_soft.passed:
|
|
1115
|
+
return
|
|
1116
|
+
if last_soft is not None:
|
|
1117
|
+
last_soft.report(header=f"[{seq.id}]")
|