jwt 2.5.0 → 2.6.0

Sign up to get free protection for your applications and to get access to all the features.
data/lib/jwt/algos.rb CHANGED
@@ -1,5 +1,13 @@
1
1
  # frozen_string_literal: true
2
2
 
3
+ begin
4
+ require 'rbnacl'
5
+ rescue LoadError
6
+ raise if defined?(RbNaCl)
7
+ end
8
+ require 'openssl'
9
+
10
+ require 'jwt/security_utils'
3
11
  require 'jwt/algos/hmac'
4
12
  require 'jwt/algos/eddsa'
5
13
  require 'jwt/algos/ecdsa'
@@ -7,35 +15,50 @@ require 'jwt/algos/rsa'
7
15
  require 'jwt/algos/ps'
8
16
  require 'jwt/algos/none'
9
17
  require 'jwt/algos/unsupported'
18
+ require 'jwt/algos/algo_wrapper'
10
19
 
11
- # JWT::Signature module
12
20
  module JWT
13
- # Signature logic for JWT
14
21
  module Algos
15
22
  extend self
16
23
 
17
- ALGOS = [
18
- Algos::Hmac,
19
- Algos::Ecdsa,
20
- Algos::Rsa,
21
- Algos::Eddsa,
22
- Algos::Ps,
23
- Algos::None,
24
- Algos::Unsupported
25
- ].freeze
24
+ ALGOS = [Algos::Ecdsa,
25
+ Algos::Rsa,
26
+ Algos::Eddsa,
27
+ Algos::Ps,
28
+ Algos::None,
29
+ Algos::Unsupported].tap do |l|
30
+ if ::JWT.rbnacl_6_or_greater?
31
+ require_relative 'algos/hmac_rbnacl'
32
+ l.unshift(Algos::HmacRbNaCl)
33
+ elsif ::JWT.rbnacl?
34
+ require_relative 'algos/hmac_rbnacl_fixed'
35
+ l.unshift(Algos::HmacRbNaClFixed)
36
+ else
37
+ l.unshift(Algos::Hmac)
38
+ end
39
+ end.freeze
26
40
 
27
41
  def find(algorithm)
28
42
  indexed[algorithm && algorithm.downcase]
29
43
  end
30
44
 
45
+ def create(algorithm)
46
+ Algos::AlgoWrapper.new(*find(algorithm))
47
+ end
48
+
49
+ def implementation?(algorithm)
50
+ (algorithm.respond_to?(:valid_alg?) && algorithm.respond_to?(:verify)) ||
51
+ (algorithm.respond_to?(:alg) && algorithm.respond_to?(:sign))
52
+ end
53
+
31
54
  private
32
55
 
33
56
  def indexed
34
57
  @indexed ||= begin
35
- fallback = [Algos::Unsupported, nil]
36
- ALGOS.each_with_object(Hash.new(fallback)) do |alg, hash|
37
- alg.const_get(:SUPPORTED).each do |code|
38
- hash[code.downcase] = [alg, code]
58
+ fallback = [nil, Algos::Unsupported]
59
+ ALGOS.each_with_object(Hash.new(fallback)) do |cls, hash|
60
+ cls.const_get(:SUPPORTED).each do |alg|
61
+ hash[alg.downcase] = [alg, cls]
39
62
  end
40
63
  end
41
64
  end
data/lib/jwt/decode.rb CHANGED
@@ -2,9 +2,9 @@
2
2
 
3
3
  require 'json'
4
4
 
5
- require 'jwt/signature'
6
5
  require 'jwt/verify'
7
6
  require 'jwt/x5c_key_finder'
7
+
8
8
  # JWT::Decode module
9
9
  module JWT
10
10
  # Decoding logic for JWT
@@ -24,7 +24,7 @@ module JWT
24
24
  def decode_segments
25
25
  validate_segment_count!
26
26
  if @verify
27
- decode_crypto
27
+ decode_signature
28
28
  verify_algo
29
29
  set_key
30
30
  verify_signature
@@ -51,8 +51,8 @@ module JWT
51
51
 
52
52
  def verify_algo
53
53
  raise(JWT::IncorrectAlgorithm, 'An algorithm must be specified') if allowed_algorithms.empty?
54
- raise(JWT::IncorrectAlgorithm, 'Token is missing alg header') unless algorithm
55
- raise(JWT::IncorrectAlgorithm, 'Expected a different algorithm') unless options_includes_algo_in_header?
54
+ raise(JWT::IncorrectAlgorithm, 'Token is missing alg header') unless alg_in_header
55
+ raise(JWT::IncorrectAlgorithm, 'Expected a different algorithm') unless valid_alg_in_header?
56
56
  end
57
57
 
58
58
  def set_key
@@ -64,27 +64,50 @@ module JWT
64
64
  end
65
65
 
66
66
  def verify_signature_for?(key)
67
- Signature.verify(algorithm, key, signing_input, @signature)
67
+ allowed_algorithms.any? do |alg|
68
+ alg.verify(data: signing_input, signature: @signature, verification_key: key)
69
+ end
70
+ end
71
+
72
+ def valid_alg_in_header?
73
+ allowed_algorithms.any? { |alg| alg.valid_alg?(alg_in_header) }
68
74
  end
69
75
 
70
- def options_includes_algo_in_header?
71
- allowed_algorithms.any? { |alg| alg.casecmp(algorithm).zero? }
76
+ # Order is very important - first check for string keys, next for symbols
77
+ ALGORITHM_KEYS = ['algorithm',
78
+ :algorithm,
79
+ 'algorithms',
80
+ :algorithms].freeze
81
+
82
+ def given_algorithms
83
+ ALGORITHM_KEYS.each do |alg_key|
84
+ alg = @options[alg_key]
85
+ return Array(alg) if alg
86
+ end
87
+ []
72
88
  end
73
89
 
74
90
  def allowed_algorithms
75
- # Order is very important - first check for string keys, next for symbols
76
- algos = if @options.key?('algorithm')
77
- @options['algorithm']
78
- elsif @options.key?(:algorithm)
79
- @options[:algorithm]
80
- elsif @options.key?('algorithms')
81
- @options['algorithms']
82
- elsif @options.key?(:algorithms)
83
- @options[:algorithms]
84
- else
85
- []
91
+ @allowed_algorithms ||= resolve_allowed_algorithms
92
+ end
93
+
94
+ def resolve_allowed_algorithms
95
+ algs = given_algorithms.map do |alg|
96
+ if Algos.implementation?(alg)
97
+ alg
98
+ else
99
+ Algos.create(alg)
100
+ end
86
101
  end
87
- Array(algos)
102
+
103
+ sort_by_alg_header(algs)
104
+ end
105
+
106
+ # Move algorithms matching the JWT alg header to the beginning of the list
107
+ def sort_by_alg_header(algs)
108
+ return algs if algs.size <= 1
109
+
110
+ algs.partition { |alg| alg.valid_alg?(alg_in_header) }.flatten
88
111
  end
89
112
 
90
113
  def find_key(&keyfinder)
@@ -113,14 +136,14 @@ module JWT
113
136
  end
114
137
 
115
138
  def none_algorithm?
116
- algorithm == 'none'
139
+ alg_in_header == 'none'
117
140
  end
118
141
 
119
- def decode_crypto
142
+ def decode_signature
120
143
  @signature = ::JWT::Base64.url_decode(@segments[2] || '')
121
144
  end
122
145
 
123
- def algorithm
146
+ def alg_in_header
124
147
  header['alg']
125
148
  end
126
149
 
data/lib/jwt/encode.rb CHANGED
@@ -1,28 +1,35 @@
1
1
  # frozen_string_literal: true
2
2
 
3
- require_relative './algos'
4
- require_relative './claims_validator'
3
+ require_relative 'algos'
4
+ require_relative 'claims_validator'
5
5
 
6
6
  # JWT::Encode module
7
7
  module JWT
8
8
  # Encoding logic for JWT
9
9
  class Encode
10
- ALG_NONE = 'none'
11
- ALG_KEY = 'alg'
10
+ ALG_KEY = 'alg'
12
11
 
13
12
  def initialize(options)
14
- @payload = options[:payload]
15
- @key = options[:key]
16
- _, @algorithm = Algos.find(options[:algorithm])
17
- @headers = options[:headers].transform_keys(&:to_s)
13
+ @payload = options[:payload]
14
+ @key = options[:key]
15
+ @algorithm = resolve_algorithm(options[:algorithm])
16
+ @headers = options[:headers].transform_keys(&:to_s)
17
+ @headers[ALG_KEY] = @algorithm.alg
18
18
  end
19
19
 
20
20
  def segments
21
- @segments ||= combine(encoded_header_and_payload, encoded_signature)
21
+ validate_claims!
22
+ combine(encoded_header_and_payload, encoded_signature)
22
23
  end
23
24
 
24
25
  private
25
26
 
27
+ def resolve_algorithm(algorithm)
28
+ return algorithm if Algos.implementation?(algorithm)
29
+
30
+ Algos.create(algorithm)
31
+ end
32
+
26
33
  def encoded_header
27
34
  @encoded_header ||= encode_header
28
35
  end
@@ -40,25 +47,28 @@ module JWT
40
47
  end
41
48
 
42
49
  def encode_header
43
- @headers[ALG_KEY] = @algorithm
44
- encode(@headers)
50
+ encode_data(@headers)
45
51
  end
46
52
 
47
53
  def encode_payload
48
- if @payload.is_a?(Hash)
49
- ClaimsValidator.new(@payload).validate!
50
- end
54
+ encode_data(@payload)
55
+ end
51
56
 
52
- encode(@payload)
57
+ def signature
58
+ @algorithm.sign(data: encoded_header_and_payload, signing_key: @key)
53
59
  end
54
60
 
55
- def encode_signature
56
- return '' if @algorithm == ALG_NONE
61
+ def validate_claims!
62
+ return unless @payload.is_a?(Hash)
63
+
64
+ ClaimsValidator.new(@payload).validate!
65
+ end
57
66
 
58
- ::JWT::Base64.url_encode(JWT::Signature.sign(@algorithm, encoded_header_and_payload, @key))
67
+ def encode_signature
68
+ ::JWT::Base64.url_encode(signature)
59
69
  end
60
70
 
61
- def encode(data)
71
+ def encode_data(data)
62
72
  ::JWT::Base64.url_encode(JWT::JSON.generate(data))
63
73
  end
64
74
 
data/lib/jwt/jwk/ec.rb CHANGED
@@ -9,39 +9,42 @@ module JWT
9
9
  def_delegators :keypair, :public_key
10
10
 
11
11
  KTY = 'EC'
12
- KTYS = [KTY, OpenSSL::PKey::EC].freeze
12
+ KTYS = [KTY, OpenSSL::PKey::EC, JWT::JWK::EC].freeze
13
13
  BINARY = 2
14
+ EC_PUBLIC_KEY_ELEMENTS = %i[kty crv x y].freeze
15
+ EC_PRIVATE_KEY_ELEMENTS = %i[d].freeze
16
+ EC_KEY_ELEMENTS = (EC_PRIVATE_KEY_ELEMENTS + EC_PUBLIC_KEY_ELEMENTS).freeze
14
17
 
15
- attr_reader :keypair
18
+ def initialize(key, params = nil, options = {})
19
+ params ||= {}
16
20
 
17
- def initialize(keypair, options = {})
18
- raise ArgumentError, 'keypair must be of type OpenSSL::PKey::EC' unless keypair.is_a?(OpenSSL::PKey::EC)
21
+ # For backwards compatibility when kid was a String
22
+ params = { kid: params } if params.is_a?(String)
19
23
 
20
- @keypair = keypair
24
+ key_params = extract_key_params(key)
21
25
 
22
- super(options)
26
+ params = params.transform_keys(&:to_sym)
27
+ check_jwk(key_params, params)
28
+
29
+ super(options, key_params.merge(params))
30
+ end
31
+
32
+ def keypair
33
+ @keypair ||= create_ec_key(self[:crv], self[:x], self[:y], self[:d])
23
34
  end
24
35
 
25
36
  def private?
26
- @keypair.private_key?
37
+ keypair.private_key?
27
38
  end
28
39
 
29
40
  def members
30
- crv, x_octets, y_octets = keypair_components(keypair)
31
- {
32
- kty: KTY,
33
- crv: crv,
34
- x: encode_octets(x_octets),
35
- y: encode_octets(y_octets)
36
- }
41
+ EC_PUBLIC_KEY_ELEMENTS.each_with_object({}) { |i, h| h[i] = self[i] }
37
42
  end
38
43
 
39
44
  def export(options = {})
40
- exported_hash = members.merge(kid: kid)
41
-
42
- return exported_hash unless private? && options[:include_private] == true
43
-
44
- append_private_parts(exported_hash)
45
+ exported = parameters.clone
46
+ exported.reject! { |k, _| EC_PRIVATE_KEY_ELEMENTS.include? k } unless private? && options[:include_private] == true
47
+ exported
45
48
  end
46
49
 
47
50
  def key_digest
@@ -51,13 +54,34 @@ module JWT
51
54
  OpenSSL::Digest::SHA256.hexdigest(sequence.to_der)
52
55
  end
53
56
 
57
+ def []=(key, value)
58
+ if EC_KEY_ELEMENTS.include?(key.to_sym)
59
+ raise ArgumentError, 'cannot overwrite cryptographic key attributes'
60
+ end
61
+
62
+ super(key, value)
63
+ end
64
+
54
65
  private
55
66
 
56
- def append_private_parts(the_hash)
57
- octets = keypair.private_key.to_bn.to_s(BINARY)
58
- the_hash.merge(
59
- d: encode_octets(octets)
60
- )
67
+ def extract_key_params(key)
68
+ case key
69
+ when JWT::JWK::EC
70
+ key.export(include_private: true)
71
+ when OpenSSL::PKey::EC # Accept OpenSSL key as input
72
+ @keypair = key # Preserve the object to avoid recreation
73
+ parse_ec_key(key)
74
+ when Hash
75
+ key.transform_keys(&:to_sym)
76
+ else
77
+ raise ArgumentError, 'key must be of type OpenSSL::PKey::EC or Hash with key parameters'
78
+ end
79
+ end
80
+
81
+ def check_jwk(keypair, params)
82
+ raise ArgumentError, 'cannot overwrite cryptographic key attributes' unless (EC_KEY_ELEMENTS & params.keys).empty?
83
+ raise JWT::JWKError, "Incorrect 'kty' value: #{keypair[:kty]}, expected #{KTY}" unless keypair[:kty] == KTY
84
+ raise JWT::JWKError, 'Key format is invalid for EC' unless keypair[:crv] && keypair[:x] && keypair[:y]
61
85
  end
62
86
 
63
87
  def keypair_components(ec_keypair)
@@ -82,6 +106,8 @@ module JWT
82
106
  end
83
107
 
84
108
  def encode_octets(octets)
109
+ return unless octets
110
+
85
111
  ::JWT::Base64.url_encode(octets)
86
112
  end
87
113
 
@@ -89,15 +115,94 @@ module JWT
89
115
  ::JWT::Base64.url_encode(key_part.to_s(BINARY))
90
116
  end
91
117
 
92
- class << self
93
- def import(jwk_data)
94
- # See https://tools.ietf.org/html/rfc7518#section-6.2.1 for an
95
- # explanation of the relevant parameters.
118
+ def parse_ec_key(key)
119
+ crv, x_octets, y_octets = keypair_components(key)
120
+ octets = key.private_key&.to_bn&.to_s(BINARY)
121
+ {
122
+ kty: KTY,
123
+ crv: crv,
124
+ x: encode_octets(x_octets),
125
+ y: encode_octets(y_octets),
126
+ d: encode_octets(octets)
127
+ }.compact
128
+ end
96
129
 
97
- jwk_crv, jwk_x, jwk_y, jwk_d, jwk_kid = jwk_attrs(jwk_data, %i[crv x y d kid])
98
- raise JWT::JWKError, 'Key format is invalid for EC' unless jwk_crv && jwk_x && jwk_y
130
+ if ::JWT.openssl_3?
131
+ def create_ec_key(jwk_crv, jwk_x, jwk_y, jwk_d) # rubocop:disable Metrics/MethodLength
132
+ curve = EC.to_openssl_curve(jwk_crv)
133
+
134
+ x_octets = decode_octets(jwk_x)
135
+ y_octets = decode_octets(jwk_y)
136
+
137
+ point = OpenSSL::PKey::EC::Point.new(
138
+ OpenSSL::PKey::EC::Group.new(curve),
139
+ OpenSSL::BN.new([0x04, x_octets, y_octets].pack('Ca*a*'), 2)
140
+ )
141
+
142
+ sequence = if jwk_d
143
+ # https://datatracker.ietf.org/doc/html/rfc5915.html
144
+ # ECPrivateKey ::= SEQUENCE {
145
+ # version INTEGER { ecPrivkeyVer1(1) } (ecPrivkeyVer1),
146
+ # privateKey OCTET STRING,
147
+ # parameters [0] ECParameters {{ NamedCurve }} OPTIONAL,
148
+ # publicKey [1] BIT STRING OPTIONAL
149
+ # }
150
+
151
+ OpenSSL::ASN1::Sequence([
152
+ OpenSSL::ASN1::Integer(1),
153
+ OpenSSL::ASN1::OctetString(OpenSSL::BN.new(decode_octets(jwk_d), 2).to_s(2)),
154
+ OpenSSL::ASN1::ObjectId(curve, 0, :EXPLICIT),
155
+ OpenSSL::ASN1::BitString(point.to_octet_string(:uncompressed), 1, :EXPLICIT)
156
+ ])
157
+ else
158
+ OpenSSL::ASN1::Sequence([
159
+ OpenSSL::ASN1::Sequence([OpenSSL::ASN1::ObjectId('id-ecPublicKey'), OpenSSL::ASN1::ObjectId(curve)]),
160
+ OpenSSL::ASN1::BitString(point.to_octet_string(:uncompressed))
161
+ ])
162
+ end
99
163
 
100
- new(ec_pkey(jwk_crv, jwk_x, jwk_y, jwk_d), kid: jwk_kid)
164
+ OpenSSL::PKey::EC.new(sequence.to_der)
165
+ end
166
+ else
167
+ def create_ec_key(jwk_crv, jwk_x, jwk_y, jwk_d)
168
+ curve = EC.to_openssl_curve(jwk_crv)
169
+
170
+ x_octets = decode_octets(jwk_x)
171
+ y_octets = decode_octets(jwk_y)
172
+
173
+ key = OpenSSL::PKey::EC.new(curve)
174
+
175
+ # The details of the `Point` instantiation are covered in:
176
+ # - https://docs.ruby-lang.org/en/2.4.0/OpenSSL/PKey/EC.html
177
+ # - https://www.openssl.org/docs/manmaster/man3/EC_POINT_new.html
178
+ # - https://tools.ietf.org/html/rfc5480#section-2.2
179
+ # - https://www.secg.org/SEC1-Ver-1.0.pdf
180
+ # Section 2.3.3 of the last of these references specifies that the
181
+ # encoding of an uncompressed point consists of the byte `0x04` followed
182
+ # by the x value then the y value.
183
+ point = OpenSSL::PKey::EC::Point.new(
184
+ OpenSSL::PKey::EC::Group.new(curve),
185
+ OpenSSL::BN.new([0x04, x_octets, y_octets].pack('Ca*a*'), 2)
186
+ )
187
+
188
+ key.public_key = point
189
+ key.private_key = OpenSSL::BN.new(decode_octets(jwk_d), 2) if jwk_d
190
+
191
+ key
192
+ end
193
+ end
194
+
195
+ def decode_octets(jwk_data)
196
+ ::JWT::Base64.url_decode(jwk_data)
197
+ end
198
+
199
+ def decode_open_ssl_bn(jwk_data)
200
+ OpenSSL::BN.new(::JWT::Base64.url_decode(jwk_data), BINARY)
201
+ end
202
+
203
+ class << self
204
+ def import(jwk_data)
205
+ new(jwk_data)
101
206
  end
102
207
 
103
208
  def to_openssl_curve(crv)
@@ -112,87 +217,6 @@ module JWT
112
217
  else raise JWT::JWKError, 'Invalid curve provided'
113
218
  end
114
219
  end
115
-
116
- private
117
-
118
- def jwk_attrs(jwk_data, attrs)
119
- attrs.map do |attr|
120
- jwk_data[attr] || jwk_data[attr.to_s]
121
- end
122
- end
123
-
124
- if ::JWT.openssl_3?
125
- def ec_pkey(jwk_crv, jwk_x, jwk_y, jwk_d) # rubocop:disable Metrics/MethodLength
126
- curve = to_openssl_curve(jwk_crv)
127
-
128
- x_octets = decode_octets(jwk_x)
129
- y_octets = decode_octets(jwk_y)
130
-
131
- point = OpenSSL::PKey::EC::Point.new(
132
- OpenSSL::PKey::EC::Group.new(curve),
133
- OpenSSL::BN.new([0x04, x_octets, y_octets].pack('Ca*a*'), 2)
134
- )
135
-
136
- sequence = if jwk_d
137
- # https://datatracker.ietf.org/doc/html/rfc5915.html
138
- # ECPrivateKey ::= SEQUENCE {
139
- # version INTEGER { ecPrivkeyVer1(1) } (ecPrivkeyVer1),
140
- # privateKey OCTET STRING,
141
- # parameters [0] ECParameters {{ NamedCurve }} OPTIONAL,
142
- # publicKey [1] BIT STRING OPTIONAL
143
- # }
144
-
145
- OpenSSL::ASN1::Sequence([
146
- OpenSSL::ASN1::Integer(1),
147
- OpenSSL::ASN1::OctetString(OpenSSL::BN.new(decode_octets(jwk_d), 2).to_s(2)),
148
- OpenSSL::ASN1::ObjectId(curve, 0, :EXPLICIT),
149
- OpenSSL::ASN1::BitString(point.to_octet_string(:uncompressed), 1, :EXPLICIT)
150
- ])
151
- else
152
- OpenSSL::ASN1::Sequence([
153
- OpenSSL::ASN1::Sequence([OpenSSL::ASN1::ObjectId('id-ecPublicKey'), OpenSSL::ASN1::ObjectId(curve)]),
154
- OpenSSL::ASN1::BitString(point.to_octet_string(:uncompressed))
155
- ])
156
- end
157
-
158
- OpenSSL::PKey::EC.new(sequence.to_der)
159
- end
160
- else
161
- def ec_pkey(jwk_crv, jwk_x, jwk_y, jwk_d)
162
- curve = to_openssl_curve(jwk_crv)
163
-
164
- x_octets = decode_octets(jwk_x)
165
- y_octets = decode_octets(jwk_y)
166
-
167
- key = OpenSSL::PKey::EC.new(curve)
168
-
169
- # The details of the `Point` instantiation are covered in:
170
- # - https://docs.ruby-lang.org/en/2.4.0/OpenSSL/PKey/EC.html
171
- # - https://www.openssl.org/docs/manmaster/man3/EC_POINT_new.html
172
- # - https://tools.ietf.org/html/rfc5480#section-2.2
173
- # - https://www.secg.org/SEC1-Ver-1.0.pdf
174
- # Section 2.3.3 of the last of these references specifies that the
175
- # encoding of an uncompressed point consists of the byte `0x04` followed
176
- # by the x value then the y value.
177
- point = OpenSSL::PKey::EC::Point.new(
178
- OpenSSL::PKey::EC::Group.new(curve),
179
- OpenSSL::BN.new([0x04, x_octets, y_octets].pack('Ca*a*'), 2)
180
- )
181
-
182
- key.public_key = point
183
- key.private_key = OpenSSL::BN.new(decode_octets(jwk_d), 2) if jwk_d
184
-
185
- key
186
- end
187
- end
188
-
189
- def decode_octets(jwk_data)
190
- ::JWT::Base64.url_decode(jwk_data)
191
- end
192
-
193
- def decode_open_ssl_bn(jwk_data)
194
- OpenSSL::BN.new(::JWT::Base64.url_decode(jwk_data), BINARY)
195
- end
196
220
  end
197
221
  end
198
222
  end
data/lib/jwt/jwk/hmac.rb CHANGED
@@ -4,15 +4,27 @@ module JWT
4
4
  module JWK
5
5
  class HMAC < KeyBase
6
6
  KTY = 'oct'
7
- KTYS = [KTY, String].freeze
7
+ KTYS = [KTY, String, JWT::JWK::HMAC].freeze
8
+ HMAC_PUBLIC_KEY_ELEMENTS = %i[kty].freeze
9
+ HMAC_PRIVATE_KEY_ELEMENTS = %i[k].freeze
10
+ HMAC_KEY_ELEMENTS = (HMAC_PRIVATE_KEY_ELEMENTS + HMAC_PUBLIC_KEY_ELEMENTS).freeze
8
11
 
9
- attr_reader :signing_key
12
+ def initialize(key, params = nil, options = {})
13
+ params ||= {}
10
14
 
11
- def initialize(signing_key, options = {})
12
- raise ArgumentError, 'signing_key must be of type String' unless signing_key.is_a?(String)
15
+ # For backwards compatibility when kid was a String
16
+ params = { kid: params } if params.is_a?(String)
13
17
 
14
- @signing_key = signing_key
15
- super(options)
18
+ key_params = extract_key_params(key)
19
+
20
+ params = params.transform_keys(&:to_sym)
21
+ check_jwk(key_params, params)
22
+
23
+ super(options, key_params.merge(params))
24
+ end
25
+
26
+ def keypair
27
+ self[:k]
16
28
  end
17
29
 
18
30
  def private?
@@ -25,26 +37,16 @@ module JWT
25
37
 
26
38
  # See https://tools.ietf.org/html/rfc7517#appendix-A.3
27
39
  def export(options = {})
28
- exported_hash = {
29
- kty: KTY,
30
- kid: kid
31
- }
32
-
33
- return exported_hash unless private? && options[:include_private] == true
34
-
35
- exported_hash.merge(
36
- k: signing_key
37
- )
40
+ exported = parameters.clone
41
+ exported.reject! { |k, _| HMAC_PRIVATE_KEY_ELEMENTS.include? k } unless private? && options[:include_private] == true
42
+ exported
38
43
  end
39
44
 
40
45
  def members
41
- {
42
- kty: KTY,
43
- k: signing_key
44
- }
46
+ HMAC_KEY_ELEMENTS.each_with_object({}) { |i, h| h[i] = self[i] }
45
47
  end
46
48
 
47
- alias keypair signing_key # for backwards compatibility
49
+ alias signing_key keypair # for backwards compatibility
48
50
 
49
51
  def key_digest
50
52
  sequence = OpenSSL::ASN1::Sequence([OpenSSL::ASN1::UTF8String.new(signing_key),
@@ -52,14 +54,38 @@ module JWT
52
54
  OpenSSL::Digest::SHA256.hexdigest(sequence.to_der)
53
55
  end
54
56
 
55
- class << self
56
- def import(jwk_data)
57
- jwk_k = jwk_data[:k] || jwk_data['k']
58
- jwk_kid = jwk_data[:kid] || jwk_data['kid']
57
+ def []=(key, value)
58
+ if HMAC_KEY_ELEMENTS.include?(key.to_sym)
59
+ raise ArgumentError, 'cannot overwrite cryptographic key attributes'
60
+ end
59
61
 
60
- raise JWT::JWKError, 'Key format is invalid for HMAC' unless jwk_k
62
+ super(key, value)
63
+ end
64
+
65
+ private
66
+
67
+ def extract_key_params(key)
68
+ case key
69
+ when JWT::JWK::HMAC
70
+ key.export(include_private: true)
71
+ when String # Accept String key as input
72
+ { kty: KTY, k: key }
73
+ when Hash
74
+ key.transform_keys(&:to_sym)
75
+ else
76
+ raise ArgumentError, 'key must be of type String or Hash with key parameters'
77
+ end
78
+ end
61
79
 
62
- new(jwk_k, kid: jwk_kid)
80
+ def check_jwk(keypair, params)
81
+ raise ArgumentError, 'cannot overwrite cryptographic key attributes' unless (HMAC_KEY_ELEMENTS & params.keys).empty?
82
+ raise JWT::JWKError, "Incorrect 'kty' value: #{keypair[:kty]}, expected #{KTY}" unless keypair[:kty] == KTY
83
+ raise JWT::JWKError, 'Key format is invalid for HMAC' unless keypair[:k]
84
+ end
85
+
86
+ class << self
87
+ def import(jwk_data)
88
+ new(jwk_data)
63
89
  end
64
90
  end
65
91
  end