rooq 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,265 @@
1
+ # typed: strict
2
+ # frozen_string_literal: true
3
+
4
+ require "sorbet-runtime"
5
+
6
+ module Rooq
7
+ module Generator
8
+ class Introspector
9
+ extend T::Sig
10
+
11
+ PG_TYPE_MAP = T.let({
12
+ "integer" => :integer,
13
+ "bigint" => :bigint,
14
+ "smallint" => :smallint,
15
+ "serial" => :integer,
16
+ "bigserial" => :bigint,
17
+ "real" => :float,
18
+ "double precision" => :double,
19
+ "numeric" => :decimal,
20
+ "decimal" => :decimal,
21
+ "character varying" => :string,
22
+ "varchar" => :string,
23
+ "character" => :string,
24
+ "char" => :string,
25
+ "text" => :text,
26
+ "boolean" => :boolean,
27
+ "date" => :date,
28
+ "timestamp without time zone" => :datetime,
29
+ "timestamp with time zone" => :datetime_tz,
30
+ "time without time zone" => :time,
31
+ "time with time zone" => :time_tz,
32
+ "uuid" => :uuid,
33
+ "json" => :json,
34
+ "jsonb" => :jsonb,
35
+ "bytea" => :binary,
36
+ "inet" => :inet,
37
+ "cidr" => :cidr,
38
+ "macaddr" => :macaddr
39
+ }.freeze, T::Hash[String, Symbol])
40
+
41
+ sig { params(connection: T.untyped).void }
42
+ def initialize(connection)
43
+ @connection = connection
44
+ end
45
+
46
+ sig { params(schema: String).returns(T::Array[String]) }
47
+ def introspect_tables(schema: "public")
48
+ tables_sql = <<~SQL
49
+ SELECT table_name
50
+ FROM information_schema.tables
51
+ WHERE table_schema = $1
52
+ AND table_type = 'BASE TABLE'
53
+ ORDER BY table_name
54
+ SQL
55
+
56
+ result = @connection.exec_params(tables_sql, [schema])
57
+ result.map { |row| row["table_name"] }
58
+ end
59
+
60
+ sig { params(table_name: String, schema: String).returns(T::Array[ColumnInfo]) }
61
+ def introspect_columns(table_name, schema: "public")
62
+ columns_sql = <<~SQL
63
+ SELECT
64
+ column_name,
65
+ data_type,
66
+ is_nullable,
67
+ column_default,
68
+ character_maximum_length,
69
+ numeric_precision,
70
+ numeric_scale
71
+ FROM information_schema.columns
72
+ WHERE table_schema = $1
73
+ AND table_name = $2
74
+ ORDER BY ordinal_position
75
+ SQL
76
+
77
+ result = @connection.exec_params(columns_sql, [schema, table_name])
78
+ result.map do |row|
79
+ ColumnInfo.new(
80
+ name: row["column_name"],
81
+ type: map_pg_type(row["data_type"]),
82
+ pg_type: row["data_type"],
83
+ nullable: row["is_nullable"] == "YES",
84
+ default: row["column_default"],
85
+ max_length: row["character_maximum_length"]&.to_i,
86
+ precision: row["numeric_precision"]&.to_i,
87
+ scale: row["numeric_scale"]&.to_i
88
+ )
89
+ end
90
+ end
91
+
92
+ sig { params(table_name: String, schema: String).returns(T::Array[String]) }
93
+ def introspect_primary_keys(table_name, schema: "public")
94
+ pk_sql = <<~SQL
95
+ SELECT kcu.column_name
96
+ FROM information_schema.table_constraints tc
97
+ JOIN information_schema.key_column_usage kcu
98
+ ON tc.constraint_name = kcu.constraint_name
99
+ AND tc.table_schema = kcu.table_schema
100
+ WHERE tc.constraint_type = 'PRIMARY KEY'
101
+ AND tc.table_schema = $1
102
+ AND tc.table_name = $2
103
+ ORDER BY kcu.ordinal_position
104
+ SQL
105
+
106
+ result = @connection.exec_params(pk_sql, [schema, table_name])
107
+ result.map { |row| row["column_name"] }
108
+ end
109
+
110
+ sig { params(table_name: String, schema: String).returns(T::Array[ForeignKeyInfo]) }
111
+ def introspect_foreign_keys(table_name, schema: "public")
112
+ fk_sql = <<~SQL
113
+ SELECT
114
+ kcu.column_name,
115
+ ccu.table_name AS foreign_table_name,
116
+ ccu.column_name AS foreign_column_name
117
+ FROM information_schema.table_constraints tc
118
+ JOIN information_schema.key_column_usage kcu
119
+ ON tc.constraint_name = kcu.constraint_name
120
+ AND tc.table_schema = kcu.table_schema
121
+ JOIN information_schema.constraint_column_usage ccu
122
+ ON ccu.constraint_name = tc.constraint_name
123
+ AND ccu.table_schema = tc.table_schema
124
+ WHERE tc.constraint_type = 'FOREIGN KEY'
125
+ AND tc.table_schema = $1
126
+ AND tc.table_name = $2
127
+ SQL
128
+
129
+ result = @connection.exec_params(fk_sql, [schema, table_name])
130
+ result.map do |row|
131
+ ForeignKeyInfo.new(
132
+ column_name: row["column_name"],
133
+ foreign_table: row["foreign_table_name"],
134
+ foreign_column: row["foreign_column_name"]
135
+ )
136
+ end
137
+ end
138
+
139
+ sig { params(schema: String).returns(T::Array[TableInfo]) }
140
+ def introspect_schema(schema: "public")
141
+ tables = introspect_tables(schema: schema)
142
+ tables.map do |table_name|
143
+ TableInfo.new(
144
+ name: table_name,
145
+ columns: introspect_columns(table_name, schema: schema),
146
+ primary_keys: introspect_primary_keys(table_name, schema: schema),
147
+ foreign_keys: introspect_foreign_keys(table_name, schema: schema)
148
+ )
149
+ end
150
+ end
151
+
152
+ private
153
+
154
+ sig { params(pg_type: String).returns(Symbol) }
155
+ def map_pg_type(pg_type)
156
+ PG_TYPE_MAP.fetch(pg_type.downcase, :unknown)
157
+ end
158
+ end
159
+
160
+ class ColumnInfo
161
+ extend T::Sig
162
+
163
+ sig { returns(String) }
164
+ attr_reader :name
165
+
166
+ sig { returns(Symbol) }
167
+ attr_reader :type
168
+
169
+ sig { returns(String) }
170
+ attr_reader :pg_type
171
+
172
+ sig { returns(T::Boolean) }
173
+ attr_reader :nullable
174
+
175
+ sig { returns(T.nilable(String)) }
176
+ attr_reader :default
177
+
178
+ sig { returns(T.nilable(Integer)) }
179
+ attr_reader :max_length
180
+
181
+ sig { returns(T.nilable(Integer)) }
182
+ attr_reader :precision
183
+
184
+ sig { returns(T.nilable(Integer)) }
185
+ attr_reader :scale
186
+
187
+ sig do
188
+ params(
189
+ name: String,
190
+ type: Symbol,
191
+ pg_type: String,
192
+ nullable: T::Boolean,
193
+ default: T.nilable(String),
194
+ max_length: T.nilable(Integer),
195
+ precision: T.nilable(Integer),
196
+ scale: T.nilable(Integer)
197
+ ).void
198
+ end
199
+ def initialize(name:, type:, pg_type:, nullable:, default:, max_length:, precision:, scale:)
200
+ @name = name
201
+ @type = type
202
+ @pg_type = pg_type
203
+ @nullable = nullable
204
+ @default = default
205
+ @max_length = max_length
206
+ @precision = precision
207
+ @scale = scale
208
+ freeze
209
+ end
210
+ end
211
+
212
+ class ForeignKeyInfo
213
+ extend T::Sig
214
+
215
+ sig { returns(String) }
216
+ attr_reader :column_name
217
+
218
+ sig { returns(String) }
219
+ attr_reader :foreign_table
220
+
221
+ sig { returns(String) }
222
+ attr_reader :foreign_column
223
+
224
+ sig { params(column_name: String, foreign_table: String, foreign_column: String).void }
225
+ def initialize(column_name:, foreign_table:, foreign_column:)
226
+ @column_name = column_name
227
+ @foreign_table = foreign_table
228
+ @foreign_column = foreign_column
229
+ freeze
230
+ end
231
+ end
232
+
233
+ class TableInfo
234
+ extend T::Sig
235
+
236
+ sig { returns(String) }
237
+ attr_reader :name
238
+
239
+ sig { returns(T::Array[ColumnInfo]) }
240
+ attr_reader :columns
241
+
242
+ sig { returns(T::Array[String]) }
243
+ attr_reader :primary_keys
244
+
245
+ sig { returns(T::Array[ForeignKeyInfo]) }
246
+ attr_reader :foreign_keys
247
+
248
+ sig do
249
+ params(
250
+ name: String,
251
+ columns: T::Array[ColumnInfo],
252
+ primary_keys: T::Array[String],
253
+ foreign_keys: T::Array[ForeignKeyInfo]
254
+ ).void
255
+ end
256
+ def initialize(name:, columns:, primary_keys:, foreign_keys:)
257
+ @name = name
258
+ @columns = T.let(columns.freeze, T::Array[ColumnInfo])
259
+ @primary_keys = T.let(primary_keys.freeze, T::Array[String])
260
+ @foreign_keys = T.let(foreign_keys.freeze, T::Array[ForeignKeyInfo])
261
+ freeze
262
+ end
263
+ end
264
+ end
265
+ end
@@ -0,0 +1,9 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative "generator/introspector"
4
+ require_relative "generator/code_generator"
5
+
6
+ module Rooq
7
+ module Generator
8
+ end
9
+ end
@@ -0,0 +1,98 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "json"
4
+ require "time"
5
+ require "date"
6
+
7
+ module Rooq
8
+ # ParameterConverter converts Ruby objects to PostgreSQL-compatible parameter values.
9
+ #
10
+ # Conversions:
11
+ # - Time/DateTime -> ISO 8601 string
12
+ # - Date -> ISO 8601 date string
13
+ # - Hash -> JSON string
14
+ # - Array of primitives -> PostgreSQL array literal
15
+ # - Array of hashes -> JSON array string
16
+ # - Symbol -> String
17
+ # - Other types pass through unchanged
18
+ #
19
+ # @example
20
+ # converter = ParameterConverter.new
21
+ # converter.convert(Time.now) # => "2024-01-15T10:30:45+00:00"
22
+ # converter.convert({ key: "value" }) # => '{"key":"value"}'
23
+ # converter.convert([1, 2, 3]) # => "{1,2,3}"
24
+ class ParameterConverter
25
+ # Convert a single parameter value.
26
+ # @param value [Object] the value to convert
27
+ # @return [Object] the converted value
28
+ def convert(value)
29
+ case value
30
+ when nil, true, false, Integer, Float
31
+ value
32
+ when String
33
+ value
34
+ when Time, DateTime
35
+ value.iso8601
36
+ when Date
37
+ value.iso8601
38
+ when Hash
39
+ JSON.generate(value)
40
+ when Array
41
+ convert_array(value)
42
+ when Symbol
43
+ value.to_s
44
+ else
45
+ value
46
+ end
47
+ end
48
+
49
+ # Convert an array of parameter values.
50
+ # @param params [Array] the parameters to convert
51
+ # @return [Array] the converted parameters
52
+ def convert_all(params)
53
+ params.map { |p| convert(p) }
54
+ end
55
+
56
+ private
57
+
58
+ def convert_array(array)
59
+ return "{}" if array.empty?
60
+
61
+ # If array contains hashes, convert to JSON array
62
+ if array.any? { |el| el.is_a?(Hash) }
63
+ return JSON.generate(array)
64
+ end
65
+
66
+ # Otherwise convert to PostgreSQL array literal
67
+ elements = array.map { |el| format_pg_array_element(el) }
68
+ "{#{elements.join(',')}}"
69
+ end
70
+
71
+ def format_pg_array_element(value)
72
+ return "NULL" if value.nil?
73
+
74
+ str = convert(value).to_s
75
+
76
+ # Quote if contains special characters
77
+ if needs_quoting?(str)
78
+ "\"#{escape_pg_string(str)}\""
79
+ else
80
+ str
81
+ end
82
+ end
83
+
84
+ def needs_quoting?(str)
85
+ str.include?(",") ||
86
+ str.include?(" ") ||
87
+ str.include?('"') ||
88
+ str.include?("\\") ||
89
+ str.include?("{") ||
90
+ str.include?("}") ||
91
+ str.empty?
92
+ end
93
+
94
+ def escape_pg_string(str)
95
+ str.gsub('\\', '\\\\').gsub('"', '\\"')
96
+ end
97
+ end
98
+ end
@@ -0,0 +1,176 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Rooq
4
+ class QueryValidator
5
+ def initialize(tables)
6
+ @tables = tables.each_with_object({}) do |table, hash|
7
+ hash[table.name] = table
8
+ end
9
+ end
10
+
11
+ def validate_select(query)
12
+ errors = []
13
+
14
+ # Validate FROM table
15
+ if query.from_table
16
+ table_error = validate_table(query.from_table)
17
+ errors << table_error if table_error
18
+ end
19
+
20
+ # Validate selected fields
21
+ query.selected_fields.each do |field|
22
+ next unless field.is_a?(Field)
23
+
24
+ field_error = validate_field(field)
25
+ errors << field_error if field_error
26
+ end
27
+
28
+ # Validate conditions
29
+ if query.conditions
30
+ condition_errors = validate_condition(query.conditions)
31
+ errors.concat(condition_errors)
32
+ end
33
+
34
+ # Validate order specs
35
+ query.order_specs.each do |spec|
36
+ field_error = validate_field(spec.field)
37
+ errors << field_error if field_error
38
+ end
39
+
40
+ # Validate joins
41
+ query.joins.each do |join|
42
+ table_error = validate_table(join.table)
43
+ errors << table_error if table_error
44
+
45
+ condition_errors = validate_condition(join.condition)
46
+ errors.concat(condition_errors)
47
+ end
48
+
49
+ raise QueryValidationError.new(errors) unless errors.empty?
50
+
51
+ true
52
+ end
53
+
54
+ def validate_insert(query)
55
+ errors = []
56
+
57
+ table_error = validate_table(query.table)
58
+ errors << table_error if table_error
59
+
60
+ query.column_list.each do |field|
61
+ next unless field.is_a?(Field)
62
+
63
+ field_error = validate_field(field)
64
+ errors << field_error if field_error
65
+ end
66
+
67
+ raise QueryValidationError.new(errors) unless errors.empty?
68
+
69
+ true
70
+ end
71
+
72
+ def validate_update(query)
73
+ errors = []
74
+
75
+ table_error = validate_table(query.table)
76
+ errors << table_error if table_error
77
+
78
+ query.set_values.each_key do |field|
79
+ next unless field.is_a?(Field)
80
+
81
+ field_error = validate_field(field)
82
+ errors << field_error if field_error
83
+ end
84
+
85
+ if query.conditions
86
+ condition_errors = validate_condition(query.conditions)
87
+ errors.concat(condition_errors)
88
+ end
89
+
90
+ raise QueryValidationError.new(errors) unless errors.empty?
91
+
92
+ true
93
+ end
94
+
95
+ def validate_delete(query)
96
+ errors = []
97
+
98
+ table_error = validate_table(query.table)
99
+ errors << table_error if table_error
100
+
101
+ if query.conditions
102
+ condition_errors = validate_condition(query.conditions)
103
+ errors.concat(condition_errors)
104
+ end
105
+
106
+ raise QueryValidationError.new(errors) unless errors.empty?
107
+
108
+ true
109
+ end
110
+
111
+ private
112
+
113
+ def validate_table(table)
114
+ return nil unless table.is_a?(Table)
115
+ return nil if @tables.key?(table.name)
116
+
117
+ "Unknown table '#{table.name}'. Known tables: #{@tables.keys.join(', ')}"
118
+ end
119
+
120
+ def validate_field(field)
121
+ table = @tables[field.table_name]
122
+ return "Unknown table '#{field.table_name}' for field '#{field.name}'" unless table
123
+ return nil if table.fields.key?(field.name)
124
+
125
+ "Unknown field '#{field.name}' on table '#{field.table_name}'. Available: #{table.fields.keys.join(', ')}"
126
+ end
127
+
128
+ def validate_condition(condition)
129
+ case condition
130
+ when Condition
131
+ error = validate_field(condition.field)
132
+ error ? [error] : []
133
+ when CombinedCondition
134
+ condition.conditions.flat_map { |c| validate_condition(c) }
135
+ else
136
+ []
137
+ end
138
+ end
139
+ end
140
+
141
+ class QueryValidationError < Error
142
+ attr_reader :validation_errors
143
+
144
+ def initialize(errors)
145
+ @validation_errors = errors
146
+ super("Query validation failed:\n - #{errors.join("\n - ")}")
147
+ end
148
+ end
149
+
150
+ class ValidatingExecutor < Executor
151
+ def initialize(connection, tables, dialect: Dialect::PostgreSQL.new)
152
+ super(connection, dialect: dialect)
153
+ @validator = QueryValidator.new(tables)
154
+ end
155
+
156
+ def execute(query)
157
+ validate_query(query)
158
+ super
159
+ end
160
+
161
+ private
162
+
163
+ def validate_query(query)
164
+ case query
165
+ when DSL::SelectQuery
166
+ @validator.validate_select(query)
167
+ when DSL::InsertQuery
168
+ @validator.validate_insert(query)
169
+ when DSL::UpdateQuery
170
+ @validator.validate_update(query)
171
+ when DSL::DeleteQuery
172
+ @validator.validate_delete(query)
173
+ end
174
+ end
175
+ end
176
+ end