standard_singpass 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,350 @@
1
+ # typed: strict
2
+
3
+ module StandardSingpass
4
+ module Myinfo
5
+ # Native ECDH-ES+A256KW JWE implementation for decryption.
6
+ #
7
+ # The `jwt` gem does not support ECDH-ES key agreement algorithms; this
8
+ # service implements the subset needed for MyInfo (FAPI 2.0):
9
+ # alg: ECDH-ES+A256KW
10
+ # enc: A128CBC-HS256, A256CBC-HS512, A128GCM, A256GCM
11
+ #
12
+ # References:
13
+ # - RFC 7516 (JWE)
14
+ # - RFC 7518 Section 4.6 (ECDH-ES key agreement)
15
+ # - NIST SP 800-56A Concat KDF
16
+ class EcdhJwe
17
+ extend T::Sig
18
+
19
+ class DecryptionFailed < StandardError; end
20
+ class InvalidAlgorithm < StandardError; end
21
+
22
+ SUPPORTED_ALGS = T.let(%w[ECDH-ES+A128KW ECDH-ES+A256KW].freeze, T::Array[String])
23
+ SUPPORTED_ENCS = T.let(%w[A128CBC-HS256 A256CBC-HS512 A128GCM A256GCM].freeze, T::Array[String])
24
+
25
+ # Key wrap key sizes (in bytes) for each alg
26
+ KEK_SIZES = T.let({
27
+ "ECDH-ES+A128KW" => 16,
28
+ "ECDH-ES+A256KW" => 32
29
+ }.freeze, T::Hash[String, Integer])
30
+
31
+ # CEK sizes (in bytes) for each enc
32
+ CEK_SIZES = T.let({
33
+ "A128CBC-HS256" => 32,
34
+ "A256CBC-HS512" => 64,
35
+ "A128GCM" => 16,
36
+ "A256GCM" => 32
37
+ }.freeze, T::Hash[String, Integer])
38
+
39
+ # Encrypts a payload and returns a compact-serialized JWE string.
40
+ sig do
41
+ params(
42
+ payload: String,
43
+ public_key: OpenSSL::PKey::EC,
44
+ alg: String,
45
+ enc: String,
46
+ kid: T.nilable(String),
47
+ apu: T.nilable(String),
48
+ apv: T.nilable(String)
49
+ ).returns(String)
50
+ end
51
+ def self.encrypt(payload, public_key:, alg:, enc:, kid: nil, apu: nil, apv: nil)
52
+ validate_algorithms!(alg, enc)
53
+
54
+ # Generate ephemeral key pair on same curve
55
+ group = public_key.group
56
+ ephemeral_key = OpenSSL::PKey::EC.generate(group.curve_name)
57
+
58
+ # ECDH key agreement
59
+ shared_secret = derive_shared_secret(ephemeral_key, public_key)
60
+
61
+ # Derive KEK via Concat KDF
62
+ kek_size = KEK_SIZES.fetch(alg)
63
+ kek = concat_kdf(shared_secret, alg, kek_size, apu:, apv:)
64
+
65
+ # Generate random CEK
66
+ cek_size = CEK_SIZES.fetch(enc)
67
+ cek = SecureRandom.random_bytes(cek_size)
68
+
69
+ # Wrap CEK with KEK
70
+ encrypted_key = AESKeyWrap.wrap(cek, kek)
71
+
72
+ # Build header
73
+ epk_jwk = ec_public_key_to_jwk(ephemeral_key)
74
+ header = { "alg" => alg, "enc" => enc, "epk" => epk_jwk }
75
+ header["kid"] = kid if kid
76
+ header["apu"] = Base64.urlsafe_encode64(apu, padding: false) if apu
77
+ header["apv"] = Base64.urlsafe_encode64(apv, padding: false) if apv
78
+
79
+ # Encrypt content
80
+ header_b64 = Base64.urlsafe_encode64(header.to_json, padding: false)
81
+ iv, ciphertext, auth_tag = encrypt_content(cek, enc, payload, header_b64)
82
+
83
+ # Assemble compact serialization
84
+ [
85
+ header_b64,
86
+ Base64.urlsafe_encode64(encrypted_key, padding: false),
87
+ Base64.urlsafe_encode64(T.must(iv), padding: false),
88
+ Base64.urlsafe_encode64(T.must(ciphertext), padding: false),
89
+ Base64.urlsafe_encode64(T.must(auth_tag), padding: false)
90
+ ].join(".")
91
+ end
92
+
93
+ # Decrypts a compact-serialized JWE string.
94
+ sig { params(jwe_string: String, private_key: OpenSSL::PKey::EC).returns(String) }
95
+ def self.decrypt(jwe_string, private_key:)
96
+ parts = jwe_string.split(".")
97
+ raise DecryptionFailed, "Invalid JWE format" unless parts.length == 5
98
+
99
+ header_b64, encrypted_key_b64, iv_b64, ciphertext_b64, tag_b64 = parts
100
+
101
+ header = JSON.parse(Base64.urlsafe_decode64(T.must(header_b64)))
102
+ alg = header["alg"]
103
+ enc = header["enc"]
104
+
105
+ validate_algorithms!(alg, enc)
106
+
107
+ epk = header["epk"]
108
+ raise DecryptionFailed, "Missing ephemeral public key (epk)" unless epk
109
+
110
+ # Decode apu/apv from header if present
111
+ apu = header["apu"] ? Base64.urlsafe_decode64(header["apu"]) : nil
112
+ apv = header["apv"] ? Base64.urlsafe_decode64(header["apv"]) : nil
113
+
114
+ # Reconstruct ephemeral public key
115
+ ephemeral_public_key = jwk_to_ec_public_key(epk)
116
+
117
+ # ECDH key agreement
118
+ shared_secret = derive_shared_secret(private_key, ephemeral_public_key)
119
+
120
+ # Derive KEK via Concat KDF
121
+ kek_size = KEK_SIZES.fetch(alg)
122
+ kek = concat_kdf(shared_secret, alg, kek_size, apu:, apv:)
123
+
124
+ # Unwrap CEK
125
+ encrypted_key = Base64.urlsafe_decode64(T.must(encrypted_key_b64))
126
+ cek = AESKeyWrap.unwrap(encrypted_key, kek)
127
+ raise DecryptionFailed, "Key unwrap failed" unless cek
128
+
129
+ # Decrypt content
130
+ iv = Base64.urlsafe_decode64(T.must(iv_b64))
131
+ ciphertext = Base64.urlsafe_decode64(T.must(ciphertext_b64))
132
+ auth_tag = Base64.urlsafe_decode64(T.must(tag_b64))
133
+
134
+ decrypt_content(cek, enc, ciphertext, iv, auth_tag, T.must(header_b64))
135
+ rescue JSON::ParserError, ArgumentError => e
136
+ raise DecryptionFailed, "Malformed JWE: #{e.message}"
137
+ rescue OpenSSL::OpenSSLError => e
138
+ raise DecryptionFailed, "Decryption failed: #{e.message}"
139
+ end
140
+
141
+ class << self
142
+ extend T::Sig
143
+
144
+ private
145
+
146
+ sig { params(alg: T.untyped, enc: T.untyped).void }
147
+ def validate_algorithms!(alg, enc)
148
+ raise InvalidAlgorithm, "Unsupported alg: #{alg}" unless SUPPORTED_ALGS.include?(alg)
149
+ raise InvalidAlgorithm, "Unsupported enc: #{enc}" unless SUPPORTED_ENCS.include?(enc)
150
+ end
151
+
152
+ # Performs ECDH key agreement and returns the raw shared secret.
153
+ sig { params(local_key: OpenSSL::PKey::EC, remote_public_key: OpenSSL::PKey::EC).returns(String) }
154
+ def derive_shared_secret(local_key, remote_public_key)
155
+ local_key.dh_compute_key(remote_public_key.public_key)
156
+ end
157
+
158
+ # NIST Concat KDF (single-pass, SHA-256) per RFC 7518 Section 4.6.2
159
+ sig { params(shared_secret: String, algorithm: String, key_length: Integer, apu: T.nilable(String), apv: T.nilable(String)).returns(String) }
160
+ def concat_kdf(shared_secret, algorithm, key_length, apu: nil, apv: nil)
161
+ algorithm_id = [algorithm.bytesize].pack("N") + algorithm
162
+ party_u_info = apu ? [apu.bytesize].pack("N") + apu : [0].pack("N")
163
+ party_v_info = apv ? [apv.bytesize].pack("N") + apv : [0].pack("N")
164
+ supp_pub_info = [key_length * 8].pack("N")
165
+
166
+ other_info = algorithm_id + party_u_info + party_v_info + supp_pub_info
167
+
168
+ # Single round (SHA-256 output is 32 bytes, enough for up to 256-bit keys)
169
+ round_input = [1].pack("N") + shared_secret + other_info
170
+ digest = OpenSSL::Digest::SHA256.digest(round_input)
171
+ digest[0, key_length]
172
+ end
173
+
174
+ # Converts an OpenSSL EC key to a JWK hash (public components only).
175
+ sig { params(ec_key: OpenSSL::PKey::EC).returns(T::Hash[String, String]) }
176
+ def ec_public_key_to_jwk(ec_key)
177
+ # Get the public key point
178
+ point = ec_key.public_key
179
+ group = ec_key.group
180
+
181
+ # Determine curve name for JWK
182
+ crv = case group.curve_name
183
+ when "prime256v1" then "P-256"
184
+ when "secp384r1" then "P-384"
185
+ when "secp521r1" then "P-521"
186
+ else raise InvalidAlgorithm, "Unsupported curve: #{group.curve_name}"
187
+ end
188
+
189
+ # Get uncompressed point bytes (0x04 || x || y)
190
+ bn = point.to_bn(:uncompressed)
191
+ uncompressed = bn.to_s(2)
192
+
193
+ # Skip the 0x04 prefix byte
194
+ coord_length = (uncompressed.bytesize - 1) / 2
195
+ x = uncompressed[1, coord_length]
196
+ y = uncompressed[1 + coord_length, coord_length]
197
+
198
+ {
199
+ "kty" => "EC",
200
+ "crv" => crv,
201
+ "x" => Base64.urlsafe_encode64(x, padding: false),
202
+ "y" => Base64.urlsafe_encode64(y, padding: false)
203
+ }
204
+ end
205
+
206
+ # Reconstructs an EC public key from a JWK hash.
207
+ sig { params(jwk: T::Hash[String, T.untyped]).returns(OpenSSL::PKey::EC) }
208
+ def jwk_to_ec_public_key(jwk)
209
+ crv = jwk["crv"]
210
+ curve_name = case crv
211
+ when "P-256" then "prime256v1"
212
+ when "P-384" then "secp384r1"
213
+ when "P-521" then "secp521r1"
214
+ else raise InvalidAlgorithm, "Unsupported curve: #{crv}"
215
+ end
216
+
217
+ x = Base64.urlsafe_decode64(jwk["x"])
218
+ y = Base64.urlsafe_decode64(jwk["y"])
219
+
220
+ group = OpenSSL::PKey::EC::Group.new(curve_name)
221
+ # Build uncompressed point: 0x04 || x || y
222
+ point_hex = "04" + x.unpack1("H*") + y.unpack1("H*")
223
+ point = OpenSSL::PKey::EC::Point.new(group, OpenSSL::BN.new(point_hex, 16))
224
+
225
+ # Build a public-only EC key
226
+ inner = [OpenSSL::ASN1::ObjectId.new("id-ecPublicKey"),
227
+ OpenSSL::ASN1::ObjectId.new(curve_name)]
228
+ outer = [OpenSSL::ASN1::Sequence.new(inner),
229
+ OpenSSL::ASN1::BitString.new(point.to_octet_string(:uncompressed))]
230
+ asn1 = OpenSSL::ASN1::Sequence.new(outer)
231
+ OpenSSL::PKey::EC.new(asn1.to_der)
232
+ end
233
+
234
+ sig { params(cek: String, enc: String, plaintext: String, aad: String).returns(T::Array[String]) }
235
+ def encrypt_content(cek, enc, plaintext, aad)
236
+ case enc
237
+ when "A128GCM", "A256GCM"
238
+ encrypt_gcm(cek, plaintext, aad, enc)
239
+ when "A128CBC-HS256", "A256CBC-HS512"
240
+ encrypt_cbc(cek, plaintext, aad, enc)
241
+ else
242
+ raise InvalidAlgorithm, "Unsupported enc: #{enc}"
243
+ end
244
+ end
245
+
246
+ sig { params(cek: String, enc: String, ciphertext: String, iv: String, auth_tag: String, aad: String).returns(String) }
247
+ def decrypt_content(cek, enc, ciphertext, iv, auth_tag, aad)
248
+ case enc
249
+ when "A128GCM", "A256GCM"
250
+ decrypt_gcm(cek, ciphertext, iv, auth_tag, aad, enc)
251
+ when "A128CBC-HS256", "A256CBC-HS512"
252
+ decrypt_cbc(cek, ciphertext, iv, auth_tag, aad, enc)
253
+ else
254
+ raise InvalidAlgorithm, "Unsupported enc: #{enc}"
255
+ end
256
+ end
257
+
258
+ sig { params(enc: String).returns(String) }
259
+ def gcm_cipher_name(enc)
260
+ enc == "A128GCM" ? "aes-128-gcm" : "aes-256-gcm"
261
+ end
262
+
263
+ sig { params(enc: String).returns(String) }
264
+ def cbc_cipher_name(enc)
265
+ enc == "A128CBC-HS256" ? "aes-128-cbc" : "aes-256-cbc"
266
+ end
267
+
268
+ sig { params(enc: String).returns(String) }
269
+ def cbc_hmac_digest(enc)
270
+ enc == "A128CBC-HS256" ? "SHA256" : "SHA512"
271
+ end
272
+
273
+ sig { params(cek: String, plaintext: String, aad: String, enc: String).returns(T::Array[String]) }
274
+ def encrypt_gcm(cek, plaintext, aad, enc)
275
+ cipher = OpenSSL::Cipher.new(gcm_cipher_name(enc))
276
+ cipher.encrypt
277
+ cipher.key = cek
278
+ iv = cipher.random_iv
279
+ cipher.auth_data = aad
280
+ ciphertext = cipher.update(plaintext) + cipher.final
281
+ auth_tag = cipher.auth_tag
282
+ [iv, ciphertext, auth_tag]
283
+ end
284
+
285
+ sig { params(cek: String, ciphertext: String, iv: String, auth_tag: String, aad: String, enc: String).returns(String) }
286
+ def decrypt_gcm(cek, ciphertext, iv, auth_tag, aad, enc)
287
+ raise DecryptionFailed, "Invalid authentication tag" if auth_tag.bytesize < 16
288
+
289
+ cipher = OpenSSL::Cipher.new(gcm_cipher_name(enc))
290
+ cipher.decrypt
291
+ cipher.key = cek
292
+ cipher.iv = iv
293
+ cipher.auth_tag = auth_tag
294
+ cipher.auth_data = aad
295
+ cipher.update(ciphertext) + cipher.final
296
+ rescue OpenSSL::Cipher::CipherError
297
+ raise DecryptionFailed, "Content decryption failed"
298
+ end
299
+
300
+ sig { params(cek: String, plaintext: String, aad: String, enc: String).returns(T::Array[String]) }
301
+ def encrypt_cbc(cek, plaintext, aad, enc)
302
+ mac_key_len = cek.bytesize / 2
303
+ mac_key = cek[0, mac_key_len]
304
+ enc_key = cek[mac_key_len, mac_key_len]
305
+
306
+ cipher = OpenSSL::Cipher.new(cbc_cipher_name(enc))
307
+ cipher.encrypt
308
+ cipher.key = T.must(enc_key)
309
+ iv = cipher.random_iv
310
+ ciphertext = cipher.update(plaintext) + cipher.final
311
+
312
+ # Compute authentication tag (HMAC over AAD || IV || ciphertext || AL)
313
+ al = [aad.bytesize * 8].pack("Q>")
314
+ hmac_input = aad + iv + ciphertext + al
315
+ hmac = OpenSSL::HMAC.digest(cbc_hmac_digest(enc), T.must(mac_key), hmac_input)
316
+ tag_len = mac_key_len # half of HMAC output
317
+ auth_tag = hmac[0, tag_len]
318
+
319
+ [iv, ciphertext, auth_tag]
320
+ end
321
+
322
+ sig { params(cek: String, ciphertext: String, iv: String, auth_tag: String, aad: String, enc: String).returns(String) }
323
+ def decrypt_cbc(cek, ciphertext, iv, auth_tag, aad, enc)
324
+ mac_key_len = cek.bytesize / 2
325
+ mac_key = cek[0, mac_key_len]
326
+ enc_key = cek[mac_key_len, mac_key_len]
327
+
328
+ # Verify authentication tag
329
+ al = [aad.bytesize * 8].pack("Q>")
330
+ hmac_input = aad + iv + ciphertext + al
331
+ hmac = OpenSSL::HMAC.digest(cbc_hmac_digest(enc), T.must(mac_key), hmac_input)
332
+ tag_len = mac_key_len
333
+ expected_tag = hmac[0, tag_len]
334
+
335
+ unless OpenSSL.fixed_length_secure_compare(auth_tag, expected_tag)
336
+ raise DecryptionFailed, "Authentication tag verification failed"
337
+ end
338
+
339
+ cipher = OpenSSL::Cipher.new(cbc_cipher_name(enc))
340
+ cipher.decrypt
341
+ cipher.key = T.must(enc_key)
342
+ cipher.iv = iv
343
+ cipher.update(ciphertext) + cipher.final
344
+ rescue OpenSSL::Cipher::CipherError
345
+ raise DecryptionFailed, "Content decryption failed"
346
+ end
347
+ end
348
+ end
349
+ end
350
+ end
@@ -0,0 +1,14 @@
1
+ # typed: strict
2
+
3
+ module StandardSingpass
4
+ module Myinfo
5
+ class Error < StandardError; end
6
+ class AuthenticationError < Error; end
7
+ class ApiError < Error; end
8
+ class PARError < Error; end
9
+ class DecryptionError < Error; end
10
+ class SignatureError < Error; end
11
+ class RateLimitError < Error; end
12
+ class ConfigurationError < Error; end
13
+ end
14
+ end
@@ -0,0 +1,116 @@
1
+ # typed: strict
2
+
3
+ module StandardSingpass
4
+ module Myinfo
5
+ # Generates and validates the private JWKS document that gets pasted into
6
+ # the host application's `MYINFO_PRIVATE_JWKS` env var (or equivalent).
7
+ # Public-facing entrypoint is the `standard_singpass:myinfo:generate_jwks`
8
+ # rake task; this module holds the logic so it's testable without going
9
+ # through Rake.
10
+ #
11
+ # The validator mirrors what `Configuration#private_jwks_json=` requires —
12
+ # particularly the "must have private scalar `d`" check that traps the
13
+ # public/private key mix-up.
14
+ module JwksGenerator
15
+ extend T::Sig
16
+
17
+ SIG_ALG = T.let("ES256", String)
18
+ ENC_ALG = T.let("ECDH-ES+A256KW", String)
19
+ EC_CURVE = T.let("P-256", String)
20
+ EC_OPENSSL_NAME = T.let("prime256v1", String)
21
+
22
+ sig { params(sig_kid: String, enc_kid: String).returns(T::Hash[Symbol, T.untyped]) }
23
+ def self.generate(sig_kid:, enc_kid:)
24
+ sig_jwk = build_jwk(OpenSSL::PKey::EC.generate(EC_OPENSSL_NAME), kid: sig_kid, use: "sig", alg: SIG_ALG)
25
+ enc_jwk = build_jwk(OpenSSL::PKey::EC.generate(EC_OPENSSL_NAME), kid: enc_kid, use: "enc", alg: ENC_ALG)
26
+
27
+ jwks = { keys: [sig_jwk, enc_jwk] }
28
+
29
+ # Defensive — the trap this module is built to prevent. If we ever
30
+ # emit a public-only JWK from a generation path, refuse early.
31
+ issues = validate(jwks)
32
+ raise "Internal: generated JWKS failed self-validation:\n#{issues.join("\n")}" if issues.any?
33
+
34
+ jwks
35
+ end
36
+
37
+ # Returns an array of issue strings; empty array means valid.
38
+ # Accepts either symbol- or string-keyed hashes (JSON.parse output is
39
+ # string-keyed; in-memory values from .generate are symbol-keyed).
40
+ sig { params(jwks: T.untyped).returns(T::Array[String]) }
41
+ def self.validate(jwks)
42
+ return ["root is not a JSON object (got #{jwks.class})"] unless jwks.is_a?(Hash)
43
+
44
+ keys = jwks["keys"] || jwks[:keys]
45
+ return ["missing 'keys' array"] unless keys.is_a?(Array)
46
+
47
+ issues = []
48
+
49
+ sig_keys = keys.select { |k| k.is_a?(Hash) && key_field(k, :use) == "sig" }
50
+ enc_keys = keys.select { |k| k.is_a?(Hash) && key_field(k, :use) == "enc" }
51
+
52
+ issues << "expected exactly one sig key (use=\"sig\"), got #{sig_keys.size}" unless sig_keys.size == 1
53
+ issues << "expected at least one enc key (use=\"enc\"), got 0" if enc_keys.empty?
54
+
55
+ # RFC 7517 §4.5: kid values within a JWKS should be distinct so a
56
+ # consumer can pick keys unambiguously. The runtime config loader
57
+ # selects by `use`, so a duplicate kid wouldn't blow up at boot —
58
+ # but Singpass may behave differently, and a duplicate is almost
59
+ # always an operator copy/paste mistake.
60
+ hashlike_keys = keys.grep(Hash)
61
+ kid_counts = hashlike_keys.group_by { |k| key_field(k, :kid) }.transform_values(&:size)
62
+ kid_counts.each do |kid, count|
63
+ next if count <= 1
64
+ issues << "duplicate kid #{kid.inspect} appears #{count} times — kids must be unique within a JWKS (RFC 7517 §4.5)"
65
+ end
66
+
67
+ keys.each_with_index do |k, i|
68
+ unless k.is_a?(Hash)
69
+ issues << "keys[#{i}] is not an object (got #{k.class})"
70
+ next
71
+ end
72
+ kid = key_field(k, :kid)
73
+ use = key_field(k, :use)
74
+ label = "keys[#{i}] kid=#{kid.inspect} use=#{use.inspect}"
75
+
76
+ d = key_field(k, :d)
77
+ kty = key_field(k, :kty)
78
+ crv = key_field(k, :crv)
79
+ alg = key_field(k, :alg)
80
+
81
+ # Public-only-key trap. Without `d` the JWK is public-only and
82
+ # cannot sign or decrypt — every Singpass call fails at runtime.
83
+ # `.blank?` matches the runtime config loader's check so the
84
+ # validator and the loader accept the same set of inputs.
85
+ issues << "#{label}: missing 'd' (public-only — re-export with include_private: true)" if d.blank?
86
+ issues << "#{label}: kty=#{kty.inspect} expected \"EC\"" unless kty == "EC"
87
+ issues << "#{label}: crv=#{crv.inspect} expected \"P-256\" (FAPI 2.0 requires EC P-256)" unless crv == EC_CURVE
88
+
89
+ case use
90
+ when "sig"
91
+ issues << "#{label}: alg=#{alg.inspect} expected \"#{SIG_ALG}\"" unless alg == SIG_ALG
92
+ when "enc"
93
+ issues << "#{label}: alg=#{alg.inspect} expected \"#{ENC_ALG}\"" unless alg == ENC_ALG
94
+ else
95
+ issues << "#{label}: use=#{use.inspect} expected \"sig\" or \"enc\""
96
+ end
97
+ end
98
+
99
+ issues
100
+ end
101
+
102
+ sig { params(key: OpenSSL::PKey::EC, kid: String, use: String, alg: String).returns(T::Hash[Symbol, T.untyped]) }
103
+ private_class_method def self.build_jwk(key, kid:, use:, alg:)
104
+ jwk = JWT::JWK.new(key, kid:).export(include_private: true)
105
+ jwk[:use] = use
106
+ jwk[:alg] = alg
107
+ jwk
108
+ end
109
+
110
+ sig { params(hash: T::Hash[T.untyped, T.untyped], field: Symbol).returns(T.untyped) }
111
+ private_class_method def self.key_field(hash, field)
112
+ hash[field] || hash[field.to_s]
113
+ end
114
+ end
115
+ end
116
+ end