better_auth 0.8.0 → 0.10.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/README.md +4 -4
  4. data/lib/better_auth/adapters/memory.rb +131 -17
  5. data/lib/better_auth/adapters/sql.rb +139 -57
  6. data/lib/better_auth/configuration.rb +7 -1
  7. data/lib/better_auth/cookies.rb +11 -3
  8. data/lib/better_auth/doctor.rb +97 -0
  9. data/lib/better_auth/endpoint.rb +88 -5
  10. data/lib/better_auth/http_client.rb +46 -0
  11. data/lib/better_auth/migration_plan.rb +15 -0
  12. data/lib/better_auth/oauth2.rb +1 -1
  13. data/lib/better_auth/plugins/admin.rb +6 -1
  14. data/lib/better_auth/plugins/anonymous.rb +2 -0
  15. data/lib/better_auth/plugins/captcha.rb +1 -1
  16. data/lib/better_auth/plugins/device_authorization.rb +34 -0
  17. data/lib/better_auth/plugins/dub.rb +8 -0
  18. data/lib/better_auth/plugins/generic_oauth.rb +34 -7
  19. data/lib/better_auth/plugins/have_i_been_pwned.rb +1 -1
  20. data/lib/better_auth/plugins/jwt.rb +10 -3
  21. data/lib/better_auth/plugins/mcp/schema.rb +13 -13
  22. data/lib/better_auth/plugins/mcp.rb +41 -0
  23. data/lib/better_auth/plugins/oauth_protocol.rb +98 -21
  24. data/lib/better_auth/plugins/oidc_provider.rb +62 -3
  25. data/lib/better_auth/plugins/one_tap.rb +17 -5
  26. data/lib/better_auth/plugins/open_api.rb +42 -2
  27. data/lib/better_auth/plugins/organization.rb +122 -11
  28. data/lib/better_auth/plugins/phone_number.rb +1 -1
  29. data/lib/better_auth/plugins/two_factor.rb +21 -0
  30. data/lib/better_auth/rate_limiter.rb +7 -2
  31. data/lib/better_auth/routes/account.rb +4 -0
  32. data/lib/better_auth/routes/email_verification.rb +5 -1
  33. data/lib/better_auth/routes/password.rb +1 -0
  34. data/lib/better_auth/routes/social.rb +29 -1
  35. data/lib/better_auth/routes/user.rb +6 -2
  36. data/lib/better_auth/schema/sql.rb +104 -15
  37. data/lib/better_auth/schema.rb +35 -2
  38. data/lib/better_auth/session.rb +2 -1
  39. data/lib/better_auth/social_providers/base.rb +4 -9
  40. data/lib/better_auth/social_providers/facebook.rb +1 -1
  41. data/lib/better_auth/social_providers/github.rb +2 -0
  42. data/lib/better_auth/social_providers/line.rb +1 -1
  43. data/lib/better_auth/social_providers/paypal.rb +1 -1
  44. data/lib/better_auth/sql_migration.rb +566 -0
  45. data/lib/better_auth/version.rb +1 -1
  46. data/lib/better_auth.rb +3 -0
  47. metadata +10 -6
@@ -527,14 +527,14 @@ module BetterAuth
527
527
  end
528
528
  end
529
529
 
530
- def organization_list_members_endpoint(_config)
530
+ def organization_list_members_endpoint(config)
531
531
  Endpoint.new(path: "/organization/list-members", method: "GET", metadata: organization_openapi("listOrganizationMembers", "List organization members", response: organization_members_response_schema)) do |ctx|
532
532
  session = Routes.current_session(ctx)
533
533
  query = normalize_hash(ctx.query)
534
534
  organization_id = query[:organization_id] || organization_by_slug(ctx, query[:organization_slug])&.fetch("id") || session[:session]["activeOrganizationId"]
535
535
  raise APIError.new("BAD_REQUEST", message: ORGANIZATION_ERROR_CODES.fetch("NO_ACTIVE_ORGANIZATION")) unless organization_id
536
536
  require_member!(ctx, session[:user]["id"], organization_id)
537
- ctx.json(list_members_for(ctx, organization_id, query))
537
+ ctx.json(list_members_for(ctx, organization_id, query, config, session[:user]))
538
538
  end
539
539
  end
540
540
 
@@ -768,6 +768,7 @@ module BetterAuth
768
768
  end
769
769
 
770
770
  def organization_openapi(operation_id, description, response:, response_description: "Success", request: nil, required: [], parameters: nil)
771
+ request ||= organization_request_schema(operation_id)
771
772
  openapi = {
772
773
  operationId: operation_id,
773
774
  description: description,
@@ -781,6 +782,68 @@ module BetterAuth
781
782
  {openapi: openapi}
782
783
  end
783
784
 
785
+ def organization_request_schema(operation_id)
786
+ string = {type: "string"}
787
+ boolean = {type: "boolean"}
788
+ array = ->(items = string) { {type: "array", items: items} }
789
+ object = {type: "object", additionalProperties: true}
790
+ {
791
+ "createOrganization" => {
792
+ name: string,
793
+ slug: string,
794
+ logo: string,
795
+ metadata: object,
796
+ userId: string,
797
+ user_id: string,
798
+ keepCurrentActiveOrganization: boolean,
799
+ keep_current_active_organization: boolean
800
+ },
801
+ "checkOrganizationSlug" => {slug: string},
802
+ "updateOrganization" => {
803
+ organizationId: string,
804
+ organization_id: string,
805
+ organizationSlug: string,
806
+ organization_slug: string,
807
+ data: object,
808
+ name: string,
809
+ slug: string,
810
+ logo: string,
811
+ metadata: object
812
+ },
813
+ "deleteOrganization" => {organizationId: string, organization_id: string, organizationSlug: string, organization_slug: string},
814
+ "setActiveOrganization" => {organizationId: string, organization_id: string, organizationSlug: string, organization_slug: string},
815
+ "createOrganizationInvitation" => {
816
+ organizationId: string,
817
+ organization_id: string,
818
+ organizationSlug: string,
819
+ organization_slug: string,
820
+ email: {type: "string", format: "email"},
821
+ role: {oneOf: [string, array.call]},
822
+ teamId: string,
823
+ team_id: string,
824
+ teamIds: array.call,
825
+ team_ids: array.call
826
+ },
827
+ "acceptOrganizationInvitation" => {invitationId: string, invitation_id: string, id: string},
828
+ "rejectOrganizationInvitation" => {invitationId: string, invitation_id: string},
829
+ "cancelOrganizationInvitation" => {invitationId: string, invitation_id: string},
830
+ "addOrganizationMember" => {organizationId: string, organization_id: string, userId: string, user_id: string, role: {oneOf: [string, array.call]}},
831
+ "removeOrganizationMember" => {memberId: string, member_id: string, userId: string, user_id: string, organizationId: string, organization_id: string},
832
+ "updateOrganizationMemberRole" => {memberId: string, member_id: string, userId: string, user_id: string, organizationId: string, organization_id: string, role: {oneOf: [string, array.call]}},
833
+ "leaveOrganization" => {organizationId: string, organization_id: string},
834
+ "hasOrganizationPermission" => {organizationId: string, organization_id: string, permission: object, permissions: object},
835
+ "createOrganizationTeam" => {organizationId: string, organization_id: string, name: string},
836
+ "updateOrganizationTeam" => {teamId: string, team_id: string, name: string},
837
+ "removeOrganizationTeam" => {teamId: string, team_id: string},
838
+ "setActiveOrganizationTeam" => {teamId: string, team_id: string},
839
+ "addTeamMember" => {teamId: string, team_id: string, userId: string, user_id: string},
840
+ "removeTeamMember" => {teamId: string, team_id: string, userId: string, user_id: string},
841
+ "createOrganizationRole" => {organizationId: string, organization_id: string, role: string, roleName: string, role_name: string, permission: object, permissions: object},
842
+ "updateOrganizationRole" => {organizationId: string, organization_id: string, roleId: string, role_id: string, role: string, roleName: string, role_name: string, permission: object, permissions: object, data: object},
843
+ "deleteOrganizationRole" => {organizationId: string, organization_id: string, roleId: string, role_id: string, role: string, roleName: string, role_name: string}
844
+ }[operation_id]
845
+ end
846
+
784
847
  def organization_ref_schema(name)
785
848
  {
786
849
  type: "object",
@@ -969,7 +1032,7 @@ module BetterAuth
969
1032
  ctx.context.adapter.find_one(model: "organizationRole", where: [{field: "organizationId", value: organization_id}, {field: "role", value: role}])
970
1033
  end
971
1034
 
972
- def list_members_for(ctx, organization_id, query = {})
1035
+ def list_members_for(ctx, organization_id, query = {}, config = nil, user = nil)
973
1036
  where = [{field: "organizationId", value: organization_id}]
974
1037
  if query[:filter_field]
975
1038
  where << {field: query[:filter_field], value: query[:filter_value], operator: query[:filter_operator]}
@@ -977,19 +1040,48 @@ module BetterAuth
977
1040
  filter = normalize_hash(query[:filter])
978
1041
  where << {field: filter[:field], value: filter[:value], operator: filter[:operator]}
979
1042
  end
1043
+ limit = member_list_limit(ctx, organization_id, query, config, user)
980
1044
  members = ctx.context.adapter.find_many(
981
1045
  model: "member",
982
1046
  where: where,
983
- limit: query[:limit],
1047
+ limit: limit,
984
1048
  offset: query[:offset],
985
1049
  sort_by: query[:sort_by] ? {field: query[:sort_by], direction: query[:sort_direction] || query[:sort_order] || "asc"} : nil
986
1050
  )
1051
+ users_by_id = member_users_by_id(ctx, members)
987
1052
  {
988
- members: members.map { |entry| member_wire(ctx, entry) },
1053
+ members: members.map { |entry| member_wire(ctx, entry, users_by_id: users_by_id) },
989
1054
  total: ctx.context.adapter.count(model: "member", where: where)
990
1055
  }
991
1056
  end
992
1057
 
1058
+ def member_list_limit(ctx, organization_id, query, config, user)
1059
+ configured = config && config[:membership_limit]
1060
+ configured = 100 if configured.nil?
1061
+ default = numeric_member_limit(configured)
1062
+ default = 100 unless default.positive?
1063
+ requested = query[:limit].to_i if query.key?(:limit) && !query[:limit].to_s.empty?
1064
+ return default unless requested&.positive?
1065
+
1066
+ [requested, default].min
1067
+ end
1068
+
1069
+ def numeric_member_limit(value)
1070
+ return value.to_i if value.is_a?(Numeric)
1071
+ return value.to_i if value.to_s.match?(/\A\d+\z/)
1072
+
1073
+ 100
1074
+ end
1075
+
1076
+ def member_users_by_id(ctx, members)
1077
+ user_ids = members.map { |member| member["userId"] }.compact.uniq
1078
+ return {} if user_ids.empty?
1079
+
1080
+ ctx.context.adapter.find_many(model: "user", where: [{field: "id", operator: "in", value: user_ids}]).each_with_object({}) do |user, result|
1081
+ result[user["id"]] = user
1082
+ end
1083
+ end
1084
+
993
1085
  def ensure_team_member_capacity!(ctx, config, team_ids)
994
1086
  max_members = config.dig(:teams, :maximum_members_per_team)
995
1087
  return unless max_members && team_ids.any?
@@ -1002,9 +1094,9 @@ module BetterAuth
1002
1094
  end
1003
1095
  end
1004
1096
 
1005
- def member_wire(ctx, member)
1097
+ def member_wire(ctx, member, users_by_id: nil)
1006
1098
  data = Schema.parse_output(ctx.context.options, "member", member)
1007
- user = ctx.context.internal_adapter.find_user_by_id(member["userId"])
1099
+ user = users_by_id ? users_by_id[member["userId"]] : ctx.context.internal_adapter.find_user_by_id(member["userId"])
1008
1100
  data["user"] = user.slice("id", "name", "email", "image") if user
1009
1101
  data
1010
1102
  end
@@ -1034,8 +1126,11 @@ module BetterAuth
1034
1126
  def ensure_not_last_owner!(ctx, member)
1035
1127
  return unless member["role"].to_s.split(",").include?("owner")
1036
1128
 
1037
- owners = ctx.context.adapter.find_many(model: "member", where: [{field: "organizationId", value: member["organizationId"]}]).select { |entry| entry["role"].to_s.split(",").include?("owner") }
1038
- raise APIError.new("BAD_REQUEST", message: ORGANIZATION_ERROR_CODES.fetch("YOU_CANNOT_LEAVE_THE_ORGANIZATION_AS_THE_ONLY_OWNER")) if owners.length <= 1
1129
+ owner_count = 0
1130
+ organization_each_adapter_record(ctx.context.adapter, "member", where: [{field: "organizationId", value: member["organizationId"]}]) do |entry|
1131
+ owner_count += 1 if entry["role"].to_s.split(",").include?("owner")
1132
+ end
1133
+ raise APIError.new("BAD_REQUEST", message: ORGANIZATION_ERROR_CODES.fetch("YOU_CANNOT_LEAVE_THE_ORGANIZATION_AS_THE_ONLY_OWNER")) if owner_count <= 1
1039
1134
  end
1040
1135
 
1041
1136
  def create_default_team(ctx, config, organization, session)
@@ -1053,8 +1148,24 @@ module BetterAuth
1053
1148
  end
1054
1149
 
1055
1150
  def organization_created_count(ctx, user_id)
1056
- members = ctx.context.adapter.find_many(model: "member", where: [{field: "userId", value: user_id}])
1057
- members.count { |member| member["role"].to_s.split(",").include?("owner") }
1151
+ count = 0
1152
+ organization_each_adapter_record(ctx.context.adapter, "member", where: [{field: "userId", value: user_id}]) do |member|
1153
+ count += 1 if member["role"].to_s.split(",").include?("owner")
1154
+ end
1155
+ count
1156
+ end
1157
+
1158
+ def organization_each_adapter_record(adapter, model, where:, page_size: 100)
1159
+ offset = 0
1160
+ loop do
1161
+ records = adapter.find_many(model: model, where: where, limit: page_size, offset: offset)
1162
+ break if records.empty?
1163
+
1164
+ records.each { |record| yield record }
1165
+ break if records.length < page_size
1166
+
1167
+ offset += records.length
1168
+ end
1058
1169
  end
1059
1170
 
1060
1171
  def run_org_hook(config, key, data, ctx)
@@ -55,7 +55,7 @@ module BetterAuth
55
55
  rate_limit: [
56
56
  {
57
57
  path_matcher: ->(path) { path.start_with?("/phone-number") },
58
- window: 60_000,
58
+ window: 60,
59
59
  max: 10
60
60
  }
61
61
  ],
@@ -300,6 +300,7 @@ module BetterAuth
300
300
  openapi: {
301
301
  operationId: operation_id,
302
302
  description: description,
303
+ requestBody: two_factor_request_body(operation_id),
303
304
  responses: {
304
305
  "200" => OpenAPI.json_response("Success", response_schema)
305
306
  }
@@ -307,6 +308,26 @@ module BetterAuth
307
308
  }
308
309
  end
309
310
 
311
+ def two_factor_request_body(operation_id)
312
+ schema = case operation_id
313
+ when "enableTwoFactor"
314
+ OpenAPI.object_schema({password: {type: "string"}, issuer: {type: "string"}})
315
+ when "disableTwoFactor", "getTOTPURI", "generateBackupCodes"
316
+ OpenAPI.object_schema({password: {type: "string"}})
317
+ when "generateTOTP"
318
+ OpenAPI.object_schema({secret: {type: "string"}}, required: ["secret"])
319
+ when "verifyTOTP", "verifyTwoFactorOTP"
320
+ OpenAPI.object_schema({code: {type: "string"}, trustDevice: {type: "boolean"}, trust_device: {type: "boolean"}}, required: ["code"])
321
+ when "verifyBackupCode"
322
+ OpenAPI.object_schema({code: {type: "string"}, disableSession: {type: "boolean"}, disable_session: {type: "boolean"}, trustDevice: {type: "boolean"}, trust_device: {type: "boolean"}}, required: ["code"])
323
+ when "sendTwoFactorOTP"
324
+ OpenAPI.empty_request_body.dig(:content, "application/json", :schema)
325
+ else
326
+ {type: "object", properties: {}}
327
+ end
328
+ OpenAPI.json_request_body(schema)
329
+ end
330
+
310
331
  def two_factor_enable_response_schema
311
332
  OpenAPI.object_schema(
312
333
  {
@@ -121,9 +121,14 @@ module BetterAuth
121
121
  end
122
122
 
123
123
  def default_special_rule(path)
124
- return unless path.start_with?("/sign-in", "/sign-up", "/change-password", "/change-email")
124
+ return {window: 10, max: 3} if path.start_with?("/sign-in", "/sign-up", "/change-password", "/change-email")
125
+ return {window: 60, max: 3} if path == "/request-password-reset" ||
126
+ path == "/send-verification-email" ||
127
+ path.start_with?("/forget-password") ||
128
+ path == "/email-otp/send-verification-otp" ||
129
+ path == "/email-otp/request-password-reset"
125
130
 
126
- {window: 10, max: 3}
131
+ nil
127
132
  end
128
133
 
129
134
  def matching_custom_rule(config, path)
@@ -118,6 +118,8 @@ module BetterAuth
118
118
  ) do |ctx|
119
119
  session = current_session(ctx, allow_nil: true)
120
120
  body = normalize_hash(ctx.body)
121
+ raise APIError.new("UNAUTHORIZED") if ctx.request && !session
122
+
121
123
  user_id = session&.dig(:user, "id") || body["userId"] || body["user_id"]
122
124
  raise APIError.new("UNAUTHORIZED") if user_id.to_s.empty?
123
125
 
@@ -174,6 +176,8 @@ module BetterAuth
174
176
  ) do |ctx|
175
177
  session = current_session(ctx, allow_nil: true)
176
178
  body = normalize_hash(ctx.body)
179
+ raise APIError.new("UNAUTHORIZED") if ctx.request && !session
180
+
177
181
  user_id = session&.dig(:user, "id") || body["userId"] || body["user_id"]
178
182
  raise APIError.new("BAD_REQUEST", message: "Either userId or session is required") if user_id.to_s.empty?
179
183
 
@@ -184,7 +184,11 @@ module BetterAuth
184
184
 
185
185
  def self.set_verified_session_cookie(ctx, user)
186
186
  session = current_session(ctx, allow_nil: true)
187
- session_data = session ? session[:session] : ctx.context.internal_adapter.create_session(user["id"])
187
+ session_data = if session && session[:user]["id"] == user["id"]
188
+ session[:session]
189
+ else
190
+ ctx.context.internal_adapter.create_session(user["id"])
191
+ end
188
192
  Cookies.set_session_cookie(ctx, {session: session_data, user: user})
189
193
  end
190
194
 
@@ -179,6 +179,7 @@ module BetterAuth
179
179
  path: "/verify-password",
180
180
  method: "POST",
181
181
  metadata: {
182
+ scope: "server",
182
183
  openapi: {
183
184
  operationId: "verifyPassword",
184
185
  description: "Verify the current user's password",
@@ -2,6 +2,7 @@
2
2
 
3
3
  require "uri"
4
4
  require "json"
5
+ require "net/http"
5
6
  require "securerandom"
6
7
 
7
8
  module BetterAuth
@@ -92,6 +93,7 @@ module BetterAuth
92
93
  ctx.context.secret,
93
94
  expires_in: 600
94
95
  )
96
+ store_oauth_state_cookie(ctx, state)
95
97
  url = call_provider(provider, :create_authorization_url, {
96
98
  state: state,
97
99
  codeVerifier: code_verifier,
@@ -148,6 +150,7 @@ module BetterAuth
148
150
  raise ctx.redirect(oauth_error_url(error_url, data["error"], data["errorDescription"] || data["error_description"])) if data["error"]
149
151
  raise ctx.redirect(oauth_error_url(error_url, "oauth_provider_not_found")) unless provider
150
152
  raise ctx.redirect(oauth_error_url(error_url, "state_not_found")) unless state_data
153
+ raise ctx.redirect(oauth_error_url(error_url, "state_mismatch")) unless valid_oauth_state_cookie?(ctx, state)
151
154
  raise ctx.redirect(oauth_error_url(error_url, "no_code")) if data["code"].to_s.empty?
152
155
 
153
156
  tokens = call_provider(provider, :validate_authorization_code, {
@@ -161,7 +164,11 @@ module BetterAuth
161
164
 
162
165
  token_data = token_hash(tokens)
163
166
  token_data["user"] = parse_json_hash(data["user"]) if data["user"]
164
- user_info = call_provider(provider, :get_user_info, token_data)
167
+ user_info = begin
168
+ call_provider(provider, :get_user_info, token_data)
169
+ rescue Net::OpenTimeout, Net::ReadTimeout, SocketError, SystemCallError
170
+ nil
171
+ end
165
172
  user = user_info[:user] || user_info["user"] if user_info
166
173
  raise ctx.redirect(oauth_error_url(error_url, "unable_to_get_user_info")) unless user
167
174
  raise ctx.redirect(oauth_error_url(error_url, "email_not_found")) if fetch_value(user, "email").to_s.empty?
@@ -269,6 +276,7 @@ module BetterAuth
269
276
  }
270
277
  }.merge(safe_additional_state(body))
271
278
  state = Crypto.sign_jwt(state_data, ctx.context.secret, expires_in: 600)
279
+ store_oauth_state_cookie(ctx, state)
272
280
  url = call_provider(provider, :create_authorization_url, {
273
281
  state: state,
274
282
  codeVerifier: code_verifier,
@@ -285,6 +293,10 @@ module BetterAuth
285
293
 
286
294
  def self.social_user_from_id_token!(ctx, provider, id_token)
287
295
  token = fetch_value(id_token, "token").to_s
296
+ unless provider_callable(provider, :verify_id_token)
297
+ raise APIError.new("NOT_FOUND", message: BASE_ERROR_CODES["ID_TOKEN_NOT_SUPPORTED"])
298
+ end
299
+
288
300
  valid = call_provider(provider, :verify_id_token, token, fetch_value(id_token, "nonce"))
289
301
  raise APIError.new("UNAUTHORIZED", message: BASE_ERROR_CODES["INVALID_TOKEN"]) unless valid
290
302
 
@@ -360,6 +372,22 @@ module BetterAuth
360
372
  {session: session, user: user, new_user: new_user}
361
373
  end
362
374
 
375
+ def self.store_oauth_state_cookie(ctx, state)
376
+ return unless ctx.request
377
+
378
+ cookie = ctx.context.create_auth_cookie("state", max_age: 600)
379
+ ctx.set_signed_cookie(cookie.name, state, ctx.context.secret, cookie.attributes)
380
+ end
381
+
382
+ def self.valid_oauth_state_cookie?(ctx, state)
383
+ return true unless ctx.request
384
+
385
+ cookie = ctx.context.create_auth_cookie("state", max_age: 600)
386
+ stored = ctx.get_signed_cookie(cookie.name, ctx.context.secret)
387
+ Cookies.expire_cookie(ctx, cookie)
388
+ stored == state
389
+ end
390
+
363
391
  def self.oauth_error_url(base_url, error, description = nil)
364
392
  uri = URI.parse(base_url.to_s)
365
393
  query = URI.decode_www_form(uri.query.to_s)
@@ -80,7 +80,9 @@ module BetterAuth
80
80
  current_password = body["currentPassword"] || body["current_password"]
81
81
  validate_password_length!(new_password, ctx.context.options.email_and_password)
82
82
  account = credential_account(ctx, session[:user]["id"])
83
- unless account && account["password"] && verify_password_value(ctx, current_password.to_s, account["password"])
83
+ raise APIError.new("BAD_REQUEST", message: BASE_ERROR_CODES["CREDENTIAL_ACCOUNT_NOT_FOUND"]) unless account && account["password"]
84
+
85
+ unless verify_password_value(ctx, current_password.to_s, account["password"])
84
86
  raise APIError.new("BAD_REQUEST", message: BASE_ERROR_CODES["INVALID_PASSWORD"])
85
87
  end
86
88
 
@@ -176,7 +178,9 @@ module BetterAuth
176
178
  sender = ctx.context.options.user.dig(:delete_user, :send_delete_account_verification)
177
179
  if body["password"]
178
180
  account = credential_account(ctx, session[:user]["id"])
179
- unless account && account["password"] && verify_password_value(ctx, body["password"], account["password"])
181
+ raise APIError.new("BAD_REQUEST", message: BASE_ERROR_CODES["CREDENTIAL_ACCOUNT_NOT_FOUND"]) unless account && account["password"]
182
+
183
+ unless verify_password_value(ctx, body["password"], account["password"])
180
184
  raise APIError.new("BAD_REQUEST", message: BASE_ERROR_CODES["INVALID_PASSWORD"])
181
185
  end
182
186
  end
@@ -12,6 +12,31 @@ module BetterAuth
12
12
  statements.concat(tables.flat_map { |_logical_name, table| index_statements(table, dialect) })
13
13
  end
14
14
 
15
+ def pending_statements(plan)
16
+ statements = plan.to_create.map do |change|
17
+ create_table_statement(change.logical_name, change.table, plan.dialect, plan.tables)
18
+ end
19
+ statements.concat(plan.to_add.flat_map do |change|
20
+ change.fields.map do |logical_field, attributes|
21
+ if logical_field.to_s == "id" && plan.dialect == :postgres
22
+ add_postgres_id_column_statements(change.table_name)
23
+ else
24
+ add_column_statement(change.table_name, logical_field, attributes, plan.dialect)
25
+ end
26
+ end
27
+ end.flatten)
28
+ statements.concat(plan.to_index.map do |change|
29
+ index_statement(
30
+ change.table_name,
31
+ change.field_name,
32
+ change.name,
33
+ plan.dialect,
34
+ unique: change.unique,
35
+ where_not_null: filtered_unique_index?(change.field, plan.dialect)
36
+ )
37
+ end)
38
+ end
39
+
15
40
  def create_table_statement(logical_name, table, dialect, tables = nil)
16
41
  table_name = table.fetch(:model_name)
17
42
  columns = table.fetch(:fields).map do |logical_field, attributes|
@@ -28,7 +53,7 @@ module BetterAuth
28
53
  when :mysql
29
54
  %(CREATE TABLE IF NOT EXISTS #{quote(table_name, dialect)} (\n #{body}\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;)
30
55
  when :mssql
31
- %(IF OBJECT_ID(N'#{quote(table_name, dialect)}', N'U') IS NULL\nCREATE TABLE #{quote(table_name, dialect)} (\n #{body}\n);)
56
+ %(#{mssql_required_set_options}\nIF OBJECT_ID(N'#{quote(table_name, dialect)}', N'U') IS NULL\nCREATE TABLE #{quote(table_name, dialect)} (\n #{body}\n);)
32
57
  else
33
58
  raise ArgumentError, "Unsupported SQL dialect: #{dialect}"
34
59
  end
@@ -52,7 +77,7 @@ module BetterAuth
52
77
  constraints = []
53
78
  column = attributes[:field_name] || physical_name(logical_field)
54
79
 
55
- if attributes[:unique] && logical_field != "id"
80
+ if attributes[:unique] && logical_field != "id" && !(dialect == :mssql && !attributes[:required])
56
81
  constraints << unique_constraint(table_name, column, dialect)
57
82
  end
58
83
 
@@ -67,21 +92,62 @@ module BetterAuth
67
92
  def index_statements(table, dialect)
68
93
  table_name = table.fetch(:model_name)
69
94
  table.fetch(:fields).filter_map do |logical_field, attributes|
70
- next unless attributes[:index]
95
+ nullable_unique_mssql = dialect == :mssql && attributes[:unique] && logical_field != "id" && !attributes[:required]
96
+ next if attributes[:unique] && !nullable_unique_mssql
97
+ next unless attributes[:index] || nullable_unique_mssql
71
98
 
72
99
  column = attributes[:field_name] || Schema.physical_name(logical_field)
73
- name = "index_#{table_name}_on_#{column}"
74
- case dialect
75
- when :postgres, :sqlite
76
- %(CREATE INDEX IF NOT EXISTS #{quote(name, dialect)} ON #{quote(table_name, dialect)} (#{quote(column, dialect)});)
77
- when :mysql
78
- %(CREATE INDEX #{quote(name, dialect)} ON #{quote(table_name, dialect)} (#{quote(column, dialect)});)
79
- when :mssql
80
- %(IF NOT EXISTS (SELECT name FROM sys.indexes WHERE name = '#{name.gsub("'", "''")}' AND object_id = OBJECT_ID(N'#{quote(table_name, dialect)}')) CREATE INDEX #{quote(name, dialect)} ON #{quote(table_name, dialect)} (#{quote(column, dialect)});)
81
- end
100
+ unique = attributes[:unique] && dialect == :mssql
101
+ name = unique ? "uniq_#{table_name}_#{column}" : "index_#{table_name}_on_#{column}"
102
+ index_statement(table_name, column, name, dialect, unique: unique, where_not_null: filtered_unique_index?(attributes, dialect))
103
+ end
104
+ end
105
+
106
+ def add_column_statement(table_name, logical_field, attributes, dialect)
107
+ keyword = (dialect == :mssql) ? "ADD" : "ADD COLUMN"
108
+ %(ALTER TABLE #{quote(table_name, dialect)} #{keyword} #{column_definition(table_name, logical_field, attributes, dialect)};)
109
+ end
110
+
111
+ def add_postgres_id_column_statements(table_name)
112
+ quoted_table = quote(table_name, :postgres)
113
+ quoted_id = quote("id", :postgres)
114
+ [
115
+ %(ALTER TABLE #{quoted_table} ADD COLUMN #{quoted_id} text;),
116
+ %(UPDATE #{quoted_table} SET #{quoted_id} = md5(random()::text || clock_timestamp()::text || ctid::text) WHERE #{quoted_id} IS NULL;),
117
+ %(ALTER TABLE #{quoted_table} ALTER COLUMN #{quoted_id} SET NOT NULL;),
118
+ %(ALTER TABLE #{quoted_table} ADD PRIMARY KEY (#{quoted_id});)
119
+ ]
120
+ end
121
+
122
+ def index_statement(table_name, column, name, dialect, unique: false, where_not_null: false)
123
+ unique_prefix = unique ? "UNIQUE " : ""
124
+ case dialect
125
+ when :postgres, :sqlite
126
+ %(CREATE #{unique_prefix}INDEX IF NOT EXISTS #{quote(name, dialect)} ON #{quote(table_name, dialect)} (#{quote(column, dialect)});)
127
+ when :mysql
128
+ %(CREATE #{unique_prefix}INDEX #{quote(name, dialect)} ON #{quote(table_name, dialect)} (#{quote(column, dialect)});)
129
+ when :mssql
130
+ filter = where_not_null ? " WHERE #{quote(column, dialect)} IS NOT NULL" : ""
131
+ %(#{mssql_required_set_options}\nIF NOT EXISTS (SELECT name FROM sys.indexes WHERE name = '#{name.gsub("'", "''")}' AND object_id = OBJECT_ID(N'#{quote(table_name, dialect)}')) CREATE #{unique_prefix}INDEX #{quote(name, dialect)} ON #{quote(table_name, dialect)} (#{quote(column, dialect)})#{filter};)
82
132
  end
83
133
  end
84
134
 
135
+ def filtered_unique_index?(attributes, dialect)
136
+ dialect == :mssql && attributes[:unique] && !attributes[:required]
137
+ end
138
+
139
+ def mssql_required_set_options
140
+ <<~SQL.strip
141
+ SET ANSI_NULLS ON;
142
+ SET QUOTED_IDENTIFIER ON;
143
+ SET ANSI_WARNINGS ON;
144
+ SET ANSI_PADDING ON;
145
+ SET CONCAT_NULL_YIELDS_NULL ON;
146
+ SET ARITHABORT ON;
147
+ SET NUMERIC_ROUNDABORT OFF;
148
+ SQL
149
+ end
150
+
85
151
  def sql_type(logical_field, attributes, dialect)
86
152
  case attributes[:type]
87
153
  when "boolean"
@@ -121,7 +187,7 @@ module BetterAuth
121
187
  end
122
188
  else
123
189
  if dialect == :mysql
124
- indexed = logical_field == "id" || attributes[:unique] || attributes[:index] || attributes[:references]
190
+ indexed = logical_field == "id" || attributes[:unique] || attributes[:index] || attributes[:references] || attributes[:sortable] || attributes.key?(:default_value)
125
191
  indexed ? "varchar(191)" : "text"
126
192
  elsif dialect == :mssql
127
193
  indexed = logical_field == "id" || attributes[:unique] || attributes[:index] || attributes[:references] || attributes[:sortable]
@@ -164,8 +230,9 @@ module BetterAuth
164
230
  end
165
231
 
166
232
  def foreign_key_constraint(table_name, column, reference, dialect, tables = nil)
167
- target_model = tables&.fetch(reference.fetch(:model).to_s, nil)&.fetch(:model_name) || reference.fetch(:model)
168
- target_field = reference.fetch(:field)
233
+ target_table = foreign_key_target_table(reference, tables)
234
+ target_model = target_table&.fetch(:model_name) || reference.fetch(:model)
235
+ target_field = foreign_key_target_field(reference, target_table)
169
236
  on_delete = reference[:on_delete] ? " ON DELETE #{reference[:on_delete].to_s.upcase}" : ""
170
237
 
171
238
  case dialect
@@ -178,6 +245,28 @@ module BetterAuth
178
245
  end
179
246
  end
180
247
 
248
+ def foreign_key_target_table(reference, tables)
249
+ return unless tables
250
+
251
+ model = reference.fetch(:model).to_s
252
+ tables.fetch(model, nil) || tables.each_value.find { |table| table.fetch(:model_name).to_s == model }
253
+ end
254
+
255
+ def foreign_key_target_field(reference, target_table)
256
+ field = reference.fetch(:field).to_s
257
+ return field unless target_table
258
+
259
+ fields = target_table.fetch(:fields)
260
+ attributes = fields.fetch(field, nil)
261
+ return attributes[:field_name] || physical_name(field) if attributes
262
+
263
+ if fields.each_value.any? { |data| data[:field_name].to_s == field }
264
+ field
265
+ else
266
+ physical_name(field)
267
+ end
268
+ end
269
+
181
270
  def quote(identifier, dialect)
182
271
  case dialect
183
272
  when :postgres, :sqlite
@@ -18,6 +18,7 @@ module BetterAuth
18
18
  tables.delete("verification") if secondary_storage?(options) && !verification_option(options, :store_in_database)
19
19
  tables.merge!(plugin_schema)
20
20
  tables["rateLimit"] = rate_limit_table(options) if rate_limit_option(options, :storage) == "database"
21
+ ensure_id_fields!(tables)
21
22
  tables.sort_by { |_name, table| table[:order] || Float::INFINITY }.to_h
22
23
  end
23
24
 
@@ -121,6 +122,15 @@ module BetterAuth
121
122
  }
122
123
  end
123
124
 
125
+ private_class_method def self.ensure_id_fields!(tables)
126
+ tables.each_value do |table|
127
+ fields = table.fetch(:fields)
128
+ next if fields.key?("id")
129
+
130
+ table[:fields] = id_field.merge(fields)
131
+ end
132
+ end
133
+
124
134
  private_class_method def self.base_fields
125
135
  id_field.merge(timestamp_fields)
126
136
  end
@@ -162,14 +172,19 @@ module BetterAuth
162
172
  schema.each do |raw_key, raw_table|
163
173
  key = storage_key(raw_key)
164
174
  table_data = symbolize_hash(raw_table || {})
165
- existing = tables[key] || {model_name: table_data[:model_name] || physical_name(key), fields: {}}
166
- existing[:model_name] = table_data[:model_name] || existing[:model_name] || physical_name(key)
175
+ existing = tables[key] || {model_name: table_data[:model_name] || physical_table_name(key), fields: {}}
176
+ existing[:model_name] = table_data[:model_name] || existing[:model_name] || physical_table_name(key)
167
177
  existing[:fields] = existing[:fields].merge(normalize_fields(table_data[:fields] || {}))
178
+ existing[:fields] = id_field.merge(existing[:fields]) unless core_table?(key) || existing[:fields].key?("id")
168
179
  tables[key] = existing
169
180
  end
170
181
  end
171
182
  end
172
183
 
184
+ private_class_method def self.core_table?(key)
185
+ %w[user session account verification].include?(key.to_s)
186
+ end
187
+
173
188
  private_class_method def self.normalize_fields(fields)
174
189
  fields.each_with_object({}) do |(raw_key, raw_value), result|
175
190
  key = storage_key(raw_key)
@@ -270,6 +285,24 @@ module BetterAuth
270
285
  underscore(value.to_s)
271
286
  end
272
287
 
288
+ private_class_method def self.physical_table_name(value)
289
+ pluralize_table_name(physical_name(value))
290
+ end
291
+
292
+ private_class_method def self.pluralize_table_name(value)
293
+ special = {
294
+ "apikey" => "api_keys",
295
+ "api_key" => "api_keys",
296
+ "wallet_address" => "wallet_addresses"
297
+ }
298
+ return special.fetch(value) if special.key?(value)
299
+ return value if value.end_with?("s")
300
+ return "#{value[0...-1]}ies" if value.end_with?("y") && value.match?(/[^aeiou]y\z/)
301
+ return "#{value}es" if value.match?(/(s|x|z|ch|sh)\z/)
302
+
303
+ "#{value}s"
304
+ end
305
+
273
306
  private_class_method def self.camelize_lower(value)
274
307
  parts = underscore(value).split("_")
275
308
  ([parts.first] + parts.drop(1).map(&:capitalize)).join
@@ -45,7 +45,8 @@ module BetterAuth
45
45
  strategy: config[:strategy] || "compact",
46
46
  version: config[:version],
47
47
  cookie_prefix: ctx.context.options.advanced[:cookie_prefix] || "better-auth",
48
- is_secure: ctx.context.auth_cookies[:session_data].name.start_with?(Cookies::SECURE_COOKIE_PREFIX)
48
+ is_secure: ctx.context.auth_cookies[:session_data].name.start_with?(Cookies::SECURE_COOKIE_PREFIX),
49
+ cookie_full_name: ctx.context.auth_cookies[:session_data].name
49
50
  )
50
51
  return nil unless payload
51
52
  return nil if payload["session"]["token"] && payload["session"]["token"] != token