rooq 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/.tool-versions +1 -0
- data/.yardopts +10 -0
- data/CHANGELOG.md +33 -0
- data/CLAUDE.md +54 -0
- data/Gemfile +20 -0
- data/Gemfile.lock +116 -0
- data/LICENSE +661 -0
- data/README.md +98 -0
- data/Rakefile +130 -0
- data/USAGE.md +850 -0
- data/exe/rooq +7 -0
- data/lib/rooq/adapters/postgresql.rb +117 -0
- data/lib/rooq/adapters.rb +3 -0
- data/lib/rooq/cli.rb +230 -0
- data/lib/rooq/condition.rb +104 -0
- data/lib/rooq/configuration.rb +56 -0
- data/lib/rooq/connection.rb +131 -0
- data/lib/rooq/context.rb +141 -0
- data/lib/rooq/dialect/base.rb +27 -0
- data/lib/rooq/dialect/postgresql.rb +531 -0
- data/lib/rooq/dialect.rb +9 -0
- data/lib/rooq/dsl/delete_query.rb +37 -0
- data/lib/rooq/dsl/insert_query.rb +43 -0
- data/lib/rooq/dsl/select_query.rb +301 -0
- data/lib/rooq/dsl/update_query.rb +44 -0
- data/lib/rooq/dsl.rb +28 -0
- data/lib/rooq/executor.rb +65 -0
- data/lib/rooq/expression.rb +494 -0
- data/lib/rooq/field.rb +71 -0
- data/lib/rooq/generator/code_generator.rb +91 -0
- data/lib/rooq/generator/introspector.rb +265 -0
- data/lib/rooq/generator.rb +9 -0
- data/lib/rooq/parameter_converter.rb +98 -0
- data/lib/rooq/query_validator.rb +176 -0
- data/lib/rooq/result.rb +248 -0
- data/lib/rooq/schema_validator.rb +56 -0
- data/lib/rooq/table.rb +69 -0
- data/lib/rooq/version.rb +5 -0
- data/lib/rooq.rb +25 -0
- data/rooq.gemspec +35 -0
- data/sorbet/config +4 -0
- metadata +115 -0
|
@@ -0,0 +1,265 @@
|
|
|
1
|
+
# typed: strict
|
|
2
|
+
# frozen_string_literal: true
|
|
3
|
+
|
|
4
|
+
require "sorbet-runtime"
|
|
5
|
+
|
|
6
|
+
module Rooq
|
|
7
|
+
module Generator
|
|
8
|
+
class Introspector
|
|
9
|
+
extend T::Sig
|
|
10
|
+
|
|
11
|
+
PG_TYPE_MAP = T.let({
|
|
12
|
+
"integer" => :integer,
|
|
13
|
+
"bigint" => :bigint,
|
|
14
|
+
"smallint" => :smallint,
|
|
15
|
+
"serial" => :integer,
|
|
16
|
+
"bigserial" => :bigint,
|
|
17
|
+
"real" => :float,
|
|
18
|
+
"double precision" => :double,
|
|
19
|
+
"numeric" => :decimal,
|
|
20
|
+
"decimal" => :decimal,
|
|
21
|
+
"character varying" => :string,
|
|
22
|
+
"varchar" => :string,
|
|
23
|
+
"character" => :string,
|
|
24
|
+
"char" => :string,
|
|
25
|
+
"text" => :text,
|
|
26
|
+
"boolean" => :boolean,
|
|
27
|
+
"date" => :date,
|
|
28
|
+
"timestamp without time zone" => :datetime,
|
|
29
|
+
"timestamp with time zone" => :datetime_tz,
|
|
30
|
+
"time without time zone" => :time,
|
|
31
|
+
"time with time zone" => :time_tz,
|
|
32
|
+
"uuid" => :uuid,
|
|
33
|
+
"json" => :json,
|
|
34
|
+
"jsonb" => :jsonb,
|
|
35
|
+
"bytea" => :binary,
|
|
36
|
+
"inet" => :inet,
|
|
37
|
+
"cidr" => :cidr,
|
|
38
|
+
"macaddr" => :macaddr
|
|
39
|
+
}.freeze, T::Hash[String, Symbol])
|
|
40
|
+
|
|
41
|
+
sig { params(connection: T.untyped).void }
|
|
42
|
+
def initialize(connection)
|
|
43
|
+
@connection = connection
|
|
44
|
+
end
|
|
45
|
+
|
|
46
|
+
sig { params(schema: String).returns(T::Array[String]) }
|
|
47
|
+
def introspect_tables(schema: "public")
|
|
48
|
+
tables_sql = <<~SQL
|
|
49
|
+
SELECT table_name
|
|
50
|
+
FROM information_schema.tables
|
|
51
|
+
WHERE table_schema = $1
|
|
52
|
+
AND table_type = 'BASE TABLE'
|
|
53
|
+
ORDER BY table_name
|
|
54
|
+
SQL
|
|
55
|
+
|
|
56
|
+
result = @connection.exec_params(tables_sql, [schema])
|
|
57
|
+
result.map { |row| row["table_name"] }
|
|
58
|
+
end
|
|
59
|
+
|
|
60
|
+
sig { params(table_name: String, schema: String).returns(T::Array[ColumnInfo]) }
|
|
61
|
+
def introspect_columns(table_name, schema: "public")
|
|
62
|
+
columns_sql = <<~SQL
|
|
63
|
+
SELECT
|
|
64
|
+
column_name,
|
|
65
|
+
data_type,
|
|
66
|
+
is_nullable,
|
|
67
|
+
column_default,
|
|
68
|
+
character_maximum_length,
|
|
69
|
+
numeric_precision,
|
|
70
|
+
numeric_scale
|
|
71
|
+
FROM information_schema.columns
|
|
72
|
+
WHERE table_schema = $1
|
|
73
|
+
AND table_name = $2
|
|
74
|
+
ORDER BY ordinal_position
|
|
75
|
+
SQL
|
|
76
|
+
|
|
77
|
+
result = @connection.exec_params(columns_sql, [schema, table_name])
|
|
78
|
+
result.map do |row|
|
|
79
|
+
ColumnInfo.new(
|
|
80
|
+
name: row["column_name"],
|
|
81
|
+
type: map_pg_type(row["data_type"]),
|
|
82
|
+
pg_type: row["data_type"],
|
|
83
|
+
nullable: row["is_nullable"] == "YES",
|
|
84
|
+
default: row["column_default"],
|
|
85
|
+
max_length: row["character_maximum_length"]&.to_i,
|
|
86
|
+
precision: row["numeric_precision"]&.to_i,
|
|
87
|
+
scale: row["numeric_scale"]&.to_i
|
|
88
|
+
)
|
|
89
|
+
end
|
|
90
|
+
end
|
|
91
|
+
|
|
92
|
+
sig { params(table_name: String, schema: String).returns(T::Array[String]) }
|
|
93
|
+
def introspect_primary_keys(table_name, schema: "public")
|
|
94
|
+
pk_sql = <<~SQL
|
|
95
|
+
SELECT kcu.column_name
|
|
96
|
+
FROM information_schema.table_constraints tc
|
|
97
|
+
JOIN information_schema.key_column_usage kcu
|
|
98
|
+
ON tc.constraint_name = kcu.constraint_name
|
|
99
|
+
AND tc.table_schema = kcu.table_schema
|
|
100
|
+
WHERE tc.constraint_type = 'PRIMARY KEY'
|
|
101
|
+
AND tc.table_schema = $1
|
|
102
|
+
AND tc.table_name = $2
|
|
103
|
+
ORDER BY kcu.ordinal_position
|
|
104
|
+
SQL
|
|
105
|
+
|
|
106
|
+
result = @connection.exec_params(pk_sql, [schema, table_name])
|
|
107
|
+
result.map { |row| row["column_name"] }
|
|
108
|
+
end
|
|
109
|
+
|
|
110
|
+
sig { params(table_name: String, schema: String).returns(T::Array[ForeignKeyInfo]) }
|
|
111
|
+
def introspect_foreign_keys(table_name, schema: "public")
|
|
112
|
+
fk_sql = <<~SQL
|
|
113
|
+
SELECT
|
|
114
|
+
kcu.column_name,
|
|
115
|
+
ccu.table_name AS foreign_table_name,
|
|
116
|
+
ccu.column_name AS foreign_column_name
|
|
117
|
+
FROM information_schema.table_constraints tc
|
|
118
|
+
JOIN information_schema.key_column_usage kcu
|
|
119
|
+
ON tc.constraint_name = kcu.constraint_name
|
|
120
|
+
AND tc.table_schema = kcu.table_schema
|
|
121
|
+
JOIN information_schema.constraint_column_usage ccu
|
|
122
|
+
ON ccu.constraint_name = tc.constraint_name
|
|
123
|
+
AND ccu.table_schema = tc.table_schema
|
|
124
|
+
WHERE tc.constraint_type = 'FOREIGN KEY'
|
|
125
|
+
AND tc.table_schema = $1
|
|
126
|
+
AND tc.table_name = $2
|
|
127
|
+
SQL
|
|
128
|
+
|
|
129
|
+
result = @connection.exec_params(fk_sql, [schema, table_name])
|
|
130
|
+
result.map do |row|
|
|
131
|
+
ForeignKeyInfo.new(
|
|
132
|
+
column_name: row["column_name"],
|
|
133
|
+
foreign_table: row["foreign_table_name"],
|
|
134
|
+
foreign_column: row["foreign_column_name"]
|
|
135
|
+
)
|
|
136
|
+
end
|
|
137
|
+
end
|
|
138
|
+
|
|
139
|
+
sig { params(schema: String).returns(T::Array[TableInfo]) }
|
|
140
|
+
def introspect_schema(schema: "public")
|
|
141
|
+
tables = introspect_tables(schema: schema)
|
|
142
|
+
tables.map do |table_name|
|
|
143
|
+
TableInfo.new(
|
|
144
|
+
name: table_name,
|
|
145
|
+
columns: introspect_columns(table_name, schema: schema),
|
|
146
|
+
primary_keys: introspect_primary_keys(table_name, schema: schema),
|
|
147
|
+
foreign_keys: introspect_foreign_keys(table_name, schema: schema)
|
|
148
|
+
)
|
|
149
|
+
end
|
|
150
|
+
end
|
|
151
|
+
|
|
152
|
+
private
|
|
153
|
+
|
|
154
|
+
sig { params(pg_type: String).returns(Symbol) }
|
|
155
|
+
def map_pg_type(pg_type)
|
|
156
|
+
PG_TYPE_MAP.fetch(pg_type.downcase, :unknown)
|
|
157
|
+
end
|
|
158
|
+
end
|
|
159
|
+
|
|
160
|
+
class ColumnInfo
|
|
161
|
+
extend T::Sig
|
|
162
|
+
|
|
163
|
+
sig { returns(String) }
|
|
164
|
+
attr_reader :name
|
|
165
|
+
|
|
166
|
+
sig { returns(Symbol) }
|
|
167
|
+
attr_reader :type
|
|
168
|
+
|
|
169
|
+
sig { returns(String) }
|
|
170
|
+
attr_reader :pg_type
|
|
171
|
+
|
|
172
|
+
sig { returns(T::Boolean) }
|
|
173
|
+
attr_reader :nullable
|
|
174
|
+
|
|
175
|
+
sig { returns(T.nilable(String)) }
|
|
176
|
+
attr_reader :default
|
|
177
|
+
|
|
178
|
+
sig { returns(T.nilable(Integer)) }
|
|
179
|
+
attr_reader :max_length
|
|
180
|
+
|
|
181
|
+
sig { returns(T.nilable(Integer)) }
|
|
182
|
+
attr_reader :precision
|
|
183
|
+
|
|
184
|
+
sig { returns(T.nilable(Integer)) }
|
|
185
|
+
attr_reader :scale
|
|
186
|
+
|
|
187
|
+
sig do
|
|
188
|
+
params(
|
|
189
|
+
name: String,
|
|
190
|
+
type: Symbol,
|
|
191
|
+
pg_type: String,
|
|
192
|
+
nullable: T::Boolean,
|
|
193
|
+
default: T.nilable(String),
|
|
194
|
+
max_length: T.nilable(Integer),
|
|
195
|
+
precision: T.nilable(Integer),
|
|
196
|
+
scale: T.nilable(Integer)
|
|
197
|
+
).void
|
|
198
|
+
end
|
|
199
|
+
def initialize(name:, type:, pg_type:, nullable:, default:, max_length:, precision:, scale:)
|
|
200
|
+
@name = name
|
|
201
|
+
@type = type
|
|
202
|
+
@pg_type = pg_type
|
|
203
|
+
@nullable = nullable
|
|
204
|
+
@default = default
|
|
205
|
+
@max_length = max_length
|
|
206
|
+
@precision = precision
|
|
207
|
+
@scale = scale
|
|
208
|
+
freeze
|
|
209
|
+
end
|
|
210
|
+
end
|
|
211
|
+
|
|
212
|
+
class ForeignKeyInfo
|
|
213
|
+
extend T::Sig
|
|
214
|
+
|
|
215
|
+
sig { returns(String) }
|
|
216
|
+
attr_reader :column_name
|
|
217
|
+
|
|
218
|
+
sig { returns(String) }
|
|
219
|
+
attr_reader :foreign_table
|
|
220
|
+
|
|
221
|
+
sig { returns(String) }
|
|
222
|
+
attr_reader :foreign_column
|
|
223
|
+
|
|
224
|
+
sig { params(column_name: String, foreign_table: String, foreign_column: String).void }
|
|
225
|
+
def initialize(column_name:, foreign_table:, foreign_column:)
|
|
226
|
+
@column_name = column_name
|
|
227
|
+
@foreign_table = foreign_table
|
|
228
|
+
@foreign_column = foreign_column
|
|
229
|
+
freeze
|
|
230
|
+
end
|
|
231
|
+
end
|
|
232
|
+
|
|
233
|
+
class TableInfo
|
|
234
|
+
extend T::Sig
|
|
235
|
+
|
|
236
|
+
sig { returns(String) }
|
|
237
|
+
attr_reader :name
|
|
238
|
+
|
|
239
|
+
sig { returns(T::Array[ColumnInfo]) }
|
|
240
|
+
attr_reader :columns
|
|
241
|
+
|
|
242
|
+
sig { returns(T::Array[String]) }
|
|
243
|
+
attr_reader :primary_keys
|
|
244
|
+
|
|
245
|
+
sig { returns(T::Array[ForeignKeyInfo]) }
|
|
246
|
+
attr_reader :foreign_keys
|
|
247
|
+
|
|
248
|
+
sig do
|
|
249
|
+
params(
|
|
250
|
+
name: String,
|
|
251
|
+
columns: T::Array[ColumnInfo],
|
|
252
|
+
primary_keys: T::Array[String],
|
|
253
|
+
foreign_keys: T::Array[ForeignKeyInfo]
|
|
254
|
+
).void
|
|
255
|
+
end
|
|
256
|
+
def initialize(name:, columns:, primary_keys:, foreign_keys:)
|
|
257
|
+
@name = name
|
|
258
|
+
@columns = T.let(columns.freeze, T::Array[ColumnInfo])
|
|
259
|
+
@primary_keys = T.let(primary_keys.freeze, T::Array[String])
|
|
260
|
+
@foreign_keys = T.let(foreign_keys.freeze, T::Array[ForeignKeyInfo])
|
|
261
|
+
freeze
|
|
262
|
+
end
|
|
263
|
+
end
|
|
264
|
+
end
|
|
265
|
+
end
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "json"
|
|
4
|
+
require "time"
|
|
5
|
+
require "date"
|
|
6
|
+
|
|
7
|
+
module Rooq
|
|
8
|
+
# ParameterConverter converts Ruby objects to PostgreSQL-compatible parameter values.
|
|
9
|
+
#
|
|
10
|
+
# Conversions:
|
|
11
|
+
# - Time/DateTime -> ISO 8601 string
|
|
12
|
+
# - Date -> ISO 8601 date string
|
|
13
|
+
# - Hash -> JSON string
|
|
14
|
+
# - Array of primitives -> PostgreSQL array literal
|
|
15
|
+
# - Array of hashes -> JSON array string
|
|
16
|
+
# - Symbol -> String
|
|
17
|
+
# - Other types pass through unchanged
|
|
18
|
+
#
|
|
19
|
+
# @example
|
|
20
|
+
# converter = ParameterConverter.new
|
|
21
|
+
# converter.convert(Time.now) # => "2024-01-15T10:30:45+00:00"
|
|
22
|
+
# converter.convert({ key: "value" }) # => '{"key":"value"}'
|
|
23
|
+
# converter.convert([1, 2, 3]) # => "{1,2,3}"
|
|
24
|
+
class ParameterConverter
|
|
25
|
+
# Convert a single parameter value.
|
|
26
|
+
# @param value [Object] the value to convert
|
|
27
|
+
# @return [Object] the converted value
|
|
28
|
+
def convert(value)
|
|
29
|
+
case value
|
|
30
|
+
when nil, true, false, Integer, Float
|
|
31
|
+
value
|
|
32
|
+
when String
|
|
33
|
+
value
|
|
34
|
+
when Time, DateTime
|
|
35
|
+
value.iso8601
|
|
36
|
+
when Date
|
|
37
|
+
value.iso8601
|
|
38
|
+
when Hash
|
|
39
|
+
JSON.generate(value)
|
|
40
|
+
when Array
|
|
41
|
+
convert_array(value)
|
|
42
|
+
when Symbol
|
|
43
|
+
value.to_s
|
|
44
|
+
else
|
|
45
|
+
value
|
|
46
|
+
end
|
|
47
|
+
end
|
|
48
|
+
|
|
49
|
+
# Convert an array of parameter values.
|
|
50
|
+
# @param params [Array] the parameters to convert
|
|
51
|
+
# @return [Array] the converted parameters
|
|
52
|
+
def convert_all(params)
|
|
53
|
+
params.map { |p| convert(p) }
|
|
54
|
+
end
|
|
55
|
+
|
|
56
|
+
private
|
|
57
|
+
|
|
58
|
+
def convert_array(array)
|
|
59
|
+
return "{}" if array.empty?
|
|
60
|
+
|
|
61
|
+
# If array contains hashes, convert to JSON array
|
|
62
|
+
if array.any? { |el| el.is_a?(Hash) }
|
|
63
|
+
return JSON.generate(array)
|
|
64
|
+
end
|
|
65
|
+
|
|
66
|
+
# Otherwise convert to PostgreSQL array literal
|
|
67
|
+
elements = array.map { |el| format_pg_array_element(el) }
|
|
68
|
+
"{#{elements.join(',')}}"
|
|
69
|
+
end
|
|
70
|
+
|
|
71
|
+
def format_pg_array_element(value)
|
|
72
|
+
return "NULL" if value.nil?
|
|
73
|
+
|
|
74
|
+
str = convert(value).to_s
|
|
75
|
+
|
|
76
|
+
# Quote if contains special characters
|
|
77
|
+
if needs_quoting?(str)
|
|
78
|
+
"\"#{escape_pg_string(str)}\""
|
|
79
|
+
else
|
|
80
|
+
str
|
|
81
|
+
end
|
|
82
|
+
end
|
|
83
|
+
|
|
84
|
+
def needs_quoting?(str)
|
|
85
|
+
str.include?(",") ||
|
|
86
|
+
str.include?(" ") ||
|
|
87
|
+
str.include?('"') ||
|
|
88
|
+
str.include?("\\") ||
|
|
89
|
+
str.include?("{") ||
|
|
90
|
+
str.include?("}") ||
|
|
91
|
+
str.empty?
|
|
92
|
+
end
|
|
93
|
+
|
|
94
|
+
def escape_pg_string(str)
|
|
95
|
+
str.gsub('\\', '\\\\').gsub('"', '\\"')
|
|
96
|
+
end
|
|
97
|
+
end
|
|
98
|
+
end
|
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Rooq
|
|
4
|
+
class QueryValidator
|
|
5
|
+
def initialize(tables)
|
|
6
|
+
@tables = tables.each_with_object({}) do |table, hash|
|
|
7
|
+
hash[table.name] = table
|
|
8
|
+
end
|
|
9
|
+
end
|
|
10
|
+
|
|
11
|
+
def validate_select(query)
|
|
12
|
+
errors = []
|
|
13
|
+
|
|
14
|
+
# Validate FROM table
|
|
15
|
+
if query.from_table
|
|
16
|
+
table_error = validate_table(query.from_table)
|
|
17
|
+
errors << table_error if table_error
|
|
18
|
+
end
|
|
19
|
+
|
|
20
|
+
# Validate selected fields
|
|
21
|
+
query.selected_fields.each do |field|
|
|
22
|
+
next unless field.is_a?(Field)
|
|
23
|
+
|
|
24
|
+
field_error = validate_field(field)
|
|
25
|
+
errors << field_error if field_error
|
|
26
|
+
end
|
|
27
|
+
|
|
28
|
+
# Validate conditions
|
|
29
|
+
if query.conditions
|
|
30
|
+
condition_errors = validate_condition(query.conditions)
|
|
31
|
+
errors.concat(condition_errors)
|
|
32
|
+
end
|
|
33
|
+
|
|
34
|
+
# Validate order specs
|
|
35
|
+
query.order_specs.each do |spec|
|
|
36
|
+
field_error = validate_field(spec.field)
|
|
37
|
+
errors << field_error if field_error
|
|
38
|
+
end
|
|
39
|
+
|
|
40
|
+
# Validate joins
|
|
41
|
+
query.joins.each do |join|
|
|
42
|
+
table_error = validate_table(join.table)
|
|
43
|
+
errors << table_error if table_error
|
|
44
|
+
|
|
45
|
+
condition_errors = validate_condition(join.condition)
|
|
46
|
+
errors.concat(condition_errors)
|
|
47
|
+
end
|
|
48
|
+
|
|
49
|
+
raise QueryValidationError.new(errors) unless errors.empty?
|
|
50
|
+
|
|
51
|
+
true
|
|
52
|
+
end
|
|
53
|
+
|
|
54
|
+
def validate_insert(query)
|
|
55
|
+
errors = []
|
|
56
|
+
|
|
57
|
+
table_error = validate_table(query.table)
|
|
58
|
+
errors << table_error if table_error
|
|
59
|
+
|
|
60
|
+
query.column_list.each do |field|
|
|
61
|
+
next unless field.is_a?(Field)
|
|
62
|
+
|
|
63
|
+
field_error = validate_field(field)
|
|
64
|
+
errors << field_error if field_error
|
|
65
|
+
end
|
|
66
|
+
|
|
67
|
+
raise QueryValidationError.new(errors) unless errors.empty?
|
|
68
|
+
|
|
69
|
+
true
|
|
70
|
+
end
|
|
71
|
+
|
|
72
|
+
def validate_update(query)
|
|
73
|
+
errors = []
|
|
74
|
+
|
|
75
|
+
table_error = validate_table(query.table)
|
|
76
|
+
errors << table_error if table_error
|
|
77
|
+
|
|
78
|
+
query.set_values.each_key do |field|
|
|
79
|
+
next unless field.is_a?(Field)
|
|
80
|
+
|
|
81
|
+
field_error = validate_field(field)
|
|
82
|
+
errors << field_error if field_error
|
|
83
|
+
end
|
|
84
|
+
|
|
85
|
+
if query.conditions
|
|
86
|
+
condition_errors = validate_condition(query.conditions)
|
|
87
|
+
errors.concat(condition_errors)
|
|
88
|
+
end
|
|
89
|
+
|
|
90
|
+
raise QueryValidationError.new(errors) unless errors.empty?
|
|
91
|
+
|
|
92
|
+
true
|
|
93
|
+
end
|
|
94
|
+
|
|
95
|
+
def validate_delete(query)
|
|
96
|
+
errors = []
|
|
97
|
+
|
|
98
|
+
table_error = validate_table(query.table)
|
|
99
|
+
errors << table_error if table_error
|
|
100
|
+
|
|
101
|
+
if query.conditions
|
|
102
|
+
condition_errors = validate_condition(query.conditions)
|
|
103
|
+
errors.concat(condition_errors)
|
|
104
|
+
end
|
|
105
|
+
|
|
106
|
+
raise QueryValidationError.new(errors) unless errors.empty?
|
|
107
|
+
|
|
108
|
+
true
|
|
109
|
+
end
|
|
110
|
+
|
|
111
|
+
private
|
|
112
|
+
|
|
113
|
+
def validate_table(table)
|
|
114
|
+
return nil unless table.is_a?(Table)
|
|
115
|
+
return nil if @tables.key?(table.name)
|
|
116
|
+
|
|
117
|
+
"Unknown table '#{table.name}'. Known tables: #{@tables.keys.join(', ')}"
|
|
118
|
+
end
|
|
119
|
+
|
|
120
|
+
def validate_field(field)
|
|
121
|
+
table = @tables[field.table_name]
|
|
122
|
+
return "Unknown table '#{field.table_name}' for field '#{field.name}'" unless table
|
|
123
|
+
return nil if table.fields.key?(field.name)
|
|
124
|
+
|
|
125
|
+
"Unknown field '#{field.name}' on table '#{field.table_name}'. Available: #{table.fields.keys.join(', ')}"
|
|
126
|
+
end
|
|
127
|
+
|
|
128
|
+
def validate_condition(condition)
|
|
129
|
+
case condition
|
|
130
|
+
when Condition
|
|
131
|
+
error = validate_field(condition.field)
|
|
132
|
+
error ? [error] : []
|
|
133
|
+
when CombinedCondition
|
|
134
|
+
condition.conditions.flat_map { |c| validate_condition(c) }
|
|
135
|
+
else
|
|
136
|
+
[]
|
|
137
|
+
end
|
|
138
|
+
end
|
|
139
|
+
end
|
|
140
|
+
|
|
141
|
+
class QueryValidationError < Error
|
|
142
|
+
attr_reader :validation_errors
|
|
143
|
+
|
|
144
|
+
def initialize(errors)
|
|
145
|
+
@validation_errors = errors
|
|
146
|
+
super("Query validation failed:\n - #{errors.join("\n - ")}")
|
|
147
|
+
end
|
|
148
|
+
end
|
|
149
|
+
|
|
150
|
+
class ValidatingExecutor < Executor
|
|
151
|
+
def initialize(connection, tables, dialect: Dialect::PostgreSQL.new)
|
|
152
|
+
super(connection, dialect: dialect)
|
|
153
|
+
@validator = QueryValidator.new(tables)
|
|
154
|
+
end
|
|
155
|
+
|
|
156
|
+
def execute(query)
|
|
157
|
+
validate_query(query)
|
|
158
|
+
super
|
|
159
|
+
end
|
|
160
|
+
|
|
161
|
+
private
|
|
162
|
+
|
|
163
|
+
def validate_query(query)
|
|
164
|
+
case query
|
|
165
|
+
when DSL::SelectQuery
|
|
166
|
+
@validator.validate_select(query)
|
|
167
|
+
when DSL::InsertQuery
|
|
168
|
+
@validator.validate_insert(query)
|
|
169
|
+
when DSL::UpdateQuery
|
|
170
|
+
@validator.validate_update(query)
|
|
171
|
+
when DSL::DeleteQuery
|
|
172
|
+
@validator.validate_delete(query)
|
|
173
|
+
end
|
|
174
|
+
end
|
|
175
|
+
end
|
|
176
|
+
end
|