better_auth 0.1.1 → 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +6 -0
  3. data/README.md +106 -16
  4. data/lib/better_auth/adapters/base.rb +49 -0
  5. data/lib/better_auth/adapters/internal_adapter.rb +439 -0
  6. data/lib/better_auth/adapters/memory.rb +232 -0
  7. data/lib/better_auth/adapters/mongodb.rb +369 -0
  8. data/lib/better_auth/adapters/mssql.rb +42 -0
  9. data/lib/better_auth/adapters/mysql.rb +33 -0
  10. data/lib/better_auth/adapters/postgres.rb +17 -0
  11. data/lib/better_auth/adapters/sql.rb +425 -0
  12. data/lib/better_auth/adapters/sqlite.rb +20 -0
  13. data/lib/better_auth/api.rb +226 -0
  14. data/lib/better_auth/api_error.rb +53 -0
  15. data/lib/better_auth/auth.rb +42 -0
  16. data/lib/better_auth/configuration.rb +399 -0
  17. data/lib/better_auth/context.rb +210 -0
  18. data/lib/better_auth/cookies.rb +278 -0
  19. data/lib/better_auth/core.rb +37 -1
  20. data/lib/better_auth/crypto/jwe.rb +76 -0
  21. data/lib/better_auth/crypto.rb +191 -0
  22. data/lib/better_auth/database_hooks.rb +114 -0
  23. data/lib/better_auth/endpoint.rb +326 -0
  24. data/lib/better_auth/error.rb +52 -0
  25. data/lib/better_auth/middleware/origin_check.rb +128 -0
  26. data/lib/better_auth/password.rb +120 -0
  27. data/lib/better_auth/plugin.rb +129 -0
  28. data/lib/better_auth/plugin_context.rb +16 -0
  29. data/lib/better_auth/plugin_registry.rb +67 -0
  30. data/lib/better_auth/plugins/access.rb +87 -0
  31. data/lib/better_auth/plugins/additional_fields.rb +29 -0
  32. data/lib/better_auth/plugins/admin/schema.rb +28 -0
  33. data/lib/better_auth/plugins/admin.rb +518 -0
  34. data/lib/better_auth/plugins/anonymous.rb +198 -0
  35. data/lib/better_auth/plugins/api_key.rb +16 -0
  36. data/lib/better_auth/plugins/bearer.rb +128 -0
  37. data/lib/better_auth/plugins/captcha.rb +159 -0
  38. data/lib/better_auth/plugins/custom_session.rb +84 -0
  39. data/lib/better_auth/plugins/device_authorization.rb +302 -0
  40. data/lib/better_auth/plugins/email_otp.rb +536 -0
  41. data/lib/better_auth/plugins/expo.rb +88 -0
  42. data/lib/better_auth/plugins/generic_oauth.rb +780 -0
  43. data/lib/better_auth/plugins/have_i_been_pwned.rb +94 -0
  44. data/lib/better_auth/plugins/jwt.rb +482 -0
  45. data/lib/better_auth/plugins/last_login_method.rb +92 -0
  46. data/lib/better_auth/plugins/magic_link.rb +181 -0
  47. data/lib/better_auth/plugins/mcp.rb +342 -0
  48. data/lib/better_auth/plugins/multi_session.rb +173 -0
  49. data/lib/better_auth/plugins/oauth_protocol.rb +348 -0
  50. data/lib/better_auth/plugins/oauth_provider.rb +16 -0
  51. data/lib/better_auth/plugins/oauth_proxy.rb +257 -0
  52. data/lib/better_auth/plugins/oidc_provider.rb +597 -0
  53. data/lib/better_auth/plugins/one_tap.rb +154 -0
  54. data/lib/better_auth/plugins/one_time_token.rb +106 -0
  55. data/lib/better_auth/plugins/open_api.rb +489 -0
  56. data/lib/better_auth/plugins/organization/schema.rb +106 -0
  57. data/lib/better_auth/plugins/organization.rb +990 -0
  58. data/lib/better_auth/plugins/passkey.rb +16 -0
  59. data/lib/better_auth/plugins/phone_number.rb +321 -0
  60. data/lib/better_auth/plugins/scim.rb +16 -0
  61. data/lib/better_auth/plugins/siwe.rb +242 -0
  62. data/lib/better_auth/plugins/sso.rb +16 -0
  63. data/lib/better_auth/plugins/stripe.rb +16 -0
  64. data/lib/better_auth/plugins/two_factor.rb +514 -0
  65. data/lib/better_auth/plugins/username.rb +278 -0
  66. data/lib/better_auth/plugins.rb +46 -0
  67. data/lib/better_auth/rate_limiter.rb +215 -0
  68. data/lib/better_auth/request_ip.rb +70 -0
  69. data/lib/better_auth/router.rb +365 -0
  70. data/lib/better_auth/routes/account.rb +211 -0
  71. data/lib/better_auth/routes/email_verification.rb +108 -0
  72. data/lib/better_auth/routes/error.rb +102 -0
  73. data/lib/better_auth/routes/ok.rb +15 -0
  74. data/lib/better_auth/routes/password.rb +164 -0
  75. data/lib/better_auth/routes/session.rb +137 -0
  76. data/lib/better_auth/routes/sign_in.rb +90 -0
  77. data/lib/better_auth/routes/sign_out.rb +15 -0
  78. data/lib/better_auth/routes/sign_up.rb +145 -0
  79. data/lib/better_auth/routes/social.rb +188 -0
  80. data/lib/better_auth/routes/user.rb +193 -0
  81. data/lib/better_auth/schema/sql.rb +191 -0
  82. data/lib/better_auth/schema.rb +275 -0
  83. data/lib/better_auth/session.rb +122 -0
  84. data/lib/better_auth/session_store.rb +91 -0
  85. data/lib/better_auth/social_providers/apple.rb +55 -0
  86. data/lib/better_auth/social_providers/base.rb +67 -0
  87. data/lib/better_auth/social_providers/discord.rb +59 -0
  88. data/lib/better_auth/social_providers/github.rb +59 -0
  89. data/lib/better_auth/social_providers/gitlab.rb +54 -0
  90. data/lib/better_auth/social_providers/google.rb +65 -0
  91. data/lib/better_auth/social_providers/microsoft_entra_id.rb +65 -0
  92. data/lib/better_auth/social_providers.rb +9 -0
  93. data/lib/better_auth/version.rb +1 -1
  94. data/lib/better_auth.rb +87 -2
  95. metadata +218 -21
  96. data/.ruby-version +0 -1
  97. data/.standard.yml +0 -12
  98. data/.vscode/settings.json +0 -22
  99. data/AGENTS.md +0 -50
  100. data/CLAUDE.md +0 -1
  101. data/CODE_OF_CONDUCT.md +0 -173
  102. data/CONTRIBUTING.md +0 -187
  103. data/Gemfile +0 -12
  104. data/Makefile +0 -207
  105. data/Rakefile +0 -25
  106. data/SECURITY.md +0 -28
  107. data/docker-compose.yml +0 -63
@@ -0,0 +1,425 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "securerandom"
4
+ require "time"
5
+
6
+ module BetterAuth
7
+ module Adapters
8
+ class SQL < Base
9
+ attr_reader :connection, :dialect
10
+
11
+ def initialize(options, connection:, dialect:)
12
+ super(options)
13
+ @connection = connection
14
+ @dialect = dialect.to_sym
15
+ end
16
+
17
+ def create(model:, data:, force_allow_id: false)
18
+ model = model.to_s
19
+ input = transform_input(model, data, "create", force_allow_id)
20
+ table = table_for(model)
21
+ columns = input.keys.map { |field| storage_field(model, field) }
22
+ params = input.keys.map { |field| input[field] }
23
+ placeholders = params.each_index.map { |index| placeholder(index + 1) }
24
+ returning = (dialect == :postgres) ? " RETURNING *" : ""
25
+ sql = "INSERT INTO #{quote(table)} (#{columns.map { |column| quote(column) }.join(", ")}) VALUES (#{placeholders.join(", ")})#{returning}"
26
+ rows = execute(sql, params)
27
+ row = rows.first
28
+ return normalize_record(model, row) if row
29
+
30
+ find_one(model: model, where: [{field: "id", value: input.fetch("id")}])
31
+ end
32
+
33
+ def find_one(model:, where: [], select: nil, join: nil)
34
+ if collection_join?(model.to_s, join)
35
+ find_many(model: model, where: where, select: select, join: join).first
36
+ else
37
+ find_many(model: model, where: where, select: select, join: join, limit: 1).first
38
+ end
39
+ end
40
+
41
+ def find_many(model:, where: [], sort_by: nil, limit: nil, offset: nil, select: nil, join: nil)
42
+ model = model.to_s
43
+ params = []
44
+ sql = +"SELECT "
45
+ sql << "TOP (#{Integer(limit)}) " if dialect == :mssql && limit && !offset
46
+ sql << select_sql(model, select, join)
47
+ sql << " FROM "
48
+ sql << quote(table_for(model))
49
+ sql << join_sql(model, join)
50
+ where_sql = build_where(model, where || [], params)
51
+ sql << " WHERE #{where_sql}" unless where_sql.empty?
52
+ sql << order_sql(model, sort_by) if sort_by
53
+ append_pagination_sql(sql, model, sort_by, limit, offset)
54
+
55
+ records = execute(sql, params).map { |row| normalize_record(model, row, join: join) }
56
+ collection_join?(model, join) ? aggregate_collection_joins(model, records, join) : records
57
+ end
58
+
59
+ def update(model:, where:, update:)
60
+ model = model.to_s
61
+ if dialect == :postgres
62
+ records = update_many(model: model, where: where, update: update, returning: true)
63
+ return records.is_a?(Array) ? records.first : records
64
+ end
65
+
66
+ existing = find_one(model: model, where: where, select: ["id"])
67
+ return nil unless existing
68
+
69
+ update_many(model: model, where: where, update: update)
70
+ find_one(model: model, where: [{field: "id", value: existing.fetch("id")}])
71
+ end
72
+
73
+ def update_many(model:, where:, update:, returning: false)
74
+ model = model.to_s
75
+ data = transform_input(model, update, "update", true)
76
+ params = []
77
+ assignments = data.each_key.map do |field|
78
+ params << data[field]
79
+ "#{quote(storage_field(model, field))} = #{placeholder(params.length)}"
80
+ end
81
+ where_sql = build_where(model, where || [], params)
82
+ sql = +"UPDATE "
83
+ sql << quote(table_for(model))
84
+ sql << " SET "
85
+ sql << assignments.join(", ")
86
+ sql << " WHERE #{where_sql}" unless where_sql.empty?
87
+ sql << " RETURNING *" if dialect == :postgres
88
+ rows = execute(sql, params).map { |row| normalize_record(model, row) }
89
+ return rows if returning || dialect == :postgres
90
+
91
+ nil
92
+ end
93
+
94
+ def delete(model:, where:)
95
+ delete_many(model: model, where: where)
96
+ nil
97
+ end
98
+
99
+ def delete_many(model:, where:)
100
+ model = model.to_s
101
+ params = []
102
+ where_sql = build_where(model, where || [], params)
103
+ sql = +"DELETE FROM "
104
+ sql << quote(table_for(model))
105
+ sql << " WHERE #{where_sql}" unless where_sql.empty?
106
+ result = execute(sql, params)
107
+ affected_rows(result)
108
+ end
109
+
110
+ def count(model:, where: nil)
111
+ model = model.to_s
112
+ params = []
113
+ where_sql = build_where(model, where || [], params)
114
+ sql = +"SELECT COUNT(*) AS count FROM "
115
+ sql << quote(table_for(model))
116
+ sql << " WHERE #{where_sql}" unless where_sql.empty?
117
+ row = execute(sql, params).first || {}
118
+ (row["count"] || row[:count] || 0).to_i
119
+ end
120
+
121
+ def transaction
122
+ execute("BEGIN", [])
123
+ result = yield self
124
+ execute("COMMIT", [])
125
+ result
126
+ rescue
127
+ execute("ROLLBACK", [])
128
+ raise
129
+ end
130
+
131
+ private
132
+
133
+ def transform_input(model, data, action, force_allow_id)
134
+ fields = Schema.auth_tables(options).fetch(model).fetch(:fields)
135
+ input = stringify_keys(data)
136
+ output = {}
137
+
138
+ fields.each do |field, attributes|
139
+ next if field == "id" && input.key?(field) && !force_allow_id
140
+
141
+ value_provided = input.key?(field)
142
+ value = input[field]
143
+ if value_provided && attributes[:input] == false && value && !force_allow_id
144
+ raise APIError.new("BAD_REQUEST", message: "#{field} is not allowed to be set")
145
+ end
146
+
147
+ if !value_provided && action == "create" && attributes.key?(:default_value)
148
+ value = resolve_default(attributes[:default_value])
149
+ value_provided = true
150
+ elsif !value_provided && action == "update" && attributes[:on_update]
151
+ value = resolve_default(attributes[:on_update])
152
+ value_provided = true
153
+ end
154
+ if !value_provided && action == "create" && attributes[:required]
155
+ raise APIError.new("BAD_REQUEST", message: "#{field} is required") unless field == "id"
156
+ end
157
+ output[field] = coerce_value(value, attributes) if value_provided
158
+ end
159
+
160
+ output["id"] = generated_id if action == "create" && !output.key?("id")
161
+ output
162
+ end
163
+
164
+ def select_sql(model, select, join)
165
+ fields = Array(select).empty? ? schema_for(model).fetch(:fields).keys : Array(select).map { |field| storage_key(field) }
166
+ columns = fields.map do |field|
167
+ column = storage_field(model, field)
168
+ "#{quote(table_for(model))}.#{quote(column)} AS #{quote(column)}"
169
+ end
170
+ columns.concat(join_select_sql(model, join)) if join
171
+ columns.join(", ")
172
+ end
173
+
174
+ def join_select_sql(model, join)
175
+ join.flat_map do |join_model, _enabled|
176
+ join_model = join_model.to_s
177
+ schema_for(join_model).fetch(:fields).map do |field, attributes|
178
+ column = attributes[:field_name] || physical_name(field)
179
+ "#{quote(join_model)}.#{quote(column)} AS #{quote("#{join_model}__#{column}")}"
180
+ end
181
+ end
182
+ end
183
+
184
+ def join_sql(model, join)
185
+ return "" unless join
186
+
187
+ join.map do |join_model, _enabled|
188
+ join_model = join_model.to_s
189
+ case [model, join_model]
190
+ when ["session", "user"], ["account", "user"]
191
+ " LEFT JOIN #{quote(table_for("user"))} AS #{quote("user")} ON #{quote("user")}.#{quote("id")} = #{quote(table_for(model))}.#{quote("user_id")}"
192
+ when ["user", "account"]
193
+ " LEFT JOIN #{quote(table_for("account"))} AS #{quote("account")} ON #{quote("account")}.#{quote("user_id")} = #{quote(table_for(model))}.#{quote("id")}"
194
+ else
195
+ ""
196
+ end
197
+ end.join
198
+ end
199
+
200
+ def build_where(model, where, params)
201
+ Array(where).each_with_index.map do |clause, index|
202
+ field = storage_key(fetch_key(clause, :field))
203
+ column = "#{quote(table_for(model))}.#{quote(storage_field(model, field))}"
204
+ operator = (fetch_key(clause, :operator) || "eq").to_s
205
+ value = fetch_key(clause, :value)
206
+
207
+ expression = case operator
208
+ when "in", "not_in"
209
+ values = Array(value)
210
+ placeholders = values.map do |entry|
211
+ params << entry
212
+ placeholder(params.length)
213
+ end.join(", ")
214
+ sql_operator = (operator == "not_in") ? "NOT IN" : "IN"
215
+ "#{column} #{sql_operator} (#{placeholders})"
216
+ when "contains", "starts_with", "ends_with"
217
+ pattern = case operator
218
+ when "starts_with" then "#{value}%"
219
+ when "ends_with" then "%#{value}"
220
+ else "%#{value}%"
221
+ end
222
+ params << pattern
223
+ "#{column} LIKE #{placeholder(params.length)}"
224
+ else
225
+ params << value
226
+ "#{column} #{sql_operator(operator)} #{placeholder(params.length)}"
227
+ end
228
+
229
+ connector = (index.positive? && fetch_key(clause, :connector).to_s.upcase == "OR") ? "OR" : "AND"
230
+ index.zero? ? expression : "#{connector} #{expression}"
231
+ end.join(" ")
232
+ end
233
+
234
+ def order_sql(model, sort_by)
235
+ field = Schema.storage_key(fetch_key(sort_by, :field))
236
+ direction = (fetch_key(sort_by, :direction).to_s.downcase == "desc") ? "DESC" : "ASC"
237
+ " ORDER BY #{quote(table_for(model))}.#{quote(storage_field(model, field))} #{direction}"
238
+ end
239
+
240
+ def append_pagination_sql(sql, model, sort_by, limit, offset)
241
+ if dialect == :mssql
242
+ return if limit && !offset
243
+ return unless offset
244
+
245
+ sql << order_sql(model, {field: "id", direction: "asc"}) unless sort_by
246
+ sql << " OFFSET #{Integer(offset)} ROWS"
247
+ sql << " FETCH NEXT #{Integer(limit)} ROWS ONLY" if limit
248
+ return
249
+ end
250
+
251
+ sql << " LIMIT #{Integer(limit)}" if limit
252
+ sql << " OFFSET #{Integer(offset)}" if offset
253
+ end
254
+
255
+ def sql_operator(operator)
256
+ {
257
+ "ne" => "!=",
258
+ "gt" => ">",
259
+ "gte" => ">=",
260
+ "lt" => "<",
261
+ "lte" => "<="
262
+ }.fetch(operator, "=")
263
+ end
264
+
265
+ def execute(sql, params)
266
+ if connection.respond_to?(:exec_params)
267
+ result = connection.exec_params(sql, params)
268
+ return result.to_a if result.respond_to?(:to_a)
269
+
270
+ result
271
+ elsif connection.respond_to?(:query) && params.empty?
272
+ result = connection.query(sql)
273
+ result.respond_to?(:to_a) ? result.to_a : result
274
+ elsif connection.respond_to?(:prepare)
275
+ statement = connection.prepare(sql)
276
+ result = statement.execute(*params)
277
+ result.respond_to?(:to_a) ? result.to_a : result
278
+ elsif connection.respond_to?(:execute)
279
+ result = connection.execute(sql, params)
280
+ result.respond_to?(:to_a) ? result.to_a : result
281
+ else
282
+ raise Error, "SQL connection must respond to exec_params or prepare"
283
+ end
284
+ end
285
+
286
+ def affected_rows(result)
287
+ return result.cmd_tuples if result.respond_to?(:cmd_tuples)
288
+ return result.affected_rows if result.respond_to?(:affected_rows)
289
+ return connection.changes if connection.respond_to?(:changes)
290
+ return result.to_i if result.respond_to?(:to_i)
291
+
292
+ 0
293
+ end
294
+
295
+ def normalize_record(model, row, join: nil)
296
+ return nil unless row
297
+
298
+ fields = schema_for(model).fetch(:fields)
299
+ record = fields.each_with_object({}) do |(field, attributes), output|
300
+ column = attributes[:field_name] || physical_name(field)
301
+ output[field] = coerce_output_value(fetch_row(row, column), attributes) if row_key?(row, column)
302
+ end
303
+
304
+ join&.each_key do |join_model|
305
+ join_model = join_model.to_s
306
+ record[join_model] = normalize_joined_record(join_model, row)
307
+ end
308
+
309
+ record
310
+ end
311
+
312
+ def normalize_joined_record(model, row)
313
+ schema_for(model).fetch(:fields).each_with_object({}) do |(field, attributes), output|
314
+ column = attributes[:field_name] || physical_name(field)
315
+ key = "#{model}__#{column}"
316
+ output[field] = coerce_output_value(fetch_row(row, key), attributes) if row_key?(row, key)
317
+ end
318
+ end
319
+
320
+ def collection_join?(model, join)
321
+ model == "user" && join&.keys&.any? { |join_model| join_model.to_s == "account" }
322
+ end
323
+
324
+ def aggregate_collection_joins(_model, records, _join)
325
+ grouped = {}
326
+ records.each do |record|
327
+ key = record.fetch("id")
328
+ grouped[key] ||= record.merge("account" => [])
329
+ account = record["account"]
330
+ grouped[key]["account"] << account if account&.values&.any?
331
+ end
332
+ grouped.values
333
+ end
334
+
335
+ def row_key?(row, key)
336
+ row.key?(key) || row.key?(key.to_sym)
337
+ end
338
+
339
+ def fetch_row(row, key)
340
+ return row[key] if row.key?(key)
341
+
342
+ row[key.to_sym]
343
+ end
344
+
345
+ def table_for(model)
346
+ schema_for(model).fetch(:model_name)
347
+ end
348
+
349
+ def schema_for(model)
350
+ Schema.auth_tables(options).fetch(model.to_s)
351
+ end
352
+
353
+ def storage_field(model, field)
354
+ schema_for(model).fetch(:fields).fetch(field.to_s).fetch(:field_name, physical_name(field))
355
+ end
356
+
357
+ def quote(identifier)
358
+ Schema::SQL.quote(identifier, dialect)
359
+ end
360
+
361
+ def placeholder(index)
362
+ (dialect == :postgres) ? "$#{index}" : "?"
363
+ end
364
+
365
+ def generated_id
366
+ generator = options.advanced.dig(:database, :generate_id)
367
+ return generator.call.to_s if generator.respond_to?(:call)
368
+ return SecureRandom.uuid if generator == "uuid"
369
+
370
+ SecureRandom.hex(16)
371
+ end
372
+
373
+ def resolve_default(default)
374
+ default.respond_to?(:call) ? default.call : default
375
+ end
376
+
377
+ def coerce_value(value, attributes)
378
+ return value if value.nil?
379
+ return value ? 1 : 0 if dialect == :sqlite && attributes[:type] == "boolean"
380
+ return value.iso8601(6) if dialect == :sqlite && attributes[:type] == "date" && value.respond_to?(:iso8601)
381
+ return Time.parse(value) if attributes[:type] == "date" && value.is_a?(String)
382
+
383
+ value
384
+ end
385
+
386
+ def coerce_output_value(value, attributes)
387
+ return value if value.nil?
388
+ return coerce_boolean(value) if attributes[:type] == "boolean"
389
+ return Time.parse(value) if attributes[:type] == "date" && value.is_a?(String)
390
+
391
+ value
392
+ end
393
+
394
+ def coerce_boolean(value)
395
+ return value if value == true || value == false
396
+ return false if value == 0 || value.to_s == "0" || value.to_s.downcase == "f" || value.to_s.downcase == "false"
397
+ return true if value == 1 || value.to_s == "1" || value.to_s.downcase == "t" || value.to_s.downcase == "true"
398
+
399
+ value
400
+ end
401
+
402
+ def stringify_keys(data)
403
+ data.each_with_object({}) do |(key, value), result|
404
+ result[storage_key(key)] = value
405
+ end
406
+ end
407
+
408
+ def fetch_key(hash, key)
409
+ hash[key] || hash[key.to_s] || hash[storage_key(key)] || hash[storage_key(key).to_sym]
410
+ end
411
+
412
+ def storage_key(value)
413
+ parts = physical_name(value).split("_")
414
+ ([parts.first] + parts.drop(1).map(&:capitalize)).join
415
+ end
416
+
417
+ def physical_name(value)
418
+ value.to_s
419
+ .gsub(/([a-z\d])([A-Z])/, "\\1_\\2")
420
+ .tr("-", "_")
421
+ .downcase
422
+ end
423
+ end
424
+ end
425
+ end
@@ -0,0 +1,20 @@
1
+ # frozen_string_literal: true
2
+
3
+ module BetterAuth
4
+ module Adapters
5
+ class SQLite < SQL
6
+ attr_reader :path
7
+
8
+ def initialize(options = nil, path: nil, connection: nil)
9
+ require "sqlite3" unless connection
10
+
11
+ config = options || Configuration.new(secret: Configuration::DEFAULT_SECRET, database: :memory)
12
+ @path = path || ":memory:"
13
+ connection ||= SQLite3::Database.new(@path)
14
+ connection.results_as_hash = true if connection.respond_to?(:results_as_hash=)
15
+ connection.execute("PRAGMA foreign_keys = ON") if connection.respond_to?(:execute)
16
+ super(config, connection: connection, dialect: :sqlite)
17
+ end
18
+ end
19
+ end
20
+ end
@@ -0,0 +1,226 @@
1
+ # frozen_string_literal: true
2
+
3
+ module BetterAuth
4
+ class API
5
+ attr_reader :context, :endpoints
6
+
7
+ def initialize(context, endpoints)
8
+ @context = context
9
+ @endpoints = endpoints
10
+ define_endpoint_methods
11
+ end
12
+
13
+ def call_endpoint(key, input = {})
14
+ context.reset_runtime! if context.respond_to?(:reset_runtime!)
15
+ endpoint = endpoints.fetch(key.to_sym)
16
+ input = symbolize_keys(input || {})
17
+ endpoint_context = Endpoint::Context.new(
18
+ path: endpoint.path,
19
+ method: Array(endpoint.methods).first,
20
+ query: input[:query] || {},
21
+ body: input[:body] || {},
22
+ params: input[:params] || {},
23
+ headers: input[:headers] || {},
24
+ context: context
25
+ )
26
+
27
+ result = run_endpoint_with_hooks(endpoint, endpoint_context)
28
+ format_result(result, input)
29
+ end
30
+
31
+ def execute(endpoint, endpoint_context)
32
+ run_endpoint_with_hooks(endpoint, endpoint_context)
33
+ end
34
+
35
+ private
36
+
37
+ def define_endpoint_methods
38
+ endpoints.each_key do |key|
39
+ method_name = normalize_method_name(key)
40
+ define_singleton_method(method_name) do |input = {}|
41
+ call_endpoint(key, input || {})
42
+ end
43
+ end
44
+ end
45
+
46
+ def normalize_method_name(key)
47
+ key.to_s
48
+ .gsub(/([a-z\d])([A-Z])/, "\\1_\\2")
49
+ .tr("-", "_")
50
+ .downcase
51
+ .to_sym
52
+ end
53
+
54
+ def run_endpoint_with_hooks(endpoint, endpoint_context)
55
+ before = run_before_hooks(endpoint_context)
56
+ return normalize_short_circuit(before, endpoint_context) if before
57
+
58
+ result = begin
59
+ endpoint.call(endpoint_context)
60
+ rescue APIError => error
61
+ Endpoint::Result.new(
62
+ response: error,
63
+ status: error.status_code,
64
+ headers: Endpoint::Result.merge_headers(endpoint_context.response_headers, error.headers)
65
+ )
66
+ end
67
+
68
+ return result if result.raw_response?
69
+
70
+ endpoint_context.returned = result.response
71
+ endpoint_context.response_headers = result.headers.dup
72
+
73
+ after_result = run_after_hooks(endpoint_context)
74
+ result.response = after_result.response
75
+ result.headers = after_result.headers
76
+ result.status = after_result.status if after_result.status
77
+ result
78
+ rescue APIError => error
79
+ Endpoint::Result.new(response: error, status: error.status_code, headers: error.headers)
80
+ end
81
+
82
+ def run_before_hooks(endpoint_context)
83
+ before_hooks.each do |hook|
84
+ next unless hook_matches?(hook, endpoint_context)
85
+
86
+ result = hook[:handler].call(endpoint_context)
87
+ next unless result
88
+
89
+ context_data = fetch_key(result, :context)
90
+ if result.is_a?(Hash) && context_data.is_a?(Hash)
91
+ endpoint_context.merge_context!(context_data)
92
+ next
93
+ end
94
+
95
+ return result
96
+ end
97
+
98
+ nil
99
+ end
100
+
101
+ def run_after_hooks(endpoint_context)
102
+ result = Endpoint::Result.new(
103
+ response: endpoint_context.returned,
104
+ status: endpoint_context.status,
105
+ headers: endpoint_context.response_headers
106
+ )
107
+
108
+ after_hooks.each do |hook|
109
+ next unless hook_matches?(hook, endpoint_context)
110
+
111
+ hook_result = begin
112
+ hook[:handler].call(endpoint_context)
113
+ rescue APIError => error
114
+ error
115
+ end
116
+
117
+ result.headers = endpoint_context.response_headers.dup
118
+
119
+ next unless hook_result
120
+
121
+ normalized = Endpoint::Result.from_value(hook_result, endpoint_context)
122
+ result.response = normalized.response
123
+ result.status = normalized.status
124
+ result.headers = normalized.headers
125
+ endpoint_context.returned = result.response
126
+ endpoint_context.response_headers = result.headers
127
+ end
128
+
129
+ result
130
+ end
131
+
132
+ def normalize_short_circuit(value, endpoint_context)
133
+ Endpoint::Result.from_value(value, endpoint_context)
134
+ rescue APIError => error
135
+ Endpoint::Result.new(response: error, status: error.status_code, headers: error.headers)
136
+ end
137
+
138
+ def format_result(result, input)
139
+ return result.to_rack_response if result.raw_response?
140
+
141
+ if result.response.is_a?(APIError)
142
+ return error_response(result.response, headers: result.headers) if input[:as_response]
143
+
144
+ raise result.response
145
+ end
146
+
147
+ return result.to_rack_response if input[:as_response]
148
+
149
+ if input[:return_headers]
150
+ output = {
151
+ headers: result.headers,
152
+ response: result.response
153
+ }
154
+ output[:status] = result.status if input[:return_status]
155
+ return output
156
+ end
157
+
158
+ return {response: result.response, status: result.status} if input[:return_status]
159
+
160
+ result.response
161
+ end
162
+
163
+ def error_response(error, headers: {})
164
+ Endpoint::Result.new(
165
+ response: error.to_h,
166
+ status: error.status_code,
167
+ headers: Endpoint::Result.merge_headers(headers, error.headers)
168
+ ).to_rack_response
169
+ end
170
+
171
+ def before_hooks
172
+ hooks = []
173
+ user_before = context.options.hooks&.fetch(:before, nil)
174
+ hooks << {matcher: ->(_ctx) { true }, handler: user_before} if user_before
175
+ hooks.concat(plugin_hooks(:before))
176
+ hooks
177
+ end
178
+
179
+ def after_hooks
180
+ hooks = []
181
+ user_after = context.options.hooks&.fetch(:after, nil)
182
+ hooks << {matcher: ->(_ctx) { true }, handler: user_after} if user_after
183
+ hooks.concat(plugin_hooks(:after))
184
+ hooks
185
+ end
186
+
187
+ def plugin_hooks(type)
188
+ context.options.plugins.flat_map do |plugin|
189
+ hooks = plugin.dig(:hooks, type)
190
+ Array(hooks).map do |hook|
191
+ {
192
+ matcher: hook[:matcher] || ->(_ctx) { true },
193
+ handler: hook[:handler]
194
+ }
195
+ end
196
+ end.compact
197
+ end
198
+
199
+ def hook_matches?(hook, endpoint_context)
200
+ matcher = hook[:matcher] || ->(_ctx) { true }
201
+ matcher.call(endpoint_context)
202
+ end
203
+
204
+ def fetch_key(hash, key)
205
+ return unless hash.is_a?(Hash)
206
+
207
+ hash[key] || hash[key.to_s]
208
+ end
209
+
210
+ def symbolize_keys(value)
211
+ return value unless value.is_a?(Hash)
212
+
213
+ value.each_with_object({}) do |(key, object_value), result|
214
+ result[normalize_key(key)] = object_value.is_a?(Hash) ? symbolize_keys(object_value) : object_value
215
+ end
216
+ end
217
+
218
+ def normalize_key(key)
219
+ key.to_s
220
+ .gsub(/([a-z\d])([A-Z])/, "\\1_\\2")
221
+ .tr("-", "_")
222
+ .downcase
223
+ .to_sym
224
+ end
225
+ end
226
+ end