@juit/pgproxy-server 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/src/server.ts ADDED
@@ -0,0 +1,514 @@
1
+ import assert from 'node:assert'
2
+ import { randomUUID } from 'node:crypto'
3
+ import { createServer, STATUS_CODES } from 'node:http'
4
+ import { resolve } from 'node:path'
5
+ import { parse as parseQueryString } from 'node:querystring'
6
+
7
+ import { ConnectionPool } from '@juit/pgproxy-pool'
8
+ import { WebSocketServer } from 'ws'
9
+
10
+ import { verifyToken } from './token'
11
+
12
+ import type {
13
+ Connection,
14
+ ConnectionPoolOptions,
15
+ ConnectionPoolStats,
16
+ Logger,
17
+ } from '@juit/pgproxy-pool'
18
+ import type {
19
+ ServerOptions as HTTPOptions,
20
+ IncomingMessage as HTTPRequest,
21
+ ServerResponse as HTTPResponse,
22
+ Server as HTTPServer,
23
+ } from 'node:http'
24
+ import type { AddressInfo } from 'node:net'
25
+ import type { Duplex } from 'node:stream'
26
+ import type { Response } from './index'
27
+
28
+ /* ========================================================================== *
29
+ * EXPORTED TYPES *
30
+ * ========================================================================== */
31
+
32
+ export interface ServerOptions extends HTTPOptions {
33
+ /** The secret used to authenticate clients */
34
+ secret: string,
35
+ /** The path used to provide stats and a healthcheck via GET */
36
+ healthCheck?: string,
37
+ /** The address where this server will be bound to */
38
+ address?: string,
39
+ /** The port number where this server will be bound to */
40
+ port?: number,
41
+ /** The maximum length of the queue of pending connections */
42
+ backlog?: number,
43
+ /** Options for the connection pool backing this server */
44
+ pool?: ConnectionPoolOptions
45
+ }
46
+
47
+ export interface Server {
48
+ readonly url: URL,
49
+ readonly address: AddressInfo,
50
+ readonly stats: ConnectionPoolStats,
51
+
52
+ start(): Promise<Server>,
53
+ stop(): Promise<void>,
54
+ }
55
+
56
+ /* ========================================================================== *
57
+ * SERVER IMPLEMENTATION *
58
+ * ========================================================================== */
59
+
60
+ /** Internal type for an invalid payload */
61
+ interface PayloadError {
62
+ valid: false,
63
+ id: string,
64
+ error: string,
65
+ query?: never,
66
+ params?: never,
67
+ }
68
+
69
+ /** Internal type for a payload validated successfully */
70
+ interface PayloadQuery {
71
+ valid: true,
72
+ id: string,
73
+ error?: never,
74
+ query: string,
75
+ params?: (string | null)[],
76
+ }
77
+
78
+ /** A validated payload */
79
+ type Payload = Readonly<PayloadError | PayloadQuery>
80
+
81
+ class ServerImpl implements Server {
82
+ /* Keep those as private class members, they contain auth data... */
83
+ readonly #tokens: Record<string, number> = {}
84
+ readonly #pool: ConnectionPool
85
+ readonly #secret: string
86
+
87
+ private readonly _server: HTTPServer
88
+ private readonly _logger: Logger
89
+
90
+ private readonly _healthCheck: string | null
91
+ private readonly _backlog?: number
92
+ private readonly _address?: string
93
+ private readonly _port?: number
94
+
95
+ private _started: boolean = false
96
+ private _stopped: boolean = false
97
+
98
+ constructor(logger: Logger, options: ServerOptions) {
99
+ const { address, port, backlog, secret, healthCheck, pool, ...serverOptions } = options
100
+
101
+ this.#pool = new ConnectionPool(logger, pool)
102
+ this.#secret = secret
103
+
104
+ this._healthCheck = healthCheck ? resolve('/', healthCheck) : null
105
+ this._backlog = backlog
106
+ this._address = address
107
+ this._logger = logger
108
+ this._port = port
109
+
110
+ /* Create our HTTP and WebSocket servers */
111
+ this._server = createServer(serverOptions)
112
+ const wss = new WebSocketServer({ noServer: true })
113
+
114
+ /* Setup handlers */
115
+ this._server.on('request', (req, res) => this._requestHandler(req, res))
116
+ this._server.on('upgrade', (req, sock, head) => this._upgradeHandler(req, sock, head, wss))
117
+ this._server.on('close', () => {
118
+ wss.close((wssError) => {
119
+ /* coverage ignore if */
120
+ if (wssError) logger.error('Error closing WebSocket server:', wssError)
121
+ /* coverage ignore catch */
122
+ try {
123
+ this.#pool.stop()
124
+ } catch (poolError) {
125
+ logger.error('Error closing connection pool:', poolError)
126
+ } finally {
127
+ this._logger.info('DB proxy server stopped')
128
+ }
129
+ })
130
+ })
131
+ }
132
+
133
+ private _catchError(message: string): (error: any) => void {
134
+ /* coverage ignore next */
135
+ return (error) => this._logger.error(message, error)
136
+ }
137
+
138
+ /* ======================================================================== *
139
+ * PROPERTIES *
140
+ * ======================================================================== */
141
+
142
+ get address(): AddressInfo {
143
+ const address = this._server?.address() as AddressInfo
144
+ assert(address, 'Server not started')
145
+ return address
146
+ }
147
+
148
+ /* coverage ignore next */
149
+ get url(): URL {
150
+ const { address, family, port } = this.address
151
+ if (family === 'IPv6') return new URL(`http://[${address}]:${port}/`)
152
+ if (family === 'IPv4') return new URL(`http://${address}:${port}/`)
153
+ throw new Error(`Unsupported address family "${family}"`)
154
+ }
155
+
156
+ get stats(): ConnectionPoolStats {
157
+ return this.#pool.stats
158
+ }
159
+
160
+ /* ======================================================================== *
161
+ * LIFECYCLE METHODS *
162
+ * ======================================================================== */
163
+
164
+ async start(): Promise<Server> {
165
+ assert(! this._started, 'Server already started')
166
+ this._started = true
167
+
168
+ /* We're doing this! */
169
+ this._logger.debug('Starting server')
170
+
171
+ /* First of all, start the connection pool */
172
+ await this.#pool.start()
173
+
174
+ /* Start listening, and catch initial error */
175
+ await new Promise<void>((resolve, reject) => {
176
+ this._server.on('error', reject)
177
+ this._server.listen(this._port, this._address, this._backlog, () => {
178
+ this._server.off('error', reject)
179
+ resolve()
180
+ })
181
+ })
182
+
183
+ /* On normal errors try to stop the server and exit */
184
+ this._server.on('error', /* coverage ignore next */ (error) => {
185
+ this._logger.error('Server Error:', error)
186
+ this.stop()
187
+ .catch(this._catchError('Error stopping server'))
188
+ .finally(() => process.exit(1)) // always exit!
189
+ })
190
+
191
+ /* Start an timer that will periodically wipe tokens */
192
+ setInterval(() => {
193
+ const now = Date.now()
194
+ for (const [ token, expiry ] of Object.entries(this.#tokens)) {
195
+ /* coverage ignore if // all is hidden, hard to test */
196
+ if (expiry < now) delete this.#tokens[token]
197
+ }
198
+ }).unref() // let the process die...
199
+
200
+ /* Log some important info... */
201
+ this._logger.info(`DB proxy server started at ${this.url}`)
202
+ if (this._healthCheck) {
203
+ this._logger.info(`Unauthenticated health check available at "${this._healthCheck}"`)
204
+ }
205
+ this._logger.info('Connection pool options')
206
+ for (const [ key, value ] of Object.entries(this.#pool.configuration)) {
207
+ const name = key.replaceAll(/[A-Z]/g, (c) => ` ${c.toLowerCase()}`)
208
+ this._logger.info(`- ${name}: ${value}`)
209
+ }
210
+
211
+ /* We're done! */
212
+ return this
213
+ }
214
+
215
+ async stop(): Promise<void> {
216
+ assert(this._started, 'Server never started')
217
+ assert(! this._stopped, 'Server already stopped')
218
+ this._stopped = true
219
+
220
+ this._logger.info(`Stopping DB proxy server at "${this.url}"`)
221
+ await new Promise<void>( /* coverage ignore next */ (resolve, reject) => {
222
+ this._server.close((error) => error ? reject(error) : resolve())
223
+ })
224
+ }
225
+
226
+ /* ======================================================================== *
227
+ * REQUEST HANDLING *
228
+ * ======================================================================== */
229
+
230
+ private _sendResponse(
231
+ object: object,
232
+ statusCode: number,
233
+ request: HTTPRequest,
234
+ response: HTTPResponse,
235
+ ): void {
236
+ new Promise<void>((resolve, reject) => {
237
+ /* coverage ignore catch */
238
+ try {
239
+ const json = JSON.stringify(object)
240
+ const buffer = Buffer.from(json, 'utf-8')
241
+
242
+ response.statusCode = statusCode
243
+ response.setHeader('content-type', 'application/json')
244
+ response.setHeader('content-length', buffer.length)
245
+ response.write(buffer, (error) => {
246
+ if (error) /* coverage ignore next */ reject(error)
247
+ else resolve()
248
+ })
249
+ } catch (error) {
250
+ reject(error)
251
+ }
252
+ }).catch( /* coverage ignore next */ (error) => {
253
+ this._logger.error(`Error handling request "${request.url}"`, error)
254
+ response.statusCode = 500 // internal server error...
255
+ }).finally(() => response.end())
256
+ }
257
+
258
+ private _healthCheckHandler(request: HTTPRequest, response: HTTPResponse): void {
259
+ /* Check that the URL is the one specified in the options */
260
+ if (request.url !== this._healthCheck) {
261
+ response.statusCode = 404
262
+ return void response.end()
263
+ }
264
+
265
+ void Promise.resolve().then(async () => {
266
+ /* Clone the stats before running the latency tests */
267
+ const stats = { ...this.stats }
268
+
269
+ /* Calculate our latency to the database */
270
+ const start = process.hrtime.bigint()
271
+ const connection = await this.#pool.acquire()
272
+ const queryStart = process.hrtime.bigint()
273
+ try {
274
+ await connection.query('SELECT now()')
275
+ } finally {
276
+ this.#pool.release(connection)
277
+ }
278
+ const queryEnd = process.hrtime.bigint()
279
+
280
+ /* Convert latency and stringify response */
281
+ const latency = Math.floor(Number(queryEnd - start) / 10000) / 100
282
+ const connTime = Math.floor(Number(queryStart - start) / 10000) / 100
283
+ const queryTime = Math.floor(Number(queryEnd - queryStart) / 10000) / 100
284
+
285
+ return { ...stats, latency, connTime, queryTime }
286
+ }).then((data) => {
287
+ const { connTime, queryTime, ...stats } = data
288
+ this._sendResponse(stats, 200, request, response)
289
+ this._logger.info(`Handled Health check with latency of ${data.latency} ms (connection ${connTime} ms, query ${queryTime} ms)`)
290
+ })
291
+ }
292
+
293
+ private _requestHandler(request: HTTPRequest, response: HTTPResponse): void {
294
+ /* Health check on GET (if configured) */
295
+ if (request.method === 'GET') return this._healthCheckHandler(request, response)
296
+
297
+ /* Authorize requests to the pool */
298
+ const statusCode = this._validateAuth(request)
299
+ if (statusCode !== 200) {
300
+ response.statusCode = statusCode
301
+ return void response.end()
302
+ }
303
+
304
+ /* As a normal "request" we only accept POST */
305
+ if (request.method !== 'POST') {
306
+ response.statusCode = 405 // method not allowed
307
+ return void response.end()
308
+ }
309
+
310
+ /* The only content type is JSON */
311
+ if (request.headers['content-type'] !== 'application/json') {
312
+ response.statusCode = 415 // unsupported media type
313
+ return void response.end()
314
+ }
315
+
316
+ /* Run asynchronously for the rest of the processing */
317
+ const now = process.hrtime.bigint()
318
+ void Promise.resolve().then(async (): Promise<Response> => {
319
+ /* Extract the payload from the request */
320
+ const string = await this._readRequest(request)
321
+ const payload = this._validatePayload(string)
322
+
323
+ /* Check for validation errors */
324
+ if (! payload.valid) {
325
+ return { id: payload.id, statusCode: 400, error: payload.error }
326
+ }
327
+
328
+ /* Acquire the connection */
329
+ let connection: Connection
330
+ /* coverage ignore catch */
331
+ try {
332
+ const now = process.hrtime.bigint()
333
+ connection = await this.#pool.acquire()
334
+ const ms = Number(process.hrtime.bigint() - now) / 1000000
335
+ this._logger.debug(`Acquired connection ${connection.id} in ${ms} ms`)
336
+ } catch (error) {
337
+ this._logger.error('Error acquiring connection:', error)
338
+ return { id: payload.id, statusCode: 500, error: 'Error acquiring connection' }
339
+ }
340
+
341
+ /* Run the query */
342
+ try {
343
+ const result = await connection.query(payload.query, payload.params)
344
+ return { ...result, statusCode: 200, id: payload.id }
345
+ } catch (error: any) {
346
+ return { id: payload.id, statusCode: 400, error: error.message }
347
+ } finally {
348
+ this.#pool.release(connection)
349
+ }
350
+ }).then((data) => {
351
+ this._sendResponse(data, data.statusCode, request, response)
352
+ const ms = Math.floor(Number(process.hrtime.bigint() - now) / 10000) / 100
353
+ this._logger.info(`Handled "${data.command}" HTTP request in ${ms} ms`)
354
+ })
355
+ }
356
+
357
+ private _upgradeHandler(request: HTTPRequest, socket: Duplex, head: Buffer, wss: WebSocketServer): void {
358
+ /* Authenticate */
359
+ const statusCode = this._validateAuth(request)
360
+ if (statusCode !== 200) {
361
+ /* coverage ignore next */
362
+ const onSocketError = (error: Error): void => {
363
+ this._logger.error('Socket error', error)
364
+ socket.destroy()
365
+ }
366
+
367
+ socket.on('error', onSocketError)
368
+ socket.write(`HTTP/1.1 ${statusCode} ${STATUS_CODES[statusCode]}\r\n\r\n`)
369
+ socket.destroy()
370
+ socket.off('error', onSocketError)
371
+ return
372
+ }
373
+
374
+ /* Do the actual _upgrade_ of the socket */
375
+ wss.handleUpgrade(request, socket, head, (ws) => {
376
+ /* Eventually acquire a connection */
377
+ const promise = this.#pool.acquire()
378
+ .catch( /* coverage ignore next */ (error) => {
379
+ this._logger.error('Error acquiring connection for WebSocket:', error)
380
+ ws.close()
381
+ })
382
+
383
+ /* Eventually release */
384
+ const release = (): void => void promise
385
+ .then((connection) => connection && this.#pool.release(connection))
386
+ .catch(this._catchError('Error releasing connection for WebSocket:'))
387
+
388
+ /* Send data back over the websocket */
389
+ const send = (data: Response): void => {
390
+ const message = JSON.stringify(data)
391
+ ws.send(message, (error) => {
392
+ /* coverage ignore if */
393
+ if (error) {
394
+ this._logger.error('Error sending WebSocket response:', error)
395
+ ws.close()
396
+ }
397
+ })
398
+ }
399
+
400
+ /* On websocket error, release the connection */
401
+ ws.on('error', /* coverage ignore next */ (error) => {
402
+ this._logger.error('WebSocket error', error)
403
+ release()
404
+ })
405
+
406
+ /* On websocket close, release the connection */
407
+ ws.on('close', (code, reason) => {
408
+ const extra = reason.toString('utf-8')
409
+ extra ?
410
+ this._logger.info(`WebSocket closed (${code}):`, extra) :
411
+ this._logger.info(`WebSocket closed (${code}):`)
412
+ release()
413
+ })
414
+
415
+ /* On message, run a query and send results back */
416
+ ws.on('message', (data) => {
417
+ const now = process.hrtime.bigint()
418
+ const payload = this._validatePayload(data.toString('utf-8'))
419
+ if (! payload.valid) {
420
+ send({ id: payload.id, statusCode: 400, error: payload.error })
421
+ } else {
422
+ promise.then(async (connection) => {
423
+ /* coverage ignore if // If we have no connection, the promise
424
+ * catcher has also already closed the websocket, just ignore */
425
+ if (! connection) return
426
+ try {
427
+ const result = await connection.query(payload.query, payload.params)
428
+
429
+ const ms = Math.floor(Number(process.hrtime.bigint() - now) / 10000) / 100
430
+ this._logger.info(`Handled "${result.command}" WebSocket request in ${ms} ms`)
431
+ return send({ ...result, statusCode: 200, id: payload.id })
432
+ } catch (error: any) {
433
+ return send({ id: payload.id, statusCode: 400, error: error.message })
434
+ }
435
+ }).catch(this._catchError('Error querying in websocket'))
436
+ }
437
+ })
438
+ })
439
+ }
440
+
441
+ /* ======================================================================== *
442
+ * INTERNALS *
443
+ * ======================================================================== */
444
+
445
+ /** Read the body of an HTTP request fully */
446
+ private _readRequest(stream: HTTPRequest): Promise<string> {
447
+ return new Promise<Buffer>((resolve, reject) => {
448
+ const buffers: Buffer[] = []
449
+
450
+ stream.on('error', /* coverage ignore next */ (error) => reject(error))
451
+ stream.on('data', (buffer) => buffers.push(buffer))
452
+ stream.on('end', () => resolve(Buffer.concat(buffers)))
453
+
454
+ /* coverage ignore if */
455
+ if (stream.isPaused()) stream.resume()
456
+ }).then((buffer) => buffer.toString('utf-8'))
457
+ }
458
+
459
+ /** Parse a payload string as JSON and validate it */
460
+ private _validatePayload(string: string): Payload {
461
+ try {
462
+ const payload = JSON.parse(string || '{}')
463
+ const id = payload?.id ? `${payload.id}` : randomUUID()
464
+
465
+ if (! payload?.query) {
466
+ return { id, valid: false, error: 'Invalid payload (or query missing)' }
467
+ }
468
+ if (typeof payload.query !== 'string') {
469
+ return { id, valid: false, error: 'Query is not a string' }
470
+ }
471
+ if (payload.params && (! Array.isArray(payload.params))) {
472
+ return { id, valid: false, error: 'Parameters are not an array' }
473
+ }
474
+
475
+ return { id, valid: true, query: payload.query, params: payload.params }
476
+ } catch (error) {
477
+ return { id: randomUUID(), valid: false, error: 'Error parsing JSON' }
478
+ }
479
+ }
480
+
481
+ /** Validate a request (it must have an "auth" query parameter) */
482
+ private _validateAuth(request: HTTPRequest): 200 | 401 | 404 | 403 {
483
+ /* Parse the query string from the request path and extract "auth" */
484
+ const auth = parseQueryString((request.url!).split('?')[1] || '').auth
485
+
486
+ /* Make sure that we have a proper authorization (defined, not an array) */
487
+ if (typeof auth !== 'string') return 401 // No "auth", 401 (Unauthorized)
488
+
489
+ try {
490
+ /* Validate the auth against our stored secret */
491
+ const token = verifyToken(auth, this.#secret)
492
+
493
+ /* Token was already seen */
494
+ if (token in this.#tokens) {
495
+ this._logger.error('Attempted to reuse an existing token')
496
+ return 403
497
+ }
498
+
499
+ this.#tokens[token] = Date.now() + 60_000 // expiry is 10 sec, but use 60
500
+ return 200
501
+ } catch (error) {
502
+ this._logger.error(error)
503
+ return 403
504
+ }
505
+ }
506
+ }
507
+
508
+ /* ========================================================================== *
509
+ * EXPORT SERVER IMPLEMENTATION *
510
+ * ========================================================================== */
511
+
512
+ export const Server: {
513
+ new (logger: Logger, options: ServerOptions): Server
514
+ } = ServerImpl
package/src/token.ts ADDED
@@ -0,0 +1,53 @@
1
+ import assert from 'node:assert'
2
+ import { createHmac } from 'node:crypto'
3
+
4
+ /* ========================================================================== *
5
+ * AUTHENTICATION TOKEN *
6
+ * ========================================================================== *
7
+ * *
8
+ * Our authentication token is defined as follows: *
9
+ * *
10
+ * +-------------------+----------------+---------------------------+ *
11
+ * | bits | bytes | field | *
12
+ * +-------------------+----------------+---------------------------+ *
13
+ * | 0 ... 63 (64) | 0 ... 7 (8) | timestamp (little endian) | *
14
+ * | 64 ... 127 (64) | 8 ... 15 (8) | random bytes | *
15
+ * | 128 ... 392 (256) | 16 ... 47 (32) | HMAC-SHA-256 signature | *
16
+ * +-------------------+----------------+---------------------------+ *
17
+ * *
18
+ * The signature is calculated using the HMAC-SHA-256 algorithm, with the *
19
+ * UTF-8 encoding of our `secret` as the _key_ and the first 16 bytes of the *
20
+ * token itself as the message. *
21
+ * *
22
+ * The total length of 48 bytes has been chosen so that the BASE-64 encoding *
23
+ * of the authentication token is precisely 64 characters and doesn't equire *
24
+ * any padding. *
25
+ * *
26
+ * Furthermore, authentication tokens must be validated against the current *
27
+ * timestamp, and this implementation requires any acceptable token to be *
28
+ * within +/- 10 seconds of _now_. *
29
+ * ========================================================================== */
30
+
31
+ export function verifyToken(
32
+ token: string,
33
+ secret: string,
34
+ ): string {
35
+ assert.strictEqual(token.length, 64, `Invalid encoded token length (${token.length} != 64)`)
36
+
37
+ const buffer = Buffer.from(token, 'base64url')
38
+ assert.strictEqual(buffer.length, 48, `Invalid decoded token length (${buffer.length} != 48)`)
39
+
40
+ /* First of all check the time delta */
41
+ const timeDelta = buffer.readBigInt64LE(0) - BigInt(Date.now())
42
+ const absoluteDelta = timeDelta < 0n ? -timeDelta : timeDelta
43
+ assert(absoluteDelta < 10_000n, `Timestamp delta out of range (${timeDelta} ms)`)
44
+
45
+ /* Compute the HMAC-SHA-256 signature of the message using our secret */
46
+ const signature = createHmac('sha256', Buffer.from(secret, 'utf8'))
47
+ .update(buffer.subarray(0, 16))
48
+ .digest()
49
+
50
+ /* Compare the signatures (computed vs received) */
51
+ assert(signature.compare(buffer, 16) === 0, 'Token signature mismatch')
52
+ return buffer.toString('hex', 0, 16).toLowerCase()
53
+ }