pgsqlarbiter 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml ADDED
@@ -0,0 +1,7 @@
1
+ ---
2
+ SHA256:
3
+ metadata.gz: fc099691abb195da6832fe16fd79c4e1d83f97581765819cd6d8605814fb0388
4
+ data.tar.gz: 4d11363cf230c211863ccfc37576fe1053a85cead407ab2c35160b52f731bd1e
5
+ SHA512:
6
+ metadata.gz: 16b87f1563b3e452090e7325bd8d6fa547dface7edf9efa08cd3c165c2a7dedb2c8f2f3921a24498062e167e963199c8f19410015f442d9bb6030698f2fd3627
7
+ data.tar.gz: 5b931ba4d264b4a8e759d8572616e291056e19f867d9f116f3c0f9900832e03a41105384795676c0f1d3d05c5a880081237322bf81de619e3933438385674f8b
@@ -0,0 +1,14 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Pgsqlarbiter
4
+ # Immutable result of analyzing a SQL query.
5
+ #
6
+ # @!attribute [r] statement_type
7
+ # @return [Symbol] the statement type (+:select+, +:insert+, +:update+,
8
+ # +:delete+, +:merge+, or +:values+)
9
+ # @!attribute [r] tables
10
+ # @return [Array<String>] sorted list of referenced table and view names
11
+ # @!attribute [r] functions
12
+ # @return [Array<String>] sorted list of called function names
13
+ Analysis = Data.define(:statement_type, :tables, :functions)
14
+ end
@@ -0,0 +1,504 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "set"
4
+
5
+ module Pgsqlarbiter
6
+ # SQL query analyzer that lexes and walks a SQL string to extract its statement type,
7
+ # referenced tables, and function calls. Uses a hand-written lexer and single-pass
8
+ # token walker rather than a full parse tree.
9
+ class Analyzer
10
+ include TokenType
11
+
12
+ JOIN_PREFIXES = Set["INNER", "LEFT", "RIGHT", "FULL", "CROSS", "NATURAL"].freeze
13
+ FUNCTIONS_WITH_FROM_SYNTAX = Set["EXTRACT", "TRIM", "SUBSTRING"].freeze
14
+
15
+ # Analyze a SQL query string.
16
+ #
17
+ # The analysis proceeds in four phases:
18
+ # 1. Reject multiple statements
19
+ # 2. Determine the statement type
20
+ # 3. Pre-collect CTE names (so they are not treated as table references)
21
+ # 4. Walk all tokens to extract table and function references
22
+ #
23
+ # @param sql [String] the SQL query to analyze
24
+ # @return [Analysis] analysis result with statement_type, tables, and functions
25
+ # @raise [ParseError] if the SQL is empty or cannot be parsed
26
+ # @raise [MultipleStatementsError] if the SQL contains more than one statement
27
+ # @raise [DisallowedStatementError] if the statement type is not SELECT, INSERT,
28
+ # UPDATE, DELETE, MERGE, or VALUES
29
+ def analyze(sql)
30
+ @tokens = Lexer.new.tokenize(sql)
31
+ @pos = 0
32
+ @paren_depth = 0
33
+ @tables = Set.new
34
+ @functions = Set.new
35
+ @cte_names = Set.new
36
+ @suppress_from_depths = []
37
+
38
+ reject_multiple_statements!
39
+ stmt_type = determine_statement_type!
40
+
41
+ @pos = 0
42
+ @paren_depth = 0
43
+ pre_collect_cte_names!
44
+
45
+ @pos = 0
46
+ @paren_depth = 0
47
+ walk!
48
+
49
+ Analysis.new(
50
+ statement_type: stmt_type,
51
+ tables: @tables.to_a.sort,
52
+ functions: @functions.to_a.sort
53
+ )
54
+ end
55
+
56
+ private
57
+
58
+ # --- Token access helpers ---
59
+
60
+ def current
61
+ @tokens[@pos]
62
+ end
63
+
64
+ def peek
65
+ @tokens[@pos + 1]
66
+ end
67
+
68
+ def advance
69
+ @pos += 1
70
+ end
71
+
72
+ def at_end?
73
+ current.type == EOF
74
+ end
75
+
76
+ def keyword?(value)
77
+ current.type == KEYWORD && current.value == value
78
+ end
79
+
80
+ def keyword_one_of?(*values)
81
+ current.type == KEYWORD && values.include?(current.value)
82
+ end
83
+
84
+ def ident_or_quoted?
85
+ current.type == IDENT || current.type == QUOTED_IDENT
86
+ end
87
+
88
+ def can_be_name?
89
+ ident_or_quoted? || (current.type == KEYWORD && !clause_keyword?(current.value))
90
+ end
91
+
92
+ def identifier_value(token)
93
+ case token.type
94
+ when IDENT then token.value
95
+ when QUOTED_IDENT
96
+ if token.value.include?(".")
97
+ raise ParseError,
98
+ "quoted identifier containing a dot is not supported: \"#{token.value}\""
99
+ end
100
+ token.value
101
+ when KEYWORD then token.value.downcase
102
+ else token.value.to_s
103
+ end
104
+ end
105
+
106
+ # --- Phase 1: Reject multiple statements ---
107
+
108
+ def reject_multiple_statements!
109
+ @tokens.each_with_index do |token, i|
110
+ if token.type == SEMICOLON
111
+ rest = @tokens[(i + 1)..]
112
+ if rest.any? { |t| t.type != EOF && t.type != SEMICOLON }
113
+ raise MultipleStatementsError, "multiple statements are not allowed"
114
+ end
115
+ end
116
+ end
117
+ end
118
+
119
+ # --- Phase 2: Determine statement type ---
120
+
121
+ def determine_statement_type!
122
+ raise ParseError, "empty query" if at_end?
123
+ raise DisallowedStatementError, "expected a SQL statement, got #{current.value.inspect}" unless current.type == KEYWORD
124
+
125
+ case current.value
126
+ when "SELECT" then :select
127
+ when "INSERT" then :insert
128
+ when "UPDATE" then :update
129
+ when "DELETE" then :delete
130
+ when "MERGE" then :merge
131
+ when "VALUES" then :values
132
+ when "WITH" then determine_cte_statement_type!
133
+ else
134
+ raise DisallowedStatementError, "#{current.value} statements are not allowed"
135
+ end
136
+ end
137
+
138
+ def determine_cte_statement_type!
139
+ depth = 0
140
+ advance # past WITH
141
+ advance if keyword?("RECURSIVE")
142
+
143
+ loop do
144
+ break if at_end?
145
+ case current.type
146
+ when LPAREN then depth += 1; advance
147
+ when RPAREN then depth -= 1; advance
148
+ when KEYWORD
149
+ if depth == 0
150
+ case current.value
151
+ when "SELECT" then return :select
152
+ when "INSERT" then return :insert
153
+ when "UPDATE" then return :update
154
+ when "DELETE" then return :delete
155
+ when "MERGE" then return :merge
156
+ when "VALUES" then return :values
157
+ else advance
158
+ end
159
+ else
160
+ advance
161
+ end
162
+ else
163
+ advance
164
+ end
165
+ end
166
+
167
+ raise ParseError, "could not determine statement type in WITH clause"
168
+ end
169
+
170
+ # --- Phase 3: Pre-collect CTE names (index-based, no @pos modification) ---
171
+
172
+ def pre_collect_cte_names!
173
+ collect_cte_names_in_range!(0, @tokens.length)
174
+ end
175
+
176
+ def collect_cte_names_in_range!(from, to)
177
+ i = from
178
+ while i < to
179
+ if @tokens[i].type == KEYWORD && @tokens[i].value == "WITH"
180
+ i += 1
181
+ i += 1 if i < to && @tokens[i]&.type == KEYWORD && @tokens[i]&.value == "RECURSIVE"
182
+ loop do
183
+ break unless i < to && (@tokens[i].type == IDENT || @tokens[i].type == QUOTED_IDENT)
184
+ @cte_names << identifier_value(@tokens[i])
185
+ i += 1
186
+ # Skip optional column list
187
+ if i < to && @tokens[i]&.type == LPAREN
188
+ depth = 1; i += 1
189
+ while depth > 0 && i < to
190
+ depth += 1 if @tokens[i].type == LPAREN
191
+ depth -= 1 if @tokens[i].type == RPAREN
192
+ i += 1
193
+ end
194
+ end
195
+ # Skip AS [NOT] MATERIALIZED
196
+ if i < to && @tokens[i]&.type == KEYWORD && @tokens[i]&.value == "AS"
197
+ i += 1
198
+ if i < to && @tokens[i]&.type == KEYWORD && @tokens[i]&.value == "NOT"
199
+ i += 1
200
+ i += 1 if i < to && @tokens[i]&.type == KEYWORD && @tokens[i]&.value == "MATERIALIZED"
201
+ elsif i < to && @tokens[i]&.type == KEYWORD && @tokens[i]&.value == "MATERIALIZED"
202
+ i += 1
203
+ end
204
+ # Recurse into CTE body for nested WITH clauses, then skip it
205
+ if i < to && @tokens[i]&.type == LPAREN
206
+ body_start = i + 1
207
+ depth = 1; i += 1
208
+ while depth > 0 && i < to
209
+ depth += 1 if @tokens[i].type == LPAREN
210
+ depth -= 1 if @tokens[i].type == RPAREN
211
+ i += 1
212
+ end
213
+ collect_cte_names_in_range!(body_start, i - 1)
214
+ end
215
+ end
216
+ if i < to && @tokens[i]&.type == COMMA
217
+ i += 1
218
+ else
219
+ break
220
+ end
221
+ end
222
+ else
223
+ i += 1
224
+ end
225
+ end
226
+ end
227
+
228
+ # --- Phase 4: Main walk ---
229
+
230
+ def walk!
231
+ while !at_end?
232
+ dispatch_token!
233
+ end
234
+ end
235
+
236
+ def dispatch_token!
237
+ case current.type
238
+ when KEYWORD then handle_keyword!
239
+ when IDENT, QUOTED_IDENT then maybe_extract_function_call!
240
+ when LPAREN then @paren_depth += 1; advance
241
+ when RPAREN then handle_rparen!
242
+ else advance
243
+ end
244
+ end
245
+
246
+ def handle_rparen!
247
+ if !@suppress_from_depths.empty? && @paren_depth == @suppress_from_depths.last + 1
248
+ @suppress_from_depths.pop
249
+ end
250
+ @paren_depth -= 1
251
+ advance
252
+ end
253
+
254
+ def suppress_from?
255
+ !@suppress_from_depths.empty? && @paren_depth > @suppress_from_depths.last
256
+ end
257
+
258
+ def handle_keyword!
259
+ val = current.value
260
+ case val
261
+ when "FROM"
262
+ if suppress_from?
263
+ advance
264
+ else
265
+ advance
266
+ extract_from_list!
267
+ end
268
+ when "JOIN"
269
+ advance
270
+ extract_single_from_item!
271
+ when "INNER"
272
+ advance
273
+ if keyword?("JOIN")
274
+ advance
275
+ extract_single_from_item!
276
+ end
277
+ when "LEFT", "RIGHT", "FULL"
278
+ advance
279
+ advance if keyword?("OUTER")
280
+ if keyword?("JOIN")
281
+ advance
282
+ extract_single_from_item!
283
+ end
284
+ when "CROSS"
285
+ advance
286
+ if keyword?("JOIN")
287
+ advance
288
+ extract_single_from_item!
289
+ end
290
+ when "NATURAL"
291
+ advance
292
+ if keyword_one_of?("LEFT", "RIGHT", "FULL", "INNER")
293
+ advance
294
+ advance if keyword?("OUTER")
295
+ end
296
+ if keyword?("JOIN")
297
+ advance
298
+ extract_single_from_item!
299
+ end
300
+ when "INTO"
301
+ advance
302
+ read_table_ref! if can_be_name?
303
+ when "UPDATE"
304
+ advance
305
+ advance if keyword?("ONLY")
306
+ read_table_ref! if can_be_name?
307
+ when "USING"
308
+ advance
309
+ # USING (col_list) after JOIN starts with LPAREN — skip
310
+ # USING table after DELETE/MERGE — extract
311
+ if can_be_name? && current.type != LPAREN
312
+ extract_single_from_item!
313
+ end
314
+ else
315
+ if FUNCTIONS_WITH_FROM_SYNTAX.include?(val) && peek&.type == LPAREN
316
+ @functions << val.downcase
317
+ @suppress_from_depths.push(@paren_depth)
318
+ advance # past keyword; main loop handles LPAREN
319
+ elsif Keywords::FUNCTION_KEYWORDS.include?(val) && peek&.type == LPAREN
320
+ @functions << val.downcase
321
+ advance
322
+ else
323
+ advance
324
+ end
325
+ end
326
+ end
327
+
328
+ # --- Function call detection ---
329
+
330
+ def maybe_extract_function_call!
331
+ if peek&.type == LPAREN
332
+ name = identifier_value(current)
333
+ @functions << name
334
+ advance
335
+ elsif peek&.type == DOT
336
+ name1 = identifier_value(current)
337
+ advance # past ident
338
+ advance # past dot
339
+ if (ident_or_quoted? || current.type == KEYWORD) && peek&.type == LPAREN
340
+ name2 = identifier_value(current)
341
+ @functions << "#{name1}.#{name2}"
342
+ advance
343
+ end
344
+ else
345
+ advance
346
+ end
347
+ end
348
+
349
+ # --- FROM list extraction ---
350
+
351
+ def extract_from_list!
352
+ loop do
353
+ break if at_end?
354
+ break if should_end_from_list?
355
+
356
+ extract_from_item!
357
+
358
+ if current.type == COMMA
359
+ advance
360
+ else
361
+ break
362
+ end
363
+ end
364
+ end
365
+
366
+ def should_end_from_list?
367
+ return true if current.type == RPAREN
368
+ return true if current.type == SEMICOLON
369
+ return true if current.type == EOF
370
+ if current.type == KEYWORD
371
+ v = current.value
372
+ return true if %w[WHERE GROUP HAVING ORDER LIMIT OFFSET FETCH
373
+ UNION INTERSECT EXCEPT WINDOW FOR RETURNING
374
+ ON SET WHEN].include?(v)
375
+ return true if v == "JOIN"
376
+ return true if JOIN_PREFIXES.include?(v) && has_join_ahead?
377
+ end
378
+ false
379
+ end
380
+
381
+ def has_join_ahead?
382
+ i = @pos + 1
383
+ while i < @tokens.length && i <= @pos + 3
384
+ return true if @tokens[i].type == KEYWORD && @tokens[i].value == "JOIN"
385
+ return false unless @tokens[i].type == KEYWORD && %w[OUTER INNER].include?(@tokens[i].value)
386
+ i += 1
387
+ end
388
+ false
389
+ end
390
+
391
+ def extract_single_from_item!
392
+ extract_from_item!
393
+ end
394
+
395
+ def extract_from_item!
396
+ advance if keyword?("LATERAL")
397
+ advance if keyword?("ONLY")
398
+
399
+ if current.type == LPAREN
400
+ # Subquery or parenthesized expression — walk inside for refs
401
+ walk_balanced_group!
402
+ skip_alias!
403
+ return
404
+ end
405
+
406
+ return if at_end?
407
+ return unless can_be_name?
408
+
409
+ name = read_qualified_name!
410
+
411
+ if current.type == LPAREN
412
+ # Table-valued function in FROM
413
+ @functions << name unless @cte_names.include?(name)
414
+ walk_balanced_group! # walk function args for nested refs
415
+ skip_alias!
416
+ else
417
+ @tables << name unless @cte_names.include?(name)
418
+ skip_alias!
419
+ end
420
+ end
421
+
422
+ # Walk tokens inside balanced parens, extracting refs. Consumes through closing RPAREN.
423
+ def walk_balanced_group!
424
+ return unless current.type == LPAREN
425
+ start_depth = @paren_depth
426
+ @paren_depth += 1
427
+ advance # past LPAREN
428
+
429
+ while !at_end? && @paren_depth > start_depth
430
+ dispatch_token!
431
+ end
432
+ end
433
+
434
+ # --- Table reference reading ---
435
+
436
+ def read_table_ref!
437
+ return if at_end?
438
+ return unless can_be_name?
439
+
440
+ name = read_qualified_name!
441
+ @tables << name unless @cte_names.include?(name)
442
+ skip_alias!
443
+ end
444
+
445
+ def read_qualified_name!
446
+ part1 = identifier_value(current)
447
+ advance
448
+
449
+ if current.type == DOT && peek && (peek.type == IDENT || peek.type == QUOTED_IDENT || peek.type == KEYWORD)
450
+ advance # past dot
451
+ part2 = identifier_value(current)
452
+ advance
453
+ if current.type == DOT && peek && (peek.type == IDENT || peek.type == QUOTED_IDENT)
454
+ advance # past dot
455
+ part3 = identifier_value(current)
456
+ advance
457
+ "#{part1}.#{part2}.#{part3}"
458
+ else
459
+ "#{part1}.#{part2}"
460
+ end
461
+ else
462
+ part1
463
+ end
464
+ end
465
+
466
+ # --- Alias skipping ---
467
+
468
+ def skip_alias!
469
+ return if at_end?
470
+
471
+ if keyword?("AS")
472
+ advance
473
+ if ident_or_quoted? || (current.type == KEYWORD && !clause_keyword?(current.value))
474
+ advance
475
+ end
476
+ skip_balanced_parens! if current.type == LPAREN
477
+ elsif ident_or_quoted?
478
+ unless current.type == KEYWORD && clause_keyword?(current.value)
479
+ advance
480
+ skip_balanced_parens! if current.type == LPAREN
481
+ end
482
+ end
483
+ end
484
+
485
+ def clause_keyword?(value)
486
+ %w[WHERE GROUP HAVING ORDER LIMIT OFFSET FETCH UNION INTERSECT EXCEPT
487
+ JOIN INNER LEFT RIGHT FULL CROSS NATURAL ON USING
488
+ WINDOW FOR RETURNING SET WHEN MATCHED FROM INTO].include?(value)
489
+ end
490
+
491
+ def skip_balanced_parens!
492
+ return unless current.type == LPAREN
493
+ depth = 1
494
+ advance
495
+ until at_end? || depth == 0
496
+ case current.type
497
+ when LPAREN then depth += 1
498
+ when RPAREN then depth -= 1
499
+ end
500
+ advance
501
+ end
502
+ end
503
+ end
504
+ end
@@ -0,0 +1,84 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "set"
4
+
5
+ module Pgsqlarbiter
6
+ # Reusable query permission checker with pre-configured whitelists.
7
+ #
8
+ # Use this class when you need to check multiple queries against the same set of
9
+ # allowed statement types, tables, and functions. For one-off checks, see
10
+ # {Pgsqlarbiter.allow?}.
11
+ class Arbiter
12
+ # @return [Set<Symbol>] the set of valid statement type symbols
13
+ VALID_STATEMENT_TYPES = Set[:select, :insert, :update, :delete, :merge, :values].freeze
14
+
15
+ # @return [Set<Symbol>] allowed statement types
16
+ attr_reader :allowed_statement_types
17
+ # @return [Set<String>] allowed table and view names
18
+ attr_reader :allowed_tables
19
+ # @return [Set<String>] allowed function names
20
+ attr_reader :allowed_functions
21
+
22
+ # Create a new Arbiter with the given whitelists.
23
+ #
24
+ # @param allowed_statement_types [Array<Symbol>] allowed statement types. Valid values:
25
+ # +:select+, +:insert+, +:update+, +:delete+, +:merge+, +:values+
26
+ # @param allowed_tables [Array<String>] allowed table and view names
27
+ # @param allowed_functions [Array<String>, Set<String>] allowed function names
28
+ # (default: {Pgsqlarbiter::DEFAULT_QUERY_FUNCTIONS})
29
+ # @raise [ArgumentError] if any statement type is not in {VALID_STATEMENT_TYPES}
30
+ def initialize(allowed_statement_types:, allowed_tables:, allowed_functions: Pgsqlarbiter::DEFAULT_QUERY_FUNCTIONS)
31
+ @allowed_statement_types = validate_statement_types(allowed_statement_types)
32
+ @allowed_tables = Set.new(allowed_tables).freeze
33
+ @allowed_functions = Set.new(allowed_functions).freeze
34
+ end
35
+
36
+ # Judge a SQL query against this arbiter's rules, returning a {Verdict} that
37
+ # explains which checks passed or failed.
38
+ #
39
+ # @param sql [String] the SQL query to judge
40
+ # @return [Verdict] detailed result with {Verdict#allowed?}, individual check
41
+ # results, and {Verdict#reasons}
42
+ # @raise [ParseError] if the SQL cannot be parsed
43
+ # @raise [MultipleStatementsError] if the SQL contains more than one statement
44
+ # @raise [DisallowedStatementError] if the statement type is not a supported DML type
45
+ def judge(sql)
46
+ result = Pgsqlarbiter.analyze(sql)
47
+ stmt_ok = @allowed_statement_types.include?(result.statement_type)
48
+ bad_tables = result.tables.reject { |t| @allowed_tables.include?(t) }.freeze
49
+ bad_functions = result.functions.reject { |f| @allowed_functions.include?(f) }.freeze
50
+
51
+ Verdict.new(
52
+ allowed: stmt_ok && bad_tables.empty? && bad_functions.empty?,
53
+ statement_type_allowed: stmt_ok,
54
+ statement_type: result.statement_type,
55
+ disallowed_tables: bad_tables,
56
+ disallowed_functions: bad_functions
57
+ )
58
+ end
59
+
60
+ # Check whether a SQL query is allowed under this arbiter's rules.
61
+ #
62
+ # @param sql [String] the SQL query to check
63
+ # @return [Boolean] +true+ if the statement type, all tables, and all functions
64
+ # are within the configured whitelists
65
+ # @raise [ParseError] if the SQL cannot be parsed
66
+ # @raise [MultipleStatementsError] if the SQL contains more than one statement
67
+ # @raise [DisallowedStatementError] if the statement type is not a supported DML type
68
+ def allow?(sql)
69
+ judge(sql).allowed?
70
+ end
71
+
72
+ private
73
+
74
+ def validate_statement_types(types)
75
+ set = Set.new(types)
76
+ invalid = set - VALID_STATEMENT_TYPES
77
+ unless invalid.empty?
78
+ raise ArgumentError, "unknown statement type(s): #{invalid.map(&:inspect).join(", ")}. " \
79
+ "Valid types: #{VALID_STATEMENT_TYPES.map(&:inspect).join(", ")}"
80
+ end
81
+ set.freeze
82
+ end
83
+ end
84
+ end