better_auth 0.2.0 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +32 -0
- data/README.md +5 -3
- data/lib/better_auth/adapters/internal_adapter.rb +173 -20
- data/lib/better_auth/adapters/memory.rb +61 -12
- data/lib/better_auth/adapters/mongodb.rb +5 -365
- data/lib/better_auth/adapters/sql.rb +44 -3
- data/lib/better_auth/api.rb +7 -2
- data/lib/better_auth/async.rb +70 -0
- data/lib/better_auth/context.rb +2 -1
- data/lib/better_auth/database_hooks.rb +3 -3
- data/lib/better_auth/deprecate.rb +28 -0
- data/lib/better_auth/endpoint.rb +5 -2
- data/lib/better_auth/host.rb +166 -0
- data/lib/better_auth/instrumentation.rb +74 -0
- data/lib/better_auth/logger.rb +31 -0
- data/lib/better_auth/middleware/origin_check.rb +2 -2
- data/lib/better_auth/oauth2.rb +94 -0
- data/lib/better_auth/plugin.rb +14 -1
- data/lib/better_auth/plugins/email_otp.rb +16 -5
- data/lib/better_auth/plugins/generic_oauth.rb +14 -28
- data/lib/better_auth/plugins/oauth_protocol.rb +553 -64
- data/lib/better_auth/plugins/organization/schema.rb +6 -0
- data/lib/better_auth/plugins/organization.rb +56 -20
- data/lib/better_auth/plugins/two_factor.rb +53 -18
- data/lib/better_auth/rate_limiter.rb +37 -2
- data/lib/better_auth/request_state.rb +44 -0
- data/lib/better_auth/router.rb +14 -1
- data/lib/better_auth/routes/account.rb +16 -4
- data/lib/better_auth/routes/email_verification.rb +5 -2
- data/lib/better_auth/routes/password.rb +21 -1
- data/lib/better_auth/routes/session.rb +27 -4
- data/lib/better_auth/routes/sign_in.rb +3 -1
- data/lib/better_auth/routes/sign_up.rb +60 -1
- data/lib/better_auth/routes/social.rb +231 -22
- data/lib/better_auth/routes/user.rb +23 -5
- data/lib/better_auth/schema/sql.rb +11 -0
- data/lib/better_auth/schema.rb +16 -0
- data/lib/better_auth/session.rb +12 -1
- data/lib/better_auth/social_providers/apple.rb +44 -8
- data/lib/better_auth/social_providers/atlassian.rb +32 -0
- data/lib/better_auth/social_providers/base.rb +262 -4
- data/lib/better_auth/social_providers/cognito.rb +32 -0
- data/lib/better_auth/social_providers/discord.rb +27 -5
- data/lib/better_auth/social_providers/dropbox.rb +33 -0
- data/lib/better_auth/social_providers/facebook.rb +35 -0
- data/lib/better_auth/social_providers/figma.rb +31 -0
- data/lib/better_auth/social_providers/github.rb +21 -6
- data/lib/better_auth/social_providers/gitlab.rb +16 -3
- data/lib/better_auth/social_providers/google.rb +38 -13
- data/lib/better_auth/social_providers/huggingface.rb +31 -0
- data/lib/better_auth/social_providers/kakao.rb +32 -0
- data/lib/better_auth/social_providers/kick.rb +32 -0
- data/lib/better_auth/social_providers/line.rb +33 -0
- data/lib/better_auth/social_providers/linear.rb +44 -0
- data/lib/better_auth/social_providers/linkedin.rb +30 -0
- data/lib/better_auth/social_providers/microsoft_entra_id.rb +79 -7
- data/lib/better_auth/social_providers/naver.rb +31 -0
- data/lib/better_auth/social_providers/notion.rb +33 -0
- data/lib/better_auth/social_providers/paybin.rb +31 -0
- data/lib/better_auth/social_providers/paypal.rb +36 -0
- data/lib/better_auth/social_providers/polar.rb +31 -0
- data/lib/better_auth/social_providers/railway.rb +49 -0
- data/lib/better_auth/social_providers/reddit.rb +32 -0
- data/lib/better_auth/social_providers/roblox.rb +31 -0
- data/lib/better_auth/social_providers/salesforce.rb +38 -0
- data/lib/better_auth/social_providers/slack.rb +30 -0
- data/lib/better_auth/social_providers/spotify.rb +31 -0
- data/lib/better_auth/social_providers/tiktok.rb +35 -0
- data/lib/better_auth/social_providers/twitch.rb +39 -0
- data/lib/better_auth/social_providers/twitter.rb +32 -0
- data/lib/better_auth/social_providers/vercel.rb +47 -0
- data/lib/better_auth/social_providers/vk.rb +34 -0
- data/lib/better_auth/social_providers/wechat.rb +104 -0
- data/lib/better_auth/social_providers/zoom.rb +31 -0
- data/lib/better_auth/social_providers.rb +29 -0
- data/lib/better_auth/url_helpers.rb +195 -0
- data/lib/better_auth/version.rb +1 -1
- data/lib/better_auth.rb +8 -1
- metadata +38 -15
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "ipaddr"
|
|
4
|
+
|
|
5
|
+
module BetterAuth
|
|
6
|
+
module Host
|
|
7
|
+
CLOUD_METADATA_HOSTS = [
|
|
8
|
+
"metadata.google.internal",
|
|
9
|
+
"metadata.goog",
|
|
10
|
+
"metadata",
|
|
11
|
+
"instance-data",
|
|
12
|
+
"instance-data.ec2.internal"
|
|
13
|
+
].freeze
|
|
14
|
+
|
|
15
|
+
module_function
|
|
16
|
+
|
|
17
|
+
def classify_host(host)
|
|
18
|
+
canonical_input = normalize_input(host)
|
|
19
|
+
lowered = canonical_input.downcase
|
|
20
|
+
return {kind: :reserved, literal: :fqdn, canonical: ""} if lowered.empty?
|
|
21
|
+
|
|
22
|
+
address = parse_ip(lowered)
|
|
23
|
+
unless address
|
|
24
|
+
return {kind: :localhost, literal: :fqdn, canonical: lowered} if lowered == "localhost" || lowered.end_with?(".localhost")
|
|
25
|
+
return {kind: :cloud_metadata, literal: :fqdn, canonical: lowered} if CLOUD_METADATA_HOSTS.include?(lowered)
|
|
26
|
+
|
|
27
|
+
return {kind: :public, literal: :fqdn, canonical: lowered}
|
|
28
|
+
end
|
|
29
|
+
|
|
30
|
+
native = address.respond_to?(:native) ? address.native : address
|
|
31
|
+
if native.ipv4?
|
|
32
|
+
canonical = native.to_s
|
|
33
|
+
return {kind: classify_ipv4(canonical), literal: :ipv4, canonical: canonical}
|
|
34
|
+
end
|
|
35
|
+
|
|
36
|
+
canonical = expanded_ipv6(native)
|
|
37
|
+
{kind: classify_ipv6(canonical), literal: :ipv6, canonical: canonical}
|
|
38
|
+
end
|
|
39
|
+
|
|
40
|
+
def loopback_ip?(host)
|
|
41
|
+
classify_host(host)[:kind] == :loopback
|
|
42
|
+
end
|
|
43
|
+
|
|
44
|
+
def loopback_host?(host)
|
|
45
|
+
[:loopback, :localhost].include?(classify_host(host)[:kind])
|
|
46
|
+
end
|
|
47
|
+
|
|
48
|
+
def public_routable_host?(host)
|
|
49
|
+
classify_host(host)[:kind] == :public
|
|
50
|
+
end
|
|
51
|
+
|
|
52
|
+
def normalize_input(host)
|
|
53
|
+
value = host.to_s.strip
|
|
54
|
+
value = strip_port(value)
|
|
55
|
+
value = value[1...-1] if value.start_with?("[") && value.end_with?("]")
|
|
56
|
+
value = value.split("%", 2).first || ""
|
|
57
|
+
value.gsub(/\.+\z/, "")
|
|
58
|
+
end
|
|
59
|
+
|
|
60
|
+
def strip_port(host)
|
|
61
|
+
if host.start_with?("[")
|
|
62
|
+
closing = host.index("]")
|
|
63
|
+
return host unless closing
|
|
64
|
+
|
|
65
|
+
return host[0..closing] if host[(closing + 1)..]&.match?(/\A:\d+\z/)
|
|
66
|
+
return host
|
|
67
|
+
end
|
|
68
|
+
|
|
69
|
+
first_colon = host.index(":")
|
|
70
|
+
return host unless first_colon
|
|
71
|
+
return host if host.index(":", first_colon + 1)
|
|
72
|
+
|
|
73
|
+
host[0...first_colon]
|
|
74
|
+
end
|
|
75
|
+
|
|
76
|
+
def parse_ip(host)
|
|
77
|
+
IPAddr.new(host)
|
|
78
|
+
rescue ArgumentError
|
|
79
|
+
nil
|
|
80
|
+
end
|
|
81
|
+
|
|
82
|
+
def classify_ipv4(ip)
|
|
83
|
+
return :unspecified if ip == "0.0.0.0"
|
|
84
|
+
return :broadcast if ip == "255.255.255.255"
|
|
85
|
+
|
|
86
|
+
value = ipv4_to_i(ip)
|
|
87
|
+
return :loopback if ipv4_range?(value, "127.0.0.0", 8)
|
|
88
|
+
return :private if ipv4_range?(value, "10.0.0.0", 8)
|
|
89
|
+
return :private if ipv4_range?(value, "172.16.0.0", 12)
|
|
90
|
+
return :private if ipv4_range?(value, "192.168.0.0", 16)
|
|
91
|
+
return :link_local if ipv4_range?(value, "169.254.0.0", 16)
|
|
92
|
+
return :shared_address_space if ipv4_range?(value, "100.64.0.0", 10)
|
|
93
|
+
return :documentation if ipv4_range?(value, "192.0.2.0", 24)
|
|
94
|
+
return :documentation if ipv4_range?(value, "198.51.100.0", 24)
|
|
95
|
+
return :documentation if ipv4_range?(value, "203.0.113.0", 24)
|
|
96
|
+
return :benchmarking if ipv4_range?(value, "198.18.0.0", 15)
|
|
97
|
+
return :multicast if ipv4_range?(value, "224.0.0.0", 4)
|
|
98
|
+
return :reserved if ipv4_range?(value, "0.0.0.0", 8)
|
|
99
|
+
return :reserved if ipv4_range?(value, "192.0.0.0", 24)
|
|
100
|
+
return :reserved if ipv4_range?(value, "240.0.0.0", 4)
|
|
101
|
+
|
|
102
|
+
:public
|
|
103
|
+
end
|
|
104
|
+
|
|
105
|
+
def ipv4_to_i(ip)
|
|
106
|
+
ip.split(".").map(&:to_i).reduce(0) { |sum, part| (sum << 8) + part }
|
|
107
|
+
end
|
|
108
|
+
|
|
109
|
+
def ipv4_range?(value, prefix, length)
|
|
110
|
+
mask = (length == 32) ? 0xffffffff : ((0xffffffff << (32 - length)) & 0xffffffff)
|
|
111
|
+
(value & mask) == (ipv4_to_i(prefix) & mask)
|
|
112
|
+
end
|
|
113
|
+
|
|
114
|
+
def classify_ipv6(expanded)
|
|
115
|
+
return :unspecified if expanded == "0000:0000:0000:0000:0000:0000:0000:0000"
|
|
116
|
+
return :loopback if expanded == "0000:0000:0000:0000:0000:0000:0000:0001"
|
|
117
|
+
|
|
118
|
+
first_byte = expanded[0, 2].to_i(16)
|
|
119
|
+
second_byte = expanded[2, 2].to_i(16)
|
|
120
|
+
|
|
121
|
+
return :multicast if first_byte == 0xff
|
|
122
|
+
return :link_local if first_byte == 0xfe && (second_byte & 0xc0) == 0x80
|
|
123
|
+
return :private if (first_byte & 0xfe) == 0xfc
|
|
124
|
+
return :documentation if expanded.start_with?("2001:0db8:")
|
|
125
|
+
|
|
126
|
+
if expanded.start_with?("2002:")
|
|
127
|
+
embedded = embedded_ipv4(expanded, 1)
|
|
128
|
+
return (classify_ipv4(embedded) == :public) ? :public : :reserved if embedded
|
|
129
|
+
end
|
|
130
|
+
|
|
131
|
+
if expanded.start_with?("0064:ff9b:0000:0000:0000:0000:")
|
|
132
|
+
embedded = embedded_ipv4(expanded, 6)
|
|
133
|
+
return :reserved if embedded
|
|
134
|
+
end
|
|
135
|
+
|
|
136
|
+
if expanded.start_with?("2001:0000:")
|
|
137
|
+
embedded = embedded_ipv4(expanded, 6, xor: true)
|
|
138
|
+
return :reserved if embedded
|
|
139
|
+
end
|
|
140
|
+
|
|
141
|
+
return :reserved if expanded.start_with?("0100:0000:0000:0000:")
|
|
142
|
+
|
|
143
|
+
:public
|
|
144
|
+
end
|
|
145
|
+
|
|
146
|
+
def embedded_ipv4(expanded, start_group, xor: false)
|
|
147
|
+
groups = expanded.split(":")
|
|
148
|
+
combined = (groups.fetch(start_group).to_i(16) << 16) | groups.fetch(start_group + 1).to_i(16)
|
|
149
|
+
combined ^= 0xffffffff if xor
|
|
150
|
+
[
|
|
151
|
+
(combined >> 24) & 0xff,
|
|
152
|
+
(combined >> 16) & 0xff,
|
|
153
|
+
(combined >> 8) & 0xff,
|
|
154
|
+
combined & 0xff
|
|
155
|
+
].join(".")
|
|
156
|
+
rescue IndexError
|
|
157
|
+
nil
|
|
158
|
+
end
|
|
159
|
+
|
|
160
|
+
def expanded_ipv6(address)
|
|
161
|
+
address.hton.bytes.each_slice(2).map do |high, low|
|
|
162
|
+
((high << 8) + low).to_s(16).rjust(4, "0")
|
|
163
|
+
end.join(":")
|
|
164
|
+
end
|
|
165
|
+
end
|
|
166
|
+
end
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module BetterAuth
|
|
4
|
+
module Instrumentation
|
|
5
|
+
module SpanStatusCode
|
|
6
|
+
UNSET = 0
|
|
7
|
+
OK = 1
|
|
8
|
+
ERROR = 2
|
|
9
|
+
end
|
|
10
|
+
|
|
11
|
+
class NoopSpan
|
|
12
|
+
def set_attribute(_key, _value)
|
|
13
|
+
self
|
|
14
|
+
end
|
|
15
|
+
|
|
16
|
+
def set_attributes(_attributes)
|
|
17
|
+
self
|
|
18
|
+
end
|
|
19
|
+
|
|
20
|
+
def record_exception(_error)
|
|
21
|
+
self
|
|
22
|
+
end
|
|
23
|
+
|
|
24
|
+
def set_status(_status)
|
|
25
|
+
self
|
|
26
|
+
end
|
|
27
|
+
|
|
28
|
+
def add_event(_name, _attributes = nil)
|
|
29
|
+
self
|
|
30
|
+
end
|
|
31
|
+
|
|
32
|
+
def end
|
|
33
|
+
self
|
|
34
|
+
end
|
|
35
|
+
end
|
|
36
|
+
|
|
37
|
+
class NoopTracer
|
|
38
|
+
def start_active_span(_name, attributes: {}, &block)
|
|
39
|
+
span = NoopSpan.new
|
|
40
|
+
return span unless block
|
|
41
|
+
|
|
42
|
+
block.call(span)
|
|
43
|
+
ensure
|
|
44
|
+
span&.end
|
|
45
|
+
end
|
|
46
|
+
end
|
|
47
|
+
|
|
48
|
+
class Trace
|
|
49
|
+
def get_tracer(_name = "better-auth")
|
|
50
|
+
NoopTracer.new
|
|
51
|
+
end
|
|
52
|
+
|
|
53
|
+
def get_active_span
|
|
54
|
+
NoopSpan.new
|
|
55
|
+
end
|
|
56
|
+
end
|
|
57
|
+
|
|
58
|
+
module_function
|
|
59
|
+
|
|
60
|
+
def trace
|
|
61
|
+
@trace ||= Trace.new
|
|
62
|
+
end
|
|
63
|
+
|
|
64
|
+
def with_span(name, attributes: {}, &block)
|
|
65
|
+
trace.get_tracer("better-auth").start_active_span(name, attributes: attributes) do |span|
|
|
66
|
+
block.call(span)
|
|
67
|
+
rescue => error
|
|
68
|
+
span.record_exception(error)
|
|
69
|
+
span.set_status(SpanStatusCode::ERROR)
|
|
70
|
+
raise
|
|
71
|
+
end
|
|
72
|
+
end
|
|
73
|
+
end
|
|
74
|
+
end
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module BetterAuth
|
|
4
|
+
module Logger
|
|
5
|
+
LEVELS = [:debug, :info, :success, :warn, :error].freeze
|
|
6
|
+
|
|
7
|
+
Internal = Struct.new(:level, :disabled, :handler, keyword_init: true) do
|
|
8
|
+
LEVELS.each do |log_level|
|
|
9
|
+
define_method(log_level) do |message, *args|
|
|
10
|
+
return if disabled || !Logger.should_publish?(level, log_level)
|
|
11
|
+
|
|
12
|
+
if handler
|
|
13
|
+
handler.call((log_level == :success) ? :info : log_level, message, *args)
|
|
14
|
+
else
|
|
15
|
+
Kernel.warn("#{log_level.upcase} [Better Auth]: #{message}")
|
|
16
|
+
end
|
|
17
|
+
end
|
|
18
|
+
end
|
|
19
|
+
end
|
|
20
|
+
|
|
21
|
+
module_function
|
|
22
|
+
|
|
23
|
+
def should_publish?(current_log_level, log_level)
|
|
24
|
+
LEVELS.index(log_level.to_sym).to_i >= LEVELS.index(current_log_level.to_sym).to_i
|
|
25
|
+
end
|
|
26
|
+
|
|
27
|
+
def create(level: :warn, disabled: false, log: nil, **)
|
|
28
|
+
Internal.new(level: level.to_sym, disabled: disabled, handler: log)
|
|
29
|
+
end
|
|
30
|
+
end
|
|
31
|
+
end
|
|
@@ -14,7 +14,7 @@ module BetterAuth
|
|
|
14
14
|
|
|
15
15
|
validate_origin(endpoint_context)
|
|
16
16
|
validate_fetch_metadata(endpoint_context)
|
|
17
|
-
return if skip_origin_check?(endpoint_context)
|
|
17
|
+
return if skip_origin_check?(endpoint_context) || skip_origin_path?(endpoint_context)
|
|
18
18
|
|
|
19
19
|
validate_callback_urls(endpoint_context)
|
|
20
20
|
nil
|
|
@@ -87,7 +87,7 @@ module BetterAuth
|
|
|
87
87
|
end
|
|
88
88
|
|
|
89
89
|
def skip_origin_check?(endpoint_context)
|
|
90
|
-
|
|
90
|
+
endpoint_context.context.options.advanced[:disable_origin_check] == true
|
|
91
91
|
end
|
|
92
92
|
|
|
93
93
|
def skip_csrf_for_backward_compat?(endpoint_context)
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "base64"
|
|
4
|
+
require "net/http"
|
|
5
|
+
require "uri"
|
|
6
|
+
require "jwt"
|
|
7
|
+
|
|
8
|
+
module BetterAuth
|
|
9
|
+
module OAuth2
|
|
10
|
+
module_function
|
|
11
|
+
|
|
12
|
+
def validate_token(token, jwks:, audience: nil, issuer: nil)
|
|
13
|
+
header = JWT.decode(token, nil, false).last
|
|
14
|
+
kid = header["kid"]
|
|
15
|
+
raise APIError.new("UNAUTHORIZED", message: "Missing jwt kid") if kid.to_s.empty?
|
|
16
|
+
|
|
17
|
+
key_data = Array(jwks["keys"] || jwks[:keys]).find { |key| (key["kid"] || key[:kid]).to_s == kid.to_s }
|
|
18
|
+
raise APIError.new("UNAUTHORIZED", message: "kid doesn't match any key") unless key_data
|
|
19
|
+
|
|
20
|
+
public_key = JWT::JWK.import(stringify_keys(key_data)).public_key
|
|
21
|
+
algorithm = header["alg"] || key_data["alg"] || key_data[:alg]
|
|
22
|
+
options = {algorithm: algorithm}
|
|
23
|
+
options[:aud] = audience if audience
|
|
24
|
+
options[:verify_aud] = true if audience
|
|
25
|
+
options[:iss] = issuer if issuer
|
|
26
|
+
options[:verify_iss] = true if issuer
|
|
27
|
+
JWT.decode(token, public_key, true, **options).first
|
|
28
|
+
rescue JWT::DecodeError => error
|
|
29
|
+
raise APIError.new("UNAUTHORIZED", message: error.message)
|
|
30
|
+
end
|
|
31
|
+
|
|
32
|
+
def refresh_access_token(refresh_token:, token_endpoint:, options:, authentication: nil, extra_params: nil, resource: nil, fetcher: nil)
|
|
33
|
+
request = create_refresh_access_token_request(
|
|
34
|
+
refresh_token: refresh_token,
|
|
35
|
+
options: options,
|
|
36
|
+
authentication: authentication,
|
|
37
|
+
extra_params: extra_params,
|
|
38
|
+
resource: resource
|
|
39
|
+
)
|
|
40
|
+
data = fetcher ? fetcher.call(token_endpoint, request) : post_form(token_endpoint, request)
|
|
41
|
+
now = Time.now
|
|
42
|
+
tokens = {
|
|
43
|
+
access_token: data["access_token"] || data[:access_token],
|
|
44
|
+
refresh_token: data["refresh_token"] || data[:refresh_token],
|
|
45
|
+
token_type: data["token_type"] || data[:token_type],
|
|
46
|
+
scopes: (data["scope"] || data[:scope])&.split(" "),
|
|
47
|
+
id_token: data["id_token"] || data[:id_token]
|
|
48
|
+
}.compact
|
|
49
|
+
|
|
50
|
+
expires_in = data["expires_in"] || data[:expires_in]
|
|
51
|
+
tokens[:access_token_expires_at] = now + expires_in.to_i if expires_in
|
|
52
|
+
|
|
53
|
+
refresh_expires_in = data["refresh_token_expires_in"] || data[:refresh_token_expires_in]
|
|
54
|
+
tokens[:refresh_token_expires_at] = now + refresh_expires_in.to_i if refresh_expires_in
|
|
55
|
+
tokens
|
|
56
|
+
end
|
|
57
|
+
|
|
58
|
+
def create_refresh_access_token_request(refresh_token:, options:, authentication: nil, extra_params: nil, resource: nil)
|
|
59
|
+
body = {
|
|
60
|
+
"grant_type" => "refresh_token",
|
|
61
|
+
"refresh_token" => refresh_token
|
|
62
|
+
}
|
|
63
|
+
headers = {
|
|
64
|
+
"content-type" => "application/x-www-form-urlencoded",
|
|
65
|
+
"accept" => "application/json"
|
|
66
|
+
}
|
|
67
|
+
client_id = Array(options[:client_id] || options["client_id"] || options[:clientId] || options["clientId"]).first
|
|
68
|
+
client_secret = options[:client_secret] || options["client_secret"] || options[:clientSecret] || options["clientSecret"]
|
|
69
|
+
|
|
70
|
+
if authentication.to_s == "basic"
|
|
71
|
+
headers["authorization"] = "Basic #{Base64.strict_encode64("#{client_id}:#{client_secret}")}"
|
|
72
|
+
else
|
|
73
|
+
body["client_id"] = client_id if client_id
|
|
74
|
+
body["client_secret"] = client_secret if client_secret
|
|
75
|
+
end
|
|
76
|
+
|
|
77
|
+
Array(resource).each { |entry| (body["resource"] ||= []) << entry } if resource
|
|
78
|
+
extra_params&.each { |key, value| body[key.to_s] = value }
|
|
79
|
+
{body: body, headers: headers}
|
|
80
|
+
end
|
|
81
|
+
|
|
82
|
+
def post_form(token_endpoint, request)
|
|
83
|
+
uri = URI.parse(token_endpoint)
|
|
84
|
+
response = Net::HTTP.post(uri, URI.encode_www_form(request[:body]), request[:headers])
|
|
85
|
+
JSON.parse(response.body)
|
|
86
|
+
end
|
|
87
|
+
|
|
88
|
+
def stringify_keys(hash)
|
|
89
|
+
hash.each_with_object({}) do |(key, value), result|
|
|
90
|
+
result[key.to_s] = value.is_a?(Hash) ? stringify_keys(value) : value
|
|
91
|
+
end
|
|
92
|
+
end
|
|
93
|
+
end
|
|
94
|
+
end
|
data/lib/better_auth/plugin.rb
CHANGED
|
@@ -11,6 +11,8 @@ module BetterAuth
|
|
|
11
11
|
:schema,
|
|
12
12
|
:migrations,
|
|
13
13
|
:options,
|
|
14
|
+
:version,
|
|
15
|
+
:client,
|
|
14
16
|
:rate_limit,
|
|
15
17
|
:error_codes,
|
|
16
18
|
:on_request,
|
|
@@ -28,7 +30,8 @@ module BetterAuth
|
|
|
28
30
|
|
|
29
31
|
def initialize(data = {}, **keywords)
|
|
30
32
|
data = data.to_h if data.respond_to?(:to_h) && !data.is_a?(Hash)
|
|
31
|
-
|
|
33
|
+
input = (data || {}).merge(keywords)
|
|
34
|
+
raw = normalize_hash(input)
|
|
32
35
|
|
|
33
36
|
@id = raw[:id].to_s
|
|
34
37
|
@init = raw[:init]
|
|
@@ -38,6 +41,8 @@ module BetterAuth
|
|
|
38
41
|
@schema = raw[:schema] || {}
|
|
39
42
|
@migrations = raw[:migrations] || {}
|
|
40
43
|
@options = raw[:options] || {}
|
|
44
|
+
@version = raw[:version]
|
|
45
|
+
@client = stringify_hash(input[:client] || input["client"])
|
|
41
46
|
@rate_limit = Array(raw[:rate_limit])
|
|
42
47
|
@error_codes = normalize_error_codes(raw)
|
|
43
48
|
@on_request = raw[:on_request]
|
|
@@ -107,6 +112,14 @@ module BetterAuth
|
|
|
107
112
|
end
|
|
108
113
|
end
|
|
109
114
|
|
|
115
|
+
def stringify_hash(value)
|
|
116
|
+
return nil unless value.is_a?(Hash)
|
|
117
|
+
|
|
118
|
+
value.each_with_object({}) do |(key, object), result|
|
|
119
|
+
result[key.to_s] = object.is_a?(Hash) ? stringify_hash(object) : object
|
|
120
|
+
end
|
|
121
|
+
end
|
|
122
|
+
|
|
110
123
|
def normalize_key(key)
|
|
111
124
|
key.to_s
|
|
112
125
|
.delete_prefix("$")
|
|
@@ -193,9 +193,9 @@ module BetterAuth
|
|
|
193
193
|
user = if found
|
|
194
194
|
found[:user]
|
|
195
195
|
else
|
|
196
|
-
raise APIError.new("BAD_REQUEST", message:
|
|
196
|
+
raise APIError.new("BAD_REQUEST", message: EMAIL_OTP_ERROR_CODES["INVALID_OTP"]) if config[:disable_sign_up]
|
|
197
197
|
|
|
198
|
-
ctx.context.internal_adapter.create_user(email_otp_sign_up_user_data(body, email))
|
|
198
|
+
ctx.context.internal_adapter.create_user(email_otp_sign_up_user_data(ctx, body, email))
|
|
199
199
|
end
|
|
200
200
|
|
|
201
201
|
unless user["emailVerified"]
|
|
@@ -419,10 +419,21 @@ module BetterAuth
|
|
|
419
419
|
Array.new(config[:otp_length].to_i) { SecureRandom.random_number(10).to_s }.join
|
|
420
420
|
end
|
|
421
421
|
|
|
422
|
-
def email_otp_sign_up_user_data(body, email)
|
|
422
|
+
def email_otp_sign_up_user_data(ctx, body, email)
|
|
423
423
|
reserved = %i[email otp name image callback_url callbackURL callbackUrl]
|
|
424
|
-
|
|
425
|
-
|
|
424
|
+
user_fields = Schema.auth_tables(ctx.context.options).fetch("user").fetch(:fields)
|
|
425
|
+
core_fields = %w[id name email emailVerified image createdAt updatedAt]
|
|
426
|
+
additional = body.each_with_object({}) do |(key, value), result|
|
|
427
|
+
next if reserved.include?(key.to_sym)
|
|
428
|
+
|
|
429
|
+
field = Schema.storage_key(key)
|
|
430
|
+
attributes = user_fields[field]
|
|
431
|
+
next unless attributes
|
|
432
|
+
next if core_fields.include?(field)
|
|
433
|
+
next if attributes[:input] == false
|
|
434
|
+
|
|
435
|
+
result[field] = value
|
|
436
|
+
end
|
|
426
437
|
additional.merge(
|
|
427
438
|
"email" => email,
|
|
428
439
|
"emailVerified" => true,
|
|
@@ -52,19 +52,7 @@ module BetterAuth
|
|
|
52
52
|
data,
|
|
53
53
|
provider_id: "auth0",
|
|
54
54
|
discovery_url: "https://#{domain}/.well-known/openid-configuration",
|
|
55
|
-
scopes: ["openid", "profile", "email"]
|
|
56
|
-
get_user_info: ->(tokens) {
|
|
57
|
-
profile = generic_oauth_fetch_json("https://#{domain}/userinfo", authorization: "Bearer #{fetch_value(tokens, "accessToken")}")
|
|
58
|
-
return nil unless profile
|
|
59
|
-
|
|
60
|
-
{
|
|
61
|
-
id: fetch_value(profile, "sub"),
|
|
62
|
-
name: fetch_value(profile, "name") || fetch_value(profile, "nickname"),
|
|
63
|
-
email: fetch_value(profile, "email"),
|
|
64
|
-
image: fetch_value(profile, "picture"),
|
|
65
|
-
emailVerified: fetch_value(profile, "email_verified") || false
|
|
66
|
-
}
|
|
67
|
-
}
|
|
55
|
+
scopes: ["openid", "profile", "email"]
|
|
68
56
|
)
|
|
69
57
|
end
|
|
70
58
|
|
|
@@ -688,24 +676,12 @@ module BetterAuth
|
|
|
688
676
|
nil
|
|
689
677
|
end
|
|
690
678
|
|
|
691
|
-
def generic_oidc_helper_provider(options, provider_id, issuer, discovery_url,
|
|
679
|
+
def generic_oidc_helper_provider(options, provider_id, issuer, discovery_url, _user_info_url)
|
|
692
680
|
generic_oauth_provider_config(
|
|
693
681
|
options,
|
|
694
682
|
provider_id: provider_id,
|
|
695
683
|
discovery_url: discovery_url,
|
|
696
|
-
scopes: ["openid", "profile", "email"]
|
|
697
|
-
get_user_info: ->(tokens) {
|
|
698
|
-
profile = generic_oauth_fetch_json(user_info_url, authorization: "Bearer #{fetch_value(tokens, "accessToken")}")
|
|
699
|
-
return nil unless profile
|
|
700
|
-
|
|
701
|
-
{
|
|
702
|
-
id: fetch_value(profile, "sub"),
|
|
703
|
-
name: fetch_value(profile, "name") || fetch_value(profile, "preferred_username"),
|
|
704
|
-
email: fetch_value(profile, "email"),
|
|
705
|
-
image: fetch_value(profile, "picture"),
|
|
706
|
-
emailVerified: fetch_value(profile, "email_verified") || false
|
|
707
|
-
}
|
|
708
|
-
}
|
|
684
|
+
scopes: ["openid", "profile", "email"]
|
|
709
685
|
)
|
|
710
686
|
end
|
|
711
687
|
|
|
@@ -730,12 +706,22 @@ module BetterAuth
|
|
|
730
706
|
result[provider_id.to_sym] = {
|
|
731
707
|
id: provider_id,
|
|
732
708
|
name: provider_id,
|
|
733
|
-
get_user_info: ->(tokens) {
|
|
709
|
+
get_user_info: ->(tokens) { generic_oauth_provider_user_info(provider, tokens) },
|
|
734
710
|
refresh_access_token: ->(refresh_token) { generic_oauth_refresh_access_token(context, provider, refresh_token) }
|
|
735
711
|
}
|
|
736
712
|
end
|
|
737
713
|
end
|
|
738
714
|
|
|
715
|
+
def generic_oauth_provider_user_info(provider, tokens)
|
|
716
|
+
user_info = generic_oauth_user_info(provider, tokens)
|
|
717
|
+
return nil unless user_info
|
|
718
|
+
|
|
719
|
+
{
|
|
720
|
+
user: generic_oauth_map_user(provider, user_info),
|
|
721
|
+
data: user_info
|
|
722
|
+
}
|
|
723
|
+
end
|
|
724
|
+
|
|
739
725
|
def generic_oauth_refresh_access_token(ctx, provider, refresh_token)
|
|
740
726
|
token_url = provider[:token_url] || generic_oauth_discovery(provider)["token_endpoint"]
|
|
741
727
|
raise APIError.new("BAD_REQUEST", message: GENERIC_OAUTH_ERROR_CODES["TOKEN_URL_NOT_FOUND"]) if token_url.to_s.empty?
|