sqlglot 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,785 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Sqlglot
4
+ # High-level query metadata extraction, inspired by Python's sql-metadata.
5
+ #
6
+ # Parses SQL once via the Rust FFI, then walks the resulting AST Hash in
7
+ # pure Ruby to extract tables, columns, aliases, subqueries, CTEs, etc.
8
+ # All properties are lazy-evaluated and cached.
9
+ #
10
+ # @example
11
+ # q = Sqlglot::Query.new(
12
+ # "SELECT u.name, COUNT(o.id) AS cnt FROM users AS u " \
13
+ # "JOIN orders AS o ON u.id = o.user_id WHERE u.active = true",
14
+ # dialect: :postgres
15
+ # )
16
+ # q.query_type # => :select
17
+ # q.tables # => ["users", "orders"]
18
+ # q.tables_aliases # => {"u" => "users", "o" => "orders"}
19
+ # q.columns # => ["users.name", "orders.id", "users.id", ...]
20
+ # q.output_columns # => ["name", "cnt"]
21
+ # q.columns_dict # => {select: [...], join: [...], where: [...]}
22
+ class Query
23
+ # @param sql [String] the SQL query
24
+ # @param dialect [Symbol, String, nil] the SQL dialect
25
+ def initialize(sql, dialect: nil)
26
+ @sql = sql
27
+ @dialect = dialect
28
+ reset_cache!
29
+ end
30
+
31
+ # @return [String] the original SQL string
32
+ attr_reader :sql
33
+
34
+ # ── AST access ─────────────────────────────────────────────────────
35
+
36
+ # The parsed AST as a Ruby Hash.
37
+ # @return [Hash]
38
+ def ast
39
+ @ast ||= Sqlglot.parse(@sql, dialect: @dialect)
40
+ end
41
+
42
+ # ── Query type ─────────────────────────────────────────────────────
43
+
44
+ # The type of SQL statement.
45
+ #
46
+ # @return [Symbol] one of :select, :insert, :update, :delete,
47
+ # :create_table, :create_view, :drop_table, :drop_view,
48
+ # :alter_table, :truncate, :merge, :begin, :commit, :rollback,
49
+ # :explain, :use, :unknown
50
+ def query_type
51
+ @query_type ||= detect_query_type
52
+ end
53
+
54
+ # ── Tables ─────────────────────────────────────────────────────────
55
+
56
+ # All table names referenced in the query, with CTE names excluded.
57
+ #
58
+ # @return [Array<String>]
59
+ def tables
60
+ @tables ||= extract_tables
61
+ end
62
+
63
+ # Map of table alias => real table name.
64
+ #
65
+ # @return [Hash{String => String}]
66
+ def tables_aliases
67
+ @tables_aliases ||= extract_tables_aliases
68
+ end
69
+
70
+ # ── Columns ────────────────────────────────────────────────────────
71
+
72
+ # All column references, alias-resolved and table-qualified where
73
+ # possible.
74
+ #
75
+ # @return [Array<String>]
76
+ def columns
77
+ @columns ||= extract_all_columns
78
+ end
79
+
80
+ # Columns grouped by the clause they appear in.
81
+ #
82
+ # @return [Hash{Symbol => Array<String>}]
83
+ # Keys: :select, :where, :join, :group_by, :order_by, :having,
84
+ # :insert, :update
85
+ def columns_dict
86
+ @columns_dict ||= extract_columns_dict
87
+ end
88
+
89
+ # The column names (or aliases) that the SELECT would produce.
90
+ #
91
+ # @return [Array<String>]
92
+ def output_columns
93
+ @output_columns ||= extract_output_columns
94
+ end
95
+
96
+ # ── Column aliases ─────────────────────────────────────────────────
97
+
98
+ # Map of column alias => array of source columns.
99
+ #
100
+ # @return [Hash{String => Array<String>}]
101
+ def columns_aliases
102
+ @columns_aliases ||= extract_columns_aliases
103
+ end
104
+
105
+ # Just the alias names.
106
+ #
107
+ # @return [Array<String>]
108
+ def columns_aliases_names
109
+ columns_aliases.keys
110
+ end
111
+
112
+ # Which query clause each column alias appears in.
113
+ #
114
+ # @return [Hash{Symbol => Array<String>}]
115
+ def columns_aliases_dict
116
+ @columns_aliases_dict ||= extract_columns_aliases_dict
117
+ end
118
+
119
+ # ── CTEs (WITH clauses) ────────────────────────────────────────────
120
+
121
+ # Names of CTE definitions.
122
+ #
123
+ # @return [Array<String>]
124
+ def with_names
125
+ @with_names ||= extract_with_names
126
+ end
127
+
128
+ # CTE name => regenerated SQL body.
129
+ #
130
+ # @return [Hash{String => String}]
131
+ def with_queries
132
+ @with_queries ||= extract_with_queries
133
+ end
134
+
135
+ # ── Subqueries ─────────────────────────────────────────────────────
136
+
137
+ # Subquery alias => regenerated SQL body (from FROM / JOIN only).
138
+ #
139
+ # @return [Hash{String => String}]
140
+ def subqueries
141
+ @subqueries ||= extract_subqueries
142
+ end
143
+
144
+ # Just the subquery alias names.
145
+ #
146
+ # @return [Array<String>]
147
+ def subqueries_names
148
+ subqueries.keys
149
+ end
150
+
151
+ # ── LIMIT / OFFSET ─────────────────────────────────────────────────
152
+
153
+ # @return [Array(Integer, Integer), nil] [limit, offset] or nil
154
+ def limit_and_offset
155
+ @limit_and_offset ||= extract_limit_and_offset
156
+ end
157
+
158
+ # ── INSERT values ──────────────────────────────────────────────────
159
+
160
+ # Values from an INSERT statement.
161
+ #
162
+ # @return [Array]
163
+ def values
164
+ @values ||= extract_values
165
+ end
166
+
167
+ # Column => value pairs for INSERT. Auto-generates column_N names
168
+ # if the INSERT has no explicit column list.
169
+ #
170
+ # @return [Hash{String => Object}]
171
+ def values_dict
172
+ @values_dict ||= extract_values_dict
173
+ end
174
+
175
+ # ── Comments ───────────────────────────────────────────────────────
176
+
177
+ # SQL comments found in the AST.
178
+ #
179
+ # @return [Array<String>]
180
+ def comments
181
+ @comments ||= extract_comments
182
+ end
183
+
184
+ # ── Normalization ──────────────────────────────────────────────────
185
+
186
+ # Generalized SQL with literals replaced by placeholders.
187
+ # Useful for query fingerprinting.
188
+ #
189
+ # @return [String]
190
+ def generalize
191
+ @generalize ||= build_generalized
192
+ end
193
+
194
+ private
195
+
196
+ def reset_cache!
197
+ @ast = nil
198
+ @query_type = nil
199
+ @tables = nil
200
+ @tables_aliases = nil
201
+ @columns = nil
202
+ @columns_dict = nil
203
+ @output_columns = nil
204
+ @columns_aliases = nil
205
+ @columns_aliases_dict = nil
206
+ @with_names = nil
207
+ @with_queries = nil
208
+ @subqueries = nil
209
+ @limit_and_offset = nil
210
+ @values = nil
211
+ @values_dict = nil
212
+ @comments = nil
213
+ @generalize = nil
214
+ end
215
+
216
+ # ── Query type detection ───────────────────────────────────────────
217
+
218
+ QUERY_TYPE_MAP = {
219
+ "Select" => :select,
220
+ "Insert" => :insert,
221
+ "Update" => :update,
222
+ "Delete" => :delete,
223
+ "CreateTable" => :create_table,
224
+ "CreateView" => :create_view,
225
+ "DropTable" => :drop_table,
226
+ "DropView" => :drop_view,
227
+ "AlterTable" => :alter_table,
228
+ "Truncate" => :truncate,
229
+ "Merge" => :merge,
230
+ "Begin" => :begin,
231
+ "Commit" => :commit,
232
+ "Rollback" => :rollback,
233
+ "Explain" => :explain,
234
+ "Use" => :use,
235
+ }.freeze
236
+
237
+ def detect_query_type
238
+ key = AstWalker.node_type(ast)
239
+ QUERY_TYPE_MAP.fetch(key, :unknown)
240
+ end
241
+
242
+ # ── Statement body accessor ────────────────────────────────────────
243
+
244
+ # The inner statement hash (e.g. the SelectStatement contents).
245
+ def stmt
246
+ @stmt ||= ast.values.first || {}
247
+ end
248
+
249
+ # ── Table extraction ───────────────────────────────────────────────
250
+
251
+ def extract_tables
252
+ cte_names = with_names.to_set
253
+ raw = collect_table_refs(stmt)
254
+
255
+ # Also collect tables referenced inside CTE bodies.
256
+ (stmt["ctes"] || []).each do |cte|
257
+ query_ast = cte["query"]
258
+ next unless query_ast.is_a?(Hash)
259
+
260
+ inner_stmt = query_ast.values.first
261
+ raw.concat(collect_table_refs(inner_stmt)) if inner_stmt.is_a?(Hash)
262
+ end
263
+
264
+ raw.map { |t| t[:name] }.uniq.reject { |n| cte_names.include?(n) }
265
+ end
266
+
267
+ def extract_tables_aliases
268
+ aliases = {}
269
+ collect_table_refs(stmt).each do |t|
270
+ aliases[t[:alias]] = t[:name] if t[:alias] && t[:alias] != t[:name]
271
+ end
272
+ aliases
273
+ end
274
+
275
+ # Collect all TableRef-like nodes from the statement.
276
+ # Returns [{name:, alias:}, ...]
277
+ def collect_table_refs(node)
278
+ refs = []
279
+
280
+ # FROM clause
281
+ if (from = node["from"])
282
+ refs.concat(table_refs_from_source(from["source"] || from))
283
+ end
284
+
285
+ # JOINs
286
+ if (joins = node["joins"])
287
+ joins.each do |join|
288
+ source = join["table"] || join["source"] || join
289
+ refs.concat(table_refs_from_source(source))
290
+ end
291
+ end
292
+
293
+ # UPDATE target table
294
+ if (table = node["table"])
295
+ refs.concat(table_refs_from_source(table))
296
+ end
297
+
298
+ # INSERT target
299
+ if node.is_a?(Hash) && query_type == :insert && node["table"]
300
+ refs.concat(table_refs_from_source(node["table"]))
301
+ end
302
+
303
+ # DELETE FROM
304
+ if (from_table = node["from_table"])
305
+ refs.concat(table_refs_from_source(from_table))
306
+ end
307
+
308
+ # USING (DELETE ... USING ...)
309
+ if (using = node["using"])
310
+ Array(using).each { |u| refs.concat(table_refs_from_source(u)) }
311
+ end
312
+
313
+ refs
314
+ end
315
+
316
+ def table_refs_from_source(source)
317
+ return [] unless source.is_a?(Hash)
318
+
319
+ refs = []
320
+
321
+ if source.key?("Table")
322
+ t = source["Table"]
323
+ name = build_table_name(t)
324
+ refs << { name: name, alias: t["alias"] || name }
325
+ elsif source.key?("name")
326
+ # Direct table ref (not wrapped in "Table")
327
+ name = build_table_name(source)
328
+ refs << { name: name, alias: source["alias"] || name }
329
+ elsif source.key?("Subquery")
330
+ # Subquery in FROM -- skip the table name, it's a derived table
331
+ alias_name = source["alias"]
332
+ if alias_name
333
+ refs << { name: alias_name, alias: alias_name }
334
+ end
335
+ end
336
+
337
+ # Recurse into source if it has its own "source" (nested structure)
338
+ if source.key?("source")
339
+ refs.concat(table_refs_from_source(source["source"]))
340
+ end
341
+
342
+ refs
343
+ end
344
+
345
+ def build_table_name(t)
346
+ parts = [t["catalog"], t["schema"] || t["db"], t["name"]].compact.reject(&:empty?)
347
+ parts.join(".")
348
+ end
349
+
350
+ # ── Alias resolution helpers ───────────────────────────────────────
351
+
352
+ # Reverse map: alias -> real table name.
353
+ def alias_to_table
354
+ @alias_to_table ||= begin
355
+ map = {}
356
+ collect_table_refs(stmt).each do |t|
357
+ map[t[:alias]] = t[:name] if t[:alias]
358
+ end
359
+ map
360
+ end
361
+ end
362
+
363
+ # Resolve a potentially-aliased table prefix to the real table name.
364
+ def resolve_table(table_or_alias)
365
+ return nil if table_or_alias.nil?
366
+
367
+ alias_to_table[table_or_alias] || table_or_alias
368
+ end
369
+
370
+ # Qualify a column with its resolved table name.
371
+ def qualify_column(name, table)
372
+ resolved = resolve_table(table)
373
+ if resolved && !resolved.empty?
374
+ "#{resolved}.#{name}"
375
+ else
376
+ name
377
+ end
378
+ end
379
+
380
+ # ── Column extraction ──────────────────────────────────────────────
381
+
382
+ def extract_all_columns
383
+ dict = columns_dict
384
+ dict.values.flatten.uniq
385
+ end
386
+
387
+ def extract_columns_dict
388
+ result = {}
389
+
390
+ case query_type
391
+ when :select
392
+ result[:select] = columns_from_select_list
393
+ result[:where] = columns_from_expr(stmt["where_clause"])
394
+ result[:join] = columns_from_joins
395
+ result[:group_by] = columns_from_exprs(stmt["group_by"])
396
+ result[:order_by] = columns_from_order_by
397
+ result[:having] = columns_from_expr(stmt["having"])
398
+ when :insert
399
+ result[:insert] = columns_from_insert
400
+ when :update
401
+ result[:update] = columns_from_update
402
+ result[:where] = columns_from_expr(stmt["where_clause"])
403
+ when :delete
404
+ result[:where] = columns_from_expr(stmt["where_clause"])
405
+ end
406
+
407
+ # Resolve aliases used in ORDER BY / GROUP BY / HAVING.
408
+ col_alias_map = columns_aliases
409
+ %i[order_by group_by having].each do |clause|
410
+ next unless result[clause]
411
+
412
+ result[clause] = result[clause].flat_map do |c|
413
+ col_alias_map[c] || [c]
414
+ end
415
+ end
416
+
417
+ result.reject { |_, v| v.nil? || v.empty? }
418
+ end
419
+
420
+ def columns_from_select_list
421
+ items = stmt["columns"] || []
422
+ cols = []
423
+
424
+ items.each do |item|
425
+ if item.is_a?(Hash) && item.key?("Expr")
426
+ expr_data = item["Expr"]
427
+ expr_node = expr_data["expr"] || expr_data
428
+ cols.concat(columns_from_expr(expr_node))
429
+ elsif item == "Wildcard" || (item.is_a?(Hash) && item.key?("Wildcard"))
430
+ cols << "*"
431
+ elsif item.is_a?(Hash) && item.key?("QualifiedWildcard")
432
+ qw = item["QualifiedWildcard"]
433
+ table = resolve_table(qw["table"] || qw["qualifier"])
434
+ cols << "#{table}.*"
435
+ end
436
+ end
437
+
438
+ cols.uniq
439
+ end
440
+
441
+ def columns_from_joins
442
+ cols = []
443
+ (stmt["joins"] || []).each do |join|
444
+ condition = join["condition"] || join["on"]
445
+ cols.concat(columns_from_expr(condition))
446
+ end
447
+ cols.uniq
448
+ end
449
+
450
+ def columns_from_order_by
451
+ items = stmt["order_by"] || []
452
+ cols = []
453
+ items.each do |item|
454
+ expr_node = item.is_a?(Hash) ? (item["expr"] || item) : item
455
+ cols.concat(columns_from_expr(expr_node))
456
+ end
457
+ cols.uniq
458
+ end
459
+
460
+ def columns_from_insert
461
+ (stmt["columns"] || []).map do |c|
462
+ c.is_a?(String) ? c : (c["name"] || c.to_s)
463
+ end
464
+ end
465
+
466
+ def columns_from_update
467
+ assignments = stmt["assignments"] || stmt["set"] || []
468
+ cols = []
469
+ assignments.each do |a|
470
+ if a.is_a?(Hash)
471
+ col = a["column"] || a["target"]
472
+ cols.concat(columns_from_expr(col)) if col
473
+ end
474
+ end
475
+ cols.uniq
476
+ end
477
+
478
+ def columns_from_exprs(exprs)
479
+ return [] unless exprs.is_a?(Array)
480
+
481
+ cols = []
482
+ exprs.each { |e| cols.concat(columns_from_expr(e)) }
483
+ cols.uniq
484
+ end
485
+
486
+ # Recursively extract qualified column names from an expression node.
487
+ def columns_from_expr(node)
488
+ return [] if node.nil?
489
+
490
+ cols = []
491
+
492
+ case node
493
+ when Hash
494
+ if node.key?("Column")
495
+ col = node["Column"]
496
+ cols << qualify_column(col["name"], col["table"])
497
+ elsif node.key?("QualifiedWildcard")
498
+ qw = node["QualifiedWildcard"]
499
+ table = resolve_table(qw["table"] || qw["qualifier"])
500
+ cols << "#{table}.*"
501
+ else
502
+ node.each_value do |v|
503
+ cols.concat(columns_from_expr(v))
504
+ end
505
+ end
506
+ when Array
507
+ node.each { |child| cols.concat(columns_from_expr(child)) }
508
+ end
509
+
510
+ cols
511
+ end
512
+
513
+ # ── Output columns ─────────────────────────────────────────────────
514
+
515
+ def extract_output_columns
516
+ return [] unless query_type == :select
517
+
518
+ items = stmt["columns"] || []
519
+ items.map do |item|
520
+ if item.is_a?(Hash) && item.key?("Expr")
521
+ expr_data = item["Expr"]
522
+ # Use alias if present, else derive name from expression.
523
+ if expr_data["alias"] && !expr_data["alias"].empty?
524
+ expr_data["alias"]
525
+ else
526
+ name_from_expr(expr_data["expr"] || expr_data)
527
+ end
528
+ elsif item == "Wildcard" || (item.is_a?(Hash) && item.key?("Wildcard"))
529
+ "*"
530
+ elsif item.is_a?(Hash) && item.key?("QualifiedWildcard")
531
+ qw = item["QualifiedWildcard"]
532
+ "#{qw['table'] || qw['qualifier']}.*"
533
+ else
534
+ item.to_s
535
+ end
536
+ end
537
+ end
538
+
539
+ # Best-effort short name from an expression node.
540
+ def name_from_expr(node)
541
+ return node.to_s unless node.is_a?(Hash)
542
+
543
+ if node.key?("Column")
544
+ node["Column"]["name"]
545
+ elsif node.key?("Function")
546
+ fn = node["Function"]
547
+ "#{fn['name']}(...)"
548
+ else
549
+ key = AstWalker.node_type(node)
550
+ key || node.to_s
551
+ end
552
+ end
553
+
554
+ # ── Column alias extraction ────────────────────────────────────────
555
+
556
+ def extract_columns_aliases
557
+ return {} unless query_type == :select
558
+
559
+ aliases = {}
560
+ (stmt["columns"] || []).each do |item|
561
+ next unless item.is_a?(Hash) && item.key?("Expr")
562
+
563
+ expr_data = item["Expr"]
564
+ alias_name = expr_data["alias"]
565
+ next if alias_name.nil? || alias_name.empty?
566
+
567
+ # Walk the expression to find all referenced columns.
568
+ source_cols = columns_from_expr(expr_data["expr"] || expr_data)
569
+ aliases[alias_name] = source_cols unless source_cols.empty?
570
+ end
571
+
572
+ aliases
573
+ end
574
+
575
+ def extract_columns_aliases_dict
576
+ return {} unless query_type == :select
577
+
578
+ alias_names = columns_aliases_names.to_set
579
+ result = {}
580
+
581
+ # Check ORDER BY
582
+ (stmt["order_by"] || []).each do |item|
583
+ expr_node = item.is_a?(Hash) ? (item["expr"] || item) : item
584
+ AstWalker.find_all(expr_node, "Column").each do |col|
585
+ name = col["name"]
586
+ (result[:order_by] ||= []) << name if alias_names.include?(name)
587
+ end
588
+ end
589
+
590
+ # Check GROUP BY
591
+ (stmt["group_by"] || []).each do |expr|
592
+ AstWalker.find_all(expr, "Column").each do |col|
593
+ name = col["name"]
594
+ (result[:group_by] ||= []) << name if alias_names.include?(name)
595
+ end
596
+ end
597
+
598
+ # Check HAVING
599
+ if stmt["having"]
600
+ AstWalker.find_all(stmt["having"], "Column").each do |col|
601
+ name = col["name"]
602
+ (result[:having] ||= []) << name if alias_names.include?(name)
603
+ end
604
+ end
605
+
606
+ # The SELECT list itself
607
+ alias_names.each do |name|
608
+ (result[:select] ||= []) << name
609
+ end
610
+
611
+ result.transform_values!(&:uniq)
612
+ result
613
+ end
614
+
615
+ # ── CTE extraction ─────────────────────────────────────────────────
616
+
617
+ def extract_with_names
618
+ (stmt["ctes"] || []).filter_map { |cte| cte["name"] || cte["alias"] }
619
+ end
620
+
621
+ def extract_with_queries
622
+ result = {}
623
+ (stmt["ctes"] || []).each do |cte|
624
+ cte_name = cte["name"] || cte["alias"]
625
+ query_ast = cte["query"]
626
+ next unless cte_name && query_ast
627
+
628
+ result[cte_name] = Sqlglot.generate(query_ast, dialect: @dialect)
629
+ end
630
+ result
631
+ end
632
+
633
+ # ── Subquery extraction ────────────────────────────────────────────
634
+
635
+ def extract_subqueries
636
+ result = {}
637
+
638
+ # Subqueries in FROM
639
+ if (from = stmt["from"])
640
+ collect_subqueries_from_source(from["source"] || from, result)
641
+ end
642
+
643
+ # Subqueries in JOINs
644
+ (stmt["joins"] || []).each do |join|
645
+ source = join["table"] || join["source"] || join
646
+ collect_subqueries_from_source(source, result)
647
+ end
648
+
649
+ result
650
+ end
651
+
652
+ def collect_subqueries_from_source(source, result)
653
+ return unless source.is_a?(Hash)
654
+
655
+ if source.key?("Subquery")
656
+ alias_name = source["alias"]
657
+ if alias_name
658
+ sub_ast = source["Subquery"]
659
+ # The subquery may be directly a Statement or boxed.
660
+ sub_ast = sub_ast.values.first if sub_ast.is_a?(Hash) && sub_ast.size == 1 && !sub_ast.key?("Select")
661
+ result[alias_name] = Sqlglot.generate(source["Subquery"], dialect: @dialect)
662
+ end
663
+ end
664
+
665
+ collect_subqueries_from_source(source["source"], result) if source.key?("source")
666
+ end
667
+
668
+ # ── LIMIT / OFFSET extraction ──────────────────────────────────────
669
+
670
+ def extract_limit_and_offset
671
+ return nil unless query_type == :select
672
+
673
+ limit_node = stmt["limit"]
674
+ offset_node = stmt["offset"]
675
+
676
+ return nil unless limit_node
677
+
678
+ limit = node_to_int(limit_node)
679
+ offset = offset_node ? node_to_int(offset_node) : 0
680
+
681
+ return nil unless limit
682
+
683
+ [limit, offset]
684
+ end
685
+
686
+ def node_to_int(node)
687
+ return node if node.is_a?(Integer)
688
+
689
+ if node.is_a?(Hash) && node.key?("Number")
690
+ node["Number"].to_i
691
+ elsif node.is_a?(Hash)
692
+ # Try to find a Number anywhere in the node.
693
+ nums = AstWalker.find_all(node, "Number")
694
+ nums.first&.to_i
695
+ else
696
+ node.to_i
697
+ end
698
+ end
699
+
700
+ # ── INSERT values extraction ───────────────────────────────────────
701
+
702
+ def extract_values
703
+ return [] unless query_type == :insert
704
+
705
+ source = stmt["source"]
706
+ return [] unless source.is_a?(Hash) && source.key?("Values")
707
+
708
+ rows = source["Values"]
709
+ return [] unless rows.is_a?(Array) && !rows.empty?
710
+
711
+ # Take the first row of values.
712
+ rows.first.map { |v| AstWalker.extract_value(v) }
713
+ end
714
+
715
+ def extract_values_dict
716
+ return {} unless query_type == :insert
717
+
718
+ vals = values
719
+ return {} if vals.empty?
720
+
721
+ col_names = stmt["columns"] || []
722
+ col_names = col_names.map { |c| c.is_a?(String) ? c : (c["name"] || c.to_s) }
723
+
724
+ # Auto-generate column names if not specified.
725
+ if col_names.empty?
726
+ col_names = vals.each_index.map { |i| "column_#{i + 1}" }
727
+ end
728
+
729
+ col_names.zip(vals).to_h
730
+ end
731
+
732
+ # ── Comment extraction ─────────────────────────────────────────────
733
+
734
+ def extract_comments
735
+ all_comments = []
736
+
737
+ # Comments on the statement itself.
738
+ if stmt["comments"].is_a?(Array)
739
+ all_comments.concat(stmt["comments"])
740
+ end
741
+
742
+ # Walk the entire AST for Commented nodes.
743
+ AstWalker.walk(ast) do |key, value, _|
744
+ if key == "Commented" && value.is_a?(Hash)
745
+ c = value["comment"] || value["comments"]
746
+ all_comments.concat(Array(c))
747
+ end
748
+ end
749
+
750
+ all_comments.uniq
751
+ end
752
+
753
+ # ── Query generalization ───────────────────────────────────────────
754
+
755
+ def build_generalized
756
+ generalized_ast = deep_generalize(ast)
757
+ Sqlglot.generate(generalized_ast, dialect: @dialect)
758
+ rescue Sqlglot::Error
759
+ # If generation fails on the modified AST, fall back to regex.
760
+ @sql
761
+ .gsub(/'[^']*'/, "'X'")
762
+ .gsub(/\b\d+(\.\d+)?\b/, "N")
763
+ end
764
+
765
+ # Recursively replace all literals in the AST with placeholder values.
766
+ def deep_generalize(node)
767
+ case node
768
+ when Hash
769
+ if node.key?("Number")
770
+ { "Number" => "N" }
771
+ elsif node.key?("StringLiteral")
772
+ { "StringLiteral" => "X" }
773
+ elsif node.key?("Boolean")
774
+ { "Boolean" => node["Boolean"] } # keep booleans as-is
775
+ else
776
+ node.transform_values { |v| deep_generalize(v) }
777
+ end
778
+ when Array
779
+ node.map { |v| deep_generalize(v) }
780
+ else
781
+ node
782
+ end
783
+ end
784
+ end
785
+ end