@atproto/xrpc-server 0.2.0 → 0.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +8 -0
- package/LICENSE +21 -0
- package/README.md +11 -4
- package/build.js +0 -8
- package/dist/auth.d.ts +1 -1
- package/dist/index.d.ts +3 -0
- package/dist/index.js +17210 -8937
- package/dist/index.js.map +4 -4
- package/dist/logger.d.ts +2 -1
- package/dist/rate-limiter.d.ts +29 -0
- package/dist/server.d.ts +5 -1
- package/dist/stream/logger.d.ts +2 -1
- package/dist/stream/stream.d.ts +5 -2
- package/dist/stream/subscription.d.ts +2 -1
- package/dist/stream/types.d.ts +6 -6
- package/dist/stream/websocket-keepalive.d.ts +23 -0
- package/dist/types.d.ts +67 -9
- package/dist/util.d.ts +15 -0
- package/package.json +19 -25
- package/src/auth.ts +2 -2
- package/src/index.ts +4 -0
- package/src/logger.ts +2 -1
- package/src/rate-limiter.ts +167 -0
- package/src/server.ts +117 -7
- package/src/stream/logger.ts +2 -1
- package/src/stream/stream.ts +24 -11
- package/src/stream/subscription.ts +21 -107
- package/src/stream/websocket-keepalive.ts +151 -0
- package/src/types.ts +83 -4
- package/src/util.ts +33 -0
- package/tests/bodies.test.ts +3 -3
- package/tests/procedures.test.ts +12 -12
- package/tests/queries.test.ts +19 -14
- package/tests/rate-limiter.test.ts +249 -0
- package/tests/responses.test.ts +77 -0
- package/tests/subscriptions.test.ts +71 -15
- package/tsconfig.build.json +1 -1
- package/tsconfig.json +3 -3
- package/dist/src/index.d.ts +0 -2
- package/dist/src/logger.d.ts +0 -2
- package/dist/src/server.d.ts +0 -19
- package/dist/src/types.d.ts +0 -115
- package/dist/src/util.d.ts +0 -10
- package/dist/tsconfig.build.tsbuildinfo +0 -1
- package/tsconfig.build.tsbuildinfo +0 -1
- package/update-pkg.js +0 -14
package/src/server.ts
CHANGED
|
@@ -30,6 +30,11 @@ import {
|
|
|
30
30
|
XRPCStreamHandler,
|
|
31
31
|
Params,
|
|
32
32
|
InternalServerError,
|
|
33
|
+
XRPCReqContext,
|
|
34
|
+
RateLimiterI,
|
|
35
|
+
RateLimiterConsume,
|
|
36
|
+
isShared,
|
|
37
|
+
RateLimitExceededError,
|
|
33
38
|
} from './types'
|
|
34
39
|
import {
|
|
35
40
|
decodeQueryParams,
|
|
@@ -38,6 +43,7 @@ import {
|
|
|
38
43
|
validateOutput,
|
|
39
44
|
} from './util'
|
|
40
45
|
import log from './logger'
|
|
46
|
+
import { consumeMany } from './rate-limiter'
|
|
41
47
|
|
|
42
48
|
export function createServer(lexicons?: unknown[], options?: Options) {
|
|
43
49
|
return new Server(lexicons, options)
|
|
@@ -50,6 +56,9 @@ export class Server {
|
|
|
50
56
|
lex = new Lexicons()
|
|
51
57
|
options: Options
|
|
52
58
|
middleware: Record<'json' | 'text', RequestHandler>
|
|
59
|
+
globalRateLimiters: RateLimiterI[]
|
|
60
|
+
sharedRateLimiters: Record<string, RateLimiterI>
|
|
61
|
+
routeRateLimiterFns: Record<string, RateLimiterConsume[]>
|
|
53
62
|
|
|
54
63
|
constructor(lexicons?: unknown[], opts?: Options) {
|
|
55
64
|
if (lexicons) {
|
|
@@ -66,6 +75,27 @@ export class Server {
|
|
|
66
75
|
json: express.json({ limit: opts?.payload?.jsonLimit }),
|
|
67
76
|
text: express.text({ limit: opts?.payload?.textLimit }),
|
|
68
77
|
}
|
|
78
|
+
this.globalRateLimiters = []
|
|
79
|
+
this.sharedRateLimiters = {}
|
|
80
|
+
this.routeRateLimiterFns = {}
|
|
81
|
+
if (opts?.rateLimits?.global) {
|
|
82
|
+
for (const limit of opts.rateLimits.global) {
|
|
83
|
+
const rateLimiter = opts.rateLimits.creator({
|
|
84
|
+
...limit,
|
|
85
|
+
keyPrefix: `rl-${limit.name}`,
|
|
86
|
+
})
|
|
87
|
+
this.globalRateLimiters.push(rateLimiter)
|
|
88
|
+
}
|
|
89
|
+
}
|
|
90
|
+
if (opts?.rateLimits?.shared) {
|
|
91
|
+
for (const limit of opts.rateLimits.shared) {
|
|
92
|
+
const rateLimiter = opts.rateLimits.creator({
|
|
93
|
+
...limit,
|
|
94
|
+
keyPrefix: `rl-${limit.name}`,
|
|
95
|
+
})
|
|
96
|
+
this.sharedRateLimiters[limit.name] = rateLimiter
|
|
97
|
+
}
|
|
98
|
+
}
|
|
69
99
|
}
|
|
70
100
|
|
|
71
101
|
// handlers
|
|
@@ -138,6 +168,7 @@ export class Server {
|
|
|
138
168
|
middleware.push(this.middleware.json)
|
|
139
169
|
middleware.push(this.middleware.text)
|
|
140
170
|
}
|
|
171
|
+
this.setupRouteRateLimits(nsid, config)
|
|
141
172
|
this.routes[verb](
|
|
142
173
|
`/xrpc/${nsid}`,
|
|
143
174
|
...middleware,
|
|
@@ -185,6 +216,10 @@ export class Server {
|
|
|
185
216
|
validateOutput(nsid, def, output, this.lex)
|
|
186
217
|
const assertValidXrpcParams = (params: unknown) =>
|
|
187
218
|
this.lex.assertValidXrpcParams(nsid, params)
|
|
219
|
+
const rlFns = this.routeRateLimiterFns[nsid] ?? []
|
|
220
|
+
const consumeRateLimit = (reqCtx: XRPCReqContext) =>
|
|
221
|
+
consumeMany(reqCtx, rlFns)
|
|
222
|
+
|
|
188
223
|
return async function (req, res, next) {
|
|
189
224
|
try {
|
|
190
225
|
// validate request
|
|
@@ -203,14 +238,24 @@ export class Server {
|
|
|
203
238
|
|
|
204
239
|
const locals: RequestLocals = req[kRequestLocals]
|
|
205
240
|
|
|
206
|
-
|
|
207
|
-
const outputUnvalidated = await handler({
|
|
241
|
+
const reqCtx: XRPCReqContext = {
|
|
208
242
|
params,
|
|
209
243
|
input,
|
|
210
244
|
auth: locals.auth,
|
|
211
245
|
req,
|
|
212
246
|
res,
|
|
213
|
-
}
|
|
247
|
+
}
|
|
248
|
+
|
|
249
|
+
// handle rate limits
|
|
250
|
+
if (consumeRateLimit) {
|
|
251
|
+
const result = await consumeRateLimit(reqCtx)
|
|
252
|
+
if (result instanceof RateLimitExceededError) {
|
|
253
|
+
return next(result)
|
|
254
|
+
}
|
|
255
|
+
}
|
|
256
|
+
|
|
257
|
+
// run the handler
|
|
258
|
+
const outputUnvalidated = await handler(reqCtx)
|
|
214
259
|
|
|
215
260
|
if (isHandlerError(outputUnvalidated)) {
|
|
216
261
|
throw XRPCError.fromError(outputUnvalidated)
|
|
@@ -219,6 +264,12 @@ export class Server {
|
|
|
219
264
|
if (!outputUnvalidated || isHandlerSuccess(outputUnvalidated)) {
|
|
220
265
|
// validate response
|
|
221
266
|
const output = validateResOutput(outputUnvalidated)
|
|
267
|
+
// set headers
|
|
268
|
+
if (output?.headers) {
|
|
269
|
+
Object.entries(output.headers).forEach(([name, val]) => {
|
|
270
|
+
res.header(name, val)
|
|
271
|
+
})
|
|
272
|
+
}
|
|
222
273
|
// send response
|
|
223
274
|
if (
|
|
224
275
|
output?.encoding === 'application/json' ||
|
|
@@ -229,6 +280,7 @@ export class Server {
|
|
|
229
280
|
} else if (output?.body instanceof Readable) {
|
|
230
281
|
res.header('Content-Type', output.encoding)
|
|
231
282
|
res.status(200)
|
|
283
|
+
res.once('error', (err) => res.destroy(err))
|
|
232
284
|
forwardStreamErrors(output.body, res)
|
|
233
285
|
output.body.pipe(res)
|
|
234
286
|
} else if (output) {
|
|
@@ -338,6 +390,55 @@ export class Server {
|
|
|
338
390
|
return httpServer
|
|
339
391
|
}
|
|
340
392
|
}
|
|
393
|
+
|
|
394
|
+
private setupRouteRateLimits(nsid: string, config: XRPCHandlerConfig) {
|
|
395
|
+
this.routeRateLimiterFns[nsid] = []
|
|
396
|
+
for (const limit of this.globalRateLimiters) {
|
|
397
|
+
const consumeFn = async (ctx: XRPCReqContext) => {
|
|
398
|
+
return limit.consume(ctx)
|
|
399
|
+
}
|
|
400
|
+
this.routeRateLimiterFns[nsid].push(consumeFn)
|
|
401
|
+
}
|
|
402
|
+
|
|
403
|
+
if (config.rateLimit) {
|
|
404
|
+
const limits = Array.isArray(config.rateLimit)
|
|
405
|
+
? config.rateLimit
|
|
406
|
+
: [config.rateLimit]
|
|
407
|
+
this.routeRateLimiterFns[nsid] = []
|
|
408
|
+
for (const limit of limits) {
|
|
409
|
+
const { calcKey, calcPoints } = limit
|
|
410
|
+
if (isShared(limit)) {
|
|
411
|
+
const rateLimiter = this.sharedRateLimiters[limit.name]
|
|
412
|
+
if (rateLimiter) {
|
|
413
|
+
const consumeFn = (ctx: XRPCReqContext) =>
|
|
414
|
+
rateLimiter.consume(ctx, {
|
|
415
|
+
calcKey,
|
|
416
|
+
calcPoints,
|
|
417
|
+
})
|
|
418
|
+
this.routeRateLimiterFns[nsid].push(consumeFn)
|
|
419
|
+
}
|
|
420
|
+
} else {
|
|
421
|
+
const { durationMs, points } = limit
|
|
422
|
+
const rateLimiter = this.options.rateLimits?.creator({
|
|
423
|
+
keyPrefix: nsid,
|
|
424
|
+
durationMs,
|
|
425
|
+
points,
|
|
426
|
+
calcKey,
|
|
427
|
+
calcPoints,
|
|
428
|
+
})
|
|
429
|
+
if (rateLimiter) {
|
|
430
|
+
this.sharedRateLimiters[nsid] = rateLimiter
|
|
431
|
+
const consumeFn = (ctx: XRPCReqContext) =>
|
|
432
|
+
rateLimiter.consume(ctx, {
|
|
433
|
+
calcKey,
|
|
434
|
+
calcPoints,
|
|
435
|
+
})
|
|
436
|
+
this.routeRateLimiterFns[nsid].push(consumeFn)
|
|
437
|
+
}
|
|
438
|
+
}
|
|
439
|
+
}
|
|
440
|
+
}
|
|
441
|
+
}
|
|
341
442
|
}
|
|
342
443
|
|
|
343
444
|
function isHandlerSuccess(v: HandlerOutput): v is HandlerSuccess {
|
|
@@ -378,14 +479,23 @@ function createAuthMiddleware(verifier: AuthVerifier): RequestHandler {
|
|
|
378
479
|
const errorMiddleware: ErrorRequestHandler = function (err, req, res, next) {
|
|
379
480
|
const locals: RequestLocals | undefined = req[kRequestLocals]
|
|
380
481
|
const methodSuffix = locals ? ` method ${locals.nsid}` : ''
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
482
|
+
const xrpcError = XRPCError.fromError(err)
|
|
483
|
+
if (xrpcError instanceof InternalServerError) {
|
|
484
|
+
// log trace for unhandled exceptions
|
|
384
485
|
log.error(err, `unhandled exception in xrpc${methodSuffix}`)
|
|
486
|
+
} else {
|
|
487
|
+
// do not log trace for known xrpc errors
|
|
488
|
+
log.error(
|
|
489
|
+
{
|
|
490
|
+
status: xrpcError.type,
|
|
491
|
+
message: xrpcError.message,
|
|
492
|
+
name: xrpcError.customErrorName,
|
|
493
|
+
},
|
|
494
|
+
`error in xrpc${methodSuffix}`,
|
|
495
|
+
)
|
|
385
496
|
}
|
|
386
497
|
if (res.headersSent) {
|
|
387
498
|
return next(err)
|
|
388
499
|
}
|
|
389
|
-
const xrpcError = XRPCError.fromError(err)
|
|
390
500
|
return res.status(xrpcError.type).json(xrpcError.payload)
|
|
391
501
|
}
|
package/src/stream/logger.ts
CHANGED
package/src/stream/stream.ts
CHANGED
|
@@ -1,26 +1,39 @@
|
|
|
1
1
|
import { XRPCError, ResponseType } from '@atproto/xrpc'
|
|
2
2
|
import { DuplexOptions } from 'stream'
|
|
3
3
|
import { createWebSocketStream, WebSocket } from 'ws'
|
|
4
|
-
import { Frame } from './frames'
|
|
4
|
+
import { Frame, MessageFrame } from './frames'
|
|
5
5
|
|
|
6
|
-
export
|
|
7
|
-
|
|
6
|
+
export function streamByteChunks(ws: WebSocket, options?: DuplexOptions) {
|
|
7
|
+
return createWebSocketStream(ws, {
|
|
8
8
|
...options,
|
|
9
9
|
readableObjectMode: true, // Ensures frame bytes don't get buffered/combined together
|
|
10
10
|
})
|
|
11
|
+
}
|
|
12
|
+
|
|
13
|
+
export async function* byFrame(ws: WebSocket, options?: DuplexOptions) {
|
|
14
|
+
const wsStream = streamByteChunks(ws, options)
|
|
11
15
|
for await (const chunk of wsStream) {
|
|
12
16
|
yield Frame.fromBytes(chunk)
|
|
13
17
|
}
|
|
14
18
|
}
|
|
15
19
|
|
|
16
20
|
export async function* byMessage(ws: WebSocket, options?: DuplexOptions) {
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
21
|
+
const wsStream = streamByteChunks(ws, options)
|
|
22
|
+
for await (const chunk of wsStream) {
|
|
23
|
+
const msg = ensureChunkIsMessage(chunk)
|
|
24
|
+
yield msg
|
|
25
|
+
}
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
export function ensureChunkIsMessage(chunk: Uint8Array): MessageFrame<unknown> {
|
|
29
|
+
const frame = Frame.fromBytes(chunk)
|
|
30
|
+
if (frame.isMessage()) {
|
|
31
|
+
return frame
|
|
32
|
+
} else if (frame.isError()) {
|
|
33
|
+
// @TODO work -1 error code into XRPCError
|
|
34
|
+
// @ts-ignore
|
|
35
|
+
throw new XRPCError(-1, frame.code, frame.message)
|
|
36
|
+
} else {
|
|
37
|
+
throw new XRPCError(ResponseType.Unknown, undefined, 'Unknown frame type')
|
|
25
38
|
}
|
|
26
39
|
}
|
|
@@ -1,7 +1,6 @@
|
|
|
1
|
-
import {
|
|
2
|
-
import {
|
|
3
|
-
import {
|
|
4
|
-
import { CloseCode, DisconnectError } from './types'
|
|
1
|
+
import { ClientOptions } from 'ws'
|
|
2
|
+
import { WebSocketKeepAlive } from './websocket-keepalive'
|
|
3
|
+
import { ensureChunkIsMessage } from './stream'
|
|
5
4
|
|
|
6
5
|
export class Subscription<T = unknown> {
|
|
7
6
|
constructor(
|
|
@@ -9,6 +8,7 @@ export class Subscription<T = unknown> {
|
|
|
9
8
|
service: string
|
|
10
9
|
method: string
|
|
11
10
|
maxReconnectSeconds?: number
|
|
11
|
+
heartbeatIntervalMs?: number
|
|
12
12
|
signal?: AbortSignal
|
|
13
13
|
validate: (obj: unknown) => T | undefined
|
|
14
14
|
onReconnectError?: (
|
|
@@ -24,107 +24,31 @@ export class Subscription<T = unknown> {
|
|
|
24
24
|
) {}
|
|
25
25
|
|
|
26
26
|
async *[Symbol.asyncIterator](): AsyncGenerator<T> {
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
const
|
|
38
|
-
const
|
|
39
|
-
if (
|
|
40
|
-
|
|
27
|
+
const ws = new WebSocketKeepAlive({
|
|
28
|
+
...this.opts,
|
|
29
|
+
getUrl: async () => {
|
|
30
|
+
const params = (await this.opts.getParams?.()) ?? {}
|
|
31
|
+
const query = encodeQueryParams(params)
|
|
32
|
+
return `${this.opts.service}/xrpc/${this.opts.method}?${query}`
|
|
33
|
+
},
|
|
34
|
+
})
|
|
35
|
+
for await (const chunk of ws) {
|
|
36
|
+
const message = await ensureChunkIsMessage(chunk)
|
|
37
|
+
const t = message.header.t
|
|
38
|
+
const clone = message.body !== undefined ? { ...message.body } : undefined
|
|
39
|
+
if (clone !== undefined && t !== undefined) {
|
|
40
|
+
clone['$type'] = t.startsWith('#') ? this.opts.method + t : t
|
|
41
41
|
}
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
})
|
|
46
|
-
ws.once('close', (code, reason) => {
|
|
47
|
-
if (code === CloseCode.Abnormal) {
|
|
48
|
-
// Forward into an error to distinguish from a clean close
|
|
49
|
-
ac.abort(
|
|
50
|
-
new AbnormalCloseError(`Abnormal ws close: ${reason.toString()}`),
|
|
51
|
-
)
|
|
52
|
-
}
|
|
53
|
-
})
|
|
54
|
-
try {
|
|
55
|
-
const cancelable = { signal: ac.signal }
|
|
56
|
-
for await (const message of byMessage(ws, cancelable)) {
|
|
57
|
-
const t = message.header.t
|
|
58
|
-
const clone =
|
|
59
|
-
message.body !== undefined ? { ...message.body } : undefined
|
|
60
|
-
if (clone !== undefined && t !== undefined) {
|
|
61
|
-
clone['$type'] = t.startsWith('#') ? this.opts.method + t : t
|
|
62
|
-
}
|
|
63
|
-
const result = this.opts.validate(clone)
|
|
64
|
-
if (result !== undefined) {
|
|
65
|
-
yield result
|
|
66
|
-
}
|
|
67
|
-
}
|
|
68
|
-
} catch (_err) {
|
|
69
|
-
const err = _err?.['code'] === 'ABORT_ERR' ? _err['cause'] : _err
|
|
70
|
-
if (err instanceof DisconnectError) {
|
|
71
|
-
// We cleanly end the connection
|
|
72
|
-
ws.close(err.wsCode)
|
|
73
|
-
break
|
|
74
|
-
}
|
|
75
|
-
ws.close() // No-ops if already closed or closing
|
|
76
|
-
if (isReconnectable(err)) {
|
|
77
|
-
reconnects ??= 0 // Never reconnect with a null
|
|
78
|
-
this.opts.onReconnectError?.(err, reconnects, initialSetup)
|
|
79
|
-
continue
|
|
80
|
-
} else {
|
|
81
|
-
throw err
|
|
82
|
-
}
|
|
42
|
+
const result = this.opts.validate(clone)
|
|
43
|
+
if (result !== undefined) {
|
|
44
|
+
yield result
|
|
83
45
|
}
|
|
84
|
-
break // Other side cleanly ended stream and disconnected
|
|
85
46
|
}
|
|
86
47
|
}
|
|
87
|
-
|
|
88
|
-
private async getSocket() {
|
|
89
|
-
const params = (await this.opts.getParams?.()) ?? {}
|
|
90
|
-
const query = encodeQueryParams(params)
|
|
91
|
-
const url = `${this.opts.service}/xrpc/${this.opts.method}?${query}`
|
|
92
|
-
return new WebSocket(url, this.opts)
|
|
93
|
-
}
|
|
94
48
|
}
|
|
95
49
|
|
|
96
50
|
export default Subscription
|
|
97
51
|
|
|
98
|
-
class AbnormalCloseError extends Error {
|
|
99
|
-
code = 'EWSABNORMALCLOSE'
|
|
100
|
-
}
|
|
101
|
-
|
|
102
|
-
function isReconnectable(err: unknown): boolean {
|
|
103
|
-
// Network errors are reconnectable.
|
|
104
|
-
// AuthenticationRequired and InvalidRequest XRPCErrors are not reconnectable.
|
|
105
|
-
// @TODO method-specific XRPCErrors may be reconnectable, need to consider. Receiving
|
|
106
|
-
// an invalid message is not current reconnectable, but the user can decide to skip them.
|
|
107
|
-
if (!err || typeof err['code'] !== 'string') return false
|
|
108
|
-
return networkErrorCodes.includes(err['code'])
|
|
109
|
-
}
|
|
110
|
-
|
|
111
|
-
const networkErrorCodes = [
|
|
112
|
-
'EWSABNORMALCLOSE',
|
|
113
|
-
'ECONNRESET',
|
|
114
|
-
'ECONNREFUSED',
|
|
115
|
-
'ECONNABORTED',
|
|
116
|
-
'EPIPE',
|
|
117
|
-
'ETIMEDOUT',
|
|
118
|
-
'ECANCELED',
|
|
119
|
-
]
|
|
120
|
-
|
|
121
|
-
function backoffMs(n: number, maxMs: number) {
|
|
122
|
-
const baseSec = Math.pow(2, n) // 1, 2, 4, ...
|
|
123
|
-
const randSec = Math.random() - 0.5 // Random jitter between -.5 and .5 seconds
|
|
124
|
-
const ms = 1000 * (baseSec + randSec)
|
|
125
|
-
return Math.min(ms, maxMs)
|
|
126
|
-
}
|
|
127
|
-
|
|
128
52
|
function encodeQueryParams(obj: Record<string, unknown>): string {
|
|
129
53
|
const params = new URLSearchParams()
|
|
130
54
|
Object.entries(obj).forEach(([key, value]) => {
|
|
@@ -163,13 +87,3 @@ function encodeQueryParam(value: unknown): string | string[] {
|
|
|
163
87
|
}
|
|
164
88
|
throw new Error(`Cannot encode ${typeof value}s into query params`)
|
|
165
89
|
}
|
|
166
|
-
|
|
167
|
-
function forwardSignal(signal: AbortSignal, ac: AbortController) {
|
|
168
|
-
if (signal.aborted) {
|
|
169
|
-
return ac.abort(signal.reason)
|
|
170
|
-
} else {
|
|
171
|
-
signal.addEventListener('abort', () => ac.abort(signal.reason), {
|
|
172
|
-
signal: ac.signal,
|
|
173
|
-
})
|
|
174
|
-
}
|
|
175
|
-
}
|
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
import { SECOND, wait } from '@atproto/common'
|
|
2
|
+
import { WebSocket, ClientOptions } from 'ws'
|
|
3
|
+
import { streamByteChunks } from './stream'
|
|
4
|
+
import { CloseCode, DisconnectError } from './types'
|
|
5
|
+
|
|
6
|
+
export class WebSocketKeepAlive {
|
|
7
|
+
public ws: WebSocket | null = null
|
|
8
|
+
public initialSetup = true
|
|
9
|
+
public reconnects: number | null = null
|
|
10
|
+
|
|
11
|
+
constructor(
|
|
12
|
+
public opts: ClientOptions & {
|
|
13
|
+
getUrl: () => Promise<string>
|
|
14
|
+
maxReconnectSeconds?: number
|
|
15
|
+
signal?: AbortSignal
|
|
16
|
+
heartbeatIntervalMs?: number
|
|
17
|
+
onReconnectError?: (
|
|
18
|
+
error: unknown,
|
|
19
|
+
n: number,
|
|
20
|
+
initialSetup: boolean,
|
|
21
|
+
) => void
|
|
22
|
+
},
|
|
23
|
+
) {}
|
|
24
|
+
|
|
25
|
+
async *[Symbol.asyncIterator](): AsyncGenerator<Uint8Array> {
|
|
26
|
+
const maxReconnectMs = 1000 * (this.opts.maxReconnectSeconds ?? 64)
|
|
27
|
+
while (true) {
|
|
28
|
+
if (this.reconnects !== null) {
|
|
29
|
+
const duration = this.initialSetup
|
|
30
|
+
? Math.min(1000, maxReconnectMs)
|
|
31
|
+
: backoffMs(this.reconnects++, maxReconnectMs)
|
|
32
|
+
await wait(duration)
|
|
33
|
+
}
|
|
34
|
+
const url = await this.opts.getUrl()
|
|
35
|
+
this.ws = new WebSocket(url, this.opts)
|
|
36
|
+
const ac = new AbortController()
|
|
37
|
+
if (this.opts.signal) {
|
|
38
|
+
forwardSignal(this.opts.signal, ac)
|
|
39
|
+
}
|
|
40
|
+
this.ws.once('open', () => {
|
|
41
|
+
this.initialSetup = false
|
|
42
|
+
this.reconnects = 0
|
|
43
|
+
if (this.ws) {
|
|
44
|
+
this.startHeartbeat(this.ws)
|
|
45
|
+
}
|
|
46
|
+
})
|
|
47
|
+
this.ws.once('close', (code, reason) => {
|
|
48
|
+
if (code === CloseCode.Abnormal) {
|
|
49
|
+
// Forward into an error to distinguish from a clean close
|
|
50
|
+
ac.abort(
|
|
51
|
+
new AbnormalCloseError(`Abnormal ws close: ${reason.toString()}`),
|
|
52
|
+
)
|
|
53
|
+
}
|
|
54
|
+
})
|
|
55
|
+
|
|
56
|
+
try {
|
|
57
|
+
const wsStream = streamByteChunks(this.ws, { signal: ac.signal })
|
|
58
|
+
for await (const chunk of wsStream) {
|
|
59
|
+
yield chunk
|
|
60
|
+
}
|
|
61
|
+
} catch (_err) {
|
|
62
|
+
const err = _err?.['code'] === 'ABORT_ERR' ? _err['cause'] : _err
|
|
63
|
+
if (err instanceof DisconnectError) {
|
|
64
|
+
// We cleanly end the connection
|
|
65
|
+
this.ws?.close(err.wsCode)
|
|
66
|
+
break
|
|
67
|
+
}
|
|
68
|
+
this.ws?.close() // No-ops if already closed or closing
|
|
69
|
+
if (isReconnectable(err)) {
|
|
70
|
+
this.reconnects ??= 0 // Never reconnect with a null
|
|
71
|
+
this.opts.onReconnectError?.(err, this.reconnects, this.initialSetup)
|
|
72
|
+
continue
|
|
73
|
+
} else {
|
|
74
|
+
throw err
|
|
75
|
+
}
|
|
76
|
+
}
|
|
77
|
+
break // Other side cleanly ended stream and disconnected
|
|
78
|
+
}
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
startHeartbeat(ws: WebSocket) {
|
|
82
|
+
let isAlive = true
|
|
83
|
+
let heartbeatInterval: NodeJS.Timer | null = null
|
|
84
|
+
|
|
85
|
+
const checkAlive = () => {
|
|
86
|
+
if (!isAlive) {
|
|
87
|
+
return ws.terminate()
|
|
88
|
+
}
|
|
89
|
+
isAlive = false // expect websocket to no longer be alive unless we receive a "pong" within the interval
|
|
90
|
+
ws.ping()
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
checkAlive()
|
|
94
|
+
heartbeatInterval = setInterval(
|
|
95
|
+
checkAlive,
|
|
96
|
+
this.opts.heartbeatIntervalMs ?? 10 * SECOND,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
ws.on('pong', () => {
|
|
100
|
+
isAlive = true
|
|
101
|
+
})
|
|
102
|
+
ws.once('close', () => {
|
|
103
|
+
if (heartbeatInterval) {
|
|
104
|
+
clearInterval(heartbeatInterval)
|
|
105
|
+
heartbeatInterval = null
|
|
106
|
+
}
|
|
107
|
+
})
|
|
108
|
+
}
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
export default WebSocketKeepAlive
|
|
112
|
+
|
|
113
|
+
class AbnormalCloseError extends Error {
|
|
114
|
+
code = 'EWSABNORMALCLOSE'
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
function isReconnectable(err: unknown): boolean {
|
|
118
|
+
// Network errors are reconnectable.
|
|
119
|
+
// AuthenticationRequired and InvalidRequest XRPCErrors are not reconnectable.
|
|
120
|
+
// @TODO method-specific XRPCErrors may be reconnectable, need to consider. Receiving
|
|
121
|
+
// an invalid message is not current reconnectable, but the user can decide to skip them.
|
|
122
|
+
if (!err || typeof err['code'] !== 'string') return false
|
|
123
|
+
return networkErrorCodes.includes(err['code'])
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
const networkErrorCodes = [
|
|
127
|
+
'EWSABNORMALCLOSE',
|
|
128
|
+
'ECONNRESET',
|
|
129
|
+
'ECONNREFUSED',
|
|
130
|
+
'ECONNABORTED',
|
|
131
|
+
'EPIPE',
|
|
132
|
+
'ETIMEDOUT',
|
|
133
|
+
'ECANCELED',
|
|
134
|
+
]
|
|
135
|
+
|
|
136
|
+
function backoffMs(n: number, maxMs: number) {
|
|
137
|
+
const baseSec = Math.pow(2, n) // 1, 2, 4, ...
|
|
138
|
+
const randSec = Math.random() - 0.5 // Random jitter between -.5 and .5 seconds
|
|
139
|
+
const ms = 1000 * (baseSec + randSec)
|
|
140
|
+
return Math.min(ms, maxMs)
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
function forwardSignal(signal: AbortSignal, ac: AbortController) {
|
|
144
|
+
if (signal.aborted) {
|
|
145
|
+
return ac.abort(signal.reason)
|
|
146
|
+
} else {
|
|
147
|
+
signal.addEventListener('abort', () => ac.abort(signal.reason), {
|
|
148
|
+
signal: ac.signal,
|
|
149
|
+
})
|
|
150
|
+
}
|
|
151
|
+
}
|