aetherdialect 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1117 @@
1
+ """Reusable live-testing framework for the text2sql pipeline.
2
+
3
+ Provides data models (scenarios, expected outcomes, step results), a soft-assertion collector, a pipeline-capture context manager that monkey-patches interactive I/O, and a LiveTestRunner that orchestrates end-to-end execution of a scenario against a real LLM and live database.
4
+
5
+ All classes and helpers are database-agnostic; the only database-specific pieces live in the scenario definitions and conftest fixtures provided by the caller.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import os
11
+ import time
12
+ import traceback
13
+ from collections.abc import Callable
14
+ from contextlib import contextmanager
15
+ from dataclasses import dataclass, field, replace as dc_replace
16
+ from typing import Any
17
+ from unittest.mock import patch
18
+
19
+ from .config import PolicyConfig
20
+ from .contracts_core import RuntimeIntent
21
+ from .core_utils import debug, normalize_question, substitute_params
22
+ from .intent_process import match_template_for_union
23
+ from .pipeline import (
24
+ check_and_handle_hard_block,
25
+ check_template_reuse,
26
+ compute_final_metrics,
27
+ confirm_intent_with_user,
28
+ display_final_results,
29
+ generate_and_validate_sql,
30
+ generate_join_candidates,
31
+ handle_direct_sql_reuse,
32
+ handle_user_feedback,
33
+ load_pipeline_resources,
34
+ parse_intent_via_llm,
35
+ save_result_csv,
36
+ )
37
+ from .templates import save_template_store
38
+ from .utils import flatten_param_values, intent_key, validate_question
39
+ from .validation_execute import execute_sql, get_spark_sql_for_execution
40
+
41
+
42
+ @dataclass
43
+ class Expected:
44
+ """Declarative expectations for a single pipeline run.
45
+
46
+ Every field is optional. When a field is ``None`` or the default, the corresponding assertion is skipped by the runner.
47
+
48
+ Args:
49
+
50
+ tables: Expected tables referenced by the generated intent.
51
+
52
+ min_rows: Minimum row count (inclusive) for the result set.
53
+
54
+ max_rows: Maximum row count (inclusive) for the result set.
55
+
56
+ min_confidence: Lower-bound confidence score (inclusive).
57
+
58
+ reuse_type: One of ``"direct_reuse"``, ``"intent_direct_reuse"``, or ``"none"``; when set, the runner asserts the template-match type.
59
+
60
+ contains_join: When ``True`` the SQL must include a JOIN clause.
61
+
62
+ contains_group_by: When ``True`` the SQL must include GROUP BY.
63
+
64
+ contains_cte: When ``True`` the SQL must include a CTE (``WITH``).
65
+
66
+ sql_contains: List of substrings that must all appear in the final SQL.
67
+
68
+ sql_excludes: List of substrings that must not appear in the final SQL.
69
+
70
+ grain: Expected intent grain (``"row_level"`` or ``"grouped"``).
71
+
72
+ should_hard_block: When ``True`` the pipeline should trigger a hard-block.
73
+
74
+ should_fail_validation: When ``True`` schema or result validation should fail.
75
+
76
+ column_names_one_of: Allowed column header lists; result must match one exactly.
77
+
78
+ row_value_check: Optional callable ``(rows) -> bool`` for custom value checks.
79
+
80
+ min_semantic_warnings: Minimum number of semantic warnings expected.
81
+
82
+ status: Expected pipeline exit status when the default ``"ok"`` is not appropriate (for example ``"restricted"`` or ``"hard_blocked"``).
83
+
84
+ status_in: When set, status must be one of the given values.
85
+
86
+ tables_one_of: When set, tables must equal one of the given lists (order-independent).
87
+
88
+ grain_in: When set, grain must be one of the given values.
89
+ """
90
+
91
+ tables: list[str] | None = None
92
+ tables_one_of: list[list[str]] | None = None
93
+ grain_in: tuple[str, ...] | None = None
94
+ min_rows: int | None = None
95
+ max_rows: int | None = None
96
+ min_confidence: float | None = None
97
+ reuse_type: str | tuple[str, ...] | None = None
98
+ contains_join: bool | None = None
99
+ contains_group_by: bool | None = None
100
+ contains_cte: bool | None = None
101
+ sql_contains: list[str] | None = None
102
+ sql_excludes: list[str] | None = None
103
+ grain: str | tuple[str, ...] | None = None
104
+ should_hard_block: bool = False
105
+ should_fail_validation: bool = False
106
+ column_names_one_of: list[list[str]] | None = None
107
+ row_value_check: Callable[[list[tuple]], bool] | None = None
108
+ min_semantic_warnings: int | None = None
109
+ status: str | None = None
110
+ status_in: tuple[str, ...] | None = None
111
+
112
+
113
+ @dataclass
114
+ class Scenario:
115
+ """A single question-and-expectation pair sent to the live pipeline.
116
+
117
+ Args:
118
+
119
+ id: Unique scenario identifier (for example ``"ST-001"``).
120
+
121
+ question: Natural-language question fed to the pipeline.
122
+
123
+ expected: Declarative expectations to assert against the result.
124
+
125
+ category: Logical grouping label used for pytest parametrization.
126
+
127
+ auto_responses: Pre-programmed ``y``/``n`` answers consumed in FIFO order by every ``ask_user_choice`` call during the run; defaults to ``["y", "y", "y"]`` (accept everything).
128
+
129
+ feedback: Final ``y``/``n`` feedback value for the "is this correct?" prompt; defaults to ``"y"``.
130
+
131
+ reject_reason: Pre-canned rejection reason text supplied when *feedback* is ``"n"`` and the pipeline prompts for a reason.
132
+
133
+ sequence_id: When this scenario is run as part of a sequence, the sequence id (for result storage so failures show full logs).
134
+ """
135
+
136
+ id: str
137
+ question: str
138
+ expected: Expected
139
+ category: str = ""
140
+ auto_responses: list[str] | None = None
141
+ feedback: str = "y"
142
+ reject_reason: str = "incorrect results"
143
+ sequence_id: str | None = None
144
+
145
+
146
+ @dataclass
147
+ class SequenceScenario:
148
+ """An ordered list of scenarios executed in series to test statefulness.
149
+
150
+ Args:
151
+
152
+ id: Unique sequence identifier.
153
+
154
+ steps: Ordered ``Scenario`` objects run one after the other against a shared template store.
155
+
156
+ category: Logical grouping label.
157
+ """
158
+
159
+ id: str
160
+ steps: list[Scenario]
161
+ category: str = ""
162
+
163
+
164
+ @dataclass
165
+ class SoftFailure:
166
+ """A single soft-assertion failure recorded during evaluation.
167
+
168
+ Args:
169
+
170
+ field: The name of the checked property (for example ``"min_rows"``).
171
+
172
+ expected: The expected value.
173
+
174
+ actual: The observed value.
175
+
176
+ message: Human-readable description of the mismatch.
177
+ """
178
+
179
+ field: str
180
+ expected: Any
181
+ actual: Any
182
+ message: str
183
+
184
+
185
+ @dataclass
186
+ class StepResult:
187
+ """Captured output from a single pipeline execution.
188
+
189
+ Args:
190
+
191
+ scenario_id: The ``Scenario.id`` that produced this result.
192
+
193
+ question: The raw question string.
194
+
195
+ status: Pipeline exit status (for example ``"ok"``, ``"restricted"``, ``"empty_input"``, ``"validation_failed"``, ``"hard_blocked"``, or ``"error"``).
196
+
197
+ intent: Parsed ``RuntimeIntent``, if the pipeline got far enough.
198
+
199
+ sql: Generated SQL string, if available.
200
+
201
+ rows: Result-set rows, if execution succeeded.
202
+
203
+ confidence: Final confidence score, if computed.
204
+
205
+ reuse_type: Template-match reuse type string.
206
+
207
+ template_id: Matched or created template id.
208
+
209
+ hard_blocked: Whether a hard-block fired.
210
+
211
+ validation_failed: Whether schema or result validation failed.
212
+
213
+ feedback: The feedback value (``"y"`` or ``"n"``) that was applied.
214
+
215
+ error: Stringified traceback when the run raised an exception.
216
+
217
+ duration_seconds: Wall-clock time for the run.
218
+
219
+ captured_logs: All ``[LOG]`` and ``[DEBUG]`` messages emitted.
220
+
221
+ semantic_warnings: Semantic warning messages collected during parsing.
222
+
223
+ soft_warnings: Informational result-validation messages (non-blocking).
224
+
225
+ hard_remaining: Unresolved hard result-validation failures.
226
+
227
+ rejection_classifications: Keyword categories for rejection reasons.
228
+
229
+ llm_calls: Total number of LLM API calls during intent parsing.
230
+
231
+ reject_reason_actual: When feedback is ``"n"``, the raw reason text (e.g. from scenario).
232
+
233
+ classified_category: When feedback is ``"n"``, the LLM-classified category.
234
+
235
+ classified_reason: When feedback is ``"n"``, the normalized reason summary.
236
+
237
+ generation_path: Which SQL generation branch was taken — ``"direct_reuse"``, ``"union_match"``, or ``"fresh"``. ``None`` when the pipeline exited before reaching SQL generation.
238
+ """
239
+
240
+ scenario_id: str
241
+ question: str
242
+ status: str = "unknown"
243
+ intent: RuntimeIntent | None = None
244
+ sql: str | None = None
245
+ rows: list[tuple] | None = None
246
+ confidence: float | None = None
247
+ reuse_type: str | None = None
248
+ template_id: str | None = None
249
+ hard_blocked: bool = False
250
+ validation_failed: bool = False
251
+ feedback: str | None = None
252
+ error: str | None = None
253
+ duration_seconds: float = 0.0
254
+ captured_logs: list[str] = field(default_factory=list)
255
+ semantic_warnings: list[str] = field(default_factory=list)
256
+ soft_warnings: list[str] = field(default_factory=list)
257
+ hard_remaining: list[str] = field(default_factory=list)
258
+ rejection_classifications: list[str] = field(default_factory=list)
259
+ llm_calls: int = 0
260
+ reject_reason_actual: str | None = None
261
+ classified_category: str | None = None
262
+ classified_reason: str | None = None
263
+ generation_path: str | None = None
264
+
265
+
266
+ class SoftAssert:
267
+ """Accumulates assertion failures instead of raising immediately.
268
+
269
+ Call :meth:`check` for each condition. After all checks, call :meth:`report` to raise a single ``AssertionError`` that lists every failure, or do nothing when all checks passed.
270
+ """
271
+
272
+ def __init__(self) -> None:
273
+ self.failures: list[SoftFailure] = []
274
+
275
+ def check(
276
+ self,
277
+ condition: bool,
278
+ field_name: str,
279
+ expected: Any,
280
+ actual: Any,
281
+ message: str = "",
282
+ ) -> None:
283
+ """Record a failure when *condition* is ``False``.
284
+
285
+ Args:
286
+
287
+ condition: The boolean result of the assertion expression.
288
+
289
+ field_name: Symbolic name of the property being checked.
290
+
291
+ expected: The expected value (for reporting).
292
+
293
+ actual: The observed value (for reporting).
294
+
295
+ message: Optional human-readable explanation.
296
+ """
297
+ if not condition:
298
+ msg = message or f"{field_name}: expected {expected!r}, got {actual!r}"
299
+ self.failures.append(SoftFailure(field=field_name, expected=expected, actual=actual, message=msg))
300
+
301
+ @property
302
+ def passed(self) -> bool:
303
+ """Return ``True`` when no failures have been recorded."""
304
+ return len(self.failures) == 0
305
+
306
+ def report(self, header: str = "") -> None:
307
+ """Raise ``AssertionError`` listing all accumulated failures.
308
+
309
+ Does nothing when :attr:`passed` is ``True``.
310
+
311
+ Args:
312
+
313
+ header: Optional leading line for the error message.
314
+ """
315
+ if self.passed:
316
+ return
317
+ lines = [header] if header else []
318
+ for f in self.failures:
319
+ lines.append(f" [{f.field}] {f.message}")
320
+ raise AssertionError("\n".join(lines))
321
+
322
+
323
+ def _make_auto_responder(responses: list[str]) -> Callable:
324
+ """Build a replacement for ``ask_user_choice`` that drains a response queue.
325
+
326
+ Args:
327
+
328
+ responses: FIFO list of ``"y"``/``"n"`` strings.
329
+
330
+ Returns:
331
+
332
+ A callable with the same signature as ``ask_user_choice``.
333
+
334
+ """
335
+ queue = list(responses)
336
+
337
+ def _responder(prompt: str, options: list[str], silent_no: bool = False) -> str | None:
338
+ if queue:
339
+ return queue.pop(0)
340
+ return "y"
341
+
342
+ return _responder
343
+
344
+
345
+ def _make_input_responder(reject_reason: str = "incorrect results") -> Callable:
346
+ """Build a replacement for ``builtins.input`` that supplies canned text.
347
+
348
+ The first call returns the *reject_reason*. Subsequent calls return ``"n"`` to bail out of any further interactive prompts.
349
+
350
+ Args:
351
+
352
+ reject_reason: Text returned on the first ``input()`` call.
353
+
354
+ Returns:
355
+
356
+ A callable replacing the built-in ``input``.
357
+
358
+ """
359
+ call_count = {"n": 0}
360
+
361
+ def _fake_input(prompt: str = "") -> str:
362
+ call_count["n"] += 1
363
+ if call_count["n"] == 1:
364
+ return reject_reason
365
+ return "n"
366
+
367
+ return _fake_input
368
+
369
+
370
+ @contextmanager
371
+ def _pipeline_capture(
372
+ auto_responses: list[str],
373
+ reject_reason: str = "incorrect results",
374
+ csv_dir: str = "",
375
+ ):
376
+ """Context manager that patches interactive I/O for programmatic pipeline runs.
377
+
378
+ Replaces ``ask_user_choice`` with a FIFO auto-responder and ``builtins.input`` with a canned-text responder so that the pipeline runs without blocking on stdin. When *csv_dir* is set, ``save_result_csv`` is redirected so that the CSV file lands in that directory instead of the current working directory.
379
+
380
+ Args:
381
+
382
+ auto_responses: FIFO list of ``"y"``/``"n"`` strings for ``ask_user_choice``.
383
+
384
+ reject_reason: Canned rejection reason for ``input()`` prompts.
385
+
386
+ csv_dir: If non-empty, redirect ``results.csv`` writes into this directory.
387
+
388
+ Yields:
389
+
390
+ A dict ``{"logs": []}`` populated with captured log lines during the run.
391
+
392
+ """
393
+ capture: dict[str, Any] = {"logs": []}
394
+ responder = _make_auto_responder(auto_responses)
395
+ input_responder = _make_input_responder(reject_reason)
396
+
397
+ import text2sql.core_utils as _cu
398
+ import text2sql.dialect as _di
399
+ import text2sql.expansion_ops as _eo
400
+ import text2sql.intent_expr as _ie
401
+ import text2sql.intent_process as _ip
402
+ import text2sql.intent_repair as _ir
403
+ import text2sql.intent_resolve as _irs
404
+ import text2sql.main_execution as _me
405
+ import text2sql.pipeline as _pl
406
+ import text2sql.qsim_ops as _qo
407
+ import text2sql.qsim_sample as _qs
408
+ import text2sql.qsim_struct as _qst
409
+ import text2sql.schema as _sc
410
+ import text2sql.schema_profiling as _sp
411
+ import text2sql.simulator as _sim
412
+ import text2sql.sql_gen as _sg
413
+ import text2sql.templates as _tp
414
+ import text2sql.utils as _ut
415
+ import text2sql.validation_agg as _va
416
+ import text2sql.validation_execute as _ve
417
+ import text2sql.validation_schema as _vs
418
+ import text2sql.validation_semantic as _vsm
419
+
420
+ original_log = _cu.log
421
+ original_debug = _cu.debug
422
+
423
+ def _capturing_log(msg: str) -> None:
424
+ capture["logs"].append(f"[LOG] {msg}")
425
+ original_log(msg)
426
+
427
+ def _capturing_debug(msg: str) -> None:
428
+ capture["logs"].append(f"[DEBUG] {msg}")
429
+ original_debug(msg)
430
+
431
+ _debug_modules = [
432
+ _cu,
433
+ _pl,
434
+ _sg,
435
+ _va,
436
+ _ve,
437
+ _vs,
438
+ _vsm,
439
+ _ie,
440
+ _ip,
441
+ _ir,
442
+ _irs,
443
+ _di,
444
+ _eo,
445
+ _ut,
446
+ _tp,
447
+ _sc,
448
+ _sp,
449
+ _qs,
450
+ _qst,
451
+ _qo,
452
+ _me,
453
+ _sim,
454
+ ]
455
+ _log_modules = [_cu, _pl, _me, _eo, _sim]
456
+
457
+ extra_patches: list[Any] = []
458
+ for mod in _debug_modules:
459
+ if hasattr(mod, "debug"):
460
+ extra_patches.append(patch.object(mod, "debug", _capturing_debug))
461
+ for mod in _log_modules:
462
+ if hasattr(mod, "log"):
463
+ extra_patches.append(patch.object(mod, "log", _capturing_log))
464
+
465
+ if csv_dir:
466
+ _original_save = _pl.save_result_csv
467
+
468
+ def _redirected_save(rows, intent, sql):
469
+ orig_cwd = os.getcwd()
470
+ try:
471
+ os.chdir(csv_dir)
472
+ _original_save(rows, intent, sql)
473
+ finally:
474
+ os.chdir(orig_cwd)
475
+
476
+ import text2sql.live_testing as _lt
477
+
478
+ extra_patches.append(patch.object(_pl, "save_result_csv", _redirected_save))
479
+ extra_patches.append(patch.object(_lt, "save_result_csv", _redirected_save))
480
+
481
+ with (
482
+ patch.object(_cu, "ask_user_choice", responder),
483
+ patch.object(_pl, "ask_user_choice", responder),
484
+ patch.object(_me, "ask_user_choice", responder),
485
+ patch("builtins.input", input_responder),
486
+ ):
487
+ for p in extra_patches:
488
+ p.start()
489
+ try:
490
+ yield capture
491
+ finally:
492
+ for p in extra_patches:
493
+ p.stop()
494
+
495
+
496
+ def _extract_reuse_sql(tmpl: Any, q_norm: str) -> str:
497
+ """Reconstruct the final SQL that ``handle_direct_sql_reuse`` would produce."""
498
+ vh = tmpl.value_history
499
+ matched_params: dict[str, str] = {}
500
+ for i, hq in enumerate(vh.questions):
501
+ if hq and q_norm == hq:
502
+ matched_params = dict(vh.param_values[i])
503
+ break
504
+ if matched_params:
505
+ return substitute_params(tmpl.sql_param, matched_params)
506
+ return tmpl.sql_param
507
+
508
+
509
+ def _build_reuse_intent(tmpl: Any) -> RuntimeIntent:
510
+ """Build a lightweight ``RuntimeIntent`` from a template's intent signature."""
511
+ sig = tmpl.intent_signature
512
+ return RuntimeIntent(
513
+ tables=sig.tables or [],
514
+ grain=sig.grain or "row_level",
515
+ select_cols=sig.select_cols or [],
516
+ group_by_cols=sig.group_by_cols or [],
517
+ order_by_cols=sig.order_by_cols or [],
518
+ filters_param=sig.filters_param or [],
519
+ having_param=getattr(sig, "having_param", None) or [],
520
+ column_map=getattr(sig, "column_map", None) or {},
521
+ natural_language="",
522
+ )
523
+
524
+
525
+ def _run_pipeline_core(
526
+ question: str,
527
+ schema: Any,
528
+ store: dict[str, Any],
529
+ templates: dict,
530
+ rejected: dict,
531
+ schema_terms: set[str],
532
+ feedback: str,
533
+ captured_logs: list[str],
534
+ reject_reason: str = "",
535
+ ) -> StepResult:
536
+ """Execute the pipeline steps for a single question and return captured state.
537
+
538
+ This mirrors the control flow of ``interactive_run_once`` but accepts all arguments programmatically and returns a ``StepResult`` instead of printing.
539
+
540
+ Args:
541
+
542
+ question: Natural-language question string.
543
+
544
+ schema: Loaded ``SchemaGraph``.
545
+
546
+ store: Mutable template store dict.
547
+
548
+ templates: Accepted templates dict.
549
+
550
+ rejected: Rejected templates dict.
551
+
552
+ schema_terms: Set of schema term tokens.
553
+
554
+ feedback: Pre-determined ``"y"``/``"n"`` feedback value.
555
+
556
+ captured_logs: Mutable list to which log lines are appended.
557
+
558
+ reject_reason: When feedback is ``"n"``, the canned reason supplied to the pipeline (for recording on ``StepResult``).
559
+
560
+ Returns:
561
+
562
+ Populated ``StepResult``.
563
+
564
+ """
565
+ result = StepResult(scenario_id="", question=question, captured_logs=captured_logs)
566
+
567
+ valid, query_type, corrected = validate_question(question)
568
+ if not valid:
569
+ result.status = "restricted" if query_type == "restricted" else "invalid_question"
570
+ return result
571
+
572
+ if corrected != question:
573
+ question = corrected
574
+
575
+ q_norm = normalize_question(question)
576
+
577
+ dialect, engine, schema, store, templates, rejected, schema_terms = load_pipeline_resources(
578
+ schema, store, templates, rejected, schema_terms
579
+ )
580
+
581
+ tmpl_match = check_template_reuse(q_norm, templates)
582
+ result.reuse_type = tmpl_match.reuse_type
583
+
584
+ if tmpl_match.reuse_type == "direct_reuse":
585
+ result.reuse_type = "direct_reuse"
586
+ result.template_id = tmpl_match.best_template.id if tmpl_match.best_template else None
587
+ result.generation_path = "direct_reuse"
588
+ handled = handle_direct_sql_reuse(
589
+ q_norm,
590
+ tmpl_match.best_template,
591
+ dialect,
592
+ store,
593
+ templates,
594
+ schema,
595
+ engine=engine,
596
+ existing_nl=None,
597
+ )
598
+ if handled:
599
+ result.status = "ok"
600
+ result.sql = _extract_reuse_sql(tmpl_match.best_template, q_norm)
601
+ result.intent = _build_reuse_intent(tmpl_match.best_template)
602
+ return result
603
+
604
+ intent = tmpl_match.intent
605
+ semantic_warnings: list[dict[str, Any]] = []
606
+
607
+ if intent is None:
608
+ parsed = parse_intent_via_llm(q_norm, schema, templates, store)
609
+ if parsed is None:
610
+ result.status = "intent_parse_failed"
611
+ return result
612
+ intent, semantic_warnings, llm_calls = parsed
613
+ result.llm_calls = llm_calls
614
+ debug(f"[live_testing] intent parsed: tables={intent.tables} grain={intent.grain} llm_calls={llm_calls}")
615
+ if llm_calls > 2:
616
+ debug(f"[live_testing] WARNING: intent parse required {llm_calls} LLM calls for: {q_norm}")
617
+
618
+ result.intent = intent
619
+
620
+ ikey = intent_key(intent)
621
+ result.semantic_warnings = [w.get("message", "") if isinstance(w, dict) else str(w) for w in semantic_warnings]
622
+
623
+ union_result = match_template_for_union(intent, templates)
624
+ matched_template = None
625
+ union_select_cols = None
626
+ cols_changed = False
627
+ has_union_match = union_result is not None
628
+ if union_result is not None:
629
+ matched_template, union_select_cols, cols_changed = union_result
630
+ result.reuse_type = "intent_reuse" if cols_changed else "intent_direct_reuse"
631
+ debug(f"[live_testing] union match: template={matched_template.id if matched_template else None} cols_changed={cols_changed}")
632
+
633
+ result.generation_path = "union_match" if has_union_match else "fresh"
634
+
635
+ if not confirm_intent_with_user(
636
+ intent, store, semantic_warnings, has_union_match=has_union_match,
637
+ ):
638
+ result.status = "intent_rejected"
639
+ return result
640
+
641
+ join_candidates, cmap, cte_join_hints = generate_join_candidates(intent, schema)
642
+ if join_candidates is None:
643
+ save_template_store(store)
644
+ result.status = "join_failed"
645
+ return result
646
+
647
+ (
648
+ hard_block_override,
649
+ hard_block_rejected_template,
650
+ matched_rejected_template,
651
+ proceed,
652
+ ) = check_and_handle_hard_block(rejected, ikey, intent)
653
+ result.hard_blocked = not proceed
654
+ if not proceed:
655
+ save_template_store(store)
656
+ result.status = "hard_blocked"
657
+ return result
658
+
659
+ sql, ok = generate_and_validate_sql(
660
+ q_norm,
661
+ intent,
662
+ schema,
663
+ join_candidates,
664
+ cmap,
665
+ dialect,
666
+ store,
667
+ engine=engine,
668
+ cte_join_hints=cte_join_hints,
669
+ matched_template=matched_template,
670
+ union_select_cols=union_select_cols,
671
+ cols_changed=cols_changed,
672
+ )
673
+ result.sql = sql
674
+ if not ok:
675
+ debug(f"[live_testing] SQL validation failed for: {q_norm}")
676
+ result.status = "validation_failed"
677
+ result.validation_failed = True
678
+ return result
679
+
680
+ spark_sql = get_spark_sql_for_execution(
681
+ intent.sql_param or "",
682
+ dict(flatten_param_values(intent)),
683
+ schema,
684
+ intent,
685
+ dialect,
686
+ )
687
+ rows = execute_sql(
688
+ dialect,
689
+ sql,
690
+ spark_sql_for_execution=spark_sql if spark_sql else None,
691
+ )
692
+ result.sql = sql
693
+ result.rows = rows
694
+ debug(f"[live_testing] SQL generated ({result.generation_path}): rows={len(rows) if rows else 0}")
695
+
696
+ conf = compute_final_metrics(sql, intent, schema, templates, join_candidates, store)
697
+ result.confidence = conf
698
+
699
+ ux_summary = display_final_results(q_norm, intent, sql, rows)
700
+
701
+ if conf >= PolicyConfig.FINAL_SQL_AUTO_ACCEPT_THRESHOLD:
702
+ effective_feedback = "y"
703
+ else:
704
+ effective_feedback = feedback
705
+
706
+ result.feedback = effective_feedback
707
+
708
+ if effective_feedback == "y":
709
+ save_result_csv(rows, intent, sql)
710
+
711
+ reject_info = handle_user_feedback(
712
+ effective_feedback,
713
+ intent,
714
+ sql,
715
+ schema,
716
+ store,
717
+ templates,
718
+ rejected,
719
+ q_norm,
720
+ hard_block_override,
721
+ hard_block_rejected_template,
722
+ matched_rejected_template,
723
+ ux_summary,
724
+ dialect=dialect,
725
+ )
726
+ if effective_feedback == "n" and reject_info:
727
+ result.reject_reason_actual = reject_reason or reject_info.get("reject_reason")
728
+ result.classified_category = reject_info.get("category")
729
+ result.classified_reason = reject_info.get("normalized_reason")
730
+
731
+ result.status = "ok"
732
+ return result
733
+
734
+
735
+ class LiveTestRunner:
736
+ """Orchestrates single-scenario and sequence-scenario execution.
737
+
738
+ Holds pre-loaded pipeline resources and exposes :meth:`run` and :meth:`run_sequence` which wrap the pipeline in a capture context, execute the question or questions, and return ``StepResult`` objects ready for assertion.
739
+
740
+ Args:
741
+
742
+ schema: Profiled ``SchemaGraph``.
743
+
744
+ store: Mutable template store dict.
745
+
746
+ templates: Accepted templates dict.
747
+
748
+ rejected: Rejected templates dict.
749
+
750
+ schema_terms: Set of schema term tokens.
751
+
752
+ csv_dir: Directory for ``results.csv`` output (empty means current working directory).
753
+ """
754
+
755
+ def __init__(
756
+ self,
757
+ schema: Any,
758
+ store: dict[str, Any],
759
+ templates: dict,
760
+ rejected: dict,
761
+ schema_terms: set[str],
762
+ csv_dir: str = "",
763
+ ) -> None:
764
+ self.schema = schema
765
+ self.store = store
766
+ self.templates = templates
767
+ self.rejected = rejected
768
+ self.schema_terms = schema_terms
769
+ self.csv_dir = csv_dir
770
+
771
+ def run(self, scenario: Scenario, retries: int = 0) -> StepResult:
772
+ """Execute a single scenario against the live pipeline.
773
+
774
+ Args:
775
+
776
+ scenario: The ``Scenario`` to execute.
777
+
778
+ retries: Number of additional attempts on failure (0 means a single try).
779
+
780
+ Returns:
781
+
782
+ ``StepResult`` from the last attempt.
783
+
784
+ """
785
+ auto = scenario.auto_responses if scenario.auto_responses is not None else ["y", "y", "y"]
786
+ last_result: StepResult | None = None
787
+
788
+ for _ in range(1 + retries):
789
+ t0 = time.monotonic()
790
+ try:
791
+ with _pipeline_capture(list(auto), scenario.reject_reason, csv_dir=self.csv_dir) as cap:
792
+ step = _run_pipeline_core(
793
+ question=scenario.question,
794
+ schema=self.schema,
795
+ store=self.store,
796
+ templates=self.templates,
797
+ rejected=self.rejected,
798
+ schema_terms=self.schema_terms,
799
+ feedback=scenario.feedback,
800
+ captured_logs=cap["logs"],
801
+ reject_reason=scenario.reject_reason,
802
+ )
803
+ except Exception:
804
+ step = StepResult(
805
+ scenario_id=scenario.id,
806
+ question=scenario.question,
807
+ status="error",
808
+ error=traceback.format_exc(),
809
+ captured_logs=cap.get("logs", []) if "cap" in dir() else [],
810
+ )
811
+
812
+ step.scenario_id = scenario.id
813
+ step.duration_seconds = time.monotonic() - t0
814
+ last_result = step
815
+
816
+ if step.status == "ok":
817
+ break
818
+
819
+ return last_result
820
+
821
+ def run_sequence(self, seq: SequenceScenario, retries: int = 0) -> list[StepResult]:
822
+ """Execute an ordered list of scenarios sharing template state.
823
+
824
+ Args:
825
+
826
+ seq: The ``SequenceScenario`` whose steps are run in order.
827
+
828
+ retries: Per-step retry count passed to :meth:`run`.
829
+
830
+ Returns:
831
+
832
+ List of ``StepResult`` objects, one per step.
833
+
834
+ """
835
+ results: list[StepResult] = []
836
+ for idx, step_scenario in enumerate(seq.steps):
837
+ debug(
838
+ f"[LiveTestRunner.run_sequence] step {idx}/{len(seq.steps)} "
839
+ f"id={step_scenario.id} rejected_keys={sorted(self.rejected.keys())}"
840
+ )
841
+ step_with_seq = dc_replace(step_scenario, sequence_id=seq.id)
842
+ r = self.run(step_with_seq, retries=retries)
843
+ debug(f"[LiveTestRunner.run_sequence] step {idx} result: status={r.status} hard_blocked={r.hard_blocked}")
844
+ results.append(r)
845
+ return results
846
+
847
+
848
+ def _assert_scenario(result: StepResult, expected: Expected, soft: SoftAssert | None = None) -> SoftAssert:
849
+ """Evaluate a ``StepResult`` against an ``Expected`` specification.
850
+
851
+ When *soft* is ``None`` a new ``SoftAssert`` is created. All applicable assertions are run and failures accumulated on the returned ``SoftAssert``.
852
+
853
+ Args:
854
+
855
+ result: The ``StepResult`` from a pipeline run.
856
+
857
+ expected: The ``Expected`` to assert against.
858
+
859
+ soft: Optional existing ``SoftAssert`` to append to.
860
+
861
+ Returns:
862
+
863
+ The ``SoftAssert`` instance (same object when one was passed in).
864
+
865
+ """
866
+ if soft is None:
867
+ soft = SoftAssert()
868
+
869
+ if expected.status is not None:
870
+ soft.check(
871
+ result.status == expected.status,
872
+ "status",
873
+ expected.status,
874
+ result.status,
875
+ )
876
+ elif expected.status_in is not None:
877
+ soft.check(
878
+ result.status in expected.status_in,
879
+ "status_in",
880
+ expected.status_in,
881
+ result.status,
882
+ )
883
+
884
+ if result.intent is not None:
885
+ actual_tables = sorted(result.intent.tables or [])
886
+ if expected.tables_one_of is not None:
887
+ allowed = [sorted(t) for t in expected.tables_one_of]
888
+ soft.check(
889
+ actual_tables in allowed,
890
+ "tables",
891
+ expected.tables_one_of,
892
+ actual_tables,
893
+ )
894
+ elif expected.tables is not None:
895
+ expected_tables = sorted(expected.tables)
896
+ soft.check(
897
+ actual_tables == expected_tables,
898
+ "tables",
899
+ expected_tables,
900
+ actual_tables,
901
+ )
902
+
903
+ if expected.grain is not None and result.intent is not None:
904
+ if isinstance(expected.grain, tuple):
905
+ soft.check(
906
+ result.intent.grain in expected.grain,
907
+ "grain",
908
+ expected.grain,
909
+ result.intent.grain,
910
+ )
911
+ else:
912
+ soft.check(
913
+ result.intent.grain == expected.grain,
914
+ "grain",
915
+ expected.grain,
916
+ result.intent.grain,
917
+ )
918
+ elif expected.grain_in is not None and result.intent is not None:
919
+ soft.check(
920
+ result.intent.grain in expected.grain_in,
921
+ "grain",
922
+ expected.grain_in,
923
+ result.intent.grain,
924
+ )
925
+
926
+ if expected.reuse_type is not None:
927
+ if isinstance(expected.reuse_type, tuple):
928
+ soft.check(
929
+ result.reuse_type in expected.reuse_type,
930
+ "reuse_type",
931
+ expected.reuse_type,
932
+ result.reuse_type,
933
+ )
934
+ else:
935
+ soft.check(
936
+ result.reuse_type == expected.reuse_type,
937
+ "reuse_type",
938
+ expected.reuse_type,
939
+ result.reuse_type,
940
+ )
941
+
942
+ sql_upper = (result.sql or "").upper()
943
+
944
+ if expected.contains_join is not None:
945
+ has_join = "JOIN" in sql_upper
946
+ soft.check(
947
+ has_join == expected.contains_join,
948
+ "contains_join",
949
+ expected.contains_join,
950
+ has_join,
951
+ )
952
+
953
+ if expected.contains_group_by is not None:
954
+ has_gb = "GROUP BY" in sql_upper
955
+ soft.check(
956
+ has_gb == expected.contains_group_by,
957
+ "contains_group_by",
958
+ expected.contains_group_by,
959
+ has_gb,
960
+ )
961
+
962
+ if expected.contains_cte is not None:
963
+ has_cte = sql_upper.lstrip().startswith("WITH ")
964
+ soft.check(
965
+ has_cte == expected.contains_cte,
966
+ "contains_cte",
967
+ expected.contains_cte,
968
+ has_cte,
969
+ )
970
+
971
+ if expected.sql_contains is not None and result.sql is not None:
972
+ for substr in expected.sql_contains:
973
+ found = substr.upper() in sql_upper
974
+ soft.check(found, "sql_contains", substr, f"not found in: {result.sql[:120]}")
975
+
976
+ if expected.sql_excludes is not None and result.sql is not None:
977
+ for substr in expected.sql_excludes:
978
+ found = substr.upper() in sql_upper
979
+ soft.check(
980
+ not found,
981
+ "sql_excludes",
982
+ f"absent: {substr}",
983
+ f"found in: {result.sql[:120]}",
984
+ )
985
+
986
+ if expected.min_rows is not None and result.rows is not None:
987
+ soft.check(
988
+ len(result.rows) >= expected.min_rows,
989
+ "min_rows",
990
+ expected.min_rows,
991
+ len(result.rows),
992
+ )
993
+
994
+ if expected.max_rows is not None and result.rows is not None:
995
+ soft.check(
996
+ len(result.rows) <= expected.max_rows,
997
+ "max_rows",
998
+ expected.max_rows,
999
+ len(result.rows),
1000
+ )
1001
+
1002
+ if expected.min_confidence is not None and result.confidence is not None:
1003
+ soft.check(
1004
+ result.confidence >= expected.min_confidence,
1005
+ "min_confidence",
1006
+ expected.min_confidence,
1007
+ result.confidence,
1008
+ )
1009
+
1010
+ if expected.column_names_one_of is not None and result.rows is not None and result.intent is not None:
1011
+ actual_cols = []
1012
+ for c in result.intent.select_cols or []:
1013
+ name = getattr(c, "alias", None) or c.expr.primary_term
1014
+ actual_cols.append(name.split(".")[-1] if name and "." in name else (name or ""))
1015
+ allowed = [sorted(cols) for cols in expected.column_names_one_of]
1016
+ soft.check(
1017
+ sorted(actual_cols) in allowed,
1018
+ "column_names",
1019
+ expected.column_names_one_of,
1020
+ actual_cols,
1021
+ )
1022
+
1023
+ if expected.row_value_check is not None and result.rows is not None:
1024
+ check_ok = expected.row_value_check(result.rows)
1025
+ soft.check(check_ok, "row_value_check", "True", check_ok)
1026
+
1027
+ if expected.min_semantic_warnings is not None:
1028
+ soft.check(
1029
+ len(result.semantic_warnings) >= expected.min_semantic_warnings,
1030
+ "min_semantic_warnings",
1031
+ expected.min_semantic_warnings,
1032
+ len(result.semantic_warnings),
1033
+ )
1034
+
1035
+ if expected.should_hard_block:
1036
+ soft.check(
1037
+ result.hard_blocked,
1038
+ "should_hard_block",
1039
+ True,
1040
+ result.hard_blocked,
1041
+ )
1042
+
1043
+ if expected.should_fail_validation:
1044
+ soft.check(
1045
+ result.validation_failed,
1046
+ "should_fail_validation",
1047
+ True,
1048
+ result.validation_failed,
1049
+ )
1050
+
1051
+ return soft
1052
+
1053
+
1054
+ def run_and_assert(
1055
+ runner: LiveTestRunner,
1056
+ scenario: Scenario,
1057
+ header: str,
1058
+ max_attempts: int = 2,
1059
+ retries: int = 1,
1060
+ ) -> None:
1061
+ """Run a scenario and assert expectations, retrying from scratch on failure.
1062
+
1063
+ On the first attempt the pipeline is run and assertions are
1064
+ checked. When any assertion fails and *max_attempts* > 1 the
1065
+ entire pipeline is re-executed from scratch and assertions are
1066
+ re-evaluated on the fresh result.
1067
+
1068
+ Args:
1069
+
1070
+ runner: ``LiveTestRunner`` instance configured for the target
1071
+ database.
1072
+ scenario: The ``Scenario`` to execute.
1073
+ header: Label used in the ``AssertionError`` message.
1074
+ max_attempts: Total attempts including the initial run.
1075
+ retries: Per-attempt pipeline retry count passed to
1076
+ ``runner.run``.
1077
+ """
1078
+ last_soft: SoftAssert | None = None
1079
+ for _ in range(max_attempts):
1080
+ result = runner.run(scenario, retries=retries)
1081
+ last_soft = _assert_scenario(result, scenario.expected)
1082
+ if last_soft.passed:
1083
+ return
1084
+ if last_soft is not None:
1085
+ last_soft.report(header=header)
1086
+
1087
+
1088
+ def run_sequence_and_assert(
1089
+ runner: LiveTestRunner,
1090
+ seq: SequenceScenario,
1091
+ max_attempts: int = 2,
1092
+ retries: int = 1,
1093
+ ) -> None:
1094
+ """Run a sequence of scenarios and assert each step, retrying on failure.
1095
+
1096
+ When any step's assertions fail and *max_attempts* > 1, the
1097
+ entire sequence is re-executed from scratch.
1098
+
1099
+ Args:
1100
+
1101
+ runner: ``LiveTestRunner`` instance configured for the target
1102
+ database.
1103
+ seq: The ``SequenceScenario`` whose steps are run in order.
1104
+ max_attempts: Total attempts including the initial run.
1105
+ retries: Per-step pipeline retry count passed to
1106
+ ``runner.run``.
1107
+ """
1108
+ last_soft: SoftAssert | None = None
1109
+ for _ in range(max_attempts):
1110
+ results = runner.run_sequence(seq, retries=retries)
1111
+ last_soft = SoftAssert()
1112
+ for step_scenario, result in zip(seq.steps, results):
1113
+ _assert_scenario(result, step_scenario.expected, soft=last_soft)
1114
+ if last_soft.passed:
1115
+ return
1116
+ if last_soft is not None:
1117
+ last_soft.report(header=f"[{seq.id}]")