@alliander-opensource/aws-jwt-sts 0.2.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/package.json ADDED
@@ -0,0 +1,56 @@
1
+ {
2
+ "name": "@alliander-opensource/aws-jwt-sts",
3
+ "version": "0.2.6",
4
+ "author": {
5
+ "name": "Alliander NV"
6
+ },
7
+ "main": "dist/index.js",
8
+ "files": [
9
+ "dist/*",
10
+ "src/*"
11
+ ],
12
+ "license": "MIT",
13
+ "repository": {
14
+ "type": "git",
15
+ "url": "https://github.com/alliander-opensource/aws-jwt-sts.git"
16
+ },
17
+ "scripts": {
18
+ "build": "rm -rf dist && tsc",
19
+ "watch": "tsc -w",
20
+ "test": "jest --coverage",
21
+ "cdk": "cdk",
22
+ "lint": "eslint . --ext .ts"
23
+ },
24
+ "devDependencies": {
25
+ "@types/jest": "^29.5.0",
26
+ "@types/node": "18.15.11",
27
+ "@types/prettier": "2.7.2",
28
+ "@typescript-eslint/eslint-plugin": "^5.58.0",
29
+ "@typescript-eslint/parser": "^5.58.0",
30
+ "aws-cdk": "2.74.0",
31
+ "aws-sdk-client-mock": "^2.1.1",
32
+ "esbuild": "^0.17.17",
33
+ "eslint-config-standard": "^17.0.0",
34
+ "eslint-plugin-import": "^2.27.5",
35
+ "eslint-plugin-jest": "^27.2.1",
36
+ "eslint-plugin-n": "^15.7.0",
37
+ "eslint-plugin-promise": "^6.1.1",
38
+ "jest": "^29.5.0",
39
+ "jwt-decode": "^3.1.2",
40
+ "ts-jest": "^29.1.0",
41
+ "ts-node": "^10.9.1",
42
+ "typescript": "~5.0.4"
43
+ },
44
+ "dependencies": {
45
+ "@aws-lambda-powertools/logger": "^1.8.0",
46
+ "@aws-sdk/client-kms": "^3.312.0",
47
+ "@aws-sdk/client-s3": "^3.312.0",
48
+ "@types/aws-lambda": "^8.10.114",
49
+ "@types/jsrsasign": "^10.5.8",
50
+ "aws-cdk-lib": "2.74.0",
51
+ "base64url": "^3.0.1",
52
+ "constructs": "^10.1.313",
53
+ "jsrsasign": "^10.8.2",
54
+ "source-map-support": "^0.5.21"
55
+ }
56
+ }
@@ -0,0 +1,228 @@
1
+ // SPDX-FileCopyrightText: 2023 Alliander NV
2
+ //
3
+ // SPDX-License-Identifier: Apache-2.0
4
+
5
+ import {
6
+ KMSClient,
7
+ DescribeKeyCommand,
8
+ GetPublicKeyCommand,
9
+ CreateKeyCommand,
10
+ UpdateAliasCommand,
11
+ ScheduleKeyDeletionCommand,
12
+ TagResourceCommand,
13
+ Tag,
14
+ NotFoundException,
15
+ CreateAliasCommand
16
+ } from '@aws-sdk/client-kms'
17
+ import { S3Client, PutObjectCommand } from '@aws-sdk/client-s3'
18
+ import { KEYUTIL, KJUR } from 'jsrsasign'
19
+
20
+ const client = new KMSClient({})
21
+
22
+ const ALIAS_PREVIOUS = 'alias/sts/PREVIOUS'
23
+ const ALIAS_CURRENT = 'alias/sts/CURRENT'
24
+ const ALIAS_PENDING = 'alias/sts/PENDING'
25
+
26
+ const ALIASES: string[] = [
27
+ ALIAS_PREVIOUS,
28
+ ALIAS_CURRENT,
29
+ ALIAS_PENDING
30
+ ]
31
+
32
+ export const handler = async (event: any): Promise<any> => {
33
+ // retrieve the step from the event
34
+ const step = (event.step)
35
+
36
+ // match the step with the corresponding function
37
+ switch (step) {
38
+ case 'deletePrevious':
39
+ await deletePrevious()
40
+ break
41
+ case 'movePrevious':
42
+ await movePrevious()
43
+ break
44
+ case 'moveCurrent':
45
+ await moveCurrent()
46
+ break
47
+ case 'createPending':
48
+ await createPending()
49
+ break
50
+ case 'generateArtifacts':
51
+ await generateJWKS()
52
+ await generateOpenIDConfiguration()
53
+ break
54
+
55
+ default:
56
+ console.log('invalid step')
57
+ }
58
+ }
59
+
60
+ async function deletePrevious () {
61
+ console.log('Deleting PREVIOUS aliased key')
62
+
63
+ const prevKeyId = await getKeyIdForAlias(ALIAS_PREVIOUS)
64
+ if (prevKeyId) {
65
+ const ScheduleDeleteResponse = await client.send(
66
+ new ScheduleKeyDeletionCommand({ KeyId: prevKeyId })
67
+ )
68
+ console.log(ScheduleDeleteResponse)
69
+ } else {
70
+ console.log('No PREVIOUS key at the moment, skip deletion')
71
+ }
72
+ }
73
+
74
+ async function movePrevious () {
75
+ console.log('moving PREVIOUS alias')
76
+ const currentKeyId = await getKeyIdForAlias(ALIAS_CURRENT)
77
+ if (currentKeyId) {
78
+ await updateOrCreateAlias(ALIAS_PREVIOUS, currentKeyId)
79
+ } else {
80
+ console.log('No CURRENT key at the moment, skip assigning the PREVIOUS alias to this key.')
81
+ }
82
+ }
83
+
84
+ async function moveCurrent () {
85
+ console.log('Moving CURRENT alias')
86
+
87
+ const pendingKeyId = await getKeyIdForAlias(ALIAS_PENDING)
88
+ if (pendingKeyId) {
89
+ await updateOrCreateAlias(ALIAS_CURRENT, pendingKeyId)
90
+ } else {
91
+ console.log('No PENDING key at the moment, skip assigning the CURRENT alias to this key.')
92
+ }
93
+ }
94
+
95
+ async function createPending () {
96
+ console.log('Creating new key for PENDING')
97
+
98
+ // Create new key
99
+ const createResponse = await client.send(new CreateKeyCommand({
100
+ KeySpec: 'RSA_2048',
101
+ KeyUsage: 'SIGN_VERIFY'
102
+ }))
103
+ console.log(createResponse)
104
+
105
+ // Update the new key with pending alias
106
+ await updateOrCreateAlias(ALIAS_PENDING, createResponse.KeyMetadata!.KeyId!)
107
+ }
108
+
109
+ async function updateOrCreateAlias (aliasName: string, keyId: string) {
110
+ try {
111
+ const updateResponse = await client.send(new UpdateAliasCommand({
112
+ AliasName: aliasName,
113
+ TargetKeyId: keyId
114
+ }))
115
+ console.log(updateResponse)
116
+ } catch (err) {
117
+ if (err instanceof NotFoundException) {
118
+ console.log('ALIAS not found, creating it.')
119
+ const createResponse = await client.send(new CreateAliasCommand({
120
+ AliasName: aliasName,
121
+ TargetKeyId: keyId
122
+ }))
123
+ console.log(createResponse)
124
+ } else {
125
+ throw (err)
126
+ }
127
+ }
128
+ }
129
+
130
+ async function getKeyIdForAlias (keyId: string) {
131
+ try {
132
+ const response = await client.send(new DescribeKeyCommand({ KeyId: keyId }))
133
+ console.log(response)
134
+ return response.KeyMetadata?.KeyId
135
+ } catch (err) {
136
+ if (err instanceof NotFoundException) {
137
+ return null
138
+ } else {
139
+ throw err
140
+ }
141
+ }
142
+ }
143
+
144
+ async function generateJWKS () {
145
+ const allKeys: object[] = []
146
+
147
+ for (const keyAlias of ALIASES) {
148
+ const keyId = await getKeyIdForAlias(keyAlias)
149
+ if (keyId) {
150
+ const jwkContents = await generateJWK(keyAlias)
151
+ await setKMSKeyTags(keyId, [{ TagKey: 'jwk_kid', TagValue: jwkContents.kid }])
152
+ allKeys.push(jwkContents)
153
+ }
154
+ }
155
+
156
+ const result = { keys: allKeys }
157
+
158
+ await uploadToS3('discovery/keys', result)
159
+ }
160
+
161
+ async function generateOpenIDConfiguration () {
162
+ const issuer = process.env.ISSUER
163
+
164
+ const openIdConfiguration = {
165
+ issuer,
166
+ jwks_uri: `${issuer}/discovery/keys`,
167
+ response_types_supported: [
168
+ 'token'
169
+ ],
170
+ id_token_signing_alg_values_supported: [
171
+ 'RS256'
172
+ ],
173
+ scopes_supported: [
174
+ 'openid'
175
+ ],
176
+ token_endpoint_auth_methods_supported: [
177
+ 'client_secret_basic'
178
+ ],
179
+ claims_supported: [
180
+ 'aud',
181
+ 'exp',
182
+ 'iat',
183
+ 'iss',
184
+ 'sub'
185
+ ]
186
+ }
187
+
188
+ await uploadToS3('.well-known/openid-configuration', openIdConfiguration)
189
+ }
190
+
191
+ async function generateJWK (keyAlias: string): Promise<any> {
192
+ // Get the public key from kms
193
+ const getPubKeyResponse = await client.send(new GetPublicKeyCommand({ KeyId: keyAlias }))
194
+
195
+ // generate HEX format from the response (DER)
196
+ const pubKeyHex = Buffer.from(getPubKeyResponse.PublicKey as Uint8Array).toString('hex')
197
+
198
+ // Get the pub key in pem format
199
+ const pubKeyPem = KJUR.asn1.ASN1Util.getPEMStringFromHex(pubKeyHex, 'PUBLIC KEY')
200
+
201
+ // return the JWK format for the key
202
+ const jwk: any = KEYUTIL.getJWK(pubKeyPem)
203
+
204
+ jwk.use = 'sig'
205
+ jwk.alg = 'RS256'
206
+
207
+ return jwk
208
+ }
209
+
210
+ async function setKMSKeyTags (keyAlias: string, tags: Tag[]) {
211
+ return await client.send(new TagResourceCommand({ KeyId: keyAlias, Tags: tags }))
212
+ }
213
+
214
+ async function uploadToS3 (key: string, contents: object) {
215
+ // get S3 bucket from environment variables
216
+ const s3Bucket = process.env.S3_BUCKET
217
+
218
+ const s3client = new S3Client({})
219
+
220
+ // Write jwk to s3 bucket
221
+ await s3client.send(new PutObjectCommand({
222
+ Bucket: s3Bucket,
223
+ Key: key, // File name you want to save as in S3
224
+ Body: Buffer.from(JSON.stringify(contents)),
225
+ ContentType: 'application/json',
226
+ ContentEncoding: ''
227
+ }))
228
+ }
@@ -0,0 +1,145 @@
1
+ // SPDX-FileCopyrightText: 2023 Alliander NV
2
+ //
3
+ // SPDX-License-Identifier: Apache-2.0
4
+
5
+ import { Context, APIGatewayProxyResult, APIGatewayEvent } from 'aws-lambda'
6
+ import { KMSClient, SignCommand, DescribeKeyCommand, ListResourceTagsCommand, Tag } from '@aws-sdk/client-kms'
7
+ import base64url from 'base64url'
8
+
9
+ import { Logger } from '@aws-lambda-powertools/logger'
10
+
11
+ const KEY_ALIAS_CURRENT = 'alias/sts/CURRENT'
12
+ const logger = new Logger()
13
+
14
+ export const handler = async (apiEvent: APIGatewayEvent, context: Context): Promise<APIGatewayProxyResult> => {
15
+ const identityArn = getARNFromIdentity(apiEvent.requestContext.identity?.userArn)
16
+ logger.debug(identityArn!)
17
+
18
+ if (identityArn === undefined || identityArn === null) {
19
+ logger.info(`Unable to resolve identityArn for userArn: ${apiEvent.requestContext.identity?.userArn}`)
20
+ return respond('Unable to resolve identity', 400)
21
+ }
22
+
23
+ let aud = process.env.DEFAULT_AUDIENCE
24
+
25
+ if (apiEvent.queryStringParameters && apiEvent.queryStringParameters.aud) {
26
+ aud = apiEvent.queryStringParameters.aud
27
+ }
28
+
29
+ const kms = new KMSClient({})
30
+
31
+ // Get KeyID which will be sent as kid in JWT token
32
+ const currentResponse = await kms.send(new DescribeKeyCommand({ KeyId: `${KEY_ALIAS_CURRENT}` }))
33
+ const currentKeyId = currentResponse.KeyMetadata?.KeyId
34
+
35
+ if (currentKeyId === undefined) {
36
+ return respond('KMS key could not be retrieved', 500)
37
+ }
38
+
39
+ // Retrieve Tags for KMS Key - the key is tagged with the `kid` from the JWK which is used in the JWT headers
40
+ const listResourceTagsResponse = await kms.send(new ListResourceTagsCommand({ KeyId: currentKeyId }))
41
+ const kid = getTagValueFromTags('jwk_kid', listResourceTagsResponse.Tags ?? [])
42
+
43
+ if (kid == null) {
44
+ return respond('KMS key is not correctly tagged', 500)
45
+ }
46
+
47
+ const iss = process.env.ISSUER
48
+
49
+ // JWT Token headers
50
+ const headers: any = {
51
+ alg: 'RS256',
52
+ typ: 'JWT',
53
+ kid: `${kid}`
54
+ }
55
+
56
+ // prepare token lifetime property values
57
+ const issuedAtDate = new Date()
58
+ const expirationDate = new Date(issuedAtDate)
59
+ const notBeforeDate = new Date(issuedAtDate)
60
+ expirationDate.setTime(expirationDate.getTime() + 60 * 60 * 1000) // valid for one hour
61
+ notBeforeDate.setTime(notBeforeDate.getTime() - 5 * 60 * 1000) // 5m before issuedAtDate
62
+
63
+ // JWT Token payload
64
+ const payload: any = {
65
+ sub: `${identityArn}`, // Set role arn as message for payload
66
+ aud,
67
+ iss,
68
+ iat: Math.floor(issuedAtDate.getTime() / 1000),
69
+ exp: Math.floor(expirationDate.getTime() / 1000),
70
+ nbf: Math.floor(notBeforeDate.getTime() / 1000)
71
+ }
72
+
73
+ // Prepare message to be signed by KMS
74
+ const tokenHeaders = base64url(JSON.stringify(headers))
75
+ const tokenPayload = base64url(JSON.stringify(payload))
76
+
77
+ // Sign message with KMS
78
+ const signResponse = await kms.send(new SignCommand({
79
+ KeyId: currentKeyId,
80
+ Message: Buffer.from(`${tokenHeaders}.${tokenPayload}`),
81
+ SigningAlgorithm: 'RSASSA_PKCS1_V1_5_SHA_256',
82
+ MessageType: 'RAW'
83
+ }))
84
+ logger.debug(JSON.stringify(signResponse))
85
+
86
+ const signature = Buffer
87
+ .from(signResponse.Signature as Uint8Array)
88
+ .toString('base64')
89
+ .replace(/\+/g, '-')
90
+ .replace(/\//g, '_')
91
+ .replace(/=/g, '')
92
+
93
+ const token = `${tokenHeaders}.${tokenPayload}.${signature}`
94
+ logger.debug(token)
95
+
96
+ return respond(JSON.stringify({
97
+ token
98
+ }))
99
+ }
100
+
101
+ function respond (message: string, statusCode: number = 200) {
102
+ return {
103
+ statusCode,
104
+ body: message
105
+ }
106
+ }
107
+
108
+ function getARNFromIdentity (identityArn: string | null) {
109
+ if (identityArn === undefined || identityArn === null) {
110
+ return null
111
+ }
112
+
113
+ // Regex for converting arn to base role
114
+ const captGroups = [
115
+ 'arn:aws:sts:',
116
+ '(?<regionName>[^:]*)', // group 1
117
+ ':',
118
+ '(?<accountId>\\d{12})', // group 2
119
+ ':assumed-role\\/',
120
+ '(?<roleName>[A-z0-9\\-]+?)', // group 3
121
+ '\\/',
122
+ '(?<user>[^:]*)', // group 4
123
+ '$'
124
+ ]
125
+
126
+ const regex = new RegExp(captGroups.join(''))
127
+ const { regionName, accountId, roleName } = regex.exec(identityArn)?.groups ?? {}
128
+
129
+ if (regionName === undefined || accountId === undefined || roleName === undefined) {
130
+ return null
131
+ }
132
+
133
+ // Build base role arn
134
+ return `arn:aws:iam:${regionName}:${accountId}:role/${roleName}`
135
+ }
136
+
137
+ function getTagValueFromTags (tagKey: string, tags: Tag[]) {
138
+ for (const tag of tags) {
139
+ if (tag.TagKey === tagKey) {
140
+ return tag.TagValue
141
+ }
142
+ }
143
+
144
+ return null
145
+ }