orez 0.0.38 → 0.0.40

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/src/pg-proxy.ts CHANGED
@@ -1,8 +1,9 @@
1
1
  /**
2
2
  * tcp proxy that makes pglite speak postgresql wire protocol.
3
3
  *
4
- * uses pg-gateway to handle protocol lifecycle for regular connections,
5
- * and directly handles the raw socket for replication connections.
4
+ * handles the postgresql wire protocol directly using raw tcp sockets,
5
+ * avoiding pg-gateway's Duplex.toWeb() which deadlocks under concurrent
6
+ * connections with large responses.
6
7
  *
7
8
  * regular connections: forwarded to pglite via execProtocolRaw()
8
9
  * replication connections: intercepted, replication protocol faked
@@ -14,8 +15,6 @@
14
15
 
15
16
  import { createServer, type Server, type Socket } from 'node:net'
16
17
 
17
- import { fromNodeSocket } from 'pg-gateway/node'
18
-
19
18
  import { log } from './log.js'
20
19
  import { Mutex } from './mutex.js'
21
20
  import { handleReplicationQuery, handleStartReplication } from './replication/handler.js'
@@ -67,6 +66,7 @@ const QUERY_REWRITES: Array<{ match: RegExp; replace: string }> = [
67
66
  // parameter status messages sent during connection handshake
68
67
  // pg_restore and other tools read these to determine server capabilities
69
68
  const SERVER_PARAMS: [string, string][] = [
69
+ ['server_version', '16.4'],
70
70
  ['server_encoding', 'UTF8'],
71
71
  ['client_encoding', 'UTF8'],
72
72
  ['DateStyle', 'ISO, MDY'],
@@ -76,7 +76,12 @@ const SERVER_PARAMS: [string, string][] = [
76
76
  ['IntervalStyle', 'postgres'],
77
77
  ]
78
78
 
79
- // build a ParameterStatus wire protocol message (type 'S', 0x53)
79
+ // queries to intercept and return no-op success (synthetic SET response)
80
+ // pglite rejects SET TRANSACTION if any query (e.g. SET search_path) ran first
81
+ const NOOP_QUERY_PATTERNS: RegExp[] = [/^\s*SET\s+TRANSACTION\b/i, /^\s*SET\s+SESSION\b/i]
82
+
83
+ // ── wire protocol helpers ──
84
+
80
85
  function buildParameterStatus(name: string, value: string): Uint8Array {
81
86
  const encoder = new TextEncoder()
82
87
  const nameBytes = encoder.encode(name)
@@ -95,13 +100,64 @@ function buildParameterStatus(name: string, value: string): Uint8Array {
95
100
  return buf
96
101
  }
97
102
 
98
- // queries to intercept and return no-op success (synthetic SET response)
99
- // pglite rejects SET TRANSACTION if any query (e.g. SET search_path) ran first
100
- const NOOP_QUERY_PATTERNS: RegExp[] = [/^\s*SET\s+TRANSACTION\b/i, /^\s*SET\s+SESSION\b/i]
103
+ function buildAuthOk(): Uint8Array {
104
+ const buf = new Uint8Array(9)
105
+ buf[0] = 0x52 // 'R' AuthenticationOk
106
+ new DataView(buf.buffer).setInt32(1, 8)
107
+ new DataView(buf.buffer).setInt32(5, 0) // auth ok
108
+ return buf
109
+ }
110
+
111
+ function buildAuthCleartextPassword(): Uint8Array {
112
+ const buf = new Uint8Array(9)
113
+ buf[0] = 0x52 // 'R'
114
+ new DataView(buf.buffer).setInt32(1, 8)
115
+ new DataView(buf.buffer).setInt32(5, 3) // cleartext password
116
+ return buf
117
+ }
118
+
119
+ function buildBackendKeyData(): Uint8Array {
120
+ const buf = new Uint8Array(13)
121
+ buf[0] = 0x4b // 'K'
122
+ new DataView(buf.buffer).setInt32(1, 12)
123
+ new DataView(buf.buffer).setInt32(5, process.pid)
124
+ new DataView(buf.buffer).setInt32(9, 0)
125
+ return buf
126
+ }
127
+
128
+ function buildReadyForQuery(status: number = 0x49): Uint8Array {
129
+ const buf = new Uint8Array(6)
130
+ buf[0] = 0x5a // 'Z'
131
+ new DataView(buf.buffer).setInt32(1, 5)
132
+ buf[5] = status // 'I' = idle
133
+ return buf
134
+ }
135
+
136
+ function buildErrorResponse(message: string): Uint8Array {
137
+ const encoder = new TextEncoder()
138
+ const msgBytes = encoder.encode(message)
139
+ // S(ERROR) + C(code) + M(message) + terminator
140
+ const sField = new Uint8Array([0x53, ...encoder.encode('ERROR'), 0])
141
+ const cField = new Uint8Array([0x43, ...encoder.encode('08006'), 0])
142
+ const mField = new Uint8Array([0x4d, ...msgBytes, 0])
143
+ const terminator = new Uint8Array([0])
144
+ const bodyLen = 4 + sField.length + cField.length + mField.length + terminator.length
145
+ const buf = new Uint8Array(1 + bodyLen)
146
+ buf[0] = 0x45 // 'E'
147
+ new DataView(buf.buffer).setInt32(1, bodyLen)
148
+ let pos = 5
149
+ buf.set(sField, pos)
150
+ pos += sField.length
151
+ buf.set(cField, pos)
152
+ pos += cField.length
153
+ buf.set(mField, pos)
154
+ pos += mField.length
155
+ buf.set(terminator, pos)
156
+ return buf
157
+ }
158
+
159
+ // ── query helpers ──
101
160
 
102
- /**
103
- * extract query text from a Parse message (0x50).
104
- */
105
161
  function extractParseQuery(data: Uint8Array): string | null {
106
162
  if (data[0] !== 0x50) return null
107
163
  let offset = 5
@@ -112,9 +168,6 @@ function extractParseQuery(data: Uint8Array): string | null {
112
168
  return new TextDecoder().decode(data.subarray(queryStart, offset))
113
169
  }
114
170
 
115
- /**
116
- * rebuild a Parse message with a modified query string.
117
- */
118
171
  function rebuildParseMessage(data: Uint8Array, newQuery: string): Uint8Array {
119
172
  let offset = 5
120
173
  while (offset < data.length && data[offset] !== 0) offset++
@@ -144,9 +197,6 @@ function rebuildParseMessage(data: Uint8Array, newQuery: string): Uint8Array {
144
197
  return result
145
198
  }
146
199
 
147
- /**
148
- * rebuild a Simple Query message with a modified query string.
149
- */
150
200
  function rebuildSimpleQuery(newQuery: string): Uint8Array {
151
201
  const encoder = new TextEncoder()
152
202
  const queryBytes = encoder.encode(newQuery + '\0')
@@ -157,9 +207,6 @@ function rebuildSimpleQuery(newQuery: string): Uint8Array {
157
207
  return buf
158
208
  }
159
209
 
160
- /**
161
- * intercept and rewrite query messages to make pglite look like real postgres.
162
- */
163
210
  function interceptQuery(data: Uint8Array): Uint8Array {
164
211
  const msgType = data[0]
165
212
 
@@ -203,9 +250,6 @@ function interceptQuery(data: Uint8Array): Uint8Array {
203
250
  return data
204
251
  }
205
252
 
206
- /**
207
- * check if a query should be intercepted as a no-op.
208
- */
209
253
  function isNoopQuery(data: Uint8Array): boolean {
210
254
  let query: string | null = null
211
255
  if (data[0] === 0x51) {
@@ -219,9 +263,6 @@ function isNoopQuery(data: Uint8Array): boolean {
219
263
  return NOOP_QUERY_PATTERNS.some((p) => p.test(query!))
220
264
  }
221
265
 
222
- /**
223
- * build a synthetic "SET" command complete response.
224
- */
225
266
  function buildSetCompleteResponse(): Uint8Array {
226
267
  const encoder = new TextEncoder()
227
268
  const tag = encoder.encode('SET\0')
@@ -241,9 +282,6 @@ function buildSetCompleteResponse(): Uint8Array {
241
282
  return result
242
283
  }
243
284
 
244
- /**
245
- * build a synthetic ParseComplete response for extended protocol no-ops.
246
- */
247
285
  function buildParseCompleteResponse(): Uint8Array {
248
286
  const pc = new Uint8Array(5)
249
287
  pc[0] = 0x31 // ParseComplete
@@ -251,9 +289,6 @@ function buildParseCompleteResponse(): Uint8Array {
251
289
  return pc
252
290
  }
253
291
 
254
- /**
255
- * strip ReadyForQuery messages from a response buffer.
256
- */
257
292
  function stripReadyForQuery(data: Uint8Array): Uint8Array {
258
293
  if (data.length === 0) return data
259
294
 
@@ -285,184 +320,281 @@ function stripReadyForQuery(data: Uint8Array): Uint8Array {
285
320
  return result
286
321
  }
287
322
 
288
- export async function startPgProxy(
289
- dbInput: PGlite | PGliteInstances,
323
+ // ── socket write with backpressure ──
324
+
325
+ function socketWrite(socket: Socket, data: Uint8Array): Promise<void> {
326
+ if (data.length === 0 || socket.destroyed) return Promise.resolve()
327
+ return new Promise<void>((resolve, reject) => {
328
+ const ok = socket.write(data as any, (err) => (err ? reject(err) : resolve()))
329
+ // if buffer is full, the callback still fires when flushed
330
+ if (!ok) void 0
331
+ })
332
+ }
333
+
334
+ // ── startup handshake ──
335
+
336
+ // parse startup message from raw bytes.
337
+ // handles SSLRequest (8 bytes, code 80877103) and StartupMessage.
338
+ function parseStartupMessage(buf: Buffer): {
339
+ isSSL: boolean
340
+ params: Record<string, string>
341
+ } {
342
+ const dv = new DataView(buf.buffer, buf.byteOffset, buf.byteLength)
343
+ const len = dv.getInt32(0)
344
+ const code = dv.getInt32(4)
345
+
346
+ // SSL request: length=8, code=80877103
347
+ if (len === 8 && code === 80877103) {
348
+ return { isSSL: true, params: {} }
349
+ }
350
+
351
+ // startup message: length, protocol(196608=3.0), then key=value pairs
352
+ const params: Record<string, string> = {}
353
+ let offset = 8
354
+ while (offset < len) {
355
+ const keyStart = offset
356
+ while (offset < buf.length && buf[offset] !== 0) offset++
357
+ const key = buf.subarray(keyStart, offset).toString()
358
+ offset++
359
+ if (!key) break // double-null = end of params
360
+ const valStart = offset
361
+ while (offset < buf.length && buf[offset] !== 0) offset++
362
+ params[key] = buf.subarray(valStart, offset).toString()
363
+ offset++
364
+ }
365
+
366
+ return { isSSL: false, params }
367
+ }
368
+
369
+ // read exactly `n` bytes from socket
370
+ function readBytes(socket: Socket, n: number): Promise<Buffer> {
371
+ return new Promise((resolve, reject) => {
372
+ let collected = Buffer.alloc(0)
373
+
374
+ const onData = (chunk: Buffer) => {
375
+ collected = Buffer.concat([collected, chunk])
376
+ if (collected.length >= n) {
377
+ socket.removeListener('data', onData)
378
+ socket.removeListener('error', onError)
379
+ socket.removeListener('close', onClose)
380
+ socket.pause()
381
+ resolve(collected)
382
+ }
383
+ }
384
+ const onError = (err: Error) => {
385
+ socket.removeListener('data', onData)
386
+ socket.removeListener('close', onClose)
387
+ reject(err)
388
+ }
389
+ const onClose = () => {
390
+ socket.removeListener('data', onData)
391
+ socket.removeListener('error', onError)
392
+ reject(new Error('socket closed'))
393
+ }
394
+
395
+ socket.on('data', onData)
396
+ socket.on('error', onError)
397
+ socket.on('close', onClose)
398
+ socket.resume()
399
+ })
400
+ }
401
+
402
+ // perform the startup handshake (SSL negotiation, auth, parameter status)
403
+ async function performHandshake(
404
+ socket: Socket,
290
405
  config: ZeroLiteConfig
291
- ): Promise<Server> {
292
- // normalize input: single PGlite instance = use it for all databases (backwards compat for tests)
293
- const instances: PGliteInstances =
294
- 'postgres' in dbInput
295
- ? (dbInput as PGliteInstances)
296
- : { postgres: dbInput as PGlite, cvr: dbInput as PGlite, cdb: dbInput as PGlite }
406
+ ): Promise<{ params: Record<string, string> }> {
407
+ // read initial message length (first 4 bytes)
408
+ let buf = await readBytes(socket, 8)
409
+
410
+ // check for SSL request
411
+ const startup = parseStartupMessage(buf)
412
+ if (startup.isSSL) {
413
+ // reject SSL, client will reconnect without it
414
+ socket.write(Buffer.from('N'))
415
+ buf = await readBytes(socket, 8)
416
+ }
297
417
 
298
- // per-instance mutexes for serializing pglite access
299
- const mutexes = {
300
- postgres: new Mutex(),
301
- cvr: new Mutex(),
302
- cdb: new Mutex(),
418
+ // now we have startup message header - read the rest if needed
419
+ const dv = new DataView(buf.buffer, buf.byteOffset, buf.byteLength)
420
+ const msgLen = dv.getInt32(0)
421
+ if (buf.length < msgLen) {
422
+ const rest = await readBytes(socket, msgLen - buf.length)
423
+ buf = Buffer.concat([buf, rest])
303
424
  }
304
425
 
305
- // helper to get instance + mutex for a database name
306
- function getDbContext(dbName: string): { db: PGlite; mutex: Mutex } {
307
- if (dbName === 'zero_cvr') return { db: instances.cvr, mutex: mutexes.cvr }
308
- if (dbName === 'zero_cdb') return { db: instances.cdb, mutex: mutexes.cdb }
309
- return { db: instances.postgres, mutex: mutexes.postgres }
426
+ const { params } = parseStartupMessage(buf)
427
+
428
+ // request cleartext password
429
+ socket.write(buildAuthCleartextPassword())
430
+
431
+ // read password message: type(1) + len(4) + password + null
432
+ const pwBuf = await readBytes(socket, 5)
433
+ const pwDv = new DataView(pwBuf.buffer, pwBuf.byteOffset, pwBuf.byteLength)
434
+ const pwLen = pwDv.getInt32(1)
435
+ let fullPwBuf = pwBuf
436
+ if (fullPwBuf.length < 1 + pwLen) {
437
+ const rest = await readBytes(socket, 1 + pwLen - fullPwBuf.length)
438
+ fullPwBuf = Buffer.concat([fullPwBuf, rest])
439
+ }
440
+ const password = fullPwBuf.subarray(5, 1 + pwLen - 1).toString()
441
+
442
+ // validate credentials
443
+ if (params.user !== config.pgUser || password !== config.pgPassword) {
444
+ socket.write(buildErrorResponse('authentication failed'))
445
+ socket.write(buildReadyForQuery())
446
+ socket.destroy()
447
+ throw new Error('auth failed')
310
448
  }
311
449
 
312
- const server = createServer(async (socket: Socket) => {
313
- // prevent idle timeouts from killing connections
314
- socket.setKeepAlive(true, 30000)
315
- socket.setTimeout(0)
450
+ // auth ok
451
+ socket.write(buildAuthOk())
316
452
 
317
- let dbName = 'postgres'
318
- let isReplicationConnection = false
453
+ // send parameter status messages
454
+ for (const [name, value] of SERVER_PARAMS) {
455
+ socket.write(buildParameterStatus(name, value))
456
+ }
319
457
 
320
- // clean up pglite transaction state when a client disconnects
321
- socket.on('close', async () => {
322
- const { db, mutex } = getDbContext(dbName)
323
- await mutex.acquire()
324
- try {
325
- await db.exec('ROLLBACK')
326
- } catch {
327
- // no transaction to rollback
328
- } finally {
329
- mutex.release()
330
- }
331
- })
458
+ // backend key data
459
+ socket.write(buildBackendKeyData())
332
460
 
333
- try {
334
- const connection = await fromNodeSocket(socket, {
335
- serverVersion: '16.4',
336
- auth: {
337
- method: 'password',
338
- getClearTextPassword() {
339
- return config.pgPassword
340
- },
341
- validateCredentials(credentials: {
342
- username: string
343
- password: string
344
- clearTextPassword: string
345
- }) {
346
- return (
347
- credentials.password === credentials.clearTextPassword &&
348
- credentials.username === config.pgUser
349
- )
350
- },
351
- },
352
-
353
- // send ParameterStatus messages that standard postgres tools expect
354
- // pg-gateway sends server_version via the serverVersion option above,
355
- // but tools like pg_restore also need encoding, datestyle, etc.
356
- onAuthenticated() {
357
- for (const [name, value] of SERVER_PARAMS) {
358
- socket.write(buildParameterStatus(name, value))
359
- }
360
- },
461
+ // ready for query
462
+ socket.write(buildReadyForQuery())
361
463
 
362
- async onStartup(state) {
363
- const params = state.clientParams
364
- if (params?.replication === 'database') {
365
- isReplicationConnection = true
366
- }
367
- dbName = params?.database || 'postgres'
368
- log.debug.proxy(
369
- `connection: db=${dbName} user=${params?.user} replication=${params?.replication || 'none'}`
370
- )
371
- const { db } = getDbContext(dbName)
372
- await db.waitReady
373
- },
464
+ return { params }
465
+ }
374
466
 
375
- async onMessage(data, state) {
376
- if (!state.isAuthenticated) return
467
+ // ── message loop ──
377
468
 
378
- // handle replication connections (always go to postgres instance)
379
- if (isReplicationConnection) {
380
- if (data[0] === 0x51) {
381
- const view = new DataView(data.buffer, data.byteOffset, data.byteLength)
382
- const len = view.getInt32(1)
383
- const query = new TextDecoder()
384
- .decode(data.subarray(5, 1 + len - 1))
385
- .replace(/\0$/, '')
386
- log.debug.proxy(`repl query: ${query.slice(0, 200)}`)
387
- }
388
- return handleReplicationMessage(
389
- data,
390
- socket,
391
- instances.postgres,
392
- mutexes.postgres,
393
- connection
394
- )
395
- }
469
+ // process messages from a connected, authenticated client.
470
+ // uses callback-based 'data' events instead of async iterators
471
+ // for reliable behavior across runtimes (node.js, bun).
472
+ function messageLoop(
473
+ socket: Socket,
474
+ db: PGlite,
475
+ mutex: Mutex,
476
+ isReplicationConnection: boolean,
477
+ replicationDb: PGlite,
478
+ replicationMutex: Mutex
479
+ ): Promise<void> {
480
+ return new Promise<void>((resolve, reject) => {
481
+ let buffer: Buffer = Buffer.alloc(0)
482
+ let processing = false
483
+
484
+ async function processBuffer() {
485
+ if (processing) return
486
+ processing = true
487
+ socket.pause()
396
488
 
397
- // check for no-op queries
398
- if (isNoopQuery(data)) {
399
- if (data[0] === 0x51) {
400
- return buildSetCompleteResponse()
401
- } else if (data[0] === 0x50) {
402
- return buildParseCompleteResponse()
403
- }
404
- }
489
+ try {
490
+ while (buffer.length >= 5) {
491
+ const msgType = buffer[0]
492
+ const dv = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength)
493
+ const msgLen = dv.getInt32(1)
494
+ const totalLen = 1 + msgLen
495
+
496
+ if (buffer.length < totalLen) break // need more data
405
497
 
406
- // intercept and rewrite queries
407
- data = interceptQuery(data)
408
-
409
- // message-level locking on the connection's pglite instance
410
- const { db, mutex } = getDbContext(dbName)
411
- await mutex.acquire()
412
-
413
- let result: Uint8Array
414
- try {
415
- result = await db.execProtocolRaw(data, {
416
- throwOnError: false,
417
- })
418
- } catch (err) {
419
- mutex.release()
420
- throw err
498
+ // copy message out before modifying buffer
499
+ const message = new Uint8Array(
500
+ buffer.buffer.slice(buffer.byteOffset, buffer.byteOffset + totalLen)
501
+ )
502
+ buffer = buffer.subarray(totalLen)
503
+
504
+ // handle Terminate message
505
+ if (msgType === 0x58) {
506
+ resolve()
507
+ return
421
508
  }
422
509
 
423
- // strip ReadyForQuery from non-Sync/non-SimpleQuery responses
424
- if (data[0] !== 0x53 && data[0] !== 0x51) {
425
- result = stripReadyForQuery(result)
510
+ // handle replication connections
511
+ if (isReplicationConnection) {
512
+ await handleReplicationMsg(message, socket, replicationDb, replicationMutex)
513
+ continue
426
514
  }
427
515
 
428
- mutex.release()
429
- return result
430
- },
431
- })
432
- } catch (err) {
433
- if (!socket.destroyed) {
434
- socket.destroy()
516
+ // handle regular messages
517
+ await handleRegularMessage(message, socket, db, mutex)
518
+ }
519
+ } catch (err) {
520
+ reject(err)
521
+ return
435
522
  }
523
+
524
+ processing = false
525
+ socket.resume()
436
526
  }
437
- })
438
527
 
439
- return new Promise((resolve, reject) => {
440
- server.listen(config.pgPort, '127.0.0.1', () => {
441
- log.debug.proxy(`listening on port ${config.pgPort}`)
442
- resolve(server)
528
+ socket.on('data', (chunk: Buffer) => {
529
+ buffer = buffer.length > 0 ? Buffer.concat([buffer, chunk]) : chunk
530
+ processBuffer()
443
531
  })
444
- server.on('error', reject)
532
+
533
+ socket.on('end', () => resolve())
534
+ socket.on('error', (err) => reject(err))
535
+ socket.on('close', () => resolve())
536
+
537
+ socket.resume()
445
538
  })
446
539
  }
447
540
 
448
- async function handleReplicationMessage(
541
+ async function handleRegularMessage(
449
542
  data: Uint8Array,
450
543
  socket: Socket,
451
544
  db: PGlite,
452
- mutex: Mutex,
453
- connection: Awaited<ReturnType<typeof fromNodeSocket>>
454
- ): Promise<Uint8Array | undefined> {
455
- if (data[0] !== 0x51) return undefined
545
+ mutex: Mutex
546
+ ): Promise<void> {
547
+ // check for no-op queries
548
+ if (isNoopQuery(data)) {
549
+ if (data[0] === 0x51) {
550
+ await socketWrite(socket, buildSetCompleteResponse())
551
+ return
552
+ } else if (data[0] === 0x50) {
553
+ await socketWrite(socket, buildParseCompleteResponse())
554
+ return
555
+ }
556
+ }
557
+
558
+ // intercept and rewrite queries
559
+ data = interceptQuery(data)
560
+
561
+ // serialize pglite access
562
+ await mutex.acquire()
563
+ let result: Uint8Array
564
+ try {
565
+ result = await db.execProtocolRaw(data, { throwOnError: false })
566
+ } catch (err) {
567
+ mutex.release()
568
+ throw err
569
+ }
570
+
571
+ // strip ReadyForQuery from non-Sync/non-SimpleQuery responses
572
+ if (data[0] !== 0x53 && data[0] !== 0x51) {
573
+ result = stripReadyForQuery(result)
574
+ }
575
+
576
+ mutex.release()
577
+
578
+ // write response directly to socket
579
+ await socketWrite(socket, result)
580
+ }
581
+
582
+ async function handleReplicationMsg(
583
+ data: Uint8Array,
584
+ socket: Socket,
585
+ db: PGlite,
586
+ mutex: Mutex
587
+ ): Promise<void> {
588
+ if (data[0] !== 0x51) return
456
589
 
457
590
  const view = new DataView(data.buffer, data.byteOffset, data.byteLength)
458
591
  const len = view.getInt32(1)
459
592
  const query = new TextDecoder().decode(data.subarray(5, 1 + len - 1)).replace(/\0$/, '')
460
593
  const upper = query.trim().toUpperCase()
461
594
 
462
- // check if this is a START_REPLICATION command
463
- if (upper.startsWith('START_REPLICATION')) {
464
- await connection.detach()
595
+ log.debug.proxy(`repl query: ${query.slice(0, 200)}`)
465
596
 
597
+ if (upper.startsWith('START_REPLICATION')) {
466
598
  const writer = {
467
599
  write(chunk: Uint8Array) {
468
600
  if (!socket.destroyed) {
@@ -473,31 +605,116 @@ async function handleReplicationMessage(
473
605
 
474
606
  // drain incoming standby status updates
475
607
  socket.on('data', (_chunk: Buffer) => {})
608
+ socket.on('close', () => socket.destroy())
476
609
 
477
- socket.on('close', () => {
478
- socket.destroy()
479
- })
480
-
481
- handleStartReplication(query, writer, db, mutex).catch((err) => {
610
+ // this runs indefinitely until the socket closes
611
+ await handleStartReplication(query, writer, db, mutex).catch((err) => {
482
612
  log.debug.proxy(`replication stream ended: ${err}`)
483
613
  })
484
- return undefined
614
+ return
485
615
  }
486
616
 
487
- // handle replication queries + fallthrough to pglite, all under mutex
617
+ // handle replication queries + fallthrough to pglite
488
618
  await mutex.acquire()
489
619
  try {
490
620
  const response = await handleReplicationQuery(query, db)
491
- if (response) return response
621
+ if (response) {
622
+ await socketWrite(socket, response)
623
+ return
624
+ }
492
625
 
493
626
  // apply query rewrites before forwarding
494
627
  data = interceptQuery(data)
495
628
 
496
- // fall through to pglite for unrecognized queries
497
- return await db.execProtocolRaw(data, {
498
- throwOnError: false,
499
- })
629
+ const result = await db.execProtocolRaw(data, { throwOnError: false })
630
+ await socketWrite(socket, result)
500
631
  } finally {
501
632
  mutex.release()
502
633
  }
503
634
  }
635
+
636
+ // ── main entry point ──
637
+
638
+ export async function startPgProxy(
639
+ dbInput: PGlite | PGliteInstances,
640
+ config: ZeroLiteConfig
641
+ ): Promise<Server> {
642
+ // normalize input: single PGlite instance = use it for all databases (backwards compat for tests)
643
+ const instances: PGliteInstances =
644
+ 'postgres' in dbInput
645
+ ? (dbInput as PGliteInstances)
646
+ : { postgres: dbInput as PGlite, cvr: dbInput as PGlite, cdb: dbInput as PGlite }
647
+
648
+ // per-instance mutexes for serializing pglite access
649
+ const mutexes = {
650
+ postgres: new Mutex(),
651
+ cvr: new Mutex(),
652
+ cdb: new Mutex(),
653
+ }
654
+ function getDbContext(dbName: string): { db: PGlite; mutex: Mutex } {
655
+ if (dbName === 'zero_cvr') return { db: instances.cvr, mutex: mutexes.cvr }
656
+ if (dbName === 'zero_cdb') return { db: instances.cdb, mutex: mutexes.cdb }
657
+ return { db: instances.postgres, mutex: mutexes.postgres }
658
+ }
659
+
660
+ const server = createServer(async (socket: Socket) => {
661
+ socket.setKeepAlive(true, 30000)
662
+ socket.setTimeout(0)
663
+ socket.setNoDelay(true)
664
+
665
+ let dbName = 'postgres'
666
+ let isReplicationConnection = false
667
+
668
+ try {
669
+ // perform startup handshake
670
+ const { params } = await performHandshake(socket, config)
671
+
672
+ dbName = params.database || 'postgres'
673
+ isReplicationConnection = params.replication === 'database'
674
+
675
+ log.debug.proxy(
676
+ `connection: db=${dbName} user=${params.user} replication=${params.replication || 'none'}`
677
+ )
678
+
679
+ const { db } = getDbContext(dbName)
680
+ await db.waitReady
681
+
682
+ // clean up pglite transaction state when client disconnects
683
+ socket.on('close', async () => {
684
+ const { db: closeDb, mutex: closeMutex } = getDbContext(dbName)
685
+ await closeMutex.acquire()
686
+ try {
687
+ await closeDb.exec('ROLLBACK')
688
+ } catch {
689
+ // no transaction to rollback
690
+ } finally {
691
+ closeMutex.release()
692
+ }
693
+ })
694
+
695
+ // enter message processing loop
696
+ const { db: msgDb, mutex: msgMutex } = getDbContext(dbName)
697
+ await messageLoop(
698
+ socket,
699
+ msgDb,
700
+ msgMutex,
701
+ isReplicationConnection,
702
+ instances.postgres,
703
+ mutexes.postgres
704
+ )
705
+ } catch (err) {
706
+ // connection error during handshake or message loop
707
+ if (!socket.destroyed) {
708
+ socket.destroy()
709
+ }
710
+ }
711
+ })
712
+
713
+ return new Promise((resolve, reject) => {
714
+ server.listen(config.pgPort, '127.0.0.1', () => {
715
+ log.debug.proxy(`listening on port ${config.pgPort}`)
716
+ resolve(server)
717
+ })
718
+ server.on('error', reject)
719
+ })
720
+ }