omniauth_oidc 0.2.7 → 1.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,120 +3,147 @@
3
3
  module OmniAuth
4
4
  module Strategies
5
5
  class Oidc
6
- # Callback phase
7
- module Callback
8
- def callback_phase # rubocop:disable Metrics
9
- error = params["error_reason"] || params["error"]
10
- error_description = params["error_description"] || params["error_reason"]
11
- invalid_state = (options.require_state && params["state"].to_s.empty?) || params["state"] != stored_state
6
+ # Callback phase - handles OIDC provider response
7
+ module Callback # rubocop:disable Metrics/ModuleLength
8
+ def callback_phase
9
+ OmniauthOidc::Logging.instrument("callback_phase.start", provider: name) do
10
+ handle_callback_errors do
11
+ validate_callback_params!
12
12
 
13
- raise CallbackError, error: params["error"], reason: error_description, uri: params["error_uri"] if error
14
- raise CallbackError, error: :csrf_detected, reason: "Invalid 'state' parameter" if invalid_state
13
+ options.issuer = issuer if options.issuer.nil? || options.issuer.empty?
15
14
 
16
- return unless valid_response_type?
15
+ verify_id_token!(params["id_token"]) if configured_response_type == "id_token"
17
16
 
18
- options.issuer = issuer if options.issuer.nil? || options.issuer.empty?
17
+ client.redirect_uri = redirect_uri
19
18
 
20
- verify_id_token!(params["id_token"]) if configured_response_type == "id_token"
19
+ if configured_response_type == "id_token"
20
+ handle_id_token_response
21
+ else
22
+ handle_code_response
23
+ end
21
24
 
22
- client.redirect_uri = redirect_uri
23
-
24
- return id_token_callback_phase if configured_response_type == "id_token"
25
+ super
26
+ end
27
+ end
28
+ end
25
29
 
26
- client.authorization_code = authorization_code
30
+ private
27
31
 
28
- access_token
32
+ def handle_callback_errors
33
+ yield
29
34
  rescue CallbackError => e
35
+ OmniauthOidc::Logging.error("Callback error", error: e.error, reason: e.error_reason)
30
36
  fail!(e.error, e)
31
- rescue ::Rack::OAuth2::Client::Error => e
32
- fail!(e.response[:error], e)
37
+ rescue OmniauthOidc::TokenError => e
38
+ OmniauthOidc::Logging.error("Token error", error: e.class.name, message: e.message)
39
+ fail!(:token_error, e)
40
+ rescue OmniauthOidc::HttpClient::HttpError => e
41
+ OmniauthOidc::Logging.error("HTTP error", message: e.message)
42
+ fail!(:http_error, e)
33
43
  rescue ::Timeout::Error, ::Errno::ETIMEDOUT => e
44
+ OmniauthOidc::Logging.error("Timeout error", message: e.message)
34
45
  fail!(:timeout, e)
35
46
  rescue ::SocketError => e
47
+ OmniauthOidc::Logging.error("Connection error", message: e.message)
36
48
  fail!(:failed_to_connect, e)
37
49
  end
38
50
 
39
- private
40
-
41
- def access_token
42
- return @access_token if @access_token
43
-
44
- token_request_params = {
45
- scope: (scope if options.send_scope_to_token_endpoint),
46
- client_auth_method: options.client_auth_method
47
- }
51
+ def validate_callback_params! # rubocop:disable Naming/PredicateMethod
52
+ error = params["error_reason"] || params["error"]
53
+ error_description = params["error_description"] || params["error_reason"]
54
+ invalid_state = (options.require_state && params["state"].to_s.empty?) || params["state"] != stored_state
48
55
 
49
- if options.pkce
50
- token_request_params[:code_verifier] =
51
- params["code_verifier"] || session.delete("omniauth.pkce.verifier")
56
+ if error
57
+ raise CallbackError, error: params["error"], error_reason: error_description,
58
+ error_uri: params["error_uri"]
52
59
  end
60
+ raise CallbackError, error: :csrf_detected, error_reason: "Invalid 'state' parameter" if invalid_state
53
61
 
54
- set_client_options_for_callback_phase
62
+ valid_response_type?
63
+ end
64
+
65
+ def handle_code_response
66
+ # Get access token via token exchange
67
+ @access_token = fetch_access_token
55
68
 
56
- @access_token = client.access_token!(token_request_params)
69
+ # Verify the ID token from the token response
70
+ verify_id_token!(@access_token.id_token) if @access_token.id_token
57
71
 
58
- verify_id_token!(@access_token.id_token) if configured_response_type == "code"
72
+ # Fetch and set user info
73
+ @user_info = fetch_user_info
74
+ end
59
75
 
60
- options.fetch_user_info ? user_info_from_access_token : define_access_token
76
+ def handle_id_token_response
77
+ # For id_token response type, extract user data directly from the token
78
+ decoded_token = decode_id_token(params["id_token"])
79
+ @user_info = OmniauthOidc::ResponseObjects::UserInfo.new(decoded_token.raw_attributes)
80
+
81
+ # Create a minimal access token structure for credentials
82
+ @access_token = OmniauthOidc::ResponseObjects::AccessToken.new(
83
+ id_token: params["id_token"],
84
+ access_token: nil,
85
+ refresh_token: nil,
86
+ expires_in: nil,
87
+ scope: nil
88
+ )
61
89
  end
62
90
 
63
- def id_token_callback_phase
64
- user_data = decode_id_token(params["id_token"]).raw_attributes
91
+ def fetch_access_token
92
+ OmniauthOidc::Logging.instrument("token.exchange", provider: name) do
93
+ token_request_params = {
94
+ code: authorization_code,
95
+ redirect_uri: redirect_uri
96
+ }
97
+
98
+ if options.pkce
99
+ token_request_params[:code_verifier] =
100
+ params["code_verifier"] || session.delete(session_key("pkce.verifier"))
101
+ end
65
102
 
66
- define_user_info(user_data)
103
+ set_client_options_for_callback_phase
104
+
105
+ client.access_token!(token_request_params)
106
+ end
67
107
  end
68
108
 
69
- def valid_response_type?
70
- return true if params.key?(configured_response_type)
109
+ def fetch_user_info
110
+ return minimal_user_info_from_token unless options.fetch_user_info
71
111
 
72
- error_attrs = RESPONSE_TYPE_EXCEPTIONS[configured_response_type]
73
- fail!(error_attrs[:key], error_attrs[:exception_class].new(params["error"]))
112
+ OmniauthOidc::Logging.instrument("userinfo.fetch", provider: name) do
113
+ # Use our custom client to fetch userinfo
114
+ userinfo_data = client.userinfo!(@access_token.access_token).raw_attributes
74
115
 
75
- false
116
+ # Merge with ID token claims if available
117
+ if @access_token.id_token
118
+ id_token_claims = decode_id_token(@access_token.id_token).raw_attributes
119
+ userinfo_data = id_token_claims.merge(userinfo_data)
120
+ end
121
+
122
+ OmniauthOidc::ResponseObjects::UserInfo.new(userinfo_data)
123
+ end
124
+ rescue StandardError => e
125
+ OmniauthOidc::Logging.warn("Failed to fetch userinfo, falling back to ID token", error: e.message)
126
+ minimal_user_info_from_token
76
127
  end
77
128
 
78
- def user_info_from_access_token
79
- user_data = HTTParty.get(
80
- config.userinfo_endpoint, {
81
- headers: {
82
- "Authorization" => "Bearer #{@access_token}",
83
- "Content-Type" => "application/json"
84
- }
85
- }
86
- )
129
+ def minimal_user_info_from_token
130
+ return empty_user_info unless @access_token&.id_token
87
131
 
88
- define_user_info(user_data.parsed_response)
132
+ decoded = decode_id_token(@access_token.id_token)
133
+ OmniauthOidc::ResponseObjects::UserInfo.new(decoded.raw_attributes)
89
134
  end
90
135
 
91
- def define_user_info(user_data)
92
- env["omniauth.auth"] = AuthHash.new(
93
- provider: name,
94
- uid: user_data["sub"],
95
- info: { name: user_data["name"], email: user_data["email"] },
96
- extra: { raw_info: user_data },
97
- credentials: {
98
- id_token: @access_token.id_token,
99
- token: @access_token.access_token,
100
- refresh_token: @access_token.refresh_token,
101
- expires_in: @access_token.expires_in,
102
- scope: @access_token.scope
103
- }
104
- )
105
- call_app!
136
+ def empty_user_info
137
+ OmniauthOidc::ResponseObjects::UserInfo.new({})
106
138
  end
107
139
 
108
- def define_access_token
109
- env["omniauth.auth"] = AuthHash.new(
110
- provider: name,
111
- credentials: {
112
- id_token: @access_token.id_token,
113
- token: @access_token.access_token,
114
- refresh_token: @access_token.refresh_token,
115
- expires_in: @access_token.expires_in,
116
- scope: @access_token.scope
117
- }
118
- )
119
- call_app!
140
+ def valid_response_type?
141
+ return true if params.key?(configured_response_type)
142
+
143
+ error_attrs = RESPONSE_TYPE_EXCEPTIONS[configured_response_type]
144
+ fail!(error_attrs[:key], error_attrs[:exception_class].new(params["error"]))
145
+
146
+ false
120
147
  end
121
148
 
122
149
  def configured_response_type
@@ -127,9 +154,19 @@ module OmniAuth
127
154
  def set_client_options_for_callback_phase
128
155
  client.host = host
129
156
  client.redirect_uri = redirect_uri
130
- client.authorization_endpoint = resolve_endpoint_from_host(host, config.authorization_endpoint)
131
- client.token_endpoint = resolve_endpoint_from_host(host, config.token_endpoint)
132
- client.userinfo_endpoint = resolve_endpoint_from_host(host, config.userinfo_endpoint)
157
+ client.authorization_endpoint = config.authorization_endpoint
158
+ client.token_endpoint = config.token_endpoint
159
+ client.userinfo_endpoint = config.userinfo_endpoint
160
+ end
161
+
162
+ # Accessor for OmniAuth DSL blocks
163
+ def user_info
164
+ @user_info
165
+ end
166
+
167
+ # Accessor for OmniAuth DSL blocks
168
+ def access_token
169
+ @access_token
133
170
  end
134
171
  end
135
172
  end
@@ -6,15 +6,16 @@ module OmniAuth
6
6
  # Code request phase
7
7
  module Request
8
8
  def request_phase
9
- @identifier = client_options.identifier
10
- @secret = secret
9
+ OmniauthOidc::Logging.instrument("request_phase.start", provider: name) do
10
+ @identifier = client_options.identifier
11
+ @secret = secret
11
12
 
12
- set_client_options_for_request_phase
13
- redirect authorize_uri
13
+ set_client_options_for_request_phase
14
+ redirect authorize_uri
15
+ end
14
16
  end
15
17
 
16
18
  def authorize_uri # rubocop:disable Metrics/AbcSize
17
- client.redirect_uri = redirect_uri
18
19
  opts = request_options
19
20
 
20
21
  opts.merge!(options.extra_authorize_params) unless options.extra_authorize_params.empty?
@@ -27,10 +28,14 @@ module OmniAuth
27
28
  verifier = options.pkce_verifier ? options.pkce_verifier.call : SecureRandom.hex(64)
28
29
 
29
30
  opts.merge!(pkce_authorize_params(verifier))
30
- session["omniauth.pkce.verifier"] = verifier
31
+ session[session_key("pkce.verifier")] = verifier
31
32
  end
32
33
 
33
- client.authorization_uri(opts.reject { |_k, v| v.nil? })
34
+ # Add redirect_uri and extra_params to opts
35
+ opts[:redirect_uri] = redirect_uri
36
+ opts[:extra_params] = opts.compact
37
+
38
+ client.authorization_uri(opts)
34
39
  end
35
40
 
36
41
  private
@@ -59,21 +64,20 @@ module OmniAuth
59
64
  options.state.call
60
65
  end
61
66
  end
62
- session["omniauth.state"] = state || SecureRandom.hex(16)
67
+ session[session_key("state")] = state || SecureRandom.hex(16)
63
68
  end
64
69
 
65
70
  # Parse response from OIDC endpoint and set client options for request phase
66
- def set_client_options_for_request_phase # rubocop:disable Metrics/AbcSize
71
+ def set_client_options_for_request_phase
67
72
  client_options.host = host
68
- client_options.authorization_endpoint = resolve_endpoint_from_host(host, config.authorization_endpoint)
69
- client_options.token_endpoint = resolve_endpoint_from_host(host, config.token_endpoint)
70
- client_options.userinfo_endpoint = resolve_endpoint_from_host(host, config.userinfo_endpoint)
71
- client_options.jwks_uri = resolve_endpoint_from_host(host, config.jwks_uri)
73
+ client_options.authorization_endpoint = config.authorization_endpoint
74
+ client_options.token_endpoint = config.token_endpoint
75
+ client_options.userinfo_endpoint = config.userinfo_endpoint
76
+ client_options.jwks_uri = config.jwks_uri
72
77
 
73
- return unless config.respond_to?(:end_session_endpoint)
78
+ return unless config.respond_to?(:end_session_endpoint) && config.end_session_endpoint
74
79
 
75
- client_options.end_session_endpoint = resolve_endpoint_from_host(host,
76
- config.end_session_endpoint)
80
+ client_options.end_session_endpoint = config.end_session_endpoint
77
81
  end
78
82
  end
79
83
  end
@@ -27,10 +27,27 @@ module OmniAuth
27
27
  end
28
28
  end
29
29
 
30
+ # Force refresh JWKS cache and retry verification
31
+ def public_key_with_refresh
32
+ OmniauthOidc::Logging.info("Force refreshing JWKS cache")
33
+ OmniauthOidc::JwksCache.invalidate(config.jwks_uri)
34
+ @public_key = nil
35
+ @fetch_key = nil
36
+ public_key
37
+ end
38
+
30
39
  private
31
40
 
32
41
  def fetch_key
33
- @fetch_key ||= parse_jwk_key(::OpenIDConnect.http_client.get(config.jwks_uri).body)
42
+ @fetch_key ||= OmniauthOidc::JwksCache.instance.fetch(config.jwks_uri) do
43
+ OmniauthOidc::Logging.instrument("jwks.fetch", jwks_uri: config.jwks_uri) do
44
+ response = OmniauthOidc::HttpClient.get(config.jwks_uri)
45
+ OmniauthOidc::JwkHandler.parse_jwks(response)
46
+ end
47
+ end
48
+ rescue StandardError => e
49
+ OmniauthOidc::Logging.error("Failed to fetch JWKS", error: e.message, jwks_uri: config.jwks_uri)
50
+ raise OmniauthOidc::JwksFetchError, "Failed to fetch JWKS from #{config.jwks_uri}: #{e.message}"
34
51
  end
35
52
 
36
53
  def base64_decoded_jwt_secret
@@ -42,72 +59,121 @@ module OmniAuth
42
59
  def verify_id_token!(id_token)
43
60
  return unless id_token
44
61
 
45
- decode_id_token(id_token).verify!(issuer: config.issuer,
46
- client_id: client_options.identifier,
47
- nonce: params["nonce"].presence || stored_nonce)
62
+ OmniauthOidc::Logging.instrument("id_token.verify", provider: name) do
63
+ decoded = decode_id_token(id_token)
64
+ verify_claims!(decoded)
65
+ decoded
66
+ end
48
67
  end
49
68
 
50
- def decode_id_token(id_token)
51
- decoded = JSON::JWT.decode(id_token, :skip_verification)
52
- algorithm = decoded.algorithm.to_sym
69
+ def verify_claims!(decoded_token) # rubocop:disable Metrics/MethodLength
70
+ claims = decoded_token.raw_attributes
53
71
 
54
- validate_client_algorithm!(algorithm)
72
+ # Verify issuer
73
+ if config.issuer && claims["iss"] != config.issuer
74
+ raise OmniauthOidc::InvalidIssuerError,
75
+ "Issuer mismatch. Expected: #{config.issuer}, Got: #{claims["iss"]}"
76
+ end
55
77
 
56
- keyset =
57
- case algorithm
58
- when :HS256, :HS384, :HS512
59
- secret
60
- else
61
- public_key
62
- end
78
+ # Verify audience
79
+ audience = claims["aud"]
80
+ expected_aud = client_options.identifier
81
+ unless audience_matches?(audience, expected_aud)
82
+ raise OmniauthOidc::InvalidAudienceError,
83
+ "Audience mismatch. Expected: #{expected_aud}, Got: #{audience}"
84
+ end
85
+
86
+ # Verify nonce if present
87
+ expected_nonce = params["nonce"].presence || stored_nonce
88
+ if expected_nonce && claims["nonce"] != expected_nonce
89
+ raise OmniauthOidc::InvalidNonceError,
90
+ "Nonce mismatch. Expected: #{expected_nonce}, Got: #{claims["nonce"]}"
91
+ end
63
92
 
64
- decoded.verify!(keyset)
65
- ::OpenIDConnect::ResponseObject::IdToken.new(decoded)
66
- rescue JSON::JWK::Set::KidNotFound
67
- # Workaround for https://github.com/nov/json-jwt/pull/92#issuecomment-824654949
68
- raise if decoded&.header&.key?("kid")
93
+ # Verify expiration
94
+ if claims["exp"] && Time.at(claims["exp"].to_i) < Time.now
95
+ raise OmniauthOidc::TokenExpiredError,
96
+ "Token expired at #{Time.at(claims["exp"].to_i)}"
97
+ end
69
98
 
70
- decoded = decode_with_each_key!(id_token, keyset)
99
+ decoded_token
100
+ end
71
101
 
72
- raise unless decoded
102
+ def audience_matches?(audience, expected)
103
+ return audience == expected if audience.is_a?(String)
104
+ return audience.include?(expected) if audience.is_a?(Array)
73
105
 
74
- decoded
106
+ false
75
107
  end
76
108
 
77
- # Check for jwt to match defined client_signing_alg
78
- def validate_client_algorithm!(algorithm)
79
- client_signing_alg = options.client_signing_alg&.to_sym
109
+ def decode_id_token(id_token)
110
+ # First decode without verification to get the algorithm and kid
111
+ _unverified_payload, unverified_header = JWT.decode(id_token, nil, false)
112
+ algorithm = unverified_header["alg"]
113
+ kid = unverified_header["kid"]
80
114
 
81
- return unless client_signing_alg
82
- return if algorithm == client_signing_alg
115
+ validate_client_algorithm!(algorithm.to_sym)
83
116
 
84
- reason = "Received JWT is signed with #{algorithm}, but client_singing_alg is \
85
- configured for #{client_signing_alg}"
86
- raise CallbackError, error: :invalid_jwt_algorithm, reason: reason, uri: params["error_uri"]
87
- end
117
+ # Get the appropriate key/secret for verification
118
+ key = keyset_for_algorithm(algorithm.to_sym, kid)
88
119
 
89
- def decode!(id_token, key)
90
- ::OpenIDConnect::ResponseObject::IdToken.decode(id_token, key)
120
+ # Decode and verify
121
+ verify_signature!(id_token, key, algorithm)
122
+ rescue JWT::DecodeError => e
123
+ raise OmniauthOidc::TokenVerificationError, "Invalid JWT format: #{e.message}"
91
124
  end
92
125
 
93
- def decode_with_each_key!(id_token, keyset)
94
- return unless keyset.is_a?(JSON::JWK::Set)
95
-
96
- keyset.each do |key|
97
- begin
98
- decoded = decode!(id_token, key)
99
- rescue JSON::JWS::VerificationFailed, JSON::JWS::UnexpectedAlgorithm, JSON::JWK::UnknownAlgorithm
100
- next
126
+ def keyset_for_algorithm(algorithm, kid = nil)
127
+ case algorithm
128
+ when :HS256, :HS384, :HS512
129
+ secret
130
+ else
131
+ keys = public_key
132
+ if keys.is_a?(Array)
133
+ OmniauthOidc::JwkHandler.find_key(keys, kid)
134
+ else
135
+ keys
101
136
  end
102
-
103
- return decoded if decoded
104
137
  end
138
+ end
105
139
 
106
- nil
140
+ def verify_signature!(id_token, key, algorithm)
141
+ # Use jwt gem to decode and verify
142
+ payload, _header = JWT.decode(
143
+ id_token,
144
+ key,
145
+ true, # verify signature
146
+ {
147
+ algorithm: algorithm,
148
+ verify_expiration: false # We verify this manually in verify_claims!
149
+ }
150
+ )
151
+
152
+ # Create our custom IdToken object
153
+ OmniauthOidc::ResponseObjects::IdToken.new(payload.merge("algorithm" => algorithm))
154
+ rescue JWT::VerificationError => e
155
+ # Try refreshing JWKS cache and retry once
156
+ if key.is_a?(Array) && !@signature_retry_attempted
157
+ @signature_retry_attempted = true
158
+ OmniauthOidc::Logging.warn("Signature verification failed, refreshing JWKS and retrying")
159
+ refreshed_key = public_key_with_refresh
160
+ return verify_signature!(id_token, refreshed_key, algorithm)
161
+ end
162
+ raise OmniauthOidc::InvalidSignatureError, "JWT signature verification failed: #{e.message}"
163
+ rescue JWT::IncorrectAlgorithm => e
164
+ raise OmniauthOidc::InvalidAlgorithmError, "Unexpected JWT algorithm: #{e.message}"
107
165
  end
108
166
 
109
- def stored_nonce
110
- session.delete("omniauth.nonce")
167
+ # Check for jwt to match defined client_signing_alg
168
+ def validate_client_algorithm!(algorithm)
169
+ client_signing_alg = options.client_signing_alg&.to_sym
170
+
171
+ return unless client_signing_alg
172
+ return if algorithm == client_signing_alg
173
+
174
+ reason = "Received JWT is signed with #{algorithm}, but client_signing_alg is " \
175
+ "configured for #{client_signing_alg}"
176
+ raise OmniauthOidc::InvalidAlgorithmError, reason
111
177
  end
112
178
 
113
179
  def configured_public_key
@@ -120,31 +186,15 @@ module OmniAuth
120
186
 
121
187
  def parse_x509_key(key)
122
188
  OpenSSL::X509::Certificate.new(key).public_key
189
+ rescue OpenSSL::X509::CertificateError => e
190
+ raise OmniauthOidc::TokenVerificationError, "Invalid X.509 certificate: #{e.message}"
123
191
  end
124
192
 
125
193
  def parse_jwk_key(key)
126
194
  json = key.is_a?(String) ? JSON.parse(key) : key
127
- return JSON::JWK::Set.new(json["keys"]) if json.key?("keys")
128
-
129
- JSON::JWK.new(json)
130
- end
131
-
132
- def decode(str)
133
- UrlSafeBase64.decode64(str).unpack1("B*").to_i(2).to_s
134
- end
135
-
136
- def user_info
137
- return @user_info if @user_info
138
-
139
- if access_token.id_token
140
- decoded = decode_id_token(access_token.id_token).raw_attributes
141
-
142
- @user_info = ::OpenIDConnect::ResponseObject::UserInfo.new(
143
- access_token.userinfo!.raw_attributes.merge(decoded)
144
- )
145
- else
146
- @user_info = access_token.userinfo!
147
- end
195
+ OmniauthOidc::JwkHandler.parse_jwks(json)
196
+ rescue JSON::ParserError => e
197
+ raise OmniauthOidc::TokenVerificationError, "Invalid JWK format: #{e.message}"
148
198
  end
149
199
  end
150
200
  end