aetherdialect 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
text2sql/templates.py ADDED
@@ -0,0 +1,1037 @@
1
+ """Template and rejected template persistence with trust management and promotion workflows.
2
+
3
+ Manages the on-disk template store for accepted and rejected query patterns. Accepted templates carry trust levels (0/1/2) that are promoted or demoted based on user feedback ratios. Rejected templates accumulate category-specific rejection counters used to decide hard blocks. Provides lookup, similarity-based search, promotion and demotion workflows between the two stores, and serialisation helpers for the JSON store format.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import json
9
+ import os
10
+ from typing import Any
11
+
12
+ from .config import EngineConfig, PolicyConfig
13
+ from .contracts_base import SchemaGraph, TemplateStats
14
+ from .contracts_core import (
15
+ RejectedTemplate,
16
+ RejectedValueHistory,
17
+ RuntimeIntent,
18
+ Template,
19
+ ValueHistory,
20
+ runtime_intent_to_concrete,
21
+ )
22
+ from .core_utils import (
23
+ canonicalize_sql,
24
+ colmap_signature,
25
+ debug,
26
+ normalize_sql,
27
+ parameter_abstract,
28
+ sql_fp,
29
+ )
30
+ from .intent_process import compute_intent_union, intent_similarity, llm_ux_explain
31
+ from .dialect import get_dialect
32
+ from .utils import extract_tables_from_sql, flatten_param_values, intent_key, sql_shape
33
+
34
+
35
+ def increment_rejected_template(
36
+ rt: RejectedTemplate,
37
+ category: str,
38
+ reason: str,
39
+ q_norm: str,
40
+ intent: RuntimeIntent,
41
+ sql: str,
42
+ ) -> None:
43
+ """Increment rejection count and update categories and reasons for a rejected template.
44
+
45
+ Args:
46
+
47
+ rt: The ``RejectedTemplate`` to update in-place.
48
+
49
+ category: Rejection category string (for example, ``"schema_mismatch"``).
50
+
51
+ reason: Human-readable rejection reason.
52
+
53
+ q_norm: The normalised question that triggered the rejection.
54
+
55
+ intent: The ``RuntimeIntent`` associated with the rejection.
56
+
57
+ sql: The SQL that was rejected (currently unused but kept for signature consistency).
58
+ """
59
+ all_pv = flatten_param_values(intent)
60
+ rt.value_history.add(all_pv, q_norm, intent.natural_language, reason, category)
61
+ debug(f"[templates.increment_rejected_template] updated: id={rt.id} entries={len(rt.value_history.questions)}")
62
+
63
+
64
+ def get_hard_block_info(rt: RejectedTemplate) -> dict[str, Any]:
65
+ """Get hard block information for a rejected template based on category-specific thresholds.
66
+
67
+ Args:
68
+
69
+ rt: The ``RejectedTemplate`` to evaluate.
70
+
71
+ Returns:
72
+
73
+ Dict with keys ``should_block`` (bool), ``total_rejects`` (int), ``categories`` (dict of category to count), ``blocking_category`` (str or ``None``), and ``last_reason`` (str).
74
+ """
75
+ category_counts = {}
76
+ vh = rt.value_history
77
+ total_rejections = len(vh.rejection_categories)
78
+
79
+ for cat in vh.rejection_categories:
80
+ category_counts[cat] = category_counts.get(cat, 0) + 1
81
+
82
+ last_reason = vh.rejection_reasons[-1] if vh.rejection_reasons else ""
83
+
84
+ blocking_category = None
85
+ for cat in PolicyConfig.STRUCTURAL_REJECT_CATEGORIES:
86
+ if category_counts.get(cat, 0) >= 2:
87
+ blocking_category = cat
88
+ break
89
+
90
+ return {
91
+ "should_block": blocking_category is not None,
92
+ "total_rejects": total_rejections,
93
+ "categories": category_counts,
94
+ "blocking_category": blocking_category,
95
+ "last_reason": last_reason,
96
+ }
97
+
98
+
99
+ def get_aggregated_hard_block_info(
100
+ rejected: dict[str, RejectedTemplate],
101
+ intent: RuntimeIntent,
102
+ similarity_threshold: float = 0.70,
103
+ ) -> dict[str, Any]:
104
+ """Aggregate hard block info across all similar rejected templates.
105
+
106
+ Args:
107
+
108
+ rejected: The full rejected template dict.
109
+
110
+ intent: The ``RuntimeIntent`` to match against.
111
+
112
+ similarity_threshold: Minimum similarity score to include a rejected template.
113
+
114
+ Returns:
115
+
116
+ Dict with keys ``should_block`` (bool), ``total_rejects``, ``total_structural_rejects``, ``categories``, ``blocking_category``, and ``matched_templates`` (list of ``(RejectedTemplate, score)`` tuples).
117
+ """
118
+
119
+ category_counts = {}
120
+ total_structural_rejections = 0
121
+ total_rejections = 0
122
+ matched_templates = []
123
+
124
+ for rt in rejected.values():
125
+ sim = intent_similarity(intent, rt.intent_signature) if rt.intent_signature else 0.0
126
+ if sim >= similarity_threshold:
127
+ matched_templates.append((rt, sim))
128
+ for cat in rt.value_history.rejection_categories:
129
+ category_counts[cat] = category_counts.get(cat, 0) + 1
130
+ total_rejections += 1
131
+ if cat in PolicyConfig.STRUCTURAL_REJECT_CATEGORIES:
132
+ total_structural_rejections += 1
133
+
134
+ debug(f"[templates.get_aggregated_hard_block_info] found {len(matched_templates)} similar rejected templates")
135
+ debug(f"[templates.get_aggregated_hard_block_info] category_counts: {category_counts}")
136
+ debug(f"[templates.get_aggregated_hard_block_info] total_structural_rejections: {total_structural_rejections}")
137
+
138
+ should_block = total_structural_rejections >= 2
139
+
140
+ blocking_category = None
141
+ if should_block:
142
+ for cat in PolicyConfig.STRUCTURAL_REJECT_CATEGORIES:
143
+ if cat in category_counts:
144
+ if blocking_category is None or category_counts[cat] > category_counts.get(blocking_category, 0):
145
+ blocking_category = cat
146
+
147
+ debug(
148
+ f"[templates.get_aggregated_hard_block_info] BLOCK: total_structural={total_structural_rejections} primary_category={blocking_category}"
149
+ )
150
+
151
+ return {
152
+ "should_block": should_block,
153
+ "total_rejects": total_rejections,
154
+ "total_structural_rejects": total_structural_rejections,
155
+ "categories": category_counts,
156
+ "blocking_category": blocking_category,
157
+ "matched_templates": matched_templates,
158
+ }
159
+
160
+
161
+ def find_rejected_by_intent_key(rejected: dict[str, RejectedTemplate], ikey: str) -> RejectedTemplate | None:
162
+ """Find a rejected template by intent key.
163
+
164
+ Args:
165
+
166
+ rejected: The full rejected template dict.
167
+
168
+ ikey: The intent key string to search for.
169
+
170
+ Returns:
171
+
172
+ The matching ``RejectedTemplate``, or ``None`` if not found.
173
+ """
174
+ for rt in rejected.values():
175
+ if rt.intent_key == ikey:
176
+ return rt
177
+ return None
178
+
179
+
180
+ def find_rejected_by_similarity(
181
+ rejected: dict[str, RejectedTemplate],
182
+ intent: RuntimeIntent,
183
+ threshold: float = 0.85,
184
+ ) -> RejectedTemplate | None:
185
+ """Find best matching rejected template by similarity.
186
+
187
+ Args:
188
+
189
+ rejected: The full rejected template dict.
190
+
191
+ intent: The ``RuntimeIntent`` to match against.
192
+
193
+ threshold: Minimum similarity score to consider a match.
194
+
195
+ Returns:
196
+
197
+ The best-matching ``RejectedTemplate`` above the threshold, or ``None``.
198
+ """
199
+ best_rt = None
200
+ best_sim = 0.0
201
+
202
+ for rt in rejected.values():
203
+ sim = intent_similarity(intent, rt.intent_signature) if rt.intent_signature else 0.0
204
+ if sim >= threshold and sim > best_sim:
205
+ best_sim = sim
206
+ best_rt = rt
207
+
208
+ if best_rt:
209
+ debug(f"[templates.find_rejected_by_similarity] found: id={best_rt.id} sim={best_sim:.3f}")
210
+ return best_rt
211
+
212
+
213
+ def promote_rejected_to_template(
214
+ store: dict[str, Any],
215
+ templates: dict[str, Template],
216
+ rejected: dict[str, RejectedTemplate],
217
+ rt: RejectedTemplate,
218
+ q_norm: str,
219
+ intent: RuntimeIntent,
220
+ sql: str,
221
+ schema_hash: str,
222
+ ) -> Template:
223
+ """Promote a rejected template to an accepted template after user override and acceptance.
224
+
225
+ Creates a new ``Template`` from the rejected template's data, removes the rejected entry from the store, and cleans up associated negative memory entries.
226
+
227
+ Args:
228
+
229
+ store: The mutable template store dict.
230
+
231
+ templates: The accepted templates dict to insert into.
232
+
233
+ rejected: The rejected templates dict to remove from.
234
+
235
+ rt: The ``RejectedTemplate`` being promoted.
236
+
237
+ q_norm: The normalised question for the initial value history entry.
238
+
239
+ intent: The ``RuntimeIntent`` at the time of promotion.
240
+
241
+ sql: The SQL that was accepted.
242
+
243
+ schema_hash: The schema hash to stamp on the new template.
244
+
245
+ Returns:
246
+
247
+ The newly created ``Template``.
248
+ """
249
+ tid = f"T{int(store['next_id']):04d}"
250
+ store["next_id"] = int(store["next_id"]) + 1
251
+
252
+ sql_canon = canonicalize_sql(sql)
253
+ if intent.sql_param:
254
+ sql_param = intent.sql_param
255
+ else:
256
+ sql_norm = normalize_sql(sql_canon)
257
+ sql_param, _ = parameter_abstract(sql_norm)
258
+ sql_fp_val = sql_fp(sql_param)
259
+
260
+ colmap_sig_val = colmap_signature(intent.column_map)
261
+
262
+ intent_signature = runtime_intent_to_concrete(intent, "")
263
+
264
+ all_pv = flatten_param_values(intent)
265
+ ux_summary = llm_ux_explain(intent, q_norm)
266
+
267
+ tmpl = Template(
268
+ id=tid,
269
+ schema_hash=schema_hash,
270
+ intent_signature=intent_signature,
271
+ intent_key=rt.intent_key,
272
+ tables_used=sorted(intent.tables),
273
+ sql_param=sql_param,
274
+ spark_sql_param=getattr(rt, "spark_sql_param", "") or "",
275
+ sql_display_param=rt.sql_display_param or intent.sql_display_param or sql_param,
276
+ sql_fp=sql_fp_val,
277
+ shape=sql_shape(sql_canon, intent),
278
+ colmap_sig=colmap_sig_val,
279
+ value_history=ValueHistory(
280
+ param_values=[all_pv],
281
+ questions=[q_norm],
282
+ natural_language=[intent.natural_language or ""],
283
+ ),
284
+ stats=TemplateStats(accept=1, reject=0),
285
+ ux_summary=ux_summary,
286
+ source="human",
287
+ trust_level=1,
288
+ structural_defaults={k: v for k, v in all_pv.items() if k.startswith("s")},
289
+ aliased_sql=rt.aliased_sql or intent.sql_display_param or "",
290
+ )
291
+
292
+ templates[tid] = tmpl
293
+
294
+ _cleanup_negative_memory_for_rejected(store, rt)
295
+
296
+ if rt.id in rejected:
297
+ del rejected[rt.id]
298
+ debug(f"[templates.promote_rejected_to_template] removed_rejected: id={rt.id}")
299
+
300
+ debug(f"[templates.promote_rejected_to_template] created_template: id={tid} from_rejected={rt.id}")
301
+ return tmpl
302
+
303
+
304
+ def _cleanup_negative_memory_for_rejected(store: dict[str, Any], rt: RejectedTemplate) -> None:
305
+ """Remove negative memory entries associated with a promoted rejected template.
306
+
307
+ Cleans up ``by_intent_key``, ``by_sql_fp``, ``by_join_sig``, and ``by_colmap_sig`` entries in the store's ``negative_memory`` section that correspond to this rejected template.
308
+
309
+ Args:
310
+
311
+ store: The mutable template store dict containing ``negative_memory``.
312
+
313
+ rt: The ``RejectedTemplate`` whose negative memory entries should be removed.
314
+ """
315
+ mem = store.get("negative_memory", {})
316
+ vh = rt.value_history
317
+ reasons_to_remove = list(vh.rejection_reasons)
318
+ categories_to_remove = list(vh.rejection_categories)
319
+
320
+ if rt.intent_key and rt.intent_key in mem.get("by_intent_key", {}):
321
+ del mem["by_intent_key"][rt.intent_key]
322
+ debug(f"[templates.cleanup_negative_memory_for_rejected] removed by_intent_key: {rt.intent_key[:16]}...")
323
+
324
+ if rt.sql_fp and rt.sql_fp in mem.get("by_sql_fp", {}):
325
+ del mem["by_sql_fp"][rt.sql_fp]
326
+ debug(f"[templates.cleanup_negative_memory_for_rejected] removed by_sql_fp: {rt.sql_fp[:16]}...")
327
+
328
+ join_sig_key = str(rt.chosen_join_path_signature) if rt.chosen_join_path_signature else None
329
+ if join_sig_key and join_sig_key in mem.get("by_join_sig", {}):
330
+ bucket = mem["by_join_sig"][join_sig_key]
331
+ for reason, cat in zip(reasons_to_remove, categories_to_remove, strict=False):
332
+ if reason in bucket.get("reasons", []):
333
+ bucket["reasons"].remove(reason)
334
+ bucket["rejects"] = max(0, bucket.get("rejects", 0) - 1)
335
+ if cat in bucket.get("categories", {}):
336
+ bucket["categories"][cat] = max(0, bucket["categories"].get(cat, 0) - 1)
337
+ if bucket["categories"][cat] == 0:
338
+ del bucket["categories"][cat]
339
+ if bucket.get("rejects", 0) == 0:
340
+ del mem["by_join_sig"][join_sig_key]
341
+ debug(f"[templates.cleanup_negative_memory_for_rejected] removed by_join_sig: {join_sig_key[:32]}...")
342
+ else:
343
+ debug(
344
+ f"[templates.cleanup_negative_memory_for_rejected] decremented by_join_sig: {join_sig_key[:32]}... rejects={bucket['rejects']}"
345
+ )
346
+
347
+ if rt.colmap_sig and rt.colmap_sig in mem.get("by_colmap_sig", {}):
348
+ bucket = mem["by_colmap_sig"][rt.colmap_sig]
349
+ for reason, cat in zip(reasons_to_remove, categories_to_remove, strict=False):
350
+ if reason in bucket.get("reasons", []):
351
+ bucket["reasons"].remove(reason)
352
+ bucket["rejects"] = max(0, bucket.get("rejects", 0) - 1)
353
+ if cat in bucket.get("categories", {}):
354
+ bucket["categories"][cat] = max(0, bucket["categories"].get(cat, 0) - 1)
355
+ if bucket["categories"][cat] == 0:
356
+ del bucket["categories"][cat]
357
+ if bucket.get("rejects", 0) == 0:
358
+ del mem["by_colmap_sig"][rt.colmap_sig]
359
+ debug(f"[templates.cleanup_negative_memory_for_rejected] removed by_colmap_sig: {rt.colmap_sig[:16]}...")
360
+ else:
361
+ debug(
362
+ f"[templates.cleanup_negative_memory_for_rejected] decremented by_colmap_sig: {rt.colmap_sig[:16]}... rejects={bucket['rejects']}"
363
+ )
364
+
365
+
366
+ def demote_template_to_rejected(
367
+ store: dict[str, Any],
368
+ templates: dict[str, Template],
369
+ rejected: dict[str, RejectedTemplate],
370
+ tmpl: Template,
371
+ schema: SchemaGraph,
372
+ intent: RuntimeIntent,
373
+ sql: str,
374
+ q_norm: str,
375
+ category: str,
376
+ reason: str,
377
+ ) -> RejectedTemplate:
378
+ """Demote an accepted template to a rejected template.
379
+
380
+ Removes the template from the accepted store, creates a new ``RejectedTemplate`` with the rejection metadata, and inserts it into the rejected store.
381
+
382
+ Args:
383
+
384
+ store: The mutable template store dict.
385
+
386
+ templates: The accepted templates dict to remove from.
387
+
388
+ rejected: The rejected templates dict to insert into.
389
+
390
+ tmpl: The ``Template`` being demoted.
391
+
392
+ schema: The schema graph (for ``schema_hash`` and table enumeration).
393
+
394
+ intent: The ``RuntimeIntent`` at the time of demotion.
395
+
396
+ sql: The SQL that was rejected.
397
+
398
+ q_norm: The normalised question.
399
+
400
+ category: The rejection category string.
401
+
402
+ reason: The human-readable rejection reason.
403
+
404
+ Returns:
405
+
406
+ The newly created ``RejectedTemplate``.
407
+ """
408
+ rid = f"R{int(store.get('next_reject_id', 1)):04d}"
409
+ store["next_reject_id"] = int(store.get("next_reject_id", 1)) + 1
410
+
411
+ sql_canon = canonicalize_sql(sql)
412
+ if intent.sql_param:
413
+ sql_param = intent.sql_param
414
+ else:
415
+ sql_norm = normalize_sql(sql_canon)
416
+ sql_param, _ = parameter_abstract(sql_norm)
417
+ sql_fp_val = sql_fp(sql_param)
418
+ tables_used = extract_tables_from_sql(sql_canon, list(schema.tables.keys()))
419
+ colmap_sig_val = colmap_signature(intent.column_map)
420
+
421
+ intent_sig = runtime_intent_to_concrete(intent, "")
422
+
423
+ all_pv = flatten_param_values(intent)
424
+
425
+ rt = RejectedTemplate(
426
+ id=rid,
427
+ schema_hash=schema.schema_hash,
428
+ intent_signature=intent_sig,
429
+ intent_key=tmpl.intent_key,
430
+ tables_used=sorted(tables_used),
431
+ sql_param=sql_param,
432
+ spark_sql_param=getattr(tmpl, "spark_sql_param", "") or "",
433
+ sql_display_param=tmpl.sql_display_param or sql_param,
434
+ sql_fp=sql_fp_val,
435
+ shape=sql_shape(sql_canon, intent),
436
+ colmap_sig=colmap_sig_val,
437
+ value_history=RejectedValueHistory(
438
+ param_values=[all_pv],
439
+ questions=[q_norm],
440
+ natural_language=[intent.natural_language or ""],
441
+ rejection_reasons=[reason],
442
+ rejection_categories=[category],
443
+ ),
444
+ aliased_sql=tmpl.aliased_sql or "",
445
+ )
446
+
447
+ rejected[rid] = rt
448
+
449
+ if tmpl.id in templates:
450
+ del templates[tmpl.id]
451
+
452
+ debug(f"[templates.demote_template_to_rejected] demoted: template={tmpl.id} to_rejected={rid}")
453
+ return rt
454
+
455
+
456
+ def promote_trust(template: Template) -> bool:
457
+ """Promote template trust level based on accept count and ratio.
458
+
459
+ Trust 0 → 1 is immediate on first call. Trust 1 → 2 requires total count greater than or equal to ``PolicyConfig.TRUST_PROMOTE_MIN_TOTAL`` and reject ratio less than or equal to ``PolicyConfig.TRUST_PROMOTE_MAX_REJECT_RATIO``.
460
+
461
+ Args:
462
+
463
+ template: The ``Template`` to potentially promote.
464
+
465
+ Returns:
466
+
467
+ ``True`` if the trust level was incremented, ``False`` otherwise.
468
+ """
469
+ accept_count = template.stats.accept
470
+ reject_count = template.stats.reject
471
+ total_count = accept_count + reject_count
472
+ current_trust = template.trust_level
473
+
474
+ if current_trust >= 2:
475
+ return False
476
+
477
+ if current_trust == 0:
478
+ template.trust_level = 1
479
+ debug(f"[templates.promote_trust] promoted: id={template.id} level=1")
480
+ return True
481
+
482
+ if current_trust == 1 and total_count >= PolicyConfig.TRUST_PROMOTE_MIN_TOTAL:
483
+ reject_ratio = reject_count / total_count
484
+ if reject_ratio <= PolicyConfig.TRUST_PROMOTE_MAX_REJECT_RATIO:
485
+ template.trust_level = 2
486
+ debug(f"[templates.promote_trust] promoted: id={template.id} level=2 ratio={reject_ratio:.2f}")
487
+ return True
488
+
489
+ return False
490
+
491
+
492
+ def record_template_feedback(template: Template, accept: bool) -> None:
493
+ """Record accept or reject feedback on a template.
494
+
495
+ Args:
496
+
497
+ template: The ``Template`` to update in-place.
498
+
499
+ accept: ``True`` to increment the accept counter, ``False`` for the reject counter.
500
+ """
501
+ if accept:
502
+ template.stats.accept += 1
503
+ else:
504
+ template.stats.reject += 1
505
+ debug(f"[templates.record_template_feedback] recorded: id={template.id} accept={accept}")
506
+
507
+
508
+ def maybe_demote_trust(template: Template) -> bool:
509
+ """Demote template trust level based on reject ratio.
510
+
511
+ Trust 2 → 1 when reject ratio exceeds ``PolicyConfig.TRUST_DEMOTE_REJECT_RATIO_T2``. Trust 1 → 0 when ratio exceeds ``PolicyConfig.TRUST_DEMOTE_REJECT_RATIO_T1``.
512
+
513
+ Args:
514
+
515
+ template: The ``Template`` to potentially demote.
516
+
517
+ Returns:
518
+
519
+ ``True`` if the trust level was decremented, ``False`` otherwise.
520
+ """
521
+ accept_count = template.stats.accept
522
+ reject_count = template.stats.reject
523
+ total_count = accept_count + reject_count
524
+ current_trust = template.trust_level
525
+
526
+ if total_count == 0:
527
+ return False
528
+
529
+ reject_ratio = reject_count / total_count
530
+
531
+ if current_trust == 2 and reject_ratio > PolicyConfig.TRUST_DEMOTE_REJECT_RATIO_T2:
532
+ template.trust_level = 1
533
+ debug(f"[templates.maybe_demote_trust] demoted: id={template.id} level=1 ratio={reject_ratio:.2f}")
534
+ return True
535
+ elif current_trust == 1 and reject_ratio > PolicyConfig.TRUST_DEMOTE_REJECT_RATIO_T1:
536
+ template.trust_level = 0
537
+ debug(f"[templates.maybe_demote_trust] demoted: id={template.id} level=0 ratio={reject_ratio:.2f}")
538
+ return True
539
+
540
+ return False
541
+
542
+
543
+ def load_template_store(schema_hash: str) -> dict[str, Any]:
544
+ """Load template store from disk or create empty one.
545
+
546
+ Returns a fresh empty store if the file does not exist or if the stored schema hash does not match the provided hash.
547
+
548
+ Args:
549
+
550
+ schema_hash: The current schema hash used to invalidate stale stores.
551
+
552
+ Returns:
553
+
554
+ The template store dict with ``templates``, ``rejected_templates``, ``negative_memory``, and counter fields populated.
555
+ """
556
+ if PolicyConfig.REGENERATE_TEMPLATE_STORE:
557
+ debug("[templates.load_template_store] REGENERATE_TEMPLATE_STORE: returning empty store")
558
+ return {
559
+ "schema_hash": schema_hash,
560
+ "next_id": 1,
561
+ "next_reject_id": 1,
562
+ "templates": {},
563
+ "rejected_templates": {},
564
+ "negative_memory": {
565
+ "by_intent_key": {},
566
+ "by_join_sig": {},
567
+ "by_sql_fp": {},
568
+ "by_colmap_sig": {},
569
+ },
570
+ }
571
+
572
+ template_json_path = EngineConfig.TEMPLATE_JSON_PATH
573
+
574
+ if not os.path.exists(template_json_path):
575
+ debug(f"[templates.load_template_store] no_file: path={template_json_path}")
576
+ return {
577
+ "schema_hash": schema_hash,
578
+ "next_id": 1,
579
+ "next_reject_id": 1,
580
+ "templates": {},
581
+ "rejected_templates": {},
582
+ "negative_memory": {
583
+ "by_intent_key": {},
584
+ "by_join_sig": {},
585
+ "by_sql_fp": {},
586
+ "by_colmap_sig": {},
587
+ },
588
+ }
589
+ with open(template_json_path, encoding="utf-8") as f:
590
+ d = json.load(f)
591
+ if d.get("schema_hash") != schema_hash:
592
+ debug("[templates.load_template_store] schema_hash_mismatch: resetting")
593
+ return {
594
+ "schema_hash": schema_hash,
595
+ "next_id": 1,
596
+ "next_reject_id": 1,
597
+ "templates": {},
598
+ "rejected_templates": {},
599
+ "negative_memory": {
600
+ "by_intent_key": {},
601
+ "by_join_sig": {},
602
+ "by_sql_fp": {},
603
+ "by_colmap_sig": {},
604
+ },
605
+ }
606
+ d.setdefault("next_id", 1)
607
+ d.setdefault("next_reject_id", 1)
608
+ d.setdefault("templates", {})
609
+ d.setdefault("rejected_templates", {})
610
+ d.setdefault("negative_memory", {})
611
+ d["negative_memory"].setdefault("by_intent_key", {})
612
+ d["negative_memory"].setdefault("by_join_sig", {})
613
+ d["negative_memory"].setdefault("by_sql_fp", {})
614
+ d["negative_memory"].setdefault("by_colmap_sig", {})
615
+ debug(
616
+ f"[templates.load_template_store] loaded: templates={len(d.get('templates', {}))} rejected={len(d.get('rejected_templates', {}))}"
617
+ )
618
+ return d
619
+
620
+
621
+ def save_template_store(store: dict[str, Any]) -> None:
622
+ """Save template store to disk.
623
+
624
+ Converts all non-serialisable objects (for example, sets) before writing JSON.
625
+
626
+ Args:
627
+
628
+ store: The template store dict to serialise and save.
629
+ """
630
+ template_json_path = EngineConfig.TEMPLATE_JSON_PATH
631
+
632
+ debug(
633
+ f"[templates.save_template_store] saving: templates={len(store.get('templates', {}))} rejected={len(store.get('rejected_templates', {}))}"
634
+ )
635
+ _debug_check_types(store, "store")
636
+ store = _convert_to_json_serializable(store)
637
+ with open(template_json_path, "w", encoding="utf-8") as f:
638
+ json.dump(store, f, indent=2, sort_keys=True, ensure_ascii=False)
639
+ debug(f"[templates.save_template_store] complete: path={template_json_path}")
640
+
641
+
642
+ def store_to_templates(store: dict[str, Any]) -> dict[str, Template]:
643
+ """Convert store dict to ``Template`` objects with nested dataclass reconstruction.
644
+
645
+ Args:
646
+
647
+ store: The raw template store dict loaded from disk.
648
+
649
+ Returns:
650
+
651
+ Dict mapping template ID to ``Template`` dataclass instance.
652
+ """
653
+ out = {}
654
+ for tid, v in store.get("templates", {}).items():
655
+ t = Template.from_dict(v)
656
+ t.sql_fp = sql_fp(parameter_abstract(canonicalize_sql(t.sql_param))[0])
657
+ out[tid] = t
658
+ return out
659
+
660
+
661
+ def store_to_rejected_templates(store: dict[str, Any]) -> dict[str, RejectedTemplate]:
662
+ """Convert store dict to ``RejectedTemplate`` objects with nested dataclass reconstruction.
663
+
664
+ Args:
665
+
666
+ store: The raw template store dict loaded from disk.
667
+
668
+ Returns:
669
+
670
+ Dict mapping rejected template ID to ``RejectedTemplate`` dataclass instance.
671
+ """
672
+ out: dict[str, RejectedTemplate] = {}
673
+ for rid, v in store.get("rejected_templates", {}).items():
674
+ rt = RejectedTemplate.from_dict(v)
675
+ rt.sql_fp = sql_fp(parameter_abstract(canonicalize_sql(rt.sql_param))[0])
676
+ out[rid] = rt
677
+ return out
678
+
679
+
680
+ def _convert_to_json_serializable(obj: Any) -> Any:
681
+ if isinstance(obj, (set, frozenset)):
682
+ return list(obj)
683
+ elif isinstance(obj, dict):
684
+ return {k: _convert_to_json_serializable(v) for k, v in obj.items()}
685
+ elif isinstance(obj, list | tuple):
686
+ return [_convert_to_json_serializable(item) for item in obj]
687
+ return obj
688
+
689
+
690
+ def _debug_check_types(obj: Any, path: str = "root") -> None:
691
+ if isinstance(obj, set):
692
+ debug(f"[templates.debug_check_types] found_set: path={path}")
693
+ elif isinstance(obj, dict):
694
+ for k, v in obj.items():
695
+ _debug_check_types(v, f"{path}.{k}")
696
+ elif isinstance(obj, list | tuple):
697
+ for i, item in enumerate(obj):
698
+ _debug_check_types(item, f"{path}[{i}]")
699
+
700
+
701
+ def templates_to_store(store: dict[str, Any], templates: dict[str, Template]) -> dict[str, Any]:
702
+ """Convert Template objects to store dict format.
703
+
704
+ Args:
705
+
706
+ store: The mutable template store dict to update in-place.
707
+
708
+ templates: Dict of template ID to ``Template`` to serialise.
709
+
710
+ Returns:
711
+
712
+ The updated store dict with ``templates`` key populated.
713
+ """
714
+ debug(f"[templates.templates_to_store] converting: count={len(templates)}")
715
+ for tid, t in templates.items():
716
+ template_dict = t.to_dict()
717
+ _debug_check_types(template_dict, f"template[{tid}]")
718
+ store["templates"] = {k: _convert_to_json_serializable(v.to_dict()) for k, v in templates.items()}
719
+ return store
720
+
721
+
722
+ def rejected_templates_to_store(store: dict[str, Any], rejected: dict[str, RejectedTemplate]) -> dict[str, Any]:
723
+ """Convert ``RejectedTemplate`` objects back to store dict with nested dataclass conversion.
724
+
725
+ Args:
726
+
727
+ store: The mutable template store dict to update in-place.
728
+
729
+ rejected: Dict of rejected template ID to ``RejectedTemplate`` to serialise.
730
+
731
+ Returns:
732
+
733
+ The updated store dict with ``rejected_templates`` key populated.
734
+ """
735
+ store["rejected_templates"] = {k: _convert_to_json_serializable(v.to_dict()) for k, v in rejected.items()}
736
+ return store
737
+
738
+
739
+ def top_rejected_examples(
740
+ rejected: dict[str, RejectedTemplate],
741
+ intent: RuntimeIntent,
742
+ k: int = 3,
743
+ exclude_id: str = None,
744
+ ) -> list[dict[str, Any]]:
745
+ """Get top ``k`` rejected examples for negative learning based on similarity.
746
+
747
+ Only considers rejected templates with at least two rejections and at least one structural rejection category. Ranks candidates by similarity to the intent.
748
+
749
+ Args:
750
+
751
+ rejected: The full rejected template dict.
752
+
753
+ intent: The ``RuntimeIntent`` to match against.
754
+
755
+ k: Maximum number of examples to return.
756
+
757
+ exclude_id: Optional rejected template ID to exclude from results.
758
+
759
+ Returns:
760
+
761
+ List of dicts each describing a rejected example with ``sql_param``, join path, intent fields, category, and reason.
762
+ """
763
+ debug(f"[templates.top_rejected_examples] searching: k={k} rejected_count={len(rejected)}")
764
+
765
+ scored: list[tuple] = []
766
+ for r in rejected.values():
767
+ if exclude_id and r.id == exclude_id:
768
+ debug(f"[templates.top_rejected_examples] skipping: id={r.id} excluded")
769
+ continue
770
+ vh = r.value_history
771
+ reject_count = len(vh.rejection_categories)
772
+ if reject_count <= 1:
773
+ continue
774
+
775
+ category_counts: dict[str, int] = {}
776
+ for cat in vh.rejection_categories:
777
+ category_counts[cat] = category_counts.get(cat, 0) + 1
778
+
779
+ has_structural = any(category_counts.get(cat, 0) > 0 for cat in PolicyConfig.STRUCTURAL_REJECT_CATEGORIES)
780
+ if not has_structural:
781
+ continue
782
+
783
+ sim = intent_similarity(intent, r.intent_signature) if r.intent_signature else 0.0
784
+ if sim >= PolicyConfig.AUTO_PROCEED_THRESHOLD:
785
+ scored.append((r, sim))
786
+
787
+ scored.sort(key=lambda x: (-x[1], x[0].id))
788
+ debug(f"[templates.top_rejected_examples] candidates: {len(scored)}")
789
+
790
+ out: list[dict[str, Any]] = []
791
+ for r, sim in scored[:k]:
792
+ vh = r.value_history
793
+ last_category = vh.rejection_categories[-1] if vh.rejection_categories else ""
794
+ last_reason = vh.rejection_reasons[-1] if vh.rejection_reasons else ""
795
+ intent_sig = r.intent_signature
796
+
797
+ out.append(
798
+ {
799
+ "sql_param": r.sql_param,
800
+ "join_path_signature": r.chosen_join_path_signature or [],
801
+ "tables": intent_sig.tables if intent_sig else [],
802
+ "grain": intent_sig.grain if intent_sig else "row_level",
803
+ "select_cols": ([s.to_dict() for s in (intent_sig.select_cols or [])] if intent_sig else []),
804
+ "group_by_cols": intent_sig.group_by_cols if intent_sig else [],
805
+ "order_by_cols": ([o.to_dict() for o in (intent_sig.order_by_cols or [])] if intent_sig else []),
806
+ "filters_param": ([f.to_dict() for f in (intent_sig.filters_param or [])] if intent_sig else []),
807
+ "having_param": ([h.to_dict() for h in (intent_sig.having_param or [])] if intent_sig else []),
808
+ "category": last_category,
809
+ "reason": last_reason,
810
+ }
811
+ )
812
+ debug(f"[templates.top_rejected_examples] selected: id={r.id} sim={sim:.3f}")
813
+
814
+ debug(f"[templates.top_rejected_examples] returning: {len(out)} examples")
815
+ return out
816
+
817
+
818
+ def insert_template(
819
+ store: dict,
820
+ templates: dict[str, Template],
821
+ schema,
822
+ q_norm: str,
823
+ intent: RuntimeIntent,
824
+ sql: str,
825
+ ux_summary: str = "",
826
+ dialect: Any | None = None,
827
+ ) -> Template:
828
+ """Create and insert a new accepted template, or merge question into an existing one.
829
+
830
+ If a template with the same ``(intent_key, join_path, sql_fp)`` triple already exists, the question is merged into its value history and its accept count is incremented. Otherwise a new ``Template`` is created.
831
+
832
+ Args:
833
+
834
+ store: The mutable template store dict.
835
+
836
+ templates: The accepted templates dict.
837
+
838
+ schema: The schema graph (for ``schema_hash`` and table enumeration).
839
+
840
+ q_norm: The normalised question.
841
+
842
+ intent: The ``RuntimeIntent``.
843
+
844
+ sql: The accepted SQL.
845
+
846
+ ux_summary: Optional pre-computed UX summary; generated via LLM if empty.
847
+
848
+ Returns:
849
+
850
+ The existing or newly created ``Template``.
851
+ """
852
+ sql_canon = canonicalize_sql(sql)
853
+ if intent.sql_param:
854
+ sql_param = intent.sql_param
855
+ else:
856
+ sql_norm = normalize_sql(sql_canon)
857
+ sql_param, _ = parameter_abstract(sql_norm)
858
+ sql_fp_val = sql_fp(sql_param)
859
+ tables_used = extract_tables_from_sql(sql_canon, list(schema.tables.keys()))
860
+ colmap_sig_val = colmap_signature(intent.column_map)
861
+ ikey = intent_key(intent)
862
+
863
+ intent_sig = runtime_intent_to_concrete(intent, "")
864
+
865
+ debug(f"[templates.insert_template] checking_duplicate: ikey={ikey[:32]} sql_fp={sql_fp_val[:16]})")
866
+
867
+ for t in templates.values():
868
+ if t.intent_key != ikey:
869
+ continue
870
+ _, cc = compute_intent_union(intent, t.intent_signature)
871
+ if not cc:
872
+ debug(f"[templates.insert_template] duplicate_found: id={t.id}")
873
+ t.stats.accept += 1
874
+ promote_trust(t)
875
+
876
+ all_pv = flatten_param_values(intent)
877
+ t.value_history.add(all_pv, q_norm, intent.natural_language)
878
+ debug(f"[templates.insert_template] value_history_added: entries={len(t.value_history.questions)}")
879
+ return t
880
+
881
+ debug("[templates.insert_template] no_duplicate: creating_new")
882
+
883
+ tid = f"T{int(store['next_id']):04d}"
884
+ store["next_id"] = int(store["next_id"]) + 1
885
+
886
+ all_pv = flatten_param_values(intent)
887
+ if not ux_summary:
888
+ ux_summary = llm_ux_explain(intent, q_norm)
889
+
890
+ spark_sql_param = ""
891
+ if EngineConfig.TYPE == "databricks":
892
+ d = dialect or get_dialect()
893
+ if d:
894
+ spark_sql_param = d.prepare_for_execution(sql_param)
895
+
896
+ tmpl = Template(
897
+ id=tid,
898
+ schema_hash=schema.schema_hash,
899
+ intent_signature=intent_sig,
900
+ intent_key=ikey,
901
+ tables_used=sorted(tables_used),
902
+ sql_param=sql_param,
903
+ spark_sql_param=spark_sql_param,
904
+ sql_display_param=intent.sql_display_param or sql_param,
905
+ sql_fp=sql_fp_val,
906
+ shape=intent.sql_shape if intent.sql_shape else sql_shape(sql_canon, intent),
907
+ colmap_sig=colmap_sig_val,
908
+ value_history=ValueHistory(
909
+ param_values=[all_pv],
910
+ questions=[q_norm],
911
+ natural_language=[intent.natural_language or ""],
912
+ ),
913
+ stats=TemplateStats(accept=1, reject=0),
914
+ ux_summary=ux_summary,
915
+ source="human",
916
+ trust_level=1,
917
+ structural_defaults={k: v for k, v in all_pv.items() if k.startswith("s")},
918
+ deterministic_sql=intent.deterministic_sql,
919
+ aliased_sql=intent.sql_display_param or "",
920
+ )
921
+
922
+ debug(
923
+ f"[templates.insert_template] CREATED template natural_language={tmpl.value_history.natural_language} chosen_join_candidate_id='{tmpl.chosen_join_candidate_id}' chosen_join_path_signature={tmpl.chosen_join_path_signature}"
924
+ )
925
+
926
+ templates[tid] = tmpl
927
+ debug(f"[templates.insert_template] created: id={tid}")
928
+ return tmpl
929
+
930
+
931
+ def insert_rejected_template(
932
+ store: dict[str, Any],
933
+ rejected: dict[str, RejectedTemplate],
934
+ schema: SchemaGraph,
935
+ q_norm: str,
936
+ intent: RuntimeIntent,
937
+ sql: str,
938
+ category: str,
939
+ reason: str,
940
+ dialect: Any | None = None,
941
+ ) -> RejectedTemplate:
942
+ """Create and insert a new rejected template.
943
+
944
+ Args:
945
+
946
+ store: The mutable template store dict.
947
+
948
+ rejected: The rejected templates dict to insert into.
949
+
950
+ schema: The schema graph.
951
+
952
+ q_norm: The normalised question.
953
+
954
+ intent: The ``RuntimeIntent``.
955
+
956
+ sql: The rejected SQL.
957
+
958
+ category: The rejection category string.
959
+
960
+ reason: The human-readable rejection reason.
961
+
962
+ Returns:
963
+
964
+ The newly created ``RejectedTemplate``.
965
+ """
966
+ rid = f"R{int(store.get('next_reject_id', 1)):04d}"
967
+ store["next_reject_id"] = int(store.get("next_reject_id", 1)) + 1
968
+
969
+ sql_canon = canonicalize_sql(sql)
970
+ if intent.sql_param:
971
+ sql_param = intent.sql_param
972
+ else:
973
+ sql_norm = normalize_sql(sql_canon)
974
+ sql_param, _ = parameter_abstract(sql_norm)
975
+ sql_fp_val = sql_fp(sql_param)
976
+ tables_used = extract_tables_from_sql(sql_canon, list(schema.tables.keys()))
977
+ colmap_sig_val = colmap_signature(intent.column_map)
978
+ ikey = intent_key(intent)
979
+
980
+ intent_sig = runtime_intent_to_concrete(intent, "")
981
+
982
+ all_pv = flatten_param_values(intent)
983
+
984
+ debug(f"[templates.insert_rejected_template] creating: id={rid} category={category}")
985
+
986
+ spark_sql_param = ""
987
+ if EngineConfig.TYPE == "databricks":
988
+ d = dialect or get_dialect()
989
+ if d:
990
+ spark_sql_param = d.prepare_for_execution(sql_param)
991
+
992
+ rt = RejectedTemplate(
993
+ id=rid,
994
+ schema_hash=schema.schema_hash,
995
+ intent_signature=intent_sig,
996
+ intent_key=ikey,
997
+ tables_used=sorted(tables_used),
998
+ sql_param=sql_param,
999
+ spark_sql_param=spark_sql_param,
1000
+ sql_display_param=intent.sql_display_param or sql_param,
1001
+ sql_fp=sql_fp_val,
1002
+ shape=sql_shape(sql_canon, intent),
1003
+ colmap_sig=colmap_sig_val,
1004
+ value_history=RejectedValueHistory(
1005
+ param_values=[all_pv],
1006
+ questions=[q_norm],
1007
+ natural_language=[intent.natural_language or ""],
1008
+ rejection_reasons=[reason],
1009
+ rejection_categories=[category],
1010
+ ),
1011
+ aliased_sql=intent.sql_display_param or "",
1012
+ )
1013
+ rejected[rid] = rt
1014
+ debug(f"[templates.insert_rejected_template] total_rejected: {len(rejected)}")
1015
+ return rt
1016
+
1017
+
1018
+ def best_template_by_intent_key(templates: dict[str, Template], ikey: str) -> Template | None:
1019
+ """Find best template matching an intent key.
1020
+
1021
+ Args:
1022
+
1023
+ templates: The accepted templates dict.
1024
+
1025
+ ikey: The intent key string to search for.
1026
+
1027
+ Returns:
1028
+
1029
+ The matching ``Template``, or ``None`` if not found.
1030
+ """
1031
+ debug(f"[templates.best_template_by_intent_key] searching: ikey={ikey[:32]}")
1032
+ for tid, t in templates.items():
1033
+ if t.intent_key == ikey:
1034
+ debug(f"[templates.best_template_by_intent_key] found: id={tid}")
1035
+ return t
1036
+ debug("[templates.best_template_by_intent_key] no_match")
1037
+ return None