jwt 2.5.0 → 2.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
data/lib/jwt/algos.rb CHANGED
@@ -1,5 +1,13 @@
1
1
  # frozen_string_literal: true
2
2
 
3
+ begin
4
+ require 'rbnacl'
5
+ rescue LoadError
6
+ raise if defined?(RbNaCl)
7
+ end
8
+ require 'openssl'
9
+
10
+ require 'jwt/security_utils'
3
11
  require 'jwt/algos/hmac'
4
12
  require 'jwt/algos/eddsa'
5
13
  require 'jwt/algos/ecdsa'
@@ -7,35 +15,50 @@ require 'jwt/algos/rsa'
7
15
  require 'jwt/algos/ps'
8
16
  require 'jwt/algos/none'
9
17
  require 'jwt/algos/unsupported'
18
+ require 'jwt/algos/algo_wrapper'
10
19
 
11
- # JWT::Signature module
12
20
  module JWT
13
- # Signature logic for JWT
14
21
  module Algos
15
22
  extend self
16
23
 
17
- ALGOS = [
18
- Algos::Hmac,
19
- Algos::Ecdsa,
20
- Algos::Rsa,
21
- Algos::Eddsa,
22
- Algos::Ps,
23
- Algos::None,
24
- Algos::Unsupported
25
- ].freeze
24
+ ALGOS = [Algos::Ecdsa,
25
+ Algos::Rsa,
26
+ Algos::Eddsa,
27
+ Algos::Ps,
28
+ Algos::None,
29
+ Algos::Unsupported].tap do |l|
30
+ if ::JWT.rbnacl_6_or_greater?
31
+ require_relative 'algos/hmac_rbnacl'
32
+ l.unshift(Algos::HmacRbNaCl)
33
+ elsif ::JWT.rbnacl?
34
+ require_relative 'algos/hmac_rbnacl_fixed'
35
+ l.unshift(Algos::HmacRbNaClFixed)
36
+ else
37
+ l.unshift(Algos::Hmac)
38
+ end
39
+ end.freeze
26
40
 
27
41
  def find(algorithm)
28
42
  indexed[algorithm && algorithm.downcase]
29
43
  end
30
44
 
45
+ def create(algorithm)
46
+ Algos::AlgoWrapper.new(*find(algorithm))
47
+ end
48
+
49
+ def implementation?(algorithm)
50
+ (algorithm.respond_to?(:valid_alg?) && algorithm.respond_to?(:verify)) ||
51
+ (algorithm.respond_to?(:alg) && algorithm.respond_to?(:sign))
52
+ end
53
+
31
54
  private
32
55
 
33
56
  def indexed
34
57
  @indexed ||= begin
35
- fallback = [Algos::Unsupported, nil]
36
- ALGOS.each_with_object(Hash.new(fallback)) do |alg, hash|
37
- alg.const_get(:SUPPORTED).each do |code|
38
- hash[code.downcase] = [alg, code]
58
+ fallback = [nil, Algos::Unsupported]
59
+ ALGOS.each_with_object(Hash.new(fallback)) do |cls, hash|
60
+ cls.const_get(:SUPPORTED).each do |alg|
61
+ hash[alg.downcase] = [alg, cls]
39
62
  end
40
63
  end
41
64
  end
data/lib/jwt/decode.rb CHANGED
@@ -2,9 +2,9 @@
2
2
 
3
3
  require 'json'
4
4
 
5
- require 'jwt/signature'
6
5
  require 'jwt/verify'
7
6
  require 'jwt/x5c_key_finder'
7
+
8
8
  # JWT::Decode module
9
9
  module JWT
10
10
  # Decoding logic for JWT
@@ -24,7 +24,7 @@ module JWT
24
24
  def decode_segments
25
25
  validate_segment_count!
26
26
  if @verify
27
- decode_crypto
27
+ decode_signature
28
28
  verify_algo
29
29
  set_key
30
30
  verify_signature
@@ -51,8 +51,8 @@ module JWT
51
51
 
52
52
  def verify_algo
53
53
  raise(JWT::IncorrectAlgorithm, 'An algorithm must be specified') if allowed_algorithms.empty?
54
- raise(JWT::IncorrectAlgorithm, 'Token is missing alg header') unless algorithm
55
- raise(JWT::IncorrectAlgorithm, 'Expected a different algorithm') unless options_includes_algo_in_header?
54
+ raise(JWT::IncorrectAlgorithm, 'Token is missing alg header') unless alg_in_header
55
+ raise(JWT::IncorrectAlgorithm, 'Expected a different algorithm') unless valid_alg_in_header?
56
56
  end
57
57
 
58
58
  def set_key
@@ -64,27 +64,50 @@ module JWT
64
64
  end
65
65
 
66
66
  def verify_signature_for?(key)
67
- Signature.verify(algorithm, key, signing_input, @signature)
67
+ allowed_algorithms.any? do |alg|
68
+ alg.verify(data: signing_input, signature: @signature, verification_key: key)
69
+ end
70
+ end
71
+
72
+ def valid_alg_in_header?
73
+ allowed_algorithms.any? { |alg| alg.valid_alg?(alg_in_header) }
68
74
  end
69
75
 
70
- def options_includes_algo_in_header?
71
- allowed_algorithms.any? { |alg| alg.casecmp(algorithm).zero? }
76
+ # Order is very important - first check for string keys, next for symbols
77
+ ALGORITHM_KEYS = ['algorithm',
78
+ :algorithm,
79
+ 'algorithms',
80
+ :algorithms].freeze
81
+
82
+ def given_algorithms
83
+ ALGORITHM_KEYS.each do |alg_key|
84
+ alg = @options[alg_key]
85
+ return Array(alg) if alg
86
+ end
87
+ []
72
88
  end
73
89
 
74
90
  def allowed_algorithms
75
- # Order is very important - first check for string keys, next for symbols
76
- algos = if @options.key?('algorithm')
77
- @options['algorithm']
78
- elsif @options.key?(:algorithm)
79
- @options[:algorithm]
80
- elsif @options.key?('algorithms')
81
- @options['algorithms']
82
- elsif @options.key?(:algorithms)
83
- @options[:algorithms]
84
- else
85
- []
91
+ @allowed_algorithms ||= resolve_allowed_algorithms
92
+ end
93
+
94
+ def resolve_allowed_algorithms
95
+ algs = given_algorithms.map do |alg|
96
+ if Algos.implementation?(alg)
97
+ alg
98
+ else
99
+ Algos.create(alg)
100
+ end
86
101
  end
87
- Array(algos)
102
+
103
+ sort_by_alg_header(algs)
104
+ end
105
+
106
+ # Move algorithms matching the JWT alg header to the beginning of the list
107
+ def sort_by_alg_header(algs)
108
+ return algs if algs.size <= 1
109
+
110
+ algs.partition { |alg| alg.valid_alg?(alg_in_header) }.flatten
88
111
  end
89
112
 
90
113
  def find_key(&keyfinder)
@@ -113,14 +136,14 @@ module JWT
113
136
  end
114
137
 
115
138
  def none_algorithm?
116
- algorithm == 'none'
139
+ alg_in_header == 'none'
117
140
  end
118
141
 
119
- def decode_crypto
142
+ def decode_signature
120
143
  @signature = ::JWT::Base64.url_decode(@segments[2] || '')
121
144
  end
122
145
 
123
- def algorithm
146
+ def alg_in_header
124
147
  header['alg']
125
148
  end
126
149
 
data/lib/jwt/encode.rb CHANGED
@@ -1,28 +1,35 @@
1
1
  # frozen_string_literal: true
2
2
 
3
- require_relative './algos'
4
- require_relative './claims_validator'
3
+ require_relative 'algos'
4
+ require_relative 'claims_validator'
5
5
 
6
6
  # JWT::Encode module
7
7
  module JWT
8
8
  # Encoding logic for JWT
9
9
  class Encode
10
- ALG_NONE = 'none'
11
- ALG_KEY = 'alg'
10
+ ALG_KEY = 'alg'
12
11
 
13
12
  def initialize(options)
14
- @payload = options[:payload]
15
- @key = options[:key]
16
- _, @algorithm = Algos.find(options[:algorithm])
17
- @headers = options[:headers].transform_keys(&:to_s)
13
+ @payload = options[:payload]
14
+ @key = options[:key]
15
+ @algorithm = resolve_algorithm(options[:algorithm])
16
+ @headers = options[:headers].transform_keys(&:to_s)
17
+ @headers[ALG_KEY] = @algorithm.alg
18
18
  end
19
19
 
20
20
  def segments
21
- @segments ||= combine(encoded_header_and_payload, encoded_signature)
21
+ validate_claims!
22
+ combine(encoded_header_and_payload, encoded_signature)
22
23
  end
23
24
 
24
25
  private
25
26
 
27
+ def resolve_algorithm(algorithm)
28
+ return algorithm if Algos.implementation?(algorithm)
29
+
30
+ Algos.create(algorithm)
31
+ end
32
+
26
33
  def encoded_header
27
34
  @encoded_header ||= encode_header
28
35
  end
@@ -40,25 +47,28 @@ module JWT
40
47
  end
41
48
 
42
49
  def encode_header
43
- @headers[ALG_KEY] = @algorithm
44
- encode(@headers)
50
+ encode_data(@headers)
45
51
  end
46
52
 
47
53
  def encode_payload
48
- if @payload.is_a?(Hash)
49
- ClaimsValidator.new(@payload).validate!
50
- end
54
+ encode_data(@payload)
55
+ end
51
56
 
52
- encode(@payload)
57
+ def signature
58
+ @algorithm.sign(data: encoded_header_and_payload, signing_key: @key)
53
59
  end
54
60
 
55
- def encode_signature
56
- return '' if @algorithm == ALG_NONE
61
+ def validate_claims!
62
+ return unless @payload.is_a?(Hash)
63
+
64
+ ClaimsValidator.new(@payload).validate!
65
+ end
57
66
 
58
- ::JWT::Base64.url_encode(JWT::Signature.sign(@algorithm, encoded_header_and_payload, @key))
67
+ def encode_signature
68
+ ::JWT::Base64.url_encode(signature)
59
69
  end
60
70
 
61
- def encode(data)
71
+ def encode_data(data)
62
72
  ::JWT::Base64.url_encode(JWT::JSON.generate(data))
63
73
  end
64
74
 
data/lib/jwt/jwk/ec.rb CHANGED
@@ -9,39 +9,42 @@ module JWT
9
9
  def_delegators :keypair, :public_key
10
10
 
11
11
  KTY = 'EC'
12
- KTYS = [KTY, OpenSSL::PKey::EC].freeze
12
+ KTYS = [KTY, OpenSSL::PKey::EC, JWT::JWK::EC].freeze
13
13
  BINARY = 2
14
+ EC_PUBLIC_KEY_ELEMENTS = %i[kty crv x y].freeze
15
+ EC_PRIVATE_KEY_ELEMENTS = %i[d].freeze
16
+ EC_KEY_ELEMENTS = (EC_PRIVATE_KEY_ELEMENTS + EC_PUBLIC_KEY_ELEMENTS).freeze
14
17
 
15
- attr_reader :keypair
18
+ def initialize(key, params = nil, options = {})
19
+ params ||= {}
16
20
 
17
- def initialize(keypair, options = {})
18
- raise ArgumentError, 'keypair must be of type OpenSSL::PKey::EC' unless keypair.is_a?(OpenSSL::PKey::EC)
21
+ # For backwards compatibility when kid was a String
22
+ params = { kid: params } if params.is_a?(String)
19
23
 
20
- @keypair = keypair
24
+ key_params = extract_key_params(key)
21
25
 
22
- super(options)
26
+ params = params.transform_keys(&:to_sym)
27
+ check_jwk(key_params, params)
28
+
29
+ super(options, key_params.merge(params))
30
+ end
31
+
32
+ def keypair
33
+ @keypair ||= create_ec_key(self[:crv], self[:x], self[:y], self[:d])
23
34
  end
24
35
 
25
36
  def private?
26
- @keypair.private_key?
37
+ keypair.private_key?
27
38
  end
28
39
 
29
40
  def members
30
- crv, x_octets, y_octets = keypair_components(keypair)
31
- {
32
- kty: KTY,
33
- crv: crv,
34
- x: encode_octets(x_octets),
35
- y: encode_octets(y_octets)
36
- }
41
+ EC_PUBLIC_KEY_ELEMENTS.each_with_object({}) { |i, h| h[i] = self[i] }
37
42
  end
38
43
 
39
44
  def export(options = {})
40
- exported_hash = members.merge(kid: kid)
41
-
42
- return exported_hash unless private? && options[:include_private] == true
43
-
44
- append_private_parts(exported_hash)
45
+ exported = parameters.clone
46
+ exported.reject! { |k, _| EC_PRIVATE_KEY_ELEMENTS.include? k } unless private? && options[:include_private] == true
47
+ exported
45
48
  end
46
49
 
47
50
  def key_digest
@@ -51,13 +54,34 @@ module JWT
51
54
  OpenSSL::Digest::SHA256.hexdigest(sequence.to_der)
52
55
  end
53
56
 
57
+ def []=(key, value)
58
+ if EC_KEY_ELEMENTS.include?(key.to_sym)
59
+ raise ArgumentError, 'cannot overwrite cryptographic key attributes'
60
+ end
61
+
62
+ super(key, value)
63
+ end
64
+
54
65
  private
55
66
 
56
- def append_private_parts(the_hash)
57
- octets = keypair.private_key.to_bn.to_s(BINARY)
58
- the_hash.merge(
59
- d: encode_octets(octets)
60
- )
67
+ def extract_key_params(key)
68
+ case key
69
+ when JWT::JWK::EC
70
+ key.export(include_private: true)
71
+ when OpenSSL::PKey::EC # Accept OpenSSL key as input
72
+ @keypair = key # Preserve the object to avoid recreation
73
+ parse_ec_key(key)
74
+ when Hash
75
+ key.transform_keys(&:to_sym)
76
+ else
77
+ raise ArgumentError, 'key must be of type OpenSSL::PKey::EC or Hash with key parameters'
78
+ end
79
+ end
80
+
81
+ def check_jwk(keypair, params)
82
+ raise ArgumentError, 'cannot overwrite cryptographic key attributes' unless (EC_KEY_ELEMENTS & params.keys).empty?
83
+ raise JWT::JWKError, "Incorrect 'kty' value: #{keypair[:kty]}, expected #{KTY}" unless keypair[:kty] == KTY
84
+ raise JWT::JWKError, 'Key format is invalid for EC' unless keypair[:crv] && keypair[:x] && keypair[:y]
61
85
  end
62
86
 
63
87
  def keypair_components(ec_keypair)
@@ -82,6 +106,8 @@ module JWT
82
106
  end
83
107
 
84
108
  def encode_octets(octets)
109
+ return unless octets
110
+
85
111
  ::JWT::Base64.url_encode(octets)
86
112
  end
87
113
 
@@ -89,15 +115,94 @@ module JWT
89
115
  ::JWT::Base64.url_encode(key_part.to_s(BINARY))
90
116
  end
91
117
 
92
- class << self
93
- def import(jwk_data)
94
- # See https://tools.ietf.org/html/rfc7518#section-6.2.1 for an
95
- # explanation of the relevant parameters.
118
+ def parse_ec_key(key)
119
+ crv, x_octets, y_octets = keypair_components(key)
120
+ octets = key.private_key&.to_bn&.to_s(BINARY)
121
+ {
122
+ kty: KTY,
123
+ crv: crv,
124
+ x: encode_octets(x_octets),
125
+ y: encode_octets(y_octets),
126
+ d: encode_octets(octets)
127
+ }.compact
128
+ end
96
129
 
97
- jwk_crv, jwk_x, jwk_y, jwk_d, jwk_kid = jwk_attrs(jwk_data, %i[crv x y d kid])
98
- raise JWT::JWKError, 'Key format is invalid for EC' unless jwk_crv && jwk_x && jwk_y
130
+ if ::JWT.openssl_3?
131
+ def create_ec_key(jwk_crv, jwk_x, jwk_y, jwk_d) # rubocop:disable Metrics/MethodLength
132
+ curve = EC.to_openssl_curve(jwk_crv)
133
+
134
+ x_octets = decode_octets(jwk_x)
135
+ y_octets = decode_octets(jwk_y)
136
+
137
+ point = OpenSSL::PKey::EC::Point.new(
138
+ OpenSSL::PKey::EC::Group.new(curve),
139
+ OpenSSL::BN.new([0x04, x_octets, y_octets].pack('Ca*a*'), 2)
140
+ )
141
+
142
+ sequence = if jwk_d
143
+ # https://datatracker.ietf.org/doc/html/rfc5915.html
144
+ # ECPrivateKey ::= SEQUENCE {
145
+ # version INTEGER { ecPrivkeyVer1(1) } (ecPrivkeyVer1),
146
+ # privateKey OCTET STRING,
147
+ # parameters [0] ECParameters {{ NamedCurve }} OPTIONAL,
148
+ # publicKey [1] BIT STRING OPTIONAL
149
+ # }
150
+
151
+ OpenSSL::ASN1::Sequence([
152
+ OpenSSL::ASN1::Integer(1),
153
+ OpenSSL::ASN1::OctetString(OpenSSL::BN.new(decode_octets(jwk_d), 2).to_s(2)),
154
+ OpenSSL::ASN1::ObjectId(curve, 0, :EXPLICIT),
155
+ OpenSSL::ASN1::BitString(point.to_octet_string(:uncompressed), 1, :EXPLICIT)
156
+ ])
157
+ else
158
+ OpenSSL::ASN1::Sequence([
159
+ OpenSSL::ASN1::Sequence([OpenSSL::ASN1::ObjectId('id-ecPublicKey'), OpenSSL::ASN1::ObjectId(curve)]),
160
+ OpenSSL::ASN1::BitString(point.to_octet_string(:uncompressed))
161
+ ])
162
+ end
99
163
 
100
- new(ec_pkey(jwk_crv, jwk_x, jwk_y, jwk_d), kid: jwk_kid)
164
+ OpenSSL::PKey::EC.new(sequence.to_der)
165
+ end
166
+ else
167
+ def create_ec_key(jwk_crv, jwk_x, jwk_y, jwk_d)
168
+ curve = EC.to_openssl_curve(jwk_crv)
169
+
170
+ x_octets = decode_octets(jwk_x)
171
+ y_octets = decode_octets(jwk_y)
172
+
173
+ key = OpenSSL::PKey::EC.new(curve)
174
+
175
+ # The details of the `Point` instantiation are covered in:
176
+ # - https://docs.ruby-lang.org/en/2.4.0/OpenSSL/PKey/EC.html
177
+ # - https://www.openssl.org/docs/manmaster/man3/EC_POINT_new.html
178
+ # - https://tools.ietf.org/html/rfc5480#section-2.2
179
+ # - https://www.secg.org/SEC1-Ver-1.0.pdf
180
+ # Section 2.3.3 of the last of these references specifies that the
181
+ # encoding of an uncompressed point consists of the byte `0x04` followed
182
+ # by the x value then the y value.
183
+ point = OpenSSL::PKey::EC::Point.new(
184
+ OpenSSL::PKey::EC::Group.new(curve),
185
+ OpenSSL::BN.new([0x04, x_octets, y_octets].pack('Ca*a*'), 2)
186
+ )
187
+
188
+ key.public_key = point
189
+ key.private_key = OpenSSL::BN.new(decode_octets(jwk_d), 2) if jwk_d
190
+
191
+ key
192
+ end
193
+ end
194
+
195
+ def decode_octets(jwk_data)
196
+ ::JWT::Base64.url_decode(jwk_data)
197
+ end
198
+
199
+ def decode_open_ssl_bn(jwk_data)
200
+ OpenSSL::BN.new(::JWT::Base64.url_decode(jwk_data), BINARY)
201
+ end
202
+
203
+ class << self
204
+ def import(jwk_data)
205
+ new(jwk_data)
101
206
  end
102
207
 
103
208
  def to_openssl_curve(crv)
@@ -112,87 +217,6 @@ module JWT
112
217
  else raise JWT::JWKError, 'Invalid curve provided'
113
218
  end
114
219
  end
115
-
116
- private
117
-
118
- def jwk_attrs(jwk_data, attrs)
119
- attrs.map do |attr|
120
- jwk_data[attr] || jwk_data[attr.to_s]
121
- end
122
- end
123
-
124
- if ::JWT.openssl_3?
125
- def ec_pkey(jwk_crv, jwk_x, jwk_y, jwk_d) # rubocop:disable Metrics/MethodLength
126
- curve = to_openssl_curve(jwk_crv)
127
-
128
- x_octets = decode_octets(jwk_x)
129
- y_octets = decode_octets(jwk_y)
130
-
131
- point = OpenSSL::PKey::EC::Point.new(
132
- OpenSSL::PKey::EC::Group.new(curve),
133
- OpenSSL::BN.new([0x04, x_octets, y_octets].pack('Ca*a*'), 2)
134
- )
135
-
136
- sequence = if jwk_d
137
- # https://datatracker.ietf.org/doc/html/rfc5915.html
138
- # ECPrivateKey ::= SEQUENCE {
139
- # version INTEGER { ecPrivkeyVer1(1) } (ecPrivkeyVer1),
140
- # privateKey OCTET STRING,
141
- # parameters [0] ECParameters {{ NamedCurve }} OPTIONAL,
142
- # publicKey [1] BIT STRING OPTIONAL
143
- # }
144
-
145
- OpenSSL::ASN1::Sequence([
146
- OpenSSL::ASN1::Integer(1),
147
- OpenSSL::ASN1::OctetString(OpenSSL::BN.new(decode_octets(jwk_d), 2).to_s(2)),
148
- OpenSSL::ASN1::ObjectId(curve, 0, :EXPLICIT),
149
- OpenSSL::ASN1::BitString(point.to_octet_string(:uncompressed), 1, :EXPLICIT)
150
- ])
151
- else
152
- OpenSSL::ASN1::Sequence([
153
- OpenSSL::ASN1::Sequence([OpenSSL::ASN1::ObjectId('id-ecPublicKey'), OpenSSL::ASN1::ObjectId(curve)]),
154
- OpenSSL::ASN1::BitString(point.to_octet_string(:uncompressed))
155
- ])
156
- end
157
-
158
- OpenSSL::PKey::EC.new(sequence.to_der)
159
- end
160
- else
161
- def ec_pkey(jwk_crv, jwk_x, jwk_y, jwk_d)
162
- curve = to_openssl_curve(jwk_crv)
163
-
164
- x_octets = decode_octets(jwk_x)
165
- y_octets = decode_octets(jwk_y)
166
-
167
- key = OpenSSL::PKey::EC.new(curve)
168
-
169
- # The details of the `Point` instantiation are covered in:
170
- # - https://docs.ruby-lang.org/en/2.4.0/OpenSSL/PKey/EC.html
171
- # - https://www.openssl.org/docs/manmaster/man3/EC_POINT_new.html
172
- # - https://tools.ietf.org/html/rfc5480#section-2.2
173
- # - https://www.secg.org/SEC1-Ver-1.0.pdf
174
- # Section 2.3.3 of the last of these references specifies that the
175
- # encoding of an uncompressed point consists of the byte `0x04` followed
176
- # by the x value then the y value.
177
- point = OpenSSL::PKey::EC::Point.new(
178
- OpenSSL::PKey::EC::Group.new(curve),
179
- OpenSSL::BN.new([0x04, x_octets, y_octets].pack('Ca*a*'), 2)
180
- )
181
-
182
- key.public_key = point
183
- key.private_key = OpenSSL::BN.new(decode_octets(jwk_d), 2) if jwk_d
184
-
185
- key
186
- end
187
- end
188
-
189
- def decode_octets(jwk_data)
190
- ::JWT::Base64.url_decode(jwk_data)
191
- end
192
-
193
- def decode_open_ssl_bn(jwk_data)
194
- OpenSSL::BN.new(::JWT::Base64.url_decode(jwk_data), BINARY)
195
- end
196
220
  end
197
221
  end
198
222
  end
data/lib/jwt/jwk/hmac.rb CHANGED
@@ -4,15 +4,27 @@ module JWT
4
4
  module JWK
5
5
  class HMAC < KeyBase
6
6
  KTY = 'oct'
7
- KTYS = [KTY, String].freeze
7
+ KTYS = [KTY, String, JWT::JWK::HMAC].freeze
8
+ HMAC_PUBLIC_KEY_ELEMENTS = %i[kty].freeze
9
+ HMAC_PRIVATE_KEY_ELEMENTS = %i[k].freeze
10
+ HMAC_KEY_ELEMENTS = (HMAC_PRIVATE_KEY_ELEMENTS + HMAC_PUBLIC_KEY_ELEMENTS).freeze
8
11
 
9
- attr_reader :signing_key
12
+ def initialize(key, params = nil, options = {})
13
+ params ||= {}
10
14
 
11
- def initialize(signing_key, options = {})
12
- raise ArgumentError, 'signing_key must be of type String' unless signing_key.is_a?(String)
15
+ # For backwards compatibility when kid was a String
16
+ params = { kid: params } if params.is_a?(String)
13
17
 
14
- @signing_key = signing_key
15
- super(options)
18
+ key_params = extract_key_params(key)
19
+
20
+ params = params.transform_keys(&:to_sym)
21
+ check_jwk(key_params, params)
22
+
23
+ super(options, key_params.merge(params))
24
+ end
25
+
26
+ def keypair
27
+ self[:k]
16
28
  end
17
29
 
18
30
  def private?
@@ -25,26 +37,16 @@ module JWT
25
37
 
26
38
  # See https://tools.ietf.org/html/rfc7517#appendix-A.3
27
39
  def export(options = {})
28
- exported_hash = {
29
- kty: KTY,
30
- kid: kid
31
- }
32
-
33
- return exported_hash unless private? && options[:include_private] == true
34
-
35
- exported_hash.merge(
36
- k: signing_key
37
- )
40
+ exported = parameters.clone
41
+ exported.reject! { |k, _| HMAC_PRIVATE_KEY_ELEMENTS.include? k } unless private? && options[:include_private] == true
42
+ exported
38
43
  end
39
44
 
40
45
  def members
41
- {
42
- kty: KTY,
43
- k: signing_key
44
- }
46
+ HMAC_KEY_ELEMENTS.each_with_object({}) { |i, h| h[i] = self[i] }
45
47
  end
46
48
 
47
- alias keypair signing_key # for backwards compatibility
49
+ alias signing_key keypair # for backwards compatibility
48
50
 
49
51
  def key_digest
50
52
  sequence = OpenSSL::ASN1::Sequence([OpenSSL::ASN1::UTF8String.new(signing_key),
@@ -52,14 +54,38 @@ module JWT
52
54
  OpenSSL::Digest::SHA256.hexdigest(sequence.to_der)
53
55
  end
54
56
 
55
- class << self
56
- def import(jwk_data)
57
- jwk_k = jwk_data[:k] || jwk_data['k']
58
- jwk_kid = jwk_data[:kid] || jwk_data['kid']
57
+ def []=(key, value)
58
+ if HMAC_KEY_ELEMENTS.include?(key.to_sym)
59
+ raise ArgumentError, 'cannot overwrite cryptographic key attributes'
60
+ end
59
61
 
60
- raise JWT::JWKError, 'Key format is invalid for HMAC' unless jwk_k
62
+ super(key, value)
63
+ end
64
+
65
+ private
66
+
67
+ def extract_key_params(key)
68
+ case key
69
+ when JWT::JWK::HMAC
70
+ key.export(include_private: true)
71
+ when String # Accept String key as input
72
+ { kty: KTY, k: key }
73
+ when Hash
74
+ key.transform_keys(&:to_sym)
75
+ else
76
+ raise ArgumentError, 'key must be of type String or Hash with key parameters'
77
+ end
78
+ end
61
79
 
62
- new(jwk_k, kid: jwk_kid)
80
+ def check_jwk(keypair, params)
81
+ raise ArgumentError, 'cannot overwrite cryptographic key attributes' unless (HMAC_KEY_ELEMENTS & params.keys).empty?
82
+ raise JWT::JWKError, "Incorrect 'kty' value: #{keypair[:kty]}, expected #{KTY}" unless keypair[:kty] == KTY
83
+ raise JWT::JWKError, 'Key format is invalid for HMAC' unless keypair[:k]
84
+ end
85
+
86
+ class << self
87
+ def import(jwk_data)
88
+ new(jwk_data)
63
89
  end
64
90
  end
65
91
  end