jwt_auth_cognito 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/CHANGELOG.md +33 -0
- data/CLAUDE.md +135 -0
- data/Gemfile +10 -0
- data/LICENSE.txt +21 -0
- data/PUBLISH_GUIDE.md +187 -0
- data/README.md +384 -0
- data/Rakefile +11 -0
- data/jwt_auth_cognito.gemspec +52 -0
- data/lib/generators/jwt_auth_cognito/install_generator.rb +88 -0
- data/lib/generators/jwt_auth_cognito/templates/jwt_auth_cognito.rb.erb +129 -0
- data/lib/jwt_auth_cognito/configuration.rb +83 -0
- data/lib/jwt_auth_cognito/jwks_service.rb +141 -0
- data/lib/jwt_auth_cognito/jwt_validator.rb +225 -0
- data/lib/jwt_auth_cognito/railtie.rb +13 -0
- data/lib/jwt_auth_cognito/redis_service.rb +194 -0
- data/lib/jwt_auth_cognito/token_blacklist_service.rb +56 -0
- data/lib/jwt_auth_cognito/version.rb +5 -0
- data/lib/jwt_auth_cognito.rb +41 -0
- data/lib/tasks/jwt_auth_cognito.rake +290 -0
- metadata +207 -0
@@ -0,0 +1,83 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module JwtAuthCognito
|
4
|
+
class Configuration
|
5
|
+
attr_accessor :cognito_user_pool_id, :cognito_region, :cognito_client_id, :cognito_client_secret,
|
6
|
+
:redis_host, :redis_port, :redis_password, :redis_db,
|
7
|
+
:redis_ssl, :redis_timeout, :redis_connect_timeout, :redis_read_timeout,
|
8
|
+
:redis_ca_cert_path, :redis_ca_cert_name, :redis_verify_mode,
|
9
|
+
:redis_tls_min_version, :redis_tls_max_version,
|
10
|
+
:jwks_cache_ttl, :validation_mode, :environment
|
11
|
+
|
12
|
+
def initialize
|
13
|
+
@cognito_region = ENV['COGNITO_REGION'] || ENV['AWS_REGION'] || "us-east-1"
|
14
|
+
@cognito_user_pool_id = ENV['COGNITO_USER_POOL_ID']
|
15
|
+
@cognito_client_id = ENV['COGNITO_CLIENT_ID']
|
16
|
+
@cognito_client_secret = ENV['COGNITO_CLIENT_SECRET']
|
17
|
+
|
18
|
+
# Redis configuration with environment variables
|
19
|
+
@redis_host = ENV['REDIS_HOST'] || "localhost"
|
20
|
+
@redis_port = (ENV['REDIS_PORT'] || 6379).to_i
|
21
|
+
@redis_password = ENV['REDIS_PASSWORD']
|
22
|
+
@redis_db = (ENV['REDIS_DB'] || 0).to_i
|
23
|
+
@redis_ssl = ENV['REDIS_TLS'] == 'true' || ENV['REDIS_SSL'] == 'true'
|
24
|
+
@redis_timeout = (ENV['REDIS_TIMEOUT'] || 5).to_i
|
25
|
+
@redis_connect_timeout = (ENV['REDIS_CONNECT_TIMEOUT'] || 10).to_i
|
26
|
+
@redis_read_timeout = (ENV['REDIS_READ_TIMEOUT'] || 10).to_i
|
27
|
+
|
28
|
+
# TLS specific configuration
|
29
|
+
@redis_ca_cert_path = ENV['REDIS_CA_CERT_PATH']
|
30
|
+
@redis_ca_cert_name = ENV['REDIS_CA_CERT_NAME']
|
31
|
+
@redis_verify_mode = ENV['REDIS_VERIFY_MODE'] || 'peer'
|
32
|
+
@redis_tls_min_version = ENV['REDIS_TLS_MIN_VERSION'] || 'TLSv1.2'
|
33
|
+
@redis_tls_max_version = ENV['REDIS_TLS_MAX_VERSION'] || 'TLSv1.3'
|
34
|
+
|
35
|
+
@jwks_cache_ttl = (ENV['JWKS_CACHE_TTL'] || 3600).to_i # 1 hour
|
36
|
+
@environment = ENV['RAILS_ENV'] || ENV['RACK_ENV'] || ENV['NODE_ENV'] || 'development'
|
37
|
+
@validation_mode = production? ? :secure : :basic
|
38
|
+
end
|
39
|
+
|
40
|
+
def production?
|
41
|
+
@environment == 'production'
|
42
|
+
end
|
43
|
+
|
44
|
+
def development?
|
45
|
+
@environment == 'development'
|
46
|
+
end
|
47
|
+
|
48
|
+
def cognito_issuer
|
49
|
+
"https://cognito-idp.#{cognito_region}.amazonaws.com/#{cognito_user_pool_id}"
|
50
|
+
end
|
51
|
+
|
52
|
+
def jwks_url
|
53
|
+
"#{cognito_issuer}/.well-known/jwks.json"
|
54
|
+
end
|
55
|
+
|
56
|
+
def validate!
|
57
|
+
raise ConfigurationError, "cognito_user_pool_id is required" unless cognito_user_pool_id
|
58
|
+
raise ConfigurationError, "cognito_region is required" unless cognito_region
|
59
|
+
raise ConfigurationError, "redis_host is required" unless redis_host
|
60
|
+
end
|
61
|
+
|
62
|
+
def has_client_secret?
|
63
|
+
!cognito_client_secret.nil? && !cognito_client_secret.empty?
|
64
|
+
end
|
65
|
+
|
66
|
+
def calculate_secret_hash(identifier)
|
67
|
+
return "" unless has_client_secret?
|
68
|
+
return "" unless cognito_client_id
|
69
|
+
|
70
|
+
message = identifier + cognito_client_id
|
71
|
+
|
72
|
+
require 'openssl'
|
73
|
+
require 'base64'
|
74
|
+
|
75
|
+
begin
|
76
|
+
hmac = OpenSSL::HMAC.digest('SHA256', cognito_client_secret, message)
|
77
|
+
Base64.encode64(hmac).strip
|
78
|
+
rescue => e
|
79
|
+
raise ConfigurationError, "Error calculating secret hash: #{e.message}"
|
80
|
+
end
|
81
|
+
end
|
82
|
+
end
|
83
|
+
end
|
@@ -0,0 +1,141 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require "net/http"
|
4
|
+
require "json"
|
5
|
+
require "jwt"
|
6
|
+
require "openssl"
|
7
|
+
require "base64"
|
8
|
+
|
9
|
+
module JwtAuthCognito
|
10
|
+
class JwksService
|
11
|
+
def initialize(config = JwtAuthCognito.configuration)
|
12
|
+
@config = config
|
13
|
+
@cache = {}
|
14
|
+
@cache_timestamps = {}
|
15
|
+
end
|
16
|
+
|
17
|
+
def validate_token_with_jwks(token)
|
18
|
+
@config.validate!
|
19
|
+
|
20
|
+
header = JWT.decode(token, nil, false).last
|
21
|
+
kid = header["kid"]
|
22
|
+
|
23
|
+
raise ValidationError, "Token missing key ID (kid)" unless kid
|
24
|
+
|
25
|
+
public_key = get_public_key(kid)
|
26
|
+
decoded_token = JWT.decode(
|
27
|
+
token,
|
28
|
+
public_key,
|
29
|
+
true,
|
30
|
+
{
|
31
|
+
algorithm: "RS256",
|
32
|
+
iss: @config.cognito_issuer,
|
33
|
+
verify_iss: true,
|
34
|
+
aud: @config.cognito_client_id,
|
35
|
+
verify_aud: @config.cognito_client_id ? true : false
|
36
|
+
}
|
37
|
+
)
|
38
|
+
|
39
|
+
payload = decoded_token.first
|
40
|
+
validate_token_claims(payload)
|
41
|
+
|
42
|
+
{
|
43
|
+
valid: true,
|
44
|
+
payload: payload,
|
45
|
+
sub: payload["sub"],
|
46
|
+
username: payload["cognito:username"] || payload["username"],
|
47
|
+
token_use: payload["token_use"]
|
48
|
+
}
|
49
|
+
rescue JWT::DecodeError => e
|
50
|
+
{ valid: false, error: "JWT decode error: #{e.message}" }
|
51
|
+
rescue ValidationError => e
|
52
|
+
{ valid: false, error: e.message }
|
53
|
+
rescue StandardError => e
|
54
|
+
{ valid: false, error: "Validation error: #{e.message}" }
|
55
|
+
end
|
56
|
+
|
57
|
+
private
|
58
|
+
|
59
|
+
def get_public_key(kid)
|
60
|
+
# Check cache first
|
61
|
+
if @cache[kid] && cache_valid?(kid)
|
62
|
+
return @cache[kid]
|
63
|
+
end
|
64
|
+
|
65
|
+
# Fetch JWKS
|
66
|
+
jwks = fetch_jwks
|
67
|
+
key_data = jwks["keys"].find { |key| key["kid"] == kid }
|
68
|
+
|
69
|
+
raise ValidationError, "Key ID not found in JWKS" unless key_data
|
70
|
+
|
71
|
+
# Convert JWK to PEM
|
72
|
+
public_key = jwk_to_pem(key_data)
|
73
|
+
|
74
|
+
# Cache the key
|
75
|
+
@cache[kid] = public_key
|
76
|
+
@cache_timestamps[kid] = Time.now
|
77
|
+
|
78
|
+
public_key
|
79
|
+
end
|
80
|
+
|
81
|
+
def fetch_jwks
|
82
|
+
uri = URI(@config.jwks_url)
|
83
|
+
|
84
|
+
Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https", open_timeout: 10, read_timeout: 10) do |http|
|
85
|
+
request = Net::HTTP::Get.new(uri)
|
86
|
+
response = http.request(request)
|
87
|
+
|
88
|
+
raise ValidationError, "Failed to fetch JWKS: #{response.code}" unless response.code == "200"
|
89
|
+
|
90
|
+
JSON.parse(response.body)
|
91
|
+
end
|
92
|
+
rescue JSON::ParserError => e
|
93
|
+
raise ValidationError, "Invalid JWKS JSON: #{e.message}"
|
94
|
+
rescue StandardError => e
|
95
|
+
raise ValidationError, "Failed to fetch JWKS: #{e.message}"
|
96
|
+
end
|
97
|
+
|
98
|
+
def jwk_to_pem(key_data)
|
99
|
+
# Convert JWK RSA key to PEM format
|
100
|
+
n = base64url_decode(key_data["n"])
|
101
|
+
e = base64url_decode(key_data["e"])
|
102
|
+
|
103
|
+
key = OpenSSL::PKey::RSA.new
|
104
|
+
key.n = OpenSSL::BN.new(n, 2)
|
105
|
+
key.e = OpenSSL::BN.new(e, 2)
|
106
|
+
|
107
|
+
key
|
108
|
+
end
|
109
|
+
|
110
|
+
def base64url_decode(str)
|
111
|
+
str += "=" * (4 - str.length.modulo(4))
|
112
|
+
Base64.decode64(str.tr("-_", "+/"))
|
113
|
+
end
|
114
|
+
|
115
|
+
def cache_valid?(kid)
|
116
|
+
return false unless @cache_timestamps[kid]
|
117
|
+
|
118
|
+
Time.now - @cache_timestamps[kid] < @config.jwks_cache_ttl
|
119
|
+
end
|
120
|
+
|
121
|
+
def validate_token_claims(payload)
|
122
|
+
now = Time.now.to_i
|
123
|
+
|
124
|
+
# Check expiration
|
125
|
+
raise ValidationError, "Token has expired" if payload["exp"] && payload["exp"] < now
|
126
|
+
|
127
|
+
# Check not before
|
128
|
+
raise ValidationError, "Token not yet valid" if payload["nbf"] && payload["nbf"] > now
|
129
|
+
|
130
|
+
# Check issued at (allow some clock skew)
|
131
|
+
if payload["iat"] && payload["iat"] > now + 300
|
132
|
+
raise ValidationError, "Token issued in the future"
|
133
|
+
end
|
134
|
+
|
135
|
+
# Check token use
|
136
|
+
unless %w[access id].include?(payload["token_use"])
|
137
|
+
raise ValidationError, "Invalid token_use claim"
|
138
|
+
end
|
139
|
+
end
|
140
|
+
end
|
141
|
+
end
|
@@ -0,0 +1,225 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require "jwt"
|
4
|
+
|
5
|
+
module JwtAuthCognito
|
6
|
+
class JwtValidator
|
7
|
+
def initialize(config = JwtAuthCognito.configuration)
|
8
|
+
@config = config
|
9
|
+
@jwks_service = JwksService.new(config)
|
10
|
+
@blacklist_service = TokenBlacklistService.new(config)
|
11
|
+
end
|
12
|
+
|
13
|
+
def validate_token(token, options = {})
|
14
|
+
@config.validate!
|
15
|
+
|
16
|
+
# Check blacklist first
|
17
|
+
if @blacklist_service.is_blacklisted?(token)
|
18
|
+
return { valid: false, error: "Token has been revoked" }
|
19
|
+
end
|
20
|
+
|
21
|
+
# Choose validation method based on configuration
|
22
|
+
case @config.validation_mode
|
23
|
+
when :secure
|
24
|
+
validate_token_secure(token, options)
|
25
|
+
when :basic
|
26
|
+
validate_token_basic(token, options)
|
27
|
+
else
|
28
|
+
raise ConfigurationError, "Invalid validation_mode: #{@config.validation_mode}"
|
29
|
+
end
|
30
|
+
end
|
31
|
+
|
32
|
+
def validate_access_token(token)
|
33
|
+
result = validate_token(token)
|
34
|
+
|
35
|
+
if result[:valid] && result[:payload]["token_use"] != "access"
|
36
|
+
return { valid: false, error: "Token is not an access token" }
|
37
|
+
end
|
38
|
+
|
39
|
+
result
|
40
|
+
end
|
41
|
+
|
42
|
+
def validate_id_token(token)
|
43
|
+
result = validate_token(token)
|
44
|
+
|
45
|
+
if result[:valid] && result[:payload]["token_use"] != "id"
|
46
|
+
return { valid: false, error: "Token is not an ID token" }
|
47
|
+
end
|
48
|
+
|
49
|
+
result
|
50
|
+
end
|
51
|
+
|
52
|
+
def validate_multiple_tokens(tokens)
|
53
|
+
tokens.map { |token| validate_token(token) }
|
54
|
+
end
|
55
|
+
|
56
|
+
def revoke_token(token, user_id: nil)
|
57
|
+
@blacklist_service.add_to_blacklist(token, user_id: user_id)
|
58
|
+
end
|
59
|
+
|
60
|
+
def revoke_user_tokens(user_id)
|
61
|
+
@blacklist_service.invalidate_user_tokens(user_id)
|
62
|
+
end
|
63
|
+
|
64
|
+
# Utility methods inspired by Node.js package
|
65
|
+
def extract_token_from_header(authorization_header)
|
66
|
+
return nil unless authorization_header
|
67
|
+
|
68
|
+
match = authorization_header.match(/\ABearer (.+)\z/)
|
69
|
+
match ? match[1] : nil
|
70
|
+
end
|
71
|
+
|
72
|
+
def decode_token(token)
|
73
|
+
JWT.decode(token, nil, false).first
|
74
|
+
rescue JWT::DecodeError => e
|
75
|
+
{ error: "Failed to decode token: #{e.message}" }
|
76
|
+
end
|
77
|
+
|
78
|
+
def get_token_info(token)
|
79
|
+
payload = decode_token(token)
|
80
|
+
return payload if payload.is_a?(Hash) && payload[:error]
|
81
|
+
|
82
|
+
{
|
83
|
+
sub: payload["sub"],
|
84
|
+
username: payload["cognito:username"] || payload["username"],
|
85
|
+
email: payload["email"],
|
86
|
+
token_use: payload["token_use"],
|
87
|
+
client_id: payload["aud"],
|
88
|
+
issued_at: payload["iat"] ? Time.at(payload["iat"]) : nil,
|
89
|
+
expires_at: payload["exp"] ? Time.at(payload["exp"]) : nil,
|
90
|
+
not_before: payload["nbf"] ? Time.at(payload["nbf"]) : nil,
|
91
|
+
jti: payload["jti"],
|
92
|
+
has_client_secret: @config.has_client_secret?
|
93
|
+
}
|
94
|
+
end
|
95
|
+
|
96
|
+
# Calculate secret hash for Cognito operations (when client secret is configured)
|
97
|
+
def calculate_secret_hash(identifier)
|
98
|
+
@config.calculate_secret_hash(identifier)
|
99
|
+
end
|
100
|
+
|
101
|
+
# Check if client secret is configured
|
102
|
+
def has_client_secret?
|
103
|
+
@config.has_client_secret?
|
104
|
+
end
|
105
|
+
|
106
|
+
def is_token_expired?(token)
|
107
|
+
payload = decode_token(token)
|
108
|
+
return true if payload.is_a?(Hash) && payload[:error]
|
109
|
+
|
110
|
+
exp = payload["exp"]
|
111
|
+
return false unless exp
|
112
|
+
|
113
|
+
Time.now.to_i >= exp
|
114
|
+
end
|
115
|
+
|
116
|
+
def get_time_to_expiry(token)
|
117
|
+
payload = decode_token(token)
|
118
|
+
return nil if payload.is_a?(Hash) && payload[:error]
|
119
|
+
|
120
|
+
exp = payload["exp"]
|
121
|
+
return nil unless exp
|
122
|
+
|
123
|
+
seconds = exp - Time.now.to_i
|
124
|
+
seconds > 0 ? seconds : 0
|
125
|
+
end
|
126
|
+
|
127
|
+
# Create a convenience factory method
|
128
|
+
def self.create_cognito_validator(config = nil)
|
129
|
+
if config
|
130
|
+
old_config = JwtAuthCognito.configuration
|
131
|
+
JwtAuthCognito.configure { |c| c = config }
|
132
|
+
validator = new
|
133
|
+
JwtAuthCognito.instance_variable_set(:@configuration, old_config)
|
134
|
+
validator
|
135
|
+
else
|
136
|
+
new
|
137
|
+
end
|
138
|
+
end
|
139
|
+
|
140
|
+
private
|
141
|
+
|
142
|
+
def validate_token_secure(token, options = {})
|
143
|
+
# Use JWKS validation for production
|
144
|
+
result = @jwks_service.validate_token_with_jwks(token)
|
145
|
+
|
146
|
+
if result[:valid]
|
147
|
+
# Additional custom validations
|
148
|
+
validate_custom_claims(result[:payload], options)
|
149
|
+
end
|
150
|
+
|
151
|
+
result
|
152
|
+
end
|
153
|
+
|
154
|
+
def validate_token_basic(token, options = {})
|
155
|
+
# Basic validation without signature verification (development only)
|
156
|
+
begin
|
157
|
+
payload, header = JWT.decode(token, nil, false)
|
158
|
+
|
159
|
+
# Basic claim validation
|
160
|
+
validate_basic_claims(payload)
|
161
|
+
validate_custom_claims(payload, options)
|
162
|
+
|
163
|
+
{
|
164
|
+
valid: true,
|
165
|
+
payload: payload,
|
166
|
+
sub: payload["sub"],
|
167
|
+
username: payload["cognito:username"] || payload["username"],
|
168
|
+
token_use: payload["token_use"]
|
169
|
+
}
|
170
|
+
rescue JWT::DecodeError => e
|
171
|
+
{ valid: false, error: "JWT decode error: #{e.message}" }
|
172
|
+
rescue ValidationError => e
|
173
|
+
{ valid: false, error: e.message }
|
174
|
+
rescue StandardError => e
|
175
|
+
{ valid: false, error: "Validation error: #{e.message}" }
|
176
|
+
end
|
177
|
+
end
|
178
|
+
|
179
|
+
def validate_basic_claims(payload)
|
180
|
+
now = Time.now.to_i
|
181
|
+
|
182
|
+
# Check expiration
|
183
|
+
raise ValidationError, "Token has expired" if payload["exp"] && payload["exp"] < now
|
184
|
+
|
185
|
+
# Check issuer
|
186
|
+
expected_issuer = @config.cognito_issuer
|
187
|
+
if payload["iss"] != expected_issuer
|
188
|
+
raise ValidationError, "Invalid issuer. Expected: #{expected_issuer}, got: #{payload["iss"]}"
|
189
|
+
end
|
190
|
+
|
191
|
+
# Check token use
|
192
|
+
unless %w[access id].include?(payload["token_use"])
|
193
|
+
raise ValidationError, "Invalid token_use claim"
|
194
|
+
end
|
195
|
+
end
|
196
|
+
|
197
|
+
def validate_custom_claims(payload, options)
|
198
|
+
# Validate specific user ID if provided
|
199
|
+
if options[:user_id] && payload["sub"] != options[:user_id]
|
200
|
+
raise ValidationError, "Token subject does not match expected user ID"
|
201
|
+
end
|
202
|
+
|
203
|
+
# Validate specific client ID if provided
|
204
|
+
if options[:client_id] && payload["aud"] != options[:client_id]
|
205
|
+
raise ValidationError, "Token audience does not match expected client ID"
|
206
|
+
end
|
207
|
+
|
208
|
+
# Validate token type if specified
|
209
|
+
if options[:token_use] && payload["token_use"] != options[:token_use]
|
210
|
+
raise ValidationError, "Token use does not match expected type"
|
211
|
+
end
|
212
|
+
|
213
|
+
# Custom scope validation
|
214
|
+
if options[:required_scopes]
|
215
|
+
token_scopes = payload["scope"]&.split(" ") || []
|
216
|
+
required_scopes = Array(options[:required_scopes])
|
217
|
+
|
218
|
+
missing_scopes = required_scopes - token_scopes
|
219
|
+
if missing_scopes.any?
|
220
|
+
raise ValidationError, "Token missing required scopes: #{missing_scopes.join(", ")}"
|
221
|
+
end
|
222
|
+
end
|
223
|
+
end
|
224
|
+
end
|
225
|
+
end
|
@@ -0,0 +1,194 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require "redis"
|
4
|
+
require "digest"
|
5
|
+
require "openssl"
|
6
|
+
|
7
|
+
module JwtAuthCognito
|
8
|
+
class RedisService
|
9
|
+
BLACKLIST_PREFIX = "jwt_blacklist:"
|
10
|
+
USER_TOKENS_PREFIX = "user_tokens:"
|
11
|
+
|
12
|
+
def initialize(config = JwtAuthCognito.configuration)
|
13
|
+
@config = config
|
14
|
+
@redis = nil
|
15
|
+
end
|
16
|
+
|
17
|
+
def save_revoked_token(token_id, ttl = nil)
|
18
|
+
connect_redis
|
19
|
+
key = "#{BLACKLIST_PREFIX}#{token_id}"
|
20
|
+
|
21
|
+
if ttl
|
22
|
+
@redis.setex(key, ttl, "revoked")
|
23
|
+
else
|
24
|
+
@redis.set(key, "revoked")
|
25
|
+
end
|
26
|
+
|
27
|
+
true
|
28
|
+
rescue Redis::BaseError => e
|
29
|
+
raise BlacklistError, "Failed to save revoked token: #{e.message}"
|
30
|
+
end
|
31
|
+
|
32
|
+
def is_token_revoked?(token_id)
|
33
|
+
connect_redis
|
34
|
+
key = "#{BLACKLIST_PREFIX}#{token_id}"
|
35
|
+
result = @redis.exists?(key)
|
36
|
+
result.is_a?(Integer) ? result > 0 : result
|
37
|
+
rescue Redis::BaseError => e
|
38
|
+
# Graceful degradation - if Redis is down, don't block validation
|
39
|
+
false
|
40
|
+
end
|
41
|
+
|
42
|
+
def clear_revoked_tokens
|
43
|
+
connect_redis
|
44
|
+
keys = @redis.keys("#{BLACKLIST_PREFIX}*")
|
45
|
+
@redis.del(*keys) if keys.any?
|
46
|
+
keys.length
|
47
|
+
rescue Redis::BaseError => e
|
48
|
+
raise BlacklistError, "Failed to clear revoked tokens: #{e.message}"
|
49
|
+
end
|
50
|
+
|
51
|
+
def invalidate_user_tokens(user_id)
|
52
|
+
connect_redis
|
53
|
+
|
54
|
+
# Get all tokens for the user
|
55
|
+
user_key = "#{USER_TOKENS_PREFIX}#{user_id}"
|
56
|
+
token_ids = @redis.smembers(user_key)
|
57
|
+
|
58
|
+
# Add all tokens to blacklist
|
59
|
+
token_ids.each do |token_id|
|
60
|
+
save_revoked_token(token_id)
|
61
|
+
end
|
62
|
+
|
63
|
+
# Clear the user's token set
|
64
|
+
@redis.del(user_key)
|
65
|
+
|
66
|
+
token_ids.length
|
67
|
+
rescue Redis::BaseError => e
|
68
|
+
raise BlacklistError, "Failed to invalidate user tokens: #{e.message}"
|
69
|
+
end
|
70
|
+
|
71
|
+
def track_user_token(user_id, token_id, ttl = nil)
|
72
|
+
connect_redis
|
73
|
+
|
74
|
+
user_key = "#{USER_TOKENS_PREFIX}#{user_id}"
|
75
|
+
@redis.sadd(user_key, token_id)
|
76
|
+
|
77
|
+
# Set expiration on the user's token set
|
78
|
+
@redis.expire(user_key, ttl) if ttl
|
79
|
+
|
80
|
+
true
|
81
|
+
rescue Redis::BaseError => e
|
82
|
+
# Non-critical operation, log but don't fail
|
83
|
+
false
|
84
|
+
end
|
85
|
+
|
86
|
+
def generate_token_id(token)
|
87
|
+
# Try to extract jti from token first
|
88
|
+
begin
|
89
|
+
payload = JWT.decode(token, nil, false).first
|
90
|
+
return payload["jti"] if payload["jti"]
|
91
|
+
rescue JWT::DecodeError
|
92
|
+
# Fall back to hash if token can't be decoded
|
93
|
+
end
|
94
|
+
|
95
|
+
# Generate hash-based ID
|
96
|
+
Digest::SHA256.hexdigest(token)[0, 16]
|
97
|
+
end
|
98
|
+
|
99
|
+
private
|
100
|
+
|
101
|
+
def connect_redis
|
102
|
+
return @redis if @redis
|
103
|
+
|
104
|
+
redis_options = build_redis_options
|
105
|
+
|
106
|
+
# Retry logic with exponential backoff (similar to Node.js implementation)
|
107
|
+
max_retries = 3
|
108
|
+
retry_count = 0
|
109
|
+
|
110
|
+
begin
|
111
|
+
@redis = Redis.new(redis_options)
|
112
|
+
@redis.ping # Test connection
|
113
|
+
@redis
|
114
|
+
rescue Redis::BaseError => e
|
115
|
+
retry_count += 1
|
116
|
+
if retry_count <= max_retries
|
117
|
+
sleep(0.1 * (2 ** retry_count)) # Exponential backoff
|
118
|
+
retry
|
119
|
+
else
|
120
|
+
raise BlacklistError, "Failed to connect to Redis after #{max_retries} retries: #{e.message}"
|
121
|
+
end
|
122
|
+
end
|
123
|
+
end
|
124
|
+
|
125
|
+
def build_redis_options
|
126
|
+
options = {
|
127
|
+
host: @config.redis_host,
|
128
|
+
port: @config.redis_port,
|
129
|
+
db: @config.redis_db,
|
130
|
+
timeout: @config.redis_timeout,
|
131
|
+
connect_timeout: @config.redis_connect_timeout,
|
132
|
+
read_timeout: @config.redis_read_timeout
|
133
|
+
}
|
134
|
+
|
135
|
+
options[:password] = @config.redis_password if @config.redis_password
|
136
|
+
|
137
|
+
# Enhanced TLS configuration (matching Node.js implementation)
|
138
|
+
if @config.redis_ssl
|
139
|
+
options[:ssl] = true
|
140
|
+
options[:ssl_params] = build_ssl_params
|
141
|
+
end
|
142
|
+
|
143
|
+
options
|
144
|
+
end
|
145
|
+
|
146
|
+
def build_ssl_params
|
147
|
+
ssl_params = {}
|
148
|
+
|
149
|
+
# Set TLS version constraints
|
150
|
+
if @config.redis_tls_min_version
|
151
|
+
ssl_params[:min_version] = parse_tls_version(@config.redis_tls_min_version)
|
152
|
+
end
|
153
|
+
|
154
|
+
if @config.redis_tls_max_version
|
155
|
+
ssl_params[:max_version] = parse_tls_version(@config.redis_tls_max_version)
|
156
|
+
end
|
157
|
+
|
158
|
+
# CA certificate configuration
|
159
|
+
if @config.redis_ca_cert_path && @config.redis_ca_cert_name
|
160
|
+
ca_cert_file = File.join(@config.redis_ca_cert_path, @config.redis_ca_cert_name)
|
161
|
+
if File.exist?(ca_cert_file)
|
162
|
+
ssl_params[:ca_file] = ca_cert_file
|
163
|
+
end
|
164
|
+
end
|
165
|
+
|
166
|
+
# Verification mode
|
167
|
+
case @config.redis_verify_mode
|
168
|
+
when 'none'
|
169
|
+
ssl_params[:verify_mode] = OpenSSL::SSL::VERIFY_NONE
|
170
|
+
when 'peer'
|
171
|
+
ssl_params[:verify_mode] = OpenSSL::SSL::VERIFY_PEER
|
172
|
+
else
|
173
|
+
ssl_params[:verify_mode] = OpenSSL::SSL::VERIFY_PEER
|
174
|
+
end
|
175
|
+
|
176
|
+
ssl_params
|
177
|
+
end
|
178
|
+
|
179
|
+
def parse_tls_version(version_string)
|
180
|
+
case version_string.upcase
|
181
|
+
when 'TLSV1.2'
|
182
|
+
:TLSv1_2
|
183
|
+
when 'TLSV1.3'
|
184
|
+
:TLSv1_3
|
185
|
+
when 'TLSV1.1'
|
186
|
+
:TLSv1_1
|
187
|
+
when 'TLSV1'
|
188
|
+
:TLSv1
|
189
|
+
else
|
190
|
+
:TLSv1_2 # Default to TLS 1.2
|
191
|
+
end
|
192
|
+
end
|
193
|
+
end
|
194
|
+
end
|
@@ -0,0 +1,56 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module JwtAuthCognito
|
4
|
+
class TokenBlacklistService
|
5
|
+
def initialize(config = JwtAuthCognito.configuration)
|
6
|
+
@config = config
|
7
|
+
@redis_service = RedisService.new(config)
|
8
|
+
end
|
9
|
+
|
10
|
+
def add_to_blacklist(token, user_id: nil)
|
11
|
+
token_id = @redis_service.generate_token_id(token)
|
12
|
+
ttl = calculate_ttl(token)
|
13
|
+
|
14
|
+
result = @redis_service.save_revoked_token(token_id, ttl)
|
15
|
+
|
16
|
+
# Track token for user if provided
|
17
|
+
if user_id
|
18
|
+
@redis_service.track_user_token(user_id, token_id, ttl)
|
19
|
+
end
|
20
|
+
|
21
|
+
result
|
22
|
+
end
|
23
|
+
|
24
|
+
def is_blacklisted?(token)
|
25
|
+
token_id = @redis_service.generate_token_id(token)
|
26
|
+
@redis_service.is_token_revoked?(token_id)
|
27
|
+
end
|
28
|
+
|
29
|
+
def invalidate_user_tokens(user_id)
|
30
|
+
@redis_service.invalidate_user_tokens(user_id)
|
31
|
+
end
|
32
|
+
|
33
|
+
def clear_all_blacklisted_tokens
|
34
|
+
@redis_service.clear_revoked_tokens
|
35
|
+
end
|
36
|
+
|
37
|
+
private
|
38
|
+
|
39
|
+
def calculate_ttl(token)
|
40
|
+
begin
|
41
|
+
payload = JWT.decode(token, nil, false).first
|
42
|
+
exp = payload["exp"]
|
43
|
+
|
44
|
+
if exp
|
45
|
+
ttl = exp - Time.now.to_i
|
46
|
+
return ttl > 0 ? ttl : 1 # At least 1 second TTL
|
47
|
+
end
|
48
|
+
rescue JWT::DecodeError
|
49
|
+
# If we can't decode the token, use a default TTL
|
50
|
+
end
|
51
|
+
|
52
|
+
# Default TTL of 1 day if we can't determine expiration
|
53
|
+
86400
|
54
|
+
end
|
55
|
+
end
|
56
|
+
end
|