aetherdialect 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
text2sql/simulator.py ADDED
@@ -0,0 +1,970 @@
1
+ """Coverage simulator for synthetic template generation.
2
+
3
+ Orchestrates the deterministic simulation pipeline: gold intent generation
4
+ from seed questions, fully deterministic multi-depth expansion, join
5
+ resolution (cached per table-set, LLM only for ambiguous multi-table),
6
+ SQL building and validation, value instantiation, and LLM-based NL
7
+ question generation with a realism quality gate.
8
+
9
+ Also provides helpers for persisting simulator artefacts.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import csv
15
+ import glob
16
+ import json
17
+ import os
18
+ import re
19
+ from dataclasses import replace
20
+ from typing import Any
21
+
22
+ from .config import EngineConfig, SimulatorConfig
23
+ from .contracts_base import SchemaGraph, TemplateStats, ValueDomain
24
+ from .contracts_core import (
25
+ ConcreteIntent,
26
+ FilterParam,
27
+ HavingParam,
28
+ RuntimeIntent,
29
+ SimulatorIntent,
30
+ SimulatorResult,
31
+ Template,
32
+ ValueHistory,
33
+ )
34
+ from .core_utils import (
35
+ ask_user_choice,
36
+ canonicalize_sql,
37
+ colmap_signature,
38
+ debug,
39
+ log,
40
+ normalize_question,
41
+ parameter_abstract,
42
+ print_info,
43
+ sql_fp,
44
+ substitute_params,
45
+ )
46
+ from .intent_expr import (
47
+ assign_param_keys,
48
+ collect_raw_param_values,
49
+ extract_structural_params,
50
+ )
51
+ from .intent_process import full_intent_parse
52
+ from .qsim_sample import (
53
+ deterministic_having_value,
54
+ sample_coordinated_range,
55
+ sample_value_from_domain,
56
+ )
57
+ from .sql_gen import (
58
+ build_deterministic_sql,
59
+ get_join_choice_from_llm,
60
+ inject_join_into_deterministic_sql,
61
+ join_candidate_map,
62
+ join_hints_multi,
63
+ physical_tables_for_join_hints,
64
+ )
65
+ from .utils import (
66
+ extract_tables_from_sql,
67
+ flatten_param_values,
68
+ generate_question_from_sql,
69
+ intent_key,
70
+ sql_shape,
71
+ )
72
+ from .validation_execute import execute_sql, get_spark_sql_for_execution, validate_sql
73
+
74
+
75
+ # ---------------------------------------------------------------------------
76
+ # Versioning / file helpers
77
+ # ---------------------------------------------------------------------------
78
+
79
+ def get_next_simulator_version(output_dir: str) -> int:
80
+ """Return the next auto-incrementing version number for simulator output files.
81
+
82
+ Scans *output_dir* for files matching ``gold_intents_v*.json`` and
83
+ returns the highest existing version number plus one.
84
+
85
+ Args:
86
+
87
+ output_dir: Directory to scan for versioned output files.
88
+
89
+ Returns:
90
+
91
+ The next available integer version number (starting at 1).
92
+ """
93
+ pattern = os.path.join(output_dir, "gold_intents_v*.json")
94
+ existing = glob.glob(pattern)
95
+ if not existing:
96
+ return 1
97
+ versions = []
98
+ for f in existing:
99
+ try:
100
+ v = int(os.path.basename(f).split("_v")[1].split(".")[0])
101
+ versions.append(v)
102
+ except (IndexError, ValueError):
103
+ continue
104
+ return max(versions) + 1 if versions else 1
105
+
106
+
107
+ # ---------------------------------------------------------------------------
108
+ # Seed / gold intent generation
109
+ # ---------------------------------------------------------------------------
110
+
111
+ def _load_seed_questions(filepath: str) -> list[dict[str, Any]]:
112
+ """Load seed questions from a text file.
113
+
114
+ Supports numbered formats (``1. question``, ``1) question``),
115
+ quoted strings, and plain lines. Phase sections are parsed from
116
+ ``# Phase N`` comment headers.
117
+
118
+ Args:
119
+
120
+ filepath: Path to the seed questions text file.
121
+
122
+ Returns:
123
+
124
+ List of dicts with ``number``, ``question``, and ``phase`` keys.
125
+
126
+ Raises:
127
+
128
+ FileNotFoundError: If *filepath* does not exist.
129
+ """
130
+ if not os.path.exists(filepath):
131
+ raise FileNotFoundError(f"Seed questions file not found: {filepath}")
132
+
133
+ questions: list[dict[str, Any]] = []
134
+ current_phase = "unknown"
135
+ auto_number = 1
136
+
137
+ with open(filepath, encoding="utf-8") as f:
138
+ for line in f:
139
+ line = line.strip()
140
+ if not line:
141
+ continue
142
+ if line.startswith("#"):
143
+ phase_match = re.search(
144
+ r"Phase\s+(\d+)", line, re.IGNORECASE,
145
+ )
146
+ if phase_match:
147
+ current_phase = f"phase_{phase_match.group(1)}"
148
+ continue
149
+
150
+ question_text = None
151
+ number = None
152
+
153
+ match = re.match(r"^(\d+)\.\s*(.+)$", line)
154
+ if match:
155
+ number = int(match.group(1))
156
+ question_text = match.group(2).strip()
157
+
158
+ if not question_text:
159
+ match = re.match(r"^(\d+)\)\s*(.+)$", line)
160
+ if match:
161
+ number = int(match.group(1))
162
+ question_text = match.group(2).strip()
163
+
164
+ if not question_text:
165
+ match = re.match(r'^["\'](.+)["\']$', line)
166
+ if match:
167
+ question_text = match.group(1).strip()
168
+ number = auto_number
169
+ auto_number += 1
170
+
171
+ if not question_text:
172
+ question_text = line
173
+ number = auto_number
174
+ auto_number += 1
175
+
176
+ question_text = question_text.strip("\"'")
177
+ if question_text:
178
+ questions.append({
179
+ "number": number,
180
+ "question": question_text,
181
+ "phase": current_phase,
182
+ })
183
+
184
+ debug(
185
+ f"[simulator.load_seed_questions] loaded: {len(questions)} questions"
186
+ )
187
+ return questions
188
+
189
+
190
+ def _parse_gold_intent(
191
+ question: str, schema: SchemaGraph,
192
+ ) -> RuntimeIntent | None:
193
+ """Parse a seed question into a gold RuntimeIntent via full intent pipeline."""
194
+ q_norm = normalize_question(question)
195
+ intent, semantic_issues, _llm_calls = full_intent_parse(q_norm, schema)
196
+ if intent is None:
197
+ debug(f"[simulator.parse_gold_intent] parse_failed: {q_norm}")
198
+ return None
199
+ if semantic_issues:
200
+ debug(
201
+ f"[simulator.parse_gold_intent] semantic_issues: "
202
+ f"{len(semantic_issues)} issues (ignored)"
203
+ )
204
+ return intent
205
+
206
+
207
+ def _confirm_gold_intent(
208
+ question: str, intent: RuntimeIntent,
209
+ ) -> tuple[bool, RuntimeIntent | None]:
210
+ """Interactively confirm a parsed gold intent with the user."""
211
+ nl_summary = intent.natural_language or ""
212
+ if not nl_summary:
213
+ nl_summary = (
214
+ f"Query {', '.join(intent.tables or [])} "
215
+ f"for {intent.grain or 'data'}"
216
+ )
217
+ grain = intent.grain or "row_level"
218
+ if grain == "scalar":
219
+ expected_display = "scalar"
220
+ else:
221
+ expected_display = f"{intent.expected_rows or 'many'} row(s)"
222
+
223
+ agg_ops = []
224
+ for sc in intent.select_cols or []:
225
+ if sc.is_aggregated:
226
+ e = sc.expr
227
+ agg = e.agg_func or (
228
+ e.add_groups[0].agg_func
229
+ if e.add_groups and e.add_groups[0].agg_func
230
+ else None
231
+ )
232
+ term = e.primary_term
233
+ if agg:
234
+ agg_ops.append(f"{agg.upper()}({term})")
235
+ else:
236
+ agg_ops.append(term)
237
+
238
+ print_info(
239
+ f"Question: {question}\n\nI understood: {nl_summary}",
240
+ items={
241
+ "Tables": intent.tables or [],
242
+ "Aggregations": agg_ops or ["none"],
243
+ "Expected": expected_display,
244
+ },
245
+ )
246
+ choice = ask_user_choice("\nIs this correct?", ["y", "n"])
247
+ if choice is None:
248
+ return False, None
249
+ elif choice == "y":
250
+ print("\nIntent accepted.")
251
+ return True, intent
252
+ else:
253
+ print("\nIntent rejected.")
254
+ return False, None
255
+
256
+
257
+ def _abstract_values(intent: RuntimeIntent) -> RuntimeIntent:
258
+ """Return *intent* with ``param_values`` cleared."""
259
+ return RuntimeIntent(
260
+ tables=intent.tables,
261
+ grain=intent.grain,
262
+ select_cols=intent.select_cols,
263
+ group_by_cols=intent.group_by_cols,
264
+ order_by_cols=intent.order_by_cols,
265
+ filters_param=intent.filters_param,
266
+ having_param=intent.having_param,
267
+ param_values={},
268
+ cte_steps=intent.cte_steps,
269
+ column_map=intent.column_map,
270
+ natural_language=intent.natural_language,
271
+ limit=intent.limit,
272
+ )
273
+
274
+
275
+ def _save_gold_questions(
276
+ gold_intents: list[dict[str, Any]], filepath: str,
277
+ ) -> None:
278
+ """Persist confirmed gold intents to a JSON file."""
279
+ output = {"count": len(gold_intents), "intents": gold_intents}
280
+ with open(filepath, "w", encoding="utf-8") as f:
281
+ json.dump(output, f, indent=2, ensure_ascii=False)
282
+ log(
283
+ f"save_gold_questions: saved {len(gold_intents)} gold intents "
284
+ f"to {filepath}"
285
+ )
286
+
287
+
288
+ def _load_gold_questions(filepath: str) -> list[dict[str, Any]]:
289
+ """Load previously saved gold intents from a JSON file."""
290
+ if not os.path.exists(filepath):
291
+ return []
292
+ with open(filepath, encoding="utf-8") as f:
293
+ data = json.load(f)
294
+ intents = data.get("intents", [])
295
+ debug(f"[simulator.load_gold_questions] loaded: {len(intents)} intents")
296
+ return intents
297
+
298
+
299
+ def run_gold_intent_generation(
300
+ schema: SchemaGraph,
301
+ seed_filepath: str,
302
+ output_filepath: str,
303
+ interactive: bool = True,
304
+ ) -> list[dict[str, Any]]:
305
+ """Run the full gold intent generation pipeline from seed questions.
306
+
307
+ Args:
308
+
309
+ schema: Schema graph used for intent parsing.
310
+ seed_filepath: Path to the seed questions text file.
311
+ output_filepath: JSON file path for saving gold intents.
312
+ interactive: When True, interactively confirm each parsed intent.
313
+
314
+ Returns:
315
+
316
+ List of all gold intent dicts.
317
+ """
318
+ log("gold_intent_generation: starting")
319
+ seeds = _load_seed_questions(seed_filepath)
320
+ log(f"gold_intent_generation: {len(seeds)} seed questions loaded")
321
+
322
+ existing = _load_gold_questions(output_filepath)
323
+ existing_questions = {
324
+ i.get("normalized_question")
325
+ for i in existing if i.get("normalized_question")
326
+ }
327
+ log(f"gold_intent_generation: {len(existing)} existing gold intents")
328
+
329
+ gold_intents = list(existing)
330
+ new_count = 0
331
+ skip_count = 0
332
+ fail_count = 0
333
+
334
+ for seed in seeds:
335
+ q = seed["question"]
336
+ q_norm = normalize_question(q)
337
+ if q_norm in existing_questions:
338
+ skip_count += 1
339
+ continue
340
+
341
+ log(f"gold_intent_generation: processing [{seed['number']}] {q}")
342
+ intent = _parse_gold_intent(q, schema)
343
+ if intent is None:
344
+ log(
345
+ f"gold_intent_generation: FAILED to parse [{seed['number']}]"
346
+ )
347
+ fail_count += 1
348
+ continue
349
+
350
+ if interactive:
351
+ confirmed, final_intent = _confirm_gold_intent(q, intent)
352
+ if not confirmed or final_intent is None:
353
+ log(
354
+ f"gold_intent_generation: user rejected "
355
+ f"[{seed['number']}]"
356
+ )
357
+ continue
358
+ intent = final_intent
359
+
360
+ intent = _abstract_values(intent)
361
+ gold_dict = intent.to_dict()
362
+ gold_dict["normalized_question"] = q_norm
363
+ gold_intents.append(gold_dict)
364
+ existing_questions.add(q_norm)
365
+ new_count += 1
366
+
367
+ if new_count % 10 == 0:
368
+ _save_gold_questions(gold_intents, output_filepath)
369
+
370
+ _save_gold_questions(gold_intents, output_filepath)
371
+ log(
372
+ f"gold_intent_generation: complete. "
373
+ f"new={new_count}, skipped={skip_count}, failed={fail_count}"
374
+ )
375
+ print("\nGold intent generation complete.")
376
+ print(f" New: {new_count}")
377
+ print(f" Skipped: {skip_count}")
378
+ print(f" Failed: {fail_count}")
379
+ print(f" Total: {len(gold_intents)}")
380
+ return gold_intents
381
+
382
+
383
+ # ---------------------------------------------------------------------------
384
+ # Join resolution cache
385
+ # ---------------------------------------------------------------------------
386
+
387
+ JoinCacheEntry = tuple[str, str, list[Any]]
388
+
389
+
390
+ def resolve_joins_for_table_set(
391
+ tables: list[str],
392
+ schema: SchemaGraph,
393
+ question_hint: str,
394
+ join_cache: dict[frozenset[str], JoinCacheEntry],
395
+ ) -> JoinCacheEntry:
396
+ """Resolve join path for a table set, using cache when available.
397
+
398
+ Returns ``(chosen_join_id, join_path_signature, join_candidates)``
399
+ and populates the cache. LLM is only called when there are
400
+ multiple ambiguous candidates.
401
+
402
+ Args:
403
+
404
+ tables: Sorted table list for the intent.
405
+ schema: Schema graph for join hint generation.
406
+ question_hint: Text hint passed to the LLM for disambiguation.
407
+ join_cache: Shared cache keyed by frozenset of table names.
408
+
409
+ Returns:
410
+
411
+ Tuple of (join_id, signature_string, candidates_list).
412
+ """
413
+ key = frozenset(tables)
414
+ if key in join_cache:
415
+ return join_cache[key]
416
+
417
+ if len(tables) <= 1:
418
+ entry: JoinCacheEntry = ("J00", "", [])
419
+ join_cache[key] = entry
420
+ return entry
421
+
422
+ join_tables = physical_tables_for_join_hints(tables, schema)
423
+ candidates = join_hints_multi(schema, join_tables)
424
+ cmap = join_candidate_map(candidates)
425
+
426
+ if not candidates:
427
+ entry = ("J00", "", [])
428
+ join_cache[key] = entry
429
+ return entry
430
+
431
+ non_trivial = [c for c in cmap if c != "J00"]
432
+ if len(non_trivial) <= 1:
433
+ chosen = non_trivial[0] if non_trivial else "J00"
434
+ sig = str(cmap.get(chosen, []))
435
+ entry = (chosen, sig, candidates)
436
+ join_cache[key] = entry
437
+ return entry
438
+
439
+ chosen = get_join_choice_from_llm(
440
+ question_hint, "", candidates, task="sql",
441
+ )
442
+ sig = str(cmap.get(chosen, []))
443
+ entry = (chosen, sig, candidates)
444
+ join_cache[key] = entry
445
+ return entry
446
+
447
+
448
+ # ---------------------------------------------------------------------------
449
+ # Value instantiation
450
+ # ---------------------------------------------------------------------------
451
+
452
+ def _decompose_between_filter_param(f: FilterParam) -> list[FilterParam]:
453
+ """Decompose a BETWEEN filter into >= and <= pair."""
454
+ if f.op != "between":
455
+ return [f]
456
+ return [
457
+ replace(
458
+ f, op=">=",
459
+ param_key=f"{f.param_key}_lower" if f.param_key else None,
460
+ ),
461
+ replace(
462
+ f, op="<=",
463
+ param_key=f"{f.param_key}_upper" if f.param_key else None,
464
+ ),
465
+ ]
466
+
467
+
468
+ def _identify_range_pairs(
469
+ filters: list[FilterParam],
470
+ ) -> dict[str, dict[str, int]]:
471
+ """Identify columns with paired lower/upper bound filters."""
472
+ column_ops: dict[str, dict[str, int]] = {}
473
+ for idx, f in enumerate(filters):
474
+ if f.right_expr:
475
+ continue
476
+ if f.op in (">", ">="):
477
+ column_ops.setdefault(f.left_expr.primary_column, {})[
478
+ "lower_idx"
479
+ ] = idx
480
+ elif f.op in ("<", "<="):
481
+ column_ops.setdefault(f.left_expr.primary_column, {})[
482
+ "upper_idx"
483
+ ] = idx
484
+ return {
485
+ col: ops for col, ops in column_ops.items()
486
+ if "lower_idx" in ops and "upper_idx" in ops
487
+ }
488
+
489
+
490
+ def instantiate_intent(
491
+ intent: SimulatorIntent,
492
+ value_domains: dict[str, ValueDomain],
493
+ ) -> SimulatorIntent | None:
494
+ """Populate filter and HAVING values from profiling data.
495
+
496
+ Samples one concrete value set per intent.
497
+
498
+ Args:
499
+
500
+ intent: Abstract simulator intent to instantiate.
501
+ value_domains: ``table.column`` -> ValueDomain map.
502
+
503
+ Returns:
504
+
505
+ A new SimulatorIntent with param_values populated, or None.
506
+ """
507
+ decomposed: list[FilterParam] = []
508
+ for f in intent.filters_param:
509
+ decomposed.extend(_decompose_between_filter_param(f))
510
+
511
+ range_pairs = _identify_range_pairs(decomposed)
512
+ range_values: dict[str, tuple[str, str]] = {}
513
+ for col_key, pair_indices in range_pairs.items():
514
+ domain = value_domains.get(col_key)
515
+ if domain is None:
516
+ continue
517
+ lower_idx = pair_indices["lower_idx"]
518
+ vtype = decomposed[lower_idx].value_type
519
+ lower_val, upper_val = sample_coordinated_range(domain, vtype, 0)
520
+ if lower_val is not None and upper_val is not None:
521
+ range_values[col_key] = (lower_val, upper_val)
522
+
523
+ new_filters: list[FilterParam] = []
524
+ new_param_values: dict[str, Any] = {}
525
+
526
+ for filter_idx, f in enumerate(decomposed):
527
+ col_key = f.left_expr.primary_column
528
+ op = f.op
529
+ param_key = f.param_key or f"filter_{filter_idx}"
530
+
531
+ if f.right_expr:
532
+ new_filters.append(f)
533
+ continue
534
+
535
+ if f.value_type == "null" or op in ("is null", "is not null"):
536
+ new_filters.append(
537
+ replace(f, op=op, value_type="null", param_key=param_key)
538
+ )
539
+ continue
540
+
541
+ if f.value_type in ("date_window", "date_diff"):
542
+ new_filters.append(replace(f, param_key=param_key))
543
+ continue
544
+
545
+ domain = value_domains.get(col_key)
546
+ if domain is None:
547
+ new_filters.append(replace(f, param_key=param_key))
548
+ continue
549
+
550
+ if col_key in range_values:
551
+ lower_val, upper_val = range_values[col_key]
552
+ if f.op in (">", ">="):
553
+ value = lower_val
554
+ elif f.op in ("<", "<="):
555
+ value = upper_val
556
+ else:
557
+ value = sample_value_from_domain(
558
+ domain, f.value_type, f.op, 0,
559
+ )
560
+ else:
561
+ value = sample_value_from_domain(
562
+ domain, f.value_type, f.op, 0,
563
+ )
564
+
565
+ if value is not None:
566
+ new_param_values[param_key] = value
567
+ new_filters.append(replace(f, param_key=param_key))
568
+
569
+ new_having: list[HavingParam] = []
570
+ for having_idx, h in enumerate(intent.having_param):
571
+ param_key = h.param_key or f"having_{having_idx}"
572
+ if h.right_expr is not None:
573
+ new_having.append(replace(h, param_key=param_key))
574
+ continue
575
+ value = deterministic_having_value(
576
+ h.left_expr.primary_term, 0, having_idx,
577
+ )
578
+ new_param_values[param_key] = value
579
+ new_having.append(replace(h, param_key=param_key))
580
+
581
+ return SimulatorIntent(
582
+ intent_id=intent.intent_id,
583
+ tables=intent.tables,
584
+ grain=intent.grain,
585
+ select_cols=intent.select_cols,
586
+ group_by_cols=intent.group_by_cols,
587
+ order_by_cols=intent.order_by_cols,
588
+ filters_param=new_filters,
589
+ having_param=new_having,
590
+ param_values=new_param_values,
591
+ cte_steps=intent.cte_steps,
592
+ question="",
593
+ expansion_metadata=intent.expansion_metadata,
594
+ limit=intent.limit,
595
+ )
596
+
597
+
598
+ # ---------------------------------------------------------------------------
599
+ # Template creation
600
+ # ---------------------------------------------------------------------------
601
+
602
+ def _create_template_from_result(
603
+ result: SimulatorResult,
604
+ schema: SchemaGraph,
605
+ next_id: int,
606
+ dialect: Any | None = None,
607
+ source: str = "synthetic",
608
+ trust_level: int = 0,
609
+ ) -> Template | None:
610
+ """Create a Template from a successful simulation result."""
611
+ if not result.success or not result.sql:
612
+ return None
613
+
614
+ sql_canon = canonicalize_sql(result.sql)
615
+ sql_param_str, _ = parameter_abstract(sql_canon)
616
+ sql_fp_val = sql_fp(sql_param_str)
617
+
618
+ intent = result.intent
619
+ tables_used = extract_tables_from_sql(
620
+ sql_canon, list(schema.tables.keys()),
621
+ )
622
+
623
+ param_values = (
624
+ intent.param_values if hasattr(intent, "param_values") else {}
625
+ )
626
+ natural_language = (
627
+ intent.natural_language
628
+ if hasattr(intent, "natural_language") else ""
629
+ )
630
+ chosen_join_id = (
631
+ intent.chosen_join_candidate_id
632
+ if hasattr(intent, "chosen_join_candidate_id") else ""
633
+ )
634
+
635
+ ikey = intent_key(intent)
636
+ tid = f"T{next_id:04d}"
637
+
638
+ intent_sig = ConcreteIntent(
639
+ intent_id=ikey,
640
+ tables=intent.tables or [],
641
+ grain=intent.grain or "row_level",
642
+ select_cols=intent.select_cols or [],
643
+ group_by_cols=intent.group_by_cols or [],
644
+ order_by_cols=intent.order_by_cols or [],
645
+ filters_param=intent.filters_param or [],
646
+ having_param=intent.having_param or [],
647
+ cte_steps=(
648
+ intent.cte_steps if hasattr(intent, "cte_steps") else []
649
+ ),
650
+ limit=intent.limit if hasattr(intent, "limit") else None,
651
+ chosen_join_candidate_id=chosen_join_id,
652
+ chosen_join_path_signature=(
653
+ intent.chosen_join_path_signature
654
+ if hasattr(intent, "chosen_join_path_signature")
655
+ else []
656
+ ),
657
+ )
658
+
659
+ spark_sql_param = ""
660
+ if EngineConfig.TYPE == "databricks" and dialect:
661
+ spark_sql_param = dialect.prepare_for_execution(sql_param_str)
662
+
663
+ tmpl = Template(
664
+ id=tid,
665
+ schema_hash=schema.schema_hash,
666
+ intent_signature=intent_sig,
667
+ intent_key=ikey,
668
+ tables_used=sorted(tables_used),
669
+ sql_param=sql_param_str,
670
+ spark_sql_param=spark_sql_param,
671
+ sql_display_param=sql_param_str,
672
+ sql_fp=sql_fp_val,
673
+ shape=sql_shape(sql_canon, intent),
674
+ colmap_sig=colmap_signature({}),
675
+ value_history=ValueHistory(
676
+ param_values=[param_values],
677
+ questions=[result.question],
678
+ natural_language=[natural_language or ""],
679
+ ),
680
+ stats=TemplateStats(accept=0, reject=0),
681
+ source=source,
682
+ trust_level=trust_level,
683
+ aliased_sql="",
684
+ )
685
+ debug(f"[simulator.create_template] created: id={tmpl.id}")
686
+ return tmpl
687
+
688
+
689
+ # ---------------------------------------------------------------------------
690
+ # Deterministic simulation pipeline
691
+ # ---------------------------------------------------------------------------
692
+
693
+ def _build_value_domains(
694
+ schema: SchemaGraph,
695
+ ) -> dict[str, ValueDomain]:
696
+ """Build value domains from schema column metadata."""
697
+ domains: dict[str, ValueDomain] = {}
698
+ for table_name, table_meta in schema.tables.items():
699
+ for col_name, col_meta in table_meta.columns.items():
700
+ col_key = f"{table_name}.{col_name}"
701
+ domains[col_key] = ValueDomain(
702
+ values=col_meta.top_k_values or [],
703
+ min_val=col_meta.min_val,
704
+ max_val=col_meta.max_val,
705
+ )
706
+ return domains
707
+
708
+
709
+ def run_deterministic_simulation(
710
+ intents: list[SimulatorIntent],
711
+ schema: SchemaGraph,
712
+ dialect: Any,
713
+ next_id: int,
714
+ join_cache: dict[frozenset[str], JoinCacheEntry] | None = None,
715
+ csv_output_path: str | None = None,
716
+ ) -> tuple[list[SimulatorResult], list[Template], int]:
717
+ """Run the deterministic simulation pipeline for expanded intents.
718
+
719
+ For each intent: resolve joins (cached), build SQL, instantiate
720
+ values, substitute, validate + execute, then call LLM for question
721
+ generation with realism gate. One question per intent.
722
+
723
+ Args:
724
+
725
+ intents: Expanded and deduplicated SimulatorIntents.
726
+ schema: Schema graph.
727
+ dialect: SQL dialect executor.
728
+ next_id: Starting template ID counter.
729
+ join_cache: Shared join cache; created if None.
730
+ csv_output_path: Optional CSV summary output path.
731
+
732
+ Returns:
733
+
734
+ Tuple of (results, templates, updated_next_id).
735
+ """
736
+ if join_cache is None:
737
+ join_cache = {}
738
+
739
+ value_domains = _build_value_domains(schema)
740
+ results: list[SimulatorResult] = []
741
+ templates: list[Template] = []
742
+ current_id = next_id
743
+
744
+ success_count = 0
745
+ fail_count = 0
746
+ validation_drop = 0
747
+ realism_drop = 0
748
+
749
+ log(
750
+ f"run_deterministic_simulation: processing {len(intents)} intents"
751
+ )
752
+
753
+ for intent in intents:
754
+ runtime = intent.to_runtime_intent()
755
+ result = SimulatorResult(runtime, "")
756
+
757
+ try:
758
+ join_id, join_sig, candidates = resolve_joins_for_table_set(
759
+ intent.tables or [], schema,
760
+ intent.intent_id, join_cache,
761
+ )
762
+ except Exception as e:
763
+ result.error = f"join_resolution_failed: {e}"
764
+ results.append(result)
765
+ fail_count += 1
766
+ continue
767
+
768
+ try:
769
+ fp, hp, cte, _ = assign_param_keys(
770
+ runtime.filters_param or [],
771
+ runtime.having_param or [],
772
+ runtime.cte_steps or [],
773
+ )
774
+ runtime.filters_param = fp
775
+ runtime.having_param = hp
776
+ runtime.cte_steps = cte
777
+ extract_structural_params(runtime)
778
+ collect_raw_param_values(runtime)
779
+
780
+ det_sql = build_deterministic_sql(runtime)
781
+ if candidates and join_id != "J00":
782
+ det_sql = inject_join_into_deterministic_sql(
783
+ det_sql, join_id, candidates,
784
+ )
785
+ runtime.sql_param = det_sql
786
+ runtime.deterministic_sql = det_sql
787
+ runtime.chosen_join_candidate_id = join_id
788
+ runtime.chosen_join_path_signature = join_sig
789
+ except Exception as e:
790
+ result.error = f"sql_build_failed: {e}"
791
+ results.append(result)
792
+ fail_count += 1
793
+ continue
794
+
795
+ instantiated = instantiate_intent(intent, value_domains)
796
+ if instantiated is None:
797
+ result.error = "instantiation_failed"
798
+ results.append(result)
799
+ fail_count += 1
800
+ continue
801
+
802
+ all_params = dict(instantiated.param_values)
803
+ all_params.update(runtime.param_values or {})
804
+ try:
805
+ final_sql = substitute_params(det_sql, all_params)
806
+ except Exception as e:
807
+ result.error = f"substitution_failed: {e}"
808
+ results.append(result)
809
+ fail_count += 1
810
+ continue
811
+
812
+ if not final_sql or not final_sql.strip():
813
+ result.error = "empty_sql_after_substitution"
814
+ results.append(result)
815
+ fail_count += 1
816
+ continue
817
+
818
+ try:
819
+ ok, err = validate_sql(dialect, final_sql)
820
+ except Exception as e:
821
+ result.error = f"validation_exception: {e}"
822
+ results.append(result)
823
+ validation_drop += 1
824
+ fail_count += 1
825
+ continue
826
+
827
+ if not ok:
828
+ result.error = f"validation_failed: {err}"
829
+ results.append(result)
830
+ validation_drop += 1
831
+ fail_count += 1
832
+ continue
833
+
834
+ try:
835
+ spark_sql = get_spark_sql_for_execution(
836
+ runtime.sql_param or "",
837
+ all_params,
838
+ schema,
839
+ runtime,
840
+ dialect,
841
+ )
842
+ rows = execute_sql(
843
+ dialect, final_sql,
844
+ spark_sql_for_execution=spark_sql if spark_sql else None,
845
+ )
846
+ except Exception as e:
847
+ result.error = f"execution_failed: {e}"
848
+ results.append(result)
849
+ validation_drop += 1
850
+ fail_count += 1
851
+ continue
852
+
853
+ qresult = generate_question_from_sql(
854
+ final_sql, schema, intent.tables or [],
855
+ )
856
+ if qresult is None:
857
+ result.error = "question_generation_failed"
858
+ results.append(result)
859
+ fail_count += 1
860
+ continue
861
+
862
+ if not qresult.get("is_realistic", False):
863
+ result.error = (
864
+ f"realism_dropped: {qresult.get('drop_reason', '')}"
865
+ )
866
+ results.append(result)
867
+ realism_drop += 1
868
+ fail_count += 1
869
+ continue
870
+
871
+ question = qresult["question"]
872
+ result.sql = final_sql
873
+ result.question = question
874
+ result.rows = rows
875
+ result.success = True
876
+ result.confidence = 1.0
877
+
878
+ runtime.param_values = all_params
879
+
880
+ tmpl = _create_template_from_result(
881
+ result, schema, current_id, dialect,
882
+ )
883
+ if tmpl:
884
+ templates.append(tmpl)
885
+ current_id += 1
886
+ success_count += 1
887
+ results.append(result)
888
+
889
+ log(
890
+ f"run_deterministic_simulation: "
891
+ f"{success_count} success, {fail_count} failed "
892
+ f"(validation_drop={validation_drop}, "
893
+ f"realism_drop={realism_drop}), "
894
+ f"{len(templates)} templates"
895
+ )
896
+
897
+ if csv_output_path:
898
+ with open(csv_output_path, "w", newline="", encoding="utf-8") as f:
899
+ writer = csv.writer(f)
900
+ writer.writerow(["question", "sql", "status", "error"])
901
+ for r in results:
902
+ writer.writerow([
903
+ r.question,
904
+ r.sql or "",
905
+ "success" if r.success else "failed",
906
+ r.error or "",
907
+ ])
908
+
909
+ return results, templates, current_id
910
+
911
+
912
+ # ---------------------------------------------------------------------------
913
+ # Report persistence
914
+ # ---------------------------------------------------------------------------
915
+
916
+ def save_simulation_report(
917
+ results: list[SimulatorResult], filepath: str,
918
+ ) -> None:
919
+ """Save simulation results to a JSON report file."""
920
+
921
+ def result_to_dict(r: SimulatorResult) -> dict[str, Any]:
922
+ return {
923
+ "question": r.question,
924
+ "intent": r.intent,
925
+ "sql": r.sql,
926
+ "success": r.success,
927
+ "error": r.error,
928
+ "confidence": r.confidence,
929
+ "validation_issues": r.validation_issues,
930
+ }
931
+
932
+ report = {
933
+ "total": len(results),
934
+ "success": sum(1 for r in results if r.success),
935
+ "failed": sum(1 for r in results if not r.success),
936
+ "results": [result_to_dict(r) for r in results],
937
+ }
938
+ with open(filepath, "w", encoding="utf-8") as f:
939
+ json.dump(report, f, indent=2, ensure_ascii=False)
940
+ log(f"save_simulation_report: saved {len(results)} results to {filepath}")
941
+
942
+
943
+ def save_simulation_failures(
944
+ results: list[SimulatorResult], filepath: str,
945
+ ) -> None:
946
+ """Save detailed failure information to a JSON file."""
947
+ failures = []
948
+ for r in results:
949
+ if not r.success and r.error:
950
+ error_category = (
951
+ r.error.split(":")[0] if ":" in r.error else "unknown_error"
952
+ )
953
+ failures.append({
954
+ "question": r.question,
955
+ "intent": r.intent,
956
+ "sql": r.sql or "",
957
+ "llm_response": r.llm_response or "",
958
+ "error_message": r.error,
959
+ "error_category": error_category,
960
+ "validation_issues": r.validation_issues,
961
+ })
962
+ if failures:
963
+ with open(filepath, "w", encoding="utf-8") as f:
964
+ json.dump(failures, f, indent=2, ensure_ascii=False)
965
+ log(
966
+ f"save_simulation_failures: saved {len(failures)} "
967
+ f"failures to {filepath}"
968
+ )
969
+ else:
970
+ log("save_simulation_failures: no failures to save")