better_auth 0.1.1 → 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/README.md +106 -16
- data/lib/better_auth/adapters/base.rb +49 -0
- data/lib/better_auth/adapters/internal_adapter.rb +439 -0
- data/lib/better_auth/adapters/memory.rb +232 -0
- data/lib/better_auth/adapters/mongodb.rb +369 -0
- data/lib/better_auth/adapters/mssql.rb +42 -0
- data/lib/better_auth/adapters/mysql.rb +33 -0
- data/lib/better_auth/adapters/postgres.rb +17 -0
- data/lib/better_auth/adapters/sql.rb +425 -0
- data/lib/better_auth/adapters/sqlite.rb +20 -0
- data/lib/better_auth/api.rb +226 -0
- data/lib/better_auth/api_error.rb +53 -0
- data/lib/better_auth/auth.rb +42 -0
- data/lib/better_auth/configuration.rb +399 -0
- data/lib/better_auth/context.rb +210 -0
- data/lib/better_auth/cookies.rb +278 -0
- data/lib/better_auth/core.rb +37 -1
- data/lib/better_auth/crypto/jwe.rb +76 -0
- data/lib/better_auth/crypto.rb +191 -0
- data/lib/better_auth/database_hooks.rb +114 -0
- data/lib/better_auth/endpoint.rb +326 -0
- data/lib/better_auth/error.rb +52 -0
- data/lib/better_auth/middleware/origin_check.rb +128 -0
- data/lib/better_auth/password.rb +120 -0
- data/lib/better_auth/plugin.rb +129 -0
- data/lib/better_auth/plugin_context.rb +16 -0
- data/lib/better_auth/plugin_registry.rb +67 -0
- data/lib/better_auth/plugins/access.rb +87 -0
- data/lib/better_auth/plugins/additional_fields.rb +29 -0
- data/lib/better_auth/plugins/admin/schema.rb +28 -0
- data/lib/better_auth/plugins/admin.rb +518 -0
- data/lib/better_auth/plugins/anonymous.rb +198 -0
- data/lib/better_auth/plugins/api_key.rb +16 -0
- data/lib/better_auth/plugins/bearer.rb +128 -0
- data/lib/better_auth/plugins/captcha.rb +159 -0
- data/lib/better_auth/plugins/custom_session.rb +84 -0
- data/lib/better_auth/plugins/device_authorization.rb +302 -0
- data/lib/better_auth/plugins/email_otp.rb +536 -0
- data/lib/better_auth/plugins/expo.rb +88 -0
- data/lib/better_auth/plugins/generic_oauth.rb +780 -0
- data/lib/better_auth/plugins/have_i_been_pwned.rb +94 -0
- data/lib/better_auth/plugins/jwt.rb +482 -0
- data/lib/better_auth/plugins/last_login_method.rb +92 -0
- data/lib/better_auth/plugins/magic_link.rb +181 -0
- data/lib/better_auth/plugins/mcp.rb +342 -0
- data/lib/better_auth/plugins/multi_session.rb +173 -0
- data/lib/better_auth/plugins/oauth_protocol.rb +348 -0
- data/lib/better_auth/plugins/oauth_provider.rb +16 -0
- data/lib/better_auth/plugins/oauth_proxy.rb +257 -0
- data/lib/better_auth/plugins/oidc_provider.rb +597 -0
- data/lib/better_auth/plugins/one_tap.rb +154 -0
- data/lib/better_auth/plugins/one_time_token.rb +106 -0
- data/lib/better_auth/plugins/open_api.rb +489 -0
- data/lib/better_auth/plugins/organization/schema.rb +106 -0
- data/lib/better_auth/plugins/organization.rb +990 -0
- data/lib/better_auth/plugins/passkey.rb +16 -0
- data/lib/better_auth/plugins/phone_number.rb +321 -0
- data/lib/better_auth/plugins/scim.rb +16 -0
- data/lib/better_auth/plugins/siwe.rb +242 -0
- data/lib/better_auth/plugins/sso.rb +16 -0
- data/lib/better_auth/plugins/stripe.rb +16 -0
- data/lib/better_auth/plugins/two_factor.rb +514 -0
- data/lib/better_auth/plugins/username.rb +278 -0
- data/lib/better_auth/plugins.rb +46 -0
- data/lib/better_auth/rate_limiter.rb +215 -0
- data/lib/better_auth/request_ip.rb +70 -0
- data/lib/better_auth/router.rb +365 -0
- data/lib/better_auth/routes/account.rb +211 -0
- data/lib/better_auth/routes/email_verification.rb +108 -0
- data/lib/better_auth/routes/error.rb +102 -0
- data/lib/better_auth/routes/ok.rb +15 -0
- data/lib/better_auth/routes/password.rb +164 -0
- data/lib/better_auth/routes/session.rb +137 -0
- data/lib/better_auth/routes/sign_in.rb +90 -0
- data/lib/better_auth/routes/sign_out.rb +15 -0
- data/lib/better_auth/routes/sign_up.rb +145 -0
- data/lib/better_auth/routes/social.rb +188 -0
- data/lib/better_auth/routes/user.rb +193 -0
- data/lib/better_auth/schema/sql.rb +191 -0
- data/lib/better_auth/schema.rb +275 -0
- data/lib/better_auth/session.rb +122 -0
- data/lib/better_auth/session_store.rb +91 -0
- data/lib/better_auth/social_providers/apple.rb +55 -0
- data/lib/better_auth/social_providers/base.rb +67 -0
- data/lib/better_auth/social_providers/discord.rb +59 -0
- data/lib/better_auth/social_providers/github.rb +59 -0
- data/lib/better_auth/social_providers/gitlab.rb +54 -0
- data/lib/better_auth/social_providers/google.rb +65 -0
- data/lib/better_auth/social_providers/microsoft_entra_id.rb +65 -0
- data/lib/better_auth/social_providers.rb +9 -0
- data/lib/better_auth/version.rb +1 -1
- data/lib/better_auth.rb +87 -2
- metadata +218 -21
- data/.ruby-version +0 -1
- data/.standard.yml +0 -12
- data/.vscode/settings.json +0 -22
- data/AGENTS.md +0 -50
- data/CLAUDE.md +0 -1
- data/CODE_OF_CONDUCT.md +0 -173
- data/CONTRIBUTING.md +0 -187
- data/Gemfile +0 -12
- data/Makefile +0 -207
- data/Rakefile +0 -25
- data/SECURITY.md +0 -28
- data/docker-compose.yml +0 -63
|
@@ -0,0 +1,425 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "securerandom"
|
|
4
|
+
require "time"
|
|
5
|
+
|
|
6
|
+
module BetterAuth
|
|
7
|
+
module Adapters
|
|
8
|
+
class SQL < Base
|
|
9
|
+
attr_reader :connection, :dialect
|
|
10
|
+
|
|
11
|
+
def initialize(options, connection:, dialect:)
|
|
12
|
+
super(options)
|
|
13
|
+
@connection = connection
|
|
14
|
+
@dialect = dialect.to_sym
|
|
15
|
+
end
|
|
16
|
+
|
|
17
|
+
def create(model:, data:, force_allow_id: false)
|
|
18
|
+
model = model.to_s
|
|
19
|
+
input = transform_input(model, data, "create", force_allow_id)
|
|
20
|
+
table = table_for(model)
|
|
21
|
+
columns = input.keys.map { |field| storage_field(model, field) }
|
|
22
|
+
params = input.keys.map { |field| input[field] }
|
|
23
|
+
placeholders = params.each_index.map { |index| placeholder(index + 1) }
|
|
24
|
+
returning = (dialect == :postgres) ? " RETURNING *" : ""
|
|
25
|
+
sql = "INSERT INTO #{quote(table)} (#{columns.map { |column| quote(column) }.join(", ")}) VALUES (#{placeholders.join(", ")})#{returning}"
|
|
26
|
+
rows = execute(sql, params)
|
|
27
|
+
row = rows.first
|
|
28
|
+
return normalize_record(model, row) if row
|
|
29
|
+
|
|
30
|
+
find_one(model: model, where: [{field: "id", value: input.fetch("id")}])
|
|
31
|
+
end
|
|
32
|
+
|
|
33
|
+
def find_one(model:, where: [], select: nil, join: nil)
|
|
34
|
+
if collection_join?(model.to_s, join)
|
|
35
|
+
find_many(model: model, where: where, select: select, join: join).first
|
|
36
|
+
else
|
|
37
|
+
find_many(model: model, where: where, select: select, join: join, limit: 1).first
|
|
38
|
+
end
|
|
39
|
+
end
|
|
40
|
+
|
|
41
|
+
def find_many(model:, where: [], sort_by: nil, limit: nil, offset: nil, select: nil, join: nil)
|
|
42
|
+
model = model.to_s
|
|
43
|
+
params = []
|
|
44
|
+
sql = +"SELECT "
|
|
45
|
+
sql << "TOP (#{Integer(limit)}) " if dialect == :mssql && limit && !offset
|
|
46
|
+
sql << select_sql(model, select, join)
|
|
47
|
+
sql << " FROM "
|
|
48
|
+
sql << quote(table_for(model))
|
|
49
|
+
sql << join_sql(model, join)
|
|
50
|
+
where_sql = build_where(model, where || [], params)
|
|
51
|
+
sql << " WHERE #{where_sql}" unless where_sql.empty?
|
|
52
|
+
sql << order_sql(model, sort_by) if sort_by
|
|
53
|
+
append_pagination_sql(sql, model, sort_by, limit, offset)
|
|
54
|
+
|
|
55
|
+
records = execute(sql, params).map { |row| normalize_record(model, row, join: join) }
|
|
56
|
+
collection_join?(model, join) ? aggregate_collection_joins(model, records, join) : records
|
|
57
|
+
end
|
|
58
|
+
|
|
59
|
+
def update(model:, where:, update:)
|
|
60
|
+
model = model.to_s
|
|
61
|
+
if dialect == :postgres
|
|
62
|
+
records = update_many(model: model, where: where, update: update, returning: true)
|
|
63
|
+
return records.is_a?(Array) ? records.first : records
|
|
64
|
+
end
|
|
65
|
+
|
|
66
|
+
existing = find_one(model: model, where: where, select: ["id"])
|
|
67
|
+
return nil unless existing
|
|
68
|
+
|
|
69
|
+
update_many(model: model, where: where, update: update)
|
|
70
|
+
find_one(model: model, where: [{field: "id", value: existing.fetch("id")}])
|
|
71
|
+
end
|
|
72
|
+
|
|
73
|
+
def update_many(model:, where:, update:, returning: false)
|
|
74
|
+
model = model.to_s
|
|
75
|
+
data = transform_input(model, update, "update", true)
|
|
76
|
+
params = []
|
|
77
|
+
assignments = data.each_key.map do |field|
|
|
78
|
+
params << data[field]
|
|
79
|
+
"#{quote(storage_field(model, field))} = #{placeholder(params.length)}"
|
|
80
|
+
end
|
|
81
|
+
where_sql = build_where(model, where || [], params)
|
|
82
|
+
sql = +"UPDATE "
|
|
83
|
+
sql << quote(table_for(model))
|
|
84
|
+
sql << " SET "
|
|
85
|
+
sql << assignments.join(", ")
|
|
86
|
+
sql << " WHERE #{where_sql}" unless where_sql.empty?
|
|
87
|
+
sql << " RETURNING *" if dialect == :postgres
|
|
88
|
+
rows = execute(sql, params).map { |row| normalize_record(model, row) }
|
|
89
|
+
return rows if returning || dialect == :postgres
|
|
90
|
+
|
|
91
|
+
nil
|
|
92
|
+
end
|
|
93
|
+
|
|
94
|
+
def delete(model:, where:)
|
|
95
|
+
delete_many(model: model, where: where)
|
|
96
|
+
nil
|
|
97
|
+
end
|
|
98
|
+
|
|
99
|
+
def delete_many(model:, where:)
|
|
100
|
+
model = model.to_s
|
|
101
|
+
params = []
|
|
102
|
+
where_sql = build_where(model, where || [], params)
|
|
103
|
+
sql = +"DELETE FROM "
|
|
104
|
+
sql << quote(table_for(model))
|
|
105
|
+
sql << " WHERE #{where_sql}" unless where_sql.empty?
|
|
106
|
+
result = execute(sql, params)
|
|
107
|
+
affected_rows(result)
|
|
108
|
+
end
|
|
109
|
+
|
|
110
|
+
def count(model:, where: nil)
|
|
111
|
+
model = model.to_s
|
|
112
|
+
params = []
|
|
113
|
+
where_sql = build_where(model, where || [], params)
|
|
114
|
+
sql = +"SELECT COUNT(*) AS count FROM "
|
|
115
|
+
sql << quote(table_for(model))
|
|
116
|
+
sql << " WHERE #{where_sql}" unless where_sql.empty?
|
|
117
|
+
row = execute(sql, params).first || {}
|
|
118
|
+
(row["count"] || row[:count] || 0).to_i
|
|
119
|
+
end
|
|
120
|
+
|
|
121
|
+
def transaction
|
|
122
|
+
execute("BEGIN", [])
|
|
123
|
+
result = yield self
|
|
124
|
+
execute("COMMIT", [])
|
|
125
|
+
result
|
|
126
|
+
rescue
|
|
127
|
+
execute("ROLLBACK", [])
|
|
128
|
+
raise
|
|
129
|
+
end
|
|
130
|
+
|
|
131
|
+
private
|
|
132
|
+
|
|
133
|
+
def transform_input(model, data, action, force_allow_id)
|
|
134
|
+
fields = Schema.auth_tables(options).fetch(model).fetch(:fields)
|
|
135
|
+
input = stringify_keys(data)
|
|
136
|
+
output = {}
|
|
137
|
+
|
|
138
|
+
fields.each do |field, attributes|
|
|
139
|
+
next if field == "id" && input.key?(field) && !force_allow_id
|
|
140
|
+
|
|
141
|
+
value_provided = input.key?(field)
|
|
142
|
+
value = input[field]
|
|
143
|
+
if value_provided && attributes[:input] == false && value && !force_allow_id
|
|
144
|
+
raise APIError.new("BAD_REQUEST", message: "#{field} is not allowed to be set")
|
|
145
|
+
end
|
|
146
|
+
|
|
147
|
+
if !value_provided && action == "create" && attributes.key?(:default_value)
|
|
148
|
+
value = resolve_default(attributes[:default_value])
|
|
149
|
+
value_provided = true
|
|
150
|
+
elsif !value_provided && action == "update" && attributes[:on_update]
|
|
151
|
+
value = resolve_default(attributes[:on_update])
|
|
152
|
+
value_provided = true
|
|
153
|
+
end
|
|
154
|
+
if !value_provided && action == "create" && attributes[:required]
|
|
155
|
+
raise APIError.new("BAD_REQUEST", message: "#{field} is required") unless field == "id"
|
|
156
|
+
end
|
|
157
|
+
output[field] = coerce_value(value, attributes) if value_provided
|
|
158
|
+
end
|
|
159
|
+
|
|
160
|
+
output["id"] = generated_id if action == "create" && !output.key?("id")
|
|
161
|
+
output
|
|
162
|
+
end
|
|
163
|
+
|
|
164
|
+
def select_sql(model, select, join)
|
|
165
|
+
fields = Array(select).empty? ? schema_for(model).fetch(:fields).keys : Array(select).map { |field| storage_key(field) }
|
|
166
|
+
columns = fields.map do |field|
|
|
167
|
+
column = storage_field(model, field)
|
|
168
|
+
"#{quote(table_for(model))}.#{quote(column)} AS #{quote(column)}"
|
|
169
|
+
end
|
|
170
|
+
columns.concat(join_select_sql(model, join)) if join
|
|
171
|
+
columns.join(", ")
|
|
172
|
+
end
|
|
173
|
+
|
|
174
|
+
def join_select_sql(model, join)
|
|
175
|
+
join.flat_map do |join_model, _enabled|
|
|
176
|
+
join_model = join_model.to_s
|
|
177
|
+
schema_for(join_model).fetch(:fields).map do |field, attributes|
|
|
178
|
+
column = attributes[:field_name] || physical_name(field)
|
|
179
|
+
"#{quote(join_model)}.#{quote(column)} AS #{quote("#{join_model}__#{column}")}"
|
|
180
|
+
end
|
|
181
|
+
end
|
|
182
|
+
end
|
|
183
|
+
|
|
184
|
+
def join_sql(model, join)
|
|
185
|
+
return "" unless join
|
|
186
|
+
|
|
187
|
+
join.map do |join_model, _enabled|
|
|
188
|
+
join_model = join_model.to_s
|
|
189
|
+
case [model, join_model]
|
|
190
|
+
when ["session", "user"], ["account", "user"]
|
|
191
|
+
" LEFT JOIN #{quote(table_for("user"))} AS #{quote("user")} ON #{quote("user")}.#{quote("id")} = #{quote(table_for(model))}.#{quote("user_id")}"
|
|
192
|
+
when ["user", "account"]
|
|
193
|
+
" LEFT JOIN #{quote(table_for("account"))} AS #{quote("account")} ON #{quote("account")}.#{quote("user_id")} = #{quote(table_for(model))}.#{quote("id")}"
|
|
194
|
+
else
|
|
195
|
+
""
|
|
196
|
+
end
|
|
197
|
+
end.join
|
|
198
|
+
end
|
|
199
|
+
|
|
200
|
+
def build_where(model, where, params)
|
|
201
|
+
Array(where).each_with_index.map do |clause, index|
|
|
202
|
+
field = storage_key(fetch_key(clause, :field))
|
|
203
|
+
column = "#{quote(table_for(model))}.#{quote(storage_field(model, field))}"
|
|
204
|
+
operator = (fetch_key(clause, :operator) || "eq").to_s
|
|
205
|
+
value = fetch_key(clause, :value)
|
|
206
|
+
|
|
207
|
+
expression = case operator
|
|
208
|
+
when "in", "not_in"
|
|
209
|
+
values = Array(value)
|
|
210
|
+
placeholders = values.map do |entry|
|
|
211
|
+
params << entry
|
|
212
|
+
placeholder(params.length)
|
|
213
|
+
end.join(", ")
|
|
214
|
+
sql_operator = (operator == "not_in") ? "NOT IN" : "IN"
|
|
215
|
+
"#{column} #{sql_operator} (#{placeholders})"
|
|
216
|
+
when "contains", "starts_with", "ends_with"
|
|
217
|
+
pattern = case operator
|
|
218
|
+
when "starts_with" then "#{value}%"
|
|
219
|
+
when "ends_with" then "%#{value}"
|
|
220
|
+
else "%#{value}%"
|
|
221
|
+
end
|
|
222
|
+
params << pattern
|
|
223
|
+
"#{column} LIKE #{placeholder(params.length)}"
|
|
224
|
+
else
|
|
225
|
+
params << value
|
|
226
|
+
"#{column} #{sql_operator(operator)} #{placeholder(params.length)}"
|
|
227
|
+
end
|
|
228
|
+
|
|
229
|
+
connector = (index.positive? && fetch_key(clause, :connector).to_s.upcase == "OR") ? "OR" : "AND"
|
|
230
|
+
index.zero? ? expression : "#{connector} #{expression}"
|
|
231
|
+
end.join(" ")
|
|
232
|
+
end
|
|
233
|
+
|
|
234
|
+
def order_sql(model, sort_by)
|
|
235
|
+
field = Schema.storage_key(fetch_key(sort_by, :field))
|
|
236
|
+
direction = (fetch_key(sort_by, :direction).to_s.downcase == "desc") ? "DESC" : "ASC"
|
|
237
|
+
" ORDER BY #{quote(table_for(model))}.#{quote(storage_field(model, field))} #{direction}"
|
|
238
|
+
end
|
|
239
|
+
|
|
240
|
+
def append_pagination_sql(sql, model, sort_by, limit, offset)
|
|
241
|
+
if dialect == :mssql
|
|
242
|
+
return if limit && !offset
|
|
243
|
+
return unless offset
|
|
244
|
+
|
|
245
|
+
sql << order_sql(model, {field: "id", direction: "asc"}) unless sort_by
|
|
246
|
+
sql << " OFFSET #{Integer(offset)} ROWS"
|
|
247
|
+
sql << " FETCH NEXT #{Integer(limit)} ROWS ONLY" if limit
|
|
248
|
+
return
|
|
249
|
+
end
|
|
250
|
+
|
|
251
|
+
sql << " LIMIT #{Integer(limit)}" if limit
|
|
252
|
+
sql << " OFFSET #{Integer(offset)}" if offset
|
|
253
|
+
end
|
|
254
|
+
|
|
255
|
+
def sql_operator(operator)
|
|
256
|
+
{
|
|
257
|
+
"ne" => "!=",
|
|
258
|
+
"gt" => ">",
|
|
259
|
+
"gte" => ">=",
|
|
260
|
+
"lt" => "<",
|
|
261
|
+
"lte" => "<="
|
|
262
|
+
}.fetch(operator, "=")
|
|
263
|
+
end
|
|
264
|
+
|
|
265
|
+
def execute(sql, params)
|
|
266
|
+
if connection.respond_to?(:exec_params)
|
|
267
|
+
result = connection.exec_params(sql, params)
|
|
268
|
+
return result.to_a if result.respond_to?(:to_a)
|
|
269
|
+
|
|
270
|
+
result
|
|
271
|
+
elsif connection.respond_to?(:query) && params.empty?
|
|
272
|
+
result = connection.query(sql)
|
|
273
|
+
result.respond_to?(:to_a) ? result.to_a : result
|
|
274
|
+
elsif connection.respond_to?(:prepare)
|
|
275
|
+
statement = connection.prepare(sql)
|
|
276
|
+
result = statement.execute(*params)
|
|
277
|
+
result.respond_to?(:to_a) ? result.to_a : result
|
|
278
|
+
elsif connection.respond_to?(:execute)
|
|
279
|
+
result = connection.execute(sql, params)
|
|
280
|
+
result.respond_to?(:to_a) ? result.to_a : result
|
|
281
|
+
else
|
|
282
|
+
raise Error, "SQL connection must respond to exec_params or prepare"
|
|
283
|
+
end
|
|
284
|
+
end
|
|
285
|
+
|
|
286
|
+
def affected_rows(result)
|
|
287
|
+
return result.cmd_tuples if result.respond_to?(:cmd_tuples)
|
|
288
|
+
return result.affected_rows if result.respond_to?(:affected_rows)
|
|
289
|
+
return connection.changes if connection.respond_to?(:changes)
|
|
290
|
+
return result.to_i if result.respond_to?(:to_i)
|
|
291
|
+
|
|
292
|
+
0
|
|
293
|
+
end
|
|
294
|
+
|
|
295
|
+
def normalize_record(model, row, join: nil)
|
|
296
|
+
return nil unless row
|
|
297
|
+
|
|
298
|
+
fields = schema_for(model).fetch(:fields)
|
|
299
|
+
record = fields.each_with_object({}) do |(field, attributes), output|
|
|
300
|
+
column = attributes[:field_name] || physical_name(field)
|
|
301
|
+
output[field] = coerce_output_value(fetch_row(row, column), attributes) if row_key?(row, column)
|
|
302
|
+
end
|
|
303
|
+
|
|
304
|
+
join&.each_key do |join_model|
|
|
305
|
+
join_model = join_model.to_s
|
|
306
|
+
record[join_model] = normalize_joined_record(join_model, row)
|
|
307
|
+
end
|
|
308
|
+
|
|
309
|
+
record
|
|
310
|
+
end
|
|
311
|
+
|
|
312
|
+
def normalize_joined_record(model, row)
|
|
313
|
+
schema_for(model).fetch(:fields).each_with_object({}) do |(field, attributes), output|
|
|
314
|
+
column = attributes[:field_name] || physical_name(field)
|
|
315
|
+
key = "#{model}__#{column}"
|
|
316
|
+
output[field] = coerce_output_value(fetch_row(row, key), attributes) if row_key?(row, key)
|
|
317
|
+
end
|
|
318
|
+
end
|
|
319
|
+
|
|
320
|
+
def collection_join?(model, join)
|
|
321
|
+
model == "user" && join&.keys&.any? { |join_model| join_model.to_s == "account" }
|
|
322
|
+
end
|
|
323
|
+
|
|
324
|
+
def aggregate_collection_joins(_model, records, _join)
|
|
325
|
+
grouped = {}
|
|
326
|
+
records.each do |record|
|
|
327
|
+
key = record.fetch("id")
|
|
328
|
+
grouped[key] ||= record.merge("account" => [])
|
|
329
|
+
account = record["account"]
|
|
330
|
+
grouped[key]["account"] << account if account&.values&.any?
|
|
331
|
+
end
|
|
332
|
+
grouped.values
|
|
333
|
+
end
|
|
334
|
+
|
|
335
|
+
def row_key?(row, key)
|
|
336
|
+
row.key?(key) || row.key?(key.to_sym)
|
|
337
|
+
end
|
|
338
|
+
|
|
339
|
+
def fetch_row(row, key)
|
|
340
|
+
return row[key] if row.key?(key)
|
|
341
|
+
|
|
342
|
+
row[key.to_sym]
|
|
343
|
+
end
|
|
344
|
+
|
|
345
|
+
def table_for(model)
|
|
346
|
+
schema_for(model).fetch(:model_name)
|
|
347
|
+
end
|
|
348
|
+
|
|
349
|
+
def schema_for(model)
|
|
350
|
+
Schema.auth_tables(options).fetch(model.to_s)
|
|
351
|
+
end
|
|
352
|
+
|
|
353
|
+
def storage_field(model, field)
|
|
354
|
+
schema_for(model).fetch(:fields).fetch(field.to_s).fetch(:field_name, physical_name(field))
|
|
355
|
+
end
|
|
356
|
+
|
|
357
|
+
def quote(identifier)
|
|
358
|
+
Schema::SQL.quote(identifier, dialect)
|
|
359
|
+
end
|
|
360
|
+
|
|
361
|
+
def placeholder(index)
|
|
362
|
+
(dialect == :postgres) ? "$#{index}" : "?"
|
|
363
|
+
end
|
|
364
|
+
|
|
365
|
+
def generated_id
|
|
366
|
+
generator = options.advanced.dig(:database, :generate_id)
|
|
367
|
+
return generator.call.to_s if generator.respond_to?(:call)
|
|
368
|
+
return SecureRandom.uuid if generator == "uuid"
|
|
369
|
+
|
|
370
|
+
SecureRandom.hex(16)
|
|
371
|
+
end
|
|
372
|
+
|
|
373
|
+
def resolve_default(default)
|
|
374
|
+
default.respond_to?(:call) ? default.call : default
|
|
375
|
+
end
|
|
376
|
+
|
|
377
|
+
def coerce_value(value, attributes)
|
|
378
|
+
return value if value.nil?
|
|
379
|
+
return value ? 1 : 0 if dialect == :sqlite && attributes[:type] == "boolean"
|
|
380
|
+
return value.iso8601(6) if dialect == :sqlite && attributes[:type] == "date" && value.respond_to?(:iso8601)
|
|
381
|
+
return Time.parse(value) if attributes[:type] == "date" && value.is_a?(String)
|
|
382
|
+
|
|
383
|
+
value
|
|
384
|
+
end
|
|
385
|
+
|
|
386
|
+
def coerce_output_value(value, attributes)
|
|
387
|
+
return value if value.nil?
|
|
388
|
+
return coerce_boolean(value) if attributes[:type] == "boolean"
|
|
389
|
+
return Time.parse(value) if attributes[:type] == "date" && value.is_a?(String)
|
|
390
|
+
|
|
391
|
+
value
|
|
392
|
+
end
|
|
393
|
+
|
|
394
|
+
def coerce_boolean(value)
|
|
395
|
+
return value if value == true || value == false
|
|
396
|
+
return false if value == 0 || value.to_s == "0" || value.to_s.downcase == "f" || value.to_s.downcase == "false"
|
|
397
|
+
return true if value == 1 || value.to_s == "1" || value.to_s.downcase == "t" || value.to_s.downcase == "true"
|
|
398
|
+
|
|
399
|
+
value
|
|
400
|
+
end
|
|
401
|
+
|
|
402
|
+
def stringify_keys(data)
|
|
403
|
+
data.each_with_object({}) do |(key, value), result|
|
|
404
|
+
result[storage_key(key)] = value
|
|
405
|
+
end
|
|
406
|
+
end
|
|
407
|
+
|
|
408
|
+
def fetch_key(hash, key)
|
|
409
|
+
hash[key] || hash[key.to_s] || hash[storage_key(key)] || hash[storage_key(key).to_sym]
|
|
410
|
+
end
|
|
411
|
+
|
|
412
|
+
def storage_key(value)
|
|
413
|
+
parts = physical_name(value).split("_")
|
|
414
|
+
([parts.first] + parts.drop(1).map(&:capitalize)).join
|
|
415
|
+
end
|
|
416
|
+
|
|
417
|
+
def physical_name(value)
|
|
418
|
+
value.to_s
|
|
419
|
+
.gsub(/([a-z\d])([A-Z])/, "\\1_\\2")
|
|
420
|
+
.tr("-", "_")
|
|
421
|
+
.downcase
|
|
422
|
+
end
|
|
423
|
+
end
|
|
424
|
+
end
|
|
425
|
+
end
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module BetterAuth
|
|
4
|
+
module Adapters
|
|
5
|
+
class SQLite < SQL
|
|
6
|
+
attr_reader :path
|
|
7
|
+
|
|
8
|
+
def initialize(options = nil, path: nil, connection: nil)
|
|
9
|
+
require "sqlite3" unless connection
|
|
10
|
+
|
|
11
|
+
config = options || Configuration.new(secret: Configuration::DEFAULT_SECRET, database: :memory)
|
|
12
|
+
@path = path || ":memory:"
|
|
13
|
+
connection ||= SQLite3::Database.new(@path)
|
|
14
|
+
connection.results_as_hash = true if connection.respond_to?(:results_as_hash=)
|
|
15
|
+
connection.execute("PRAGMA foreign_keys = ON") if connection.respond_to?(:execute)
|
|
16
|
+
super(config, connection: connection, dialect: :sqlite)
|
|
17
|
+
end
|
|
18
|
+
end
|
|
19
|
+
end
|
|
20
|
+
end
|
|
@@ -0,0 +1,226 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module BetterAuth
|
|
4
|
+
class API
|
|
5
|
+
attr_reader :context, :endpoints
|
|
6
|
+
|
|
7
|
+
def initialize(context, endpoints)
|
|
8
|
+
@context = context
|
|
9
|
+
@endpoints = endpoints
|
|
10
|
+
define_endpoint_methods
|
|
11
|
+
end
|
|
12
|
+
|
|
13
|
+
def call_endpoint(key, input = {})
|
|
14
|
+
context.reset_runtime! if context.respond_to?(:reset_runtime!)
|
|
15
|
+
endpoint = endpoints.fetch(key.to_sym)
|
|
16
|
+
input = symbolize_keys(input || {})
|
|
17
|
+
endpoint_context = Endpoint::Context.new(
|
|
18
|
+
path: endpoint.path,
|
|
19
|
+
method: Array(endpoint.methods).first,
|
|
20
|
+
query: input[:query] || {},
|
|
21
|
+
body: input[:body] || {},
|
|
22
|
+
params: input[:params] || {},
|
|
23
|
+
headers: input[:headers] || {},
|
|
24
|
+
context: context
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
result = run_endpoint_with_hooks(endpoint, endpoint_context)
|
|
28
|
+
format_result(result, input)
|
|
29
|
+
end
|
|
30
|
+
|
|
31
|
+
def execute(endpoint, endpoint_context)
|
|
32
|
+
run_endpoint_with_hooks(endpoint, endpoint_context)
|
|
33
|
+
end
|
|
34
|
+
|
|
35
|
+
private
|
|
36
|
+
|
|
37
|
+
def define_endpoint_methods
|
|
38
|
+
endpoints.each_key do |key|
|
|
39
|
+
method_name = normalize_method_name(key)
|
|
40
|
+
define_singleton_method(method_name) do |input = {}|
|
|
41
|
+
call_endpoint(key, input || {})
|
|
42
|
+
end
|
|
43
|
+
end
|
|
44
|
+
end
|
|
45
|
+
|
|
46
|
+
def normalize_method_name(key)
|
|
47
|
+
key.to_s
|
|
48
|
+
.gsub(/([a-z\d])([A-Z])/, "\\1_\\2")
|
|
49
|
+
.tr("-", "_")
|
|
50
|
+
.downcase
|
|
51
|
+
.to_sym
|
|
52
|
+
end
|
|
53
|
+
|
|
54
|
+
def run_endpoint_with_hooks(endpoint, endpoint_context)
|
|
55
|
+
before = run_before_hooks(endpoint_context)
|
|
56
|
+
return normalize_short_circuit(before, endpoint_context) if before
|
|
57
|
+
|
|
58
|
+
result = begin
|
|
59
|
+
endpoint.call(endpoint_context)
|
|
60
|
+
rescue APIError => error
|
|
61
|
+
Endpoint::Result.new(
|
|
62
|
+
response: error,
|
|
63
|
+
status: error.status_code,
|
|
64
|
+
headers: Endpoint::Result.merge_headers(endpoint_context.response_headers, error.headers)
|
|
65
|
+
)
|
|
66
|
+
end
|
|
67
|
+
|
|
68
|
+
return result if result.raw_response?
|
|
69
|
+
|
|
70
|
+
endpoint_context.returned = result.response
|
|
71
|
+
endpoint_context.response_headers = result.headers.dup
|
|
72
|
+
|
|
73
|
+
after_result = run_after_hooks(endpoint_context)
|
|
74
|
+
result.response = after_result.response
|
|
75
|
+
result.headers = after_result.headers
|
|
76
|
+
result.status = after_result.status if after_result.status
|
|
77
|
+
result
|
|
78
|
+
rescue APIError => error
|
|
79
|
+
Endpoint::Result.new(response: error, status: error.status_code, headers: error.headers)
|
|
80
|
+
end
|
|
81
|
+
|
|
82
|
+
def run_before_hooks(endpoint_context)
|
|
83
|
+
before_hooks.each do |hook|
|
|
84
|
+
next unless hook_matches?(hook, endpoint_context)
|
|
85
|
+
|
|
86
|
+
result = hook[:handler].call(endpoint_context)
|
|
87
|
+
next unless result
|
|
88
|
+
|
|
89
|
+
context_data = fetch_key(result, :context)
|
|
90
|
+
if result.is_a?(Hash) && context_data.is_a?(Hash)
|
|
91
|
+
endpoint_context.merge_context!(context_data)
|
|
92
|
+
next
|
|
93
|
+
end
|
|
94
|
+
|
|
95
|
+
return result
|
|
96
|
+
end
|
|
97
|
+
|
|
98
|
+
nil
|
|
99
|
+
end
|
|
100
|
+
|
|
101
|
+
def run_after_hooks(endpoint_context)
|
|
102
|
+
result = Endpoint::Result.new(
|
|
103
|
+
response: endpoint_context.returned,
|
|
104
|
+
status: endpoint_context.status,
|
|
105
|
+
headers: endpoint_context.response_headers
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
after_hooks.each do |hook|
|
|
109
|
+
next unless hook_matches?(hook, endpoint_context)
|
|
110
|
+
|
|
111
|
+
hook_result = begin
|
|
112
|
+
hook[:handler].call(endpoint_context)
|
|
113
|
+
rescue APIError => error
|
|
114
|
+
error
|
|
115
|
+
end
|
|
116
|
+
|
|
117
|
+
result.headers = endpoint_context.response_headers.dup
|
|
118
|
+
|
|
119
|
+
next unless hook_result
|
|
120
|
+
|
|
121
|
+
normalized = Endpoint::Result.from_value(hook_result, endpoint_context)
|
|
122
|
+
result.response = normalized.response
|
|
123
|
+
result.status = normalized.status
|
|
124
|
+
result.headers = normalized.headers
|
|
125
|
+
endpoint_context.returned = result.response
|
|
126
|
+
endpoint_context.response_headers = result.headers
|
|
127
|
+
end
|
|
128
|
+
|
|
129
|
+
result
|
|
130
|
+
end
|
|
131
|
+
|
|
132
|
+
def normalize_short_circuit(value, endpoint_context)
|
|
133
|
+
Endpoint::Result.from_value(value, endpoint_context)
|
|
134
|
+
rescue APIError => error
|
|
135
|
+
Endpoint::Result.new(response: error, status: error.status_code, headers: error.headers)
|
|
136
|
+
end
|
|
137
|
+
|
|
138
|
+
def format_result(result, input)
|
|
139
|
+
return result.to_rack_response if result.raw_response?
|
|
140
|
+
|
|
141
|
+
if result.response.is_a?(APIError)
|
|
142
|
+
return error_response(result.response, headers: result.headers) if input[:as_response]
|
|
143
|
+
|
|
144
|
+
raise result.response
|
|
145
|
+
end
|
|
146
|
+
|
|
147
|
+
return result.to_rack_response if input[:as_response]
|
|
148
|
+
|
|
149
|
+
if input[:return_headers]
|
|
150
|
+
output = {
|
|
151
|
+
headers: result.headers,
|
|
152
|
+
response: result.response
|
|
153
|
+
}
|
|
154
|
+
output[:status] = result.status if input[:return_status]
|
|
155
|
+
return output
|
|
156
|
+
end
|
|
157
|
+
|
|
158
|
+
return {response: result.response, status: result.status} if input[:return_status]
|
|
159
|
+
|
|
160
|
+
result.response
|
|
161
|
+
end
|
|
162
|
+
|
|
163
|
+
def error_response(error, headers: {})
|
|
164
|
+
Endpoint::Result.new(
|
|
165
|
+
response: error.to_h,
|
|
166
|
+
status: error.status_code,
|
|
167
|
+
headers: Endpoint::Result.merge_headers(headers, error.headers)
|
|
168
|
+
).to_rack_response
|
|
169
|
+
end
|
|
170
|
+
|
|
171
|
+
def before_hooks
|
|
172
|
+
hooks = []
|
|
173
|
+
user_before = context.options.hooks&.fetch(:before, nil)
|
|
174
|
+
hooks << {matcher: ->(_ctx) { true }, handler: user_before} if user_before
|
|
175
|
+
hooks.concat(plugin_hooks(:before))
|
|
176
|
+
hooks
|
|
177
|
+
end
|
|
178
|
+
|
|
179
|
+
def after_hooks
|
|
180
|
+
hooks = []
|
|
181
|
+
user_after = context.options.hooks&.fetch(:after, nil)
|
|
182
|
+
hooks << {matcher: ->(_ctx) { true }, handler: user_after} if user_after
|
|
183
|
+
hooks.concat(plugin_hooks(:after))
|
|
184
|
+
hooks
|
|
185
|
+
end
|
|
186
|
+
|
|
187
|
+
def plugin_hooks(type)
|
|
188
|
+
context.options.plugins.flat_map do |plugin|
|
|
189
|
+
hooks = plugin.dig(:hooks, type)
|
|
190
|
+
Array(hooks).map do |hook|
|
|
191
|
+
{
|
|
192
|
+
matcher: hook[:matcher] || ->(_ctx) { true },
|
|
193
|
+
handler: hook[:handler]
|
|
194
|
+
}
|
|
195
|
+
end
|
|
196
|
+
end.compact
|
|
197
|
+
end
|
|
198
|
+
|
|
199
|
+
def hook_matches?(hook, endpoint_context)
|
|
200
|
+
matcher = hook[:matcher] || ->(_ctx) { true }
|
|
201
|
+
matcher.call(endpoint_context)
|
|
202
|
+
end
|
|
203
|
+
|
|
204
|
+
def fetch_key(hash, key)
|
|
205
|
+
return unless hash.is_a?(Hash)
|
|
206
|
+
|
|
207
|
+
hash[key] || hash[key.to_s]
|
|
208
|
+
end
|
|
209
|
+
|
|
210
|
+
def symbolize_keys(value)
|
|
211
|
+
return value unless value.is_a?(Hash)
|
|
212
|
+
|
|
213
|
+
value.each_with_object({}) do |(key, object_value), result|
|
|
214
|
+
result[normalize_key(key)] = object_value.is_a?(Hash) ? symbolize_keys(object_value) : object_value
|
|
215
|
+
end
|
|
216
|
+
end
|
|
217
|
+
|
|
218
|
+
def normalize_key(key)
|
|
219
|
+
key.to_s
|
|
220
|
+
.gsub(/([a-z\d])([A-Z])/, "\\1_\\2")
|
|
221
|
+
.tr("-", "_")
|
|
222
|
+
.downcase
|
|
223
|
+
.to_sym
|
|
224
|
+
end
|
|
225
|
+
end
|
|
226
|
+
end
|