@atproto/xrpc-server 0.3.0 → 0.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,6 @@
1
1
  import { subsystemLogger } from '@atproto/common'
2
2
 
3
- export const logger = subsystemLogger('xrpc-stream')
3
+ export const logger: ReturnType<typeof subsystemLogger> =
4
+ subsystemLogger('xrpc-stream')
4
5
 
5
6
  export default logger
@@ -1,26 +1,39 @@
1
1
  import { XRPCError, ResponseType } from '@atproto/xrpc'
2
2
  import { DuplexOptions } from 'stream'
3
3
  import { createWebSocketStream, WebSocket } from 'ws'
4
- import { Frame } from './frames'
4
+ import { Frame, MessageFrame } from './frames'
5
5
 
6
- export async function* byFrame(ws: WebSocket, options?: DuplexOptions) {
7
- const wsStream = createWebSocketStream(ws, {
6
+ export function streamByteChunks(ws: WebSocket, options?: DuplexOptions) {
7
+ return createWebSocketStream(ws, {
8
8
  ...options,
9
9
  readableObjectMode: true, // Ensures frame bytes don't get buffered/combined together
10
10
  })
11
+ }
12
+
13
+ export async function* byFrame(ws: WebSocket, options?: DuplexOptions) {
14
+ const wsStream = streamByteChunks(ws, options)
11
15
  for await (const chunk of wsStream) {
12
16
  yield Frame.fromBytes(chunk)
13
17
  }
14
18
  }
15
19
 
16
20
  export async function* byMessage(ws: WebSocket, options?: DuplexOptions) {
17
- for await (const frame of byFrame(ws, options)) {
18
- if (frame.isMessage()) {
19
- yield frame
20
- } else if (frame.isError()) {
21
- throw new XRPCError(-1, frame.code, frame.message)
22
- } else {
23
- throw new XRPCError(ResponseType.Unknown, undefined, 'Unknown frame type')
24
- }
21
+ const wsStream = streamByteChunks(ws, options)
22
+ for await (const chunk of wsStream) {
23
+ const msg = ensureChunkIsMessage(chunk)
24
+ yield msg
25
+ }
26
+ }
27
+
28
+ export function ensureChunkIsMessage(chunk: Uint8Array): MessageFrame<unknown> {
29
+ const frame = Frame.fromBytes(chunk)
30
+ if (frame.isMessage()) {
31
+ return frame
32
+ } else if (frame.isError()) {
33
+ // @TODO work -1 error code into XRPCError
34
+ // @ts-ignore
35
+ throw new XRPCError(-1, frame.code, frame.message)
36
+ } else {
37
+ throw new XRPCError(ResponseType.Unknown, undefined, 'Unknown frame type')
25
38
  }
26
39
  }
@@ -1,7 +1,6 @@
1
- import { wait } from '@atproto/common'
2
- import { WebSocket, ClientOptions } from 'ws'
3
- import { byMessage } from './stream'
4
- import { CloseCode, DisconnectError } from './types'
1
+ import { ClientOptions } from 'ws'
2
+ import { WebSocketKeepAlive } from './websocket-keepalive'
3
+ import { ensureChunkIsMessage } from './stream'
5
4
 
6
5
  export class Subscription<T = unknown> {
7
6
  constructor(
@@ -9,6 +8,7 @@ export class Subscription<T = unknown> {
9
8
  service: string
10
9
  method: string
11
10
  maxReconnectSeconds?: number
11
+ heartbeatIntervalMs?: number
12
12
  signal?: AbortSignal
13
13
  validate: (obj: unknown) => T | undefined
14
14
  onReconnectError?: (
@@ -24,107 +24,31 @@ export class Subscription<T = unknown> {
24
24
  ) {}
25
25
 
26
26
  async *[Symbol.asyncIterator](): AsyncGenerator<T> {
27
- let initialSetup = true
28
- let reconnects: number | null = null
29
- const maxReconnectMs = 1000 * (this.opts.maxReconnectSeconds ?? 64)
30
- while (true) {
31
- if (reconnects !== null) {
32
- const duration = initialSetup
33
- ? Math.min(1000, maxReconnectMs)
34
- : backoffMs(reconnects++, maxReconnectMs)
35
- await wait(duration)
36
- }
37
- const ws = await this.getSocket()
38
- const ac = new AbortController()
39
- if (this.opts.signal) {
40
- forwardSignal(this.opts.signal, ac)
27
+ const ws = new WebSocketKeepAlive({
28
+ ...this.opts,
29
+ getUrl: async () => {
30
+ const params = (await this.opts.getParams?.()) ?? {}
31
+ const query = encodeQueryParams(params)
32
+ return `${this.opts.service}/xrpc/${this.opts.method}?${query}`
33
+ },
34
+ })
35
+ for await (const chunk of ws) {
36
+ const message = await ensureChunkIsMessage(chunk)
37
+ const t = message.header.t
38
+ const clone = message.body !== undefined ? { ...message.body } : undefined
39
+ if (clone !== undefined && t !== undefined) {
40
+ clone['$type'] = t.startsWith('#') ? this.opts.method + t : t
41
41
  }
42
- ws.once('open', () => {
43
- initialSetup = false
44
- reconnects = 0
45
- })
46
- ws.once('close', (code, reason) => {
47
- if (code === CloseCode.Abnormal) {
48
- // Forward into an error to distinguish from a clean close
49
- ac.abort(
50
- new AbnormalCloseError(`Abnormal ws close: ${reason.toString()}`),
51
- )
52
- }
53
- })
54
- try {
55
- const cancelable = { signal: ac.signal }
56
- for await (const message of byMessage(ws, cancelable)) {
57
- const t = message.header.t
58
- const clone =
59
- message.body !== undefined ? { ...message.body } : undefined
60
- if (clone !== undefined && t !== undefined) {
61
- clone['$type'] = t.startsWith('#') ? this.opts.method + t : t
62
- }
63
- const result = this.opts.validate(clone)
64
- if (result !== undefined) {
65
- yield result
66
- }
67
- }
68
- } catch (_err) {
69
- const err = _err?.['code'] === 'ABORT_ERR' ? _err['cause'] : _err
70
- if (err instanceof DisconnectError) {
71
- // We cleanly end the connection
72
- ws.close(err.wsCode)
73
- break
74
- }
75
- ws.close() // No-ops if already closed or closing
76
- if (isReconnectable(err)) {
77
- reconnects ??= 0 // Never reconnect with a null
78
- this.opts.onReconnectError?.(err, reconnects, initialSetup)
79
- continue
80
- } else {
81
- throw err
82
- }
42
+ const result = this.opts.validate(clone)
43
+ if (result !== undefined) {
44
+ yield result
83
45
  }
84
- break // Other side cleanly ended stream and disconnected
85
46
  }
86
47
  }
87
-
88
- private async getSocket() {
89
- const params = (await this.opts.getParams?.()) ?? {}
90
- const query = encodeQueryParams(params)
91
- const url = `${this.opts.service}/xrpc/${this.opts.method}?${query}`
92
- return new WebSocket(url, this.opts)
93
- }
94
48
  }
95
49
 
96
50
  export default Subscription
97
51
 
98
- class AbnormalCloseError extends Error {
99
- code = 'EWSABNORMALCLOSE'
100
- }
101
-
102
- function isReconnectable(err: unknown): boolean {
103
- // Network errors are reconnectable.
104
- // AuthenticationRequired and InvalidRequest XRPCErrors are not reconnectable.
105
- // @TODO method-specific XRPCErrors may be reconnectable, need to consider. Receiving
106
- // an invalid message is not current reconnectable, but the user can decide to skip them.
107
- if (!err || typeof err['code'] !== 'string') return false
108
- return networkErrorCodes.includes(err['code'])
109
- }
110
-
111
- const networkErrorCodes = [
112
- 'EWSABNORMALCLOSE',
113
- 'ECONNRESET',
114
- 'ECONNREFUSED',
115
- 'ECONNABORTED',
116
- 'EPIPE',
117
- 'ETIMEDOUT',
118
- 'ECANCELED',
119
- ]
120
-
121
- function backoffMs(n: number, maxMs: number) {
122
- const baseSec = Math.pow(2, n) // 1, 2, 4, ...
123
- const randSec = Math.random() - 0.5 // Random jitter between -.5 and .5 seconds
124
- const ms = 1000 * (baseSec + randSec)
125
- return Math.min(ms, maxMs)
126
- }
127
-
128
52
  function encodeQueryParams(obj: Record<string, unknown>): string {
129
53
  const params = new URLSearchParams()
130
54
  Object.entries(obj).forEach(([key, value]) => {
@@ -163,13 +87,3 @@ function encodeQueryParam(value: unknown): string | string[] {
163
87
  }
164
88
  throw new Error(`Cannot encode ${typeof value}s into query params`)
165
89
  }
166
-
167
- function forwardSignal(signal: AbortSignal, ac: AbortController) {
168
- if (signal.aborted) {
169
- return ac.abort(signal.reason)
170
- } else {
171
- signal.addEventListener('abort', () => ac.abort(signal.reason), {
172
- signal: ac.signal,
173
- })
174
- }
175
- }
@@ -0,0 +1,151 @@
1
+ import { SECOND, wait } from '@atproto/common'
2
+ import { WebSocket, ClientOptions } from 'ws'
3
+ import { streamByteChunks } from './stream'
4
+ import { CloseCode, DisconnectError } from './types'
5
+
6
+ export class WebSocketKeepAlive {
7
+ public ws: WebSocket | null = null
8
+ public initialSetup = true
9
+ public reconnects: number | null = null
10
+
11
+ constructor(
12
+ public opts: ClientOptions & {
13
+ getUrl: () => Promise<string>
14
+ maxReconnectSeconds?: number
15
+ signal?: AbortSignal
16
+ heartbeatIntervalMs?: number
17
+ onReconnectError?: (
18
+ error: unknown,
19
+ n: number,
20
+ initialSetup: boolean,
21
+ ) => void
22
+ },
23
+ ) {}
24
+
25
+ async *[Symbol.asyncIterator](): AsyncGenerator<Uint8Array> {
26
+ const maxReconnectMs = 1000 * (this.opts.maxReconnectSeconds ?? 64)
27
+ while (true) {
28
+ if (this.reconnects !== null) {
29
+ const duration = this.initialSetup
30
+ ? Math.min(1000, maxReconnectMs)
31
+ : backoffMs(this.reconnects++, maxReconnectMs)
32
+ await wait(duration)
33
+ }
34
+ const url = await this.opts.getUrl()
35
+ this.ws = new WebSocket(url, this.opts)
36
+ const ac = new AbortController()
37
+ if (this.opts.signal) {
38
+ forwardSignal(this.opts.signal, ac)
39
+ }
40
+ this.ws.once('open', () => {
41
+ this.initialSetup = false
42
+ this.reconnects = 0
43
+ if (this.ws) {
44
+ this.startHeartbeat(this.ws)
45
+ }
46
+ })
47
+ this.ws.once('close', (code, reason) => {
48
+ if (code === CloseCode.Abnormal) {
49
+ // Forward into an error to distinguish from a clean close
50
+ ac.abort(
51
+ new AbnormalCloseError(`Abnormal ws close: ${reason.toString()}`),
52
+ )
53
+ }
54
+ })
55
+
56
+ try {
57
+ const wsStream = streamByteChunks(this.ws, { signal: ac.signal })
58
+ for await (const chunk of wsStream) {
59
+ yield chunk
60
+ }
61
+ } catch (_err) {
62
+ const err = _err?.['code'] === 'ABORT_ERR' ? _err['cause'] : _err
63
+ if (err instanceof DisconnectError) {
64
+ // We cleanly end the connection
65
+ this.ws?.close(err.wsCode)
66
+ break
67
+ }
68
+ this.ws?.close() // No-ops if already closed or closing
69
+ if (isReconnectable(err)) {
70
+ this.reconnects ??= 0 // Never reconnect with a null
71
+ this.opts.onReconnectError?.(err, this.reconnects, this.initialSetup)
72
+ continue
73
+ } else {
74
+ throw err
75
+ }
76
+ }
77
+ break // Other side cleanly ended stream and disconnected
78
+ }
79
+ }
80
+
81
+ startHeartbeat(ws: WebSocket) {
82
+ let isAlive = true
83
+ let heartbeatInterval: NodeJS.Timer | null = null
84
+
85
+ const checkAlive = () => {
86
+ if (!isAlive) {
87
+ return ws.terminate()
88
+ }
89
+ isAlive = false // expect websocket to no longer be alive unless we receive a "pong" within the interval
90
+ ws.ping()
91
+ }
92
+
93
+ checkAlive()
94
+ heartbeatInterval = setInterval(
95
+ checkAlive,
96
+ this.opts.heartbeatIntervalMs ?? 10 * SECOND,
97
+ )
98
+
99
+ ws.on('pong', () => {
100
+ isAlive = true
101
+ })
102
+ ws.once('close', () => {
103
+ if (heartbeatInterval) {
104
+ clearInterval(heartbeatInterval)
105
+ heartbeatInterval = null
106
+ }
107
+ })
108
+ }
109
+ }
110
+
111
+ export default WebSocketKeepAlive
112
+
113
+ class AbnormalCloseError extends Error {
114
+ code = 'EWSABNORMALCLOSE'
115
+ }
116
+
117
+ function isReconnectable(err: unknown): boolean {
118
+ // Network errors are reconnectable.
119
+ // AuthenticationRequired and InvalidRequest XRPCErrors are not reconnectable.
120
+ // @TODO method-specific XRPCErrors may be reconnectable, need to consider. Receiving
121
+ // an invalid message is not current reconnectable, but the user can decide to skip them.
122
+ if (!err || typeof err['code'] !== 'string') return false
123
+ return networkErrorCodes.includes(err['code'])
124
+ }
125
+
126
+ const networkErrorCodes = [
127
+ 'EWSABNORMALCLOSE',
128
+ 'ECONNRESET',
129
+ 'ECONNREFUSED',
130
+ 'ECONNABORTED',
131
+ 'EPIPE',
132
+ 'ETIMEDOUT',
133
+ 'ECANCELED',
134
+ ]
135
+
136
+ function backoffMs(n: number, maxMs: number) {
137
+ const baseSec = Math.pow(2, n) // 1, 2, 4, ...
138
+ const randSec = Math.random() - 0.5 // Random jitter between -.5 and .5 seconds
139
+ const ms = 1000 * (baseSec + randSec)
140
+ return Math.min(ms, maxMs)
141
+ }
142
+
143
+ function forwardSignal(signal: AbortSignal, ac: AbortController) {
144
+ if (signal.aborted) {
145
+ return ac.abort(signal.reason)
146
+ } else {
147
+ signal.addEventListener('abort', () => ac.abort(signal.reason), {
148
+ signal: ac.signal,
149
+ })
150
+ }
151
+ }
package/src/types.ts CHANGED
@@ -15,6 +15,11 @@ export type Options = {
15
15
  blobLimit?: number
16
16
  textLimit?: number
17
17
  }
18
+ rateLimits?: {
19
+ creator: RateLimiterCreator
20
+ global?: ServerRateLimitDescription[]
21
+ shared?: ServerRateLimitDescription[]
22
+ }
18
23
  }
19
24
 
20
25
  export type UndecodedParams = typeof express.request['query']
@@ -50,13 +55,17 @@ export type HandlerError = zod.infer<typeof handlerError>
50
55
 
51
56
  export type HandlerOutput = HandlerSuccess | HandlerError
52
57
 
53
- export type XRPCHandler = (ctx: {
58
+ export type XRPCReqContext = {
54
59
  auth: HandlerAuth | undefined
55
60
  params: Params
56
61
  input: HandlerInput | undefined
57
62
  req: express.Request
58
63
  res: express.Response
59
- }) => Promise<HandlerOutput> | HandlerOutput | undefined
64
+ }
65
+
66
+ export type XRPCHandler = (
67
+ ctx: XRPCReqContext,
68
+ ) => Promise<HandlerOutput> | HandlerOutput | undefined
60
69
 
61
70
  export type XRPCStreamHandler = (ctx: {
62
71
  auth: HandlerAuth | undefined
@@ -76,7 +85,66 @@ export type StreamAuthVerifier = (ctx: {
76
85
  req: IncomingMessage
77
86
  }) => Promise<AuthOutput> | AuthOutput
78
87
 
88
+ export type CalcKeyFn = (ctx: XRPCReqContext) => string
89
+ export type CalcPointsFn = (ctx: XRPCReqContext) => number
90
+
91
+ export interface RateLimiterI {
92
+ consume: RateLimiterConsume
93
+ }
94
+
95
+ export type RateLimiterConsume = (
96
+ ctx: XRPCReqContext,
97
+ opts?: { calcKey?: CalcKeyFn; calcPoints?: CalcPointsFn },
98
+ ) => Promise<RateLimiterStatus | RateLimitExceededError | null>
99
+
100
+ export type RateLimiterCreator = (opts: {
101
+ keyPrefix: string
102
+ durationMs: number
103
+ points: number
104
+ calcKey?: (ctx: XRPCReqContext) => string
105
+ calcPoints?: (ctx: XRPCReqContext) => number
106
+ }) => RateLimiterI
107
+
108
+ export type ServerRateLimitDescription = {
109
+ name: string
110
+ durationMs: number
111
+ points: number
112
+ calcKey?: (ctx: XRPCReqContext) => string
113
+ calcPoints?: (ctx: XRPCReqContext) => number
114
+ }
115
+
116
+ export type SharedRateLimitOpts = {
117
+ name: string
118
+ calcKey?: (ctx: XRPCReqContext) => string
119
+ calcPoints?: (ctx: XRPCReqContext) => number
120
+ }
121
+
122
+ export type RouteRateLimitOpts = {
123
+ durationMs: number
124
+ points: number
125
+ calcKey?: (ctx: XRPCReqContext) => string
126
+ calcPoints?: (ctx: XRPCReqContext) => number
127
+ }
128
+
129
+ export type HandlerRateLimitOpts = SharedRateLimitOpts | RouteRateLimitOpts
130
+
131
+ export const isShared = (
132
+ opts: HandlerRateLimitOpts,
133
+ ): opts is SharedRateLimitOpts => {
134
+ return typeof opts['name'] === 'string'
135
+ }
136
+
137
+ export type RateLimiterStatus = {
138
+ limit: number
139
+ duration: number
140
+ remainingPoints: number
141
+ msBeforeNext: number
142
+ consumedPoints: number
143
+ isFirstInDuration: boolean
144
+ }
145
+
79
146
  export type XRPCHandlerConfig = {
147
+ rateLimit?: HandlerRateLimitOpts | HandlerRateLimitOpts[]
80
148
  auth?: AuthVerifier
81
149
  handler: XRPCHandler
82
150
  }
@@ -133,7 +201,13 @@ export class XRPCError extends Error {
133
201
  }
134
202
 
135
203
  export function isHandlerError(v: unknown): v is HandlerError {
136
- return handlerError.safeParse(v).success
204
+ return (
205
+ !!v &&
206
+ typeof v === 'object' &&
207
+ typeof v['status'] === 'number' &&
208
+ (v['error'] === undefined || typeof v['error'] === 'string') &&
209
+ (v['message'] === undefined || typeof v['message'] === 'string')
210
+ )
137
211
  }
138
212
 
139
213
  export class InvalidRequestError extends XRPCError {
@@ -154,6 +228,16 @@ export class ForbiddenError extends XRPCError {
154
228
  }
155
229
  }
156
230
 
231
+ export class RateLimitExceededError extends XRPCError {
232
+ constructor(
233
+ public status: RateLimiterStatus,
234
+ errorMessage?: string,
235
+ customErrorName?: string,
236
+ ) {
237
+ super(ResponseType.RateLimitExceeded, errorMessage, customErrorName)
238
+ }
239
+ }
240
+
157
241
  export class InternalServerError extends XRPCError {
158
242
  constructor(errorMessage?: string, customErrorName?: string) {
159
243
  super(ResponseType.InternalServerError, errorMessage, customErrorName)
@@ -166,9 +250,9 @@ export class UpstreamFailureError extends XRPCError {
166
250
  }
167
251
  }
168
252
 
169
- export class NotEnoughResoucesError extends XRPCError {
253
+ export class NotEnoughResourcesError extends XRPCError {
170
254
  constructor(errorMessage?: string, customErrorName?: string) {
171
- super(ResponseType.NotEnoughResouces, errorMessage, customErrorName)
255
+ super(ResponseType.NotEnoughResources, errorMessage, customErrorName)
172
256
  }
173
257
  }
174
258