@atproto/xrpc-server 0.3.0 → 0.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/logger.d.ts CHANGED
@@ -1,2 +1,3 @@
1
- export declare const logger: import("pino").default.Logger<import("pino").default.LoggerOptions>;
1
+ import { subsystemLogger } from '@atproto/common';
2
+ export declare const logger: ReturnType<typeof subsystemLogger>;
2
3
  export default logger;
@@ -0,0 +1,31 @@
1
+ import { RateLimiterAbstract, RateLimiterRes } from 'rate-limiter-flexible';
2
+ import { CalcKeyFn, CalcPointsFn, RateLimitExceededError, RateLimiterConsume, RateLimiterI, RateLimiterStatus, XRPCReqContext } from './types';
3
+ export declare type RateLimiterOpts = {
4
+ keyPrefix: string;
5
+ durationMs: number;
6
+ points: number;
7
+ bypassSecret?: string;
8
+ bypassIps?: string[];
9
+ calcKey?: CalcKeyFn;
10
+ calcPoints?: CalcPointsFn;
11
+ failClosed?: boolean;
12
+ };
13
+ export declare class RateLimiter implements RateLimiterI {
14
+ limiter: RateLimiterAbstract;
15
+ private bypassSecret?;
16
+ private bypassIps?;
17
+ private failClosed?;
18
+ calcKey: CalcKeyFn;
19
+ calcPoints: CalcPointsFn;
20
+ constructor(limiter: RateLimiterAbstract, opts: RateLimiterOpts);
21
+ static memory(opts: RateLimiterOpts): RateLimiter;
22
+ static redis(storeClient: unknown, opts: RateLimiterOpts): RateLimiter;
23
+ consume(ctx: XRPCReqContext, opts?: {
24
+ calcKey?: CalcKeyFn;
25
+ calcPoints?: CalcPointsFn;
26
+ }): Promise<RateLimiterStatus | RateLimitExceededError | null>;
27
+ }
28
+ export declare const formatLimiterStatus: (limiter: RateLimiterAbstract, res: RateLimiterRes) => RateLimiterStatus;
29
+ export declare const consumeMany: (ctx: XRPCReqContext, fns: RateLimiterConsume[]) => Promise<RateLimiterStatus | RateLimitExceededError | null>;
30
+ export declare const setResHeaders: (ctx: XRPCReqContext, status: RateLimiterStatus) => void;
31
+ export declare const getTightestLimit: (resps: (RateLimiterStatus | RateLimitExceededError | null)[]) => RateLimiterStatus | RateLimitExceededError | null;
package/dist/server.d.ts CHANGED
@@ -1,7 +1,7 @@
1
1
  import express, { NextFunction, RequestHandler } from 'express';
2
2
  import { Lexicons, LexXrpcProcedure, LexXrpcQuery, LexXrpcSubscription } from '@atproto/lexicon';
3
3
  import { XrpcStreamServer } from './stream';
4
- import { XRPCHandler, XRPCHandlerConfig, Options, XRPCStreamHandlerConfig, XRPCStreamHandler } from './types';
4
+ import { XRPCHandler, XRPCHandlerConfig, Options, XRPCStreamHandlerConfig, XRPCStreamHandler, RateLimiterI, RateLimiterConsume } from './types';
5
5
  export declare function createServer(lexicons?: unknown[], options?: Options): Server;
6
6
  export declare class Server {
7
7
  router: import("express-serve-static-core").Express;
@@ -10,6 +10,9 @@ export declare class Server {
10
10
  lex: Lexicons;
11
11
  options: Options;
12
12
  middleware: Record<'json' | 'text', RequestHandler>;
13
+ globalRateLimiters: RateLimiterI[];
14
+ sharedRateLimiters: Record<string, RateLimiterI>;
15
+ routeRateLimiterFns: Record<string, RateLimiterConsume[]>;
13
16
  constructor(lexicons?: unknown[], opts?: Options);
14
17
  method(nsid: string, configOrFn: XRPCHandlerConfig | XRPCHandler): void;
15
18
  addMethod(nsid: string, configOrFn: XRPCHandlerConfig | XRPCHandler): void;
@@ -22,4 +25,5 @@ export declare class Server {
22
25
  createHandler(nsid: string, def: LexXrpcQuery | LexXrpcProcedure, handler: XRPCHandler): RequestHandler;
23
26
  protected addSubscription(nsid: string, def: LexXrpcSubscription, config: XRPCStreamHandlerConfig): Promise<void>;
24
27
  private enableStreamingOnListen;
28
+ private setupRouteRateLimits;
25
29
  }
@@ -1,2 +1,3 @@
1
- export declare const logger: import("pino").default.Logger<import("pino").default.LoggerOptions>;
1
+ import { subsystemLogger } from '@atproto/common';
2
+ export declare const logger: ReturnType<typeof subsystemLogger>;
2
3
  export default logger;
@@ -1,5 +1,8 @@
1
1
  /// <reference types="node" />
2
2
  import { DuplexOptions } from 'stream';
3
3
  import { WebSocket } from 'ws';
4
- export declare function byFrame(ws: WebSocket, options?: DuplexOptions): AsyncGenerator<import("./frames").MessageFrame<unknown> | import("./frames").ErrorFrame<string>, void, unknown>;
5
- export declare function byMessage(ws: WebSocket, options?: DuplexOptions): AsyncGenerator<import("./frames").MessageFrame<unknown>, void, unknown>;
4
+ import { MessageFrame } from './frames';
5
+ export declare function streamByteChunks(ws: WebSocket, options?: DuplexOptions): import("stream").Duplex;
6
+ export declare function byFrame(ws: WebSocket, options?: DuplexOptions): AsyncGenerator<MessageFrame<unknown> | import("./frames").ErrorFrame<string>, void, unknown>;
7
+ export declare function byMessage(ws: WebSocket, options?: DuplexOptions): AsyncGenerator<MessageFrame<unknown>, void, unknown>;
8
+ export declare function ensureChunkIsMessage(chunk: Uint8Array): MessageFrame<unknown>;
@@ -4,6 +4,7 @@ export declare class Subscription<T = unknown> {
4
4
  service: string;
5
5
  method: string;
6
6
  maxReconnectSeconds?: number;
7
+ heartbeatIntervalMs?: number;
7
8
  signal?: AbortSignal;
8
9
  validate: (obj: unknown) => T | undefined;
9
10
  onReconnectError?: (error: unknown, n: number, initialSetup: boolean) => void;
@@ -13,12 +14,12 @@ export declare class Subscription<T = unknown> {
13
14
  service: string;
14
15
  method: string;
15
16
  maxReconnectSeconds?: number;
17
+ heartbeatIntervalMs?: number;
16
18
  signal?: AbortSignal;
17
19
  validate: (obj: unknown) => T | undefined;
18
20
  onReconnectError?: (error: unknown, n: number, initialSetup: boolean) => void;
19
21
  getParams?: () => Record<string, unknown> | Promise<Record<string, unknown> | undefined> | undefined;
20
22
  });
21
23
  [Symbol.asyncIterator](): AsyncGenerator<T>;
22
- private getSocket;
23
24
  }
24
25
  export default Subscription;
@@ -0,0 +1,23 @@
1
+ import { WebSocket, ClientOptions } from 'ws';
2
+ export declare class WebSocketKeepAlive {
3
+ opts: ClientOptions & {
4
+ getUrl: () => Promise<string>;
5
+ maxReconnectSeconds?: number;
6
+ signal?: AbortSignal;
7
+ heartbeatIntervalMs?: number;
8
+ onReconnectError?: (error: unknown, n: number, initialSetup: boolean) => void;
9
+ };
10
+ ws: WebSocket | null;
11
+ initialSetup: boolean;
12
+ reconnects: number | null;
13
+ constructor(opts: ClientOptions & {
14
+ getUrl: () => Promise<string>;
15
+ maxReconnectSeconds?: number;
16
+ signal?: AbortSignal;
17
+ heartbeatIntervalMs?: number;
18
+ onReconnectError?: (error: unknown, n: number, initialSetup: boolean) => void;
19
+ });
20
+ [Symbol.asyncIterator](): AsyncGenerator<Uint8Array>;
21
+ startHeartbeat(ws: WebSocket): void;
22
+ }
23
+ export default WebSocketKeepAlive;
package/dist/types.d.ts CHANGED
@@ -10,6 +10,11 @@ export declare type Options = {
10
10
  blobLimit?: number;
11
11
  textLimit?: number;
12
12
  };
13
+ rateLimits?: {
14
+ creator: RateLimiterCreator;
15
+ global?: ServerRateLimitDescription[];
16
+ shared?: ServerRateLimitDescription[];
17
+ };
13
18
  };
14
19
  export declare type UndecodedParams = typeof express.request['query'];
15
20
  export declare type Primitive = string | number | boolean;
@@ -65,13 +70,14 @@ export declare const handlerError: zod.ZodObject<{
65
70
  }>;
66
71
  export declare type HandlerError = zod.infer<typeof handlerError>;
67
72
  export declare type HandlerOutput = HandlerSuccess | HandlerError;
68
- export declare type XRPCHandler = (ctx: {
73
+ export declare type XRPCReqContext = {
69
74
  auth: HandlerAuth | undefined;
70
75
  params: Params;
71
76
  input: HandlerInput | undefined;
72
77
  req: express.Request;
73
78
  res: express.Response;
74
- }) => Promise<HandlerOutput> | HandlerOutput | undefined;
79
+ };
80
+ export declare type XRPCHandler = (ctx: XRPCReqContext) => Promise<HandlerOutput> | HandlerOutput | undefined;
75
81
  export declare type XRPCStreamHandler = (ctx: {
76
82
  auth: HandlerAuth | undefined;
77
83
  params: Params;
@@ -86,7 +92,52 @@ export declare type AuthVerifier = (ctx: {
86
92
  export declare type StreamAuthVerifier = (ctx: {
87
93
  req: IncomingMessage;
88
94
  }) => Promise<AuthOutput> | AuthOutput;
95
+ export declare type CalcKeyFn = (ctx: XRPCReqContext) => string;
96
+ export declare type CalcPointsFn = (ctx: XRPCReqContext) => number;
97
+ export interface RateLimiterI {
98
+ consume: RateLimiterConsume;
99
+ }
100
+ export declare type RateLimiterConsume = (ctx: XRPCReqContext, opts?: {
101
+ calcKey?: CalcKeyFn;
102
+ calcPoints?: CalcPointsFn;
103
+ }) => Promise<RateLimiterStatus | RateLimitExceededError | null>;
104
+ export declare type RateLimiterCreator = (opts: {
105
+ keyPrefix: string;
106
+ durationMs: number;
107
+ points: number;
108
+ calcKey?: (ctx: XRPCReqContext) => string;
109
+ calcPoints?: (ctx: XRPCReqContext) => number;
110
+ }) => RateLimiterI;
111
+ export declare type ServerRateLimitDescription = {
112
+ name: string;
113
+ durationMs: number;
114
+ points: number;
115
+ calcKey?: (ctx: XRPCReqContext) => string;
116
+ calcPoints?: (ctx: XRPCReqContext) => number;
117
+ };
118
+ export declare type SharedRateLimitOpts = {
119
+ name: string;
120
+ calcKey?: (ctx: XRPCReqContext) => string;
121
+ calcPoints?: (ctx: XRPCReqContext) => number;
122
+ };
123
+ export declare type RouteRateLimitOpts = {
124
+ durationMs: number;
125
+ points: number;
126
+ calcKey?: (ctx: XRPCReqContext) => string;
127
+ calcPoints?: (ctx: XRPCReqContext) => number;
128
+ };
129
+ export declare type HandlerRateLimitOpts = SharedRateLimitOpts | RouteRateLimitOpts;
130
+ export declare const isShared: (opts: HandlerRateLimitOpts) => opts is SharedRateLimitOpts;
131
+ export declare type RateLimiterStatus = {
132
+ limit: number;
133
+ duration: number;
134
+ remainingPoints: number;
135
+ msBeforeNext: number;
136
+ consumedPoints: number;
137
+ isFirstInDuration: boolean;
138
+ };
89
139
  export declare type XRPCHandlerConfig = {
140
+ rateLimit?: HandlerRateLimitOpts | HandlerRateLimitOpts[];
90
141
  auth?: AuthVerifier;
91
142
  handler: XRPCHandler;
92
143
  };
@@ -117,13 +168,17 @@ export declare class AuthRequiredError extends XRPCError {
117
168
  export declare class ForbiddenError extends XRPCError {
118
169
  constructor(errorMessage?: string, customErrorName?: string);
119
170
  }
171
+ export declare class RateLimitExceededError extends XRPCError {
172
+ status: RateLimiterStatus;
173
+ constructor(status: RateLimiterStatus, errorMessage?: string, customErrorName?: string);
174
+ }
120
175
  export declare class InternalServerError extends XRPCError {
121
176
  constructor(errorMessage?: string, customErrorName?: string);
122
177
  }
123
178
  export declare class UpstreamFailureError extends XRPCError {
124
179
  constructor(errorMessage?: string, customErrorName?: string);
125
180
  }
126
- export declare class NotEnoughResoucesError extends XRPCError {
181
+ export declare class NotEnoughResourcesError extends XRPCError {
127
182
  constructor(errorMessage?: string, customErrorName?: string);
128
183
  }
129
184
  export declare class UpstreamTimeoutError extends XRPCError {
package/package.json CHANGED
@@ -1,47 +1,47 @@
1
1
  {
2
2
  "name": "@atproto/xrpc-server",
3
- "version": "0.3.0",
4
- "main": "dist/index.js",
5
- "scripts": {
6
- "test": "jest",
7
- "prettier": "prettier --check src/ tests/",
8
- "prettier:fix": "prettier --write src/ tests/",
9
- "lint": "eslint . --ext .ts,.tsx",
10
- "lint:fix": "yarn lint --fix",
11
- "verify": "run-p prettier lint",
12
- "verify:fix": "yarn prettier:fix && yarn lint:fix",
13
- "build": "node ./build.js",
14
- "postbuild": "tsc --build tsconfig.build.json",
15
- "update-main-to-dist": "node ./update-pkg.js --update-main-to-dist",
16
- "update-main-to-src": "node ./update-pkg.js --update-main-to-src",
17
- "prepublish": "npm run update-main-to-dist",
18
- "postpublish": "npm run update-main-to-src"
19
- },
3
+ "version": "0.3.2",
20
4
  "license": "MIT",
5
+ "description": "atproto HTTP API (XRPC) server library",
6
+ "keywords": [
7
+ "atproto",
8
+ "xrpc"
9
+ ],
10
+ "homepage": "https://atproto.com",
21
11
  "repository": {
22
12
  "type": "git",
23
- "url": "https://github.com/bluesky-social/atproto.git",
13
+ "url": "https://github.com/bluesky-social/atproto",
24
14
  "directory": "packages/xrpc-server"
25
15
  },
16
+ "main": "dist/index.js",
26
17
  "dependencies": {
27
- "@atproto/common": "*",
28
- "@atproto/crypto": "*",
29
- "@atproto/lexicon": "*",
30
18
  "cbor-x": "^1.5.1",
31
19
  "express": "^4.17.2",
32
20
  "http-errors": "^2.0.0",
33
21
  "mime-types": "^2.1.35",
22
+ "rate-limiter-flexible": "^2.4.1",
34
23
  "uint8arrays": "3.0.0",
35
24
  "ws": "^8.12.0",
36
- "zod": "^3.21.4"
25
+ "zod": "^3.21.4",
26
+ "@atproto/common": "^0.3.1",
27
+ "@atproto/crypto": "^0.2.2",
28
+ "@atproto/lexicon": "^0.2.2"
37
29
  },
38
30
  "devDependencies": {
39
- "@atproto/crypto": "*",
40
- "@atproto/xrpc": "*",
41
31
  "@types/express": "^4.17.13",
32
+ "@types/express-serve-static-core": "^4.17.36",
42
33
  "@types/http-errors": "^2.0.1",
43
34
  "@types/ws": "^8.5.4",
44
35
  "get-port": "^6.1.2",
45
- "multiformats": "^9.6.4"
46
- }
47
- }
36
+ "multiformats": "^9.9.0",
37
+ "@atproto/crypto": "^0.2.2",
38
+ "@atproto/xrpc": "^0.3.2"
39
+ },
40
+ "scripts": {
41
+ "test": "jest",
42
+ "build": "node ./build.js",
43
+ "postbuild": "tsc --build tsconfig.build.json",
44
+ "update-main-to-dist": "node ../../update-main-to-dist.js packages/xrpc-server"
45
+ },
46
+ "types": "dist/index.d.ts"
47
+ }
package/src/index.ts CHANGED
@@ -2,6 +2,7 @@ export * from './types'
2
2
  export * from './auth'
3
3
  export * from './server'
4
4
  export * from './stream'
5
+ export * from './rate-limiter'
5
6
 
6
7
  export type { ServerTiming } from './util'
7
8
  export { serverTimingHeader, ServerTimer } from './util'
package/src/logger.ts CHANGED
@@ -1,5 +1,6 @@
1
1
  import { subsystemLogger } from '@atproto/common'
2
2
 
3
- export const logger = subsystemLogger('xrpc-server')
3
+ export const logger: ReturnType<typeof subsystemLogger> =
4
+ subsystemLogger('xrpc-server')
4
5
 
5
6
  export default logger
@@ -0,0 +1,173 @@
1
+ import {
2
+ RateLimiterAbstract,
3
+ RateLimiterMemory,
4
+ RateLimiterRedis,
5
+ RateLimiterRes,
6
+ } from 'rate-limiter-flexible'
7
+ import { logger } from './logger'
8
+ import {
9
+ CalcKeyFn,
10
+ CalcPointsFn,
11
+ RateLimitExceededError,
12
+ RateLimiterConsume,
13
+ RateLimiterI,
14
+ RateLimiterStatus,
15
+ XRPCReqContext,
16
+ } from './types'
17
+
18
+ export type RateLimiterOpts = {
19
+ keyPrefix: string
20
+ durationMs: number
21
+ points: number
22
+ bypassSecret?: string
23
+ bypassIps?: string[]
24
+ calcKey?: CalcKeyFn
25
+ calcPoints?: CalcPointsFn
26
+ failClosed?: boolean
27
+ }
28
+
29
+ export class RateLimiter implements RateLimiterI {
30
+ public limiter: RateLimiterAbstract
31
+ private bypassSecret?: string
32
+ private bypassIps?: string[]
33
+ private failClosed?: boolean
34
+ public calcKey: CalcKeyFn
35
+ public calcPoints: CalcPointsFn
36
+
37
+ constructor(limiter: RateLimiterAbstract, opts: RateLimiterOpts) {
38
+ this.limiter = limiter
39
+ this.bypassSecret = opts.bypassSecret
40
+ this.bypassIps = opts.bypassIps
41
+ this.calcKey = opts.calcKey ?? defaultKey
42
+ this.calcPoints = opts.calcPoints ?? defaultPoints
43
+ }
44
+
45
+ static memory(opts: RateLimiterOpts): RateLimiter {
46
+ const limiter = new RateLimiterMemory({
47
+ keyPrefix: opts.keyPrefix,
48
+ duration: Math.floor(opts.durationMs / 1000),
49
+ points: opts.points,
50
+ })
51
+ return new RateLimiter(limiter, opts)
52
+ }
53
+
54
+ static redis(storeClient: unknown, opts: RateLimiterOpts): RateLimiter {
55
+ const limiter = new RateLimiterRedis({
56
+ storeClient,
57
+ keyPrefix: opts.keyPrefix,
58
+ duration: Math.floor(opts.durationMs / 1000),
59
+ points: opts.points,
60
+ })
61
+ return new RateLimiter(limiter, opts)
62
+ }
63
+
64
+ async consume(
65
+ ctx: XRPCReqContext,
66
+ opts?: { calcKey?: CalcKeyFn; calcPoints?: CalcPointsFn },
67
+ ): Promise<RateLimiterStatus | RateLimitExceededError | null> {
68
+ if (
69
+ this.bypassSecret &&
70
+ ctx.req.header('x-ratelimit-bypass') === this.bypassSecret
71
+ ) {
72
+ return null
73
+ }
74
+ if (this.bypassIps && this.bypassIps.includes(ctx.req.ip)) {
75
+ return null
76
+ }
77
+ const key = opts?.calcKey ? opts.calcKey(ctx) : this.calcKey(ctx)
78
+ const points = opts?.calcPoints
79
+ ? opts.calcPoints(ctx)
80
+ : this.calcPoints(ctx)
81
+ if (points < 1) {
82
+ return null
83
+ }
84
+ try {
85
+ const res = await this.limiter.consume(key, points)
86
+ return formatLimiterStatus(this.limiter, res)
87
+ } catch (err) {
88
+ // yes this library rejects with a res not an error
89
+ if (err instanceof RateLimiterRes) {
90
+ const status = formatLimiterStatus(this.limiter, err)
91
+ return new RateLimitExceededError(status)
92
+ } else {
93
+ if (this.failClosed) {
94
+ throw err
95
+ }
96
+ logger.error(
97
+ {
98
+ err,
99
+ keyPrefix: this.limiter.keyPrefix,
100
+ points: this.limiter.points,
101
+ duration: this.limiter.duration,
102
+ },
103
+ 'rate limiter failed to consume points',
104
+ )
105
+ return null
106
+ }
107
+ }
108
+ }
109
+ }
110
+
111
+ export const formatLimiterStatus = (
112
+ limiter: RateLimiterAbstract,
113
+ res: RateLimiterRes,
114
+ ): RateLimiterStatus => {
115
+ return {
116
+ limit: limiter.points,
117
+ duration: limiter.duration,
118
+ remainingPoints: res.remainingPoints,
119
+ msBeforeNext: res.msBeforeNext,
120
+ consumedPoints: res.consumedPoints,
121
+ isFirstInDuration: res.isFirstInDuration,
122
+ }
123
+ }
124
+
125
+ export const consumeMany = async (
126
+ ctx: XRPCReqContext,
127
+ fns: RateLimiterConsume[],
128
+ ): Promise<RateLimiterStatus | RateLimitExceededError | null> => {
129
+ if (fns.length === 0) return null
130
+ const results = await Promise.all(fns.map((fn) => fn(ctx)))
131
+ const tightestLimit = getTightestLimit(results)
132
+ if (tightestLimit === null) {
133
+ return null
134
+ } else if (tightestLimit instanceof RateLimitExceededError) {
135
+ setResHeaders(ctx, tightestLimit.status)
136
+ return tightestLimit
137
+ } else {
138
+ setResHeaders(ctx, tightestLimit)
139
+ return tightestLimit
140
+ }
141
+ }
142
+
143
+ export const setResHeaders = (
144
+ ctx: XRPCReqContext,
145
+ status: RateLimiterStatus,
146
+ ) => {
147
+ ctx.res.setHeader('RateLimit-Limit', status.limit)
148
+ ctx.res.setHeader('RateLimit-Remaining', status.remainingPoints)
149
+ ctx.res.setHeader(
150
+ 'RateLimit-Reset',
151
+ Math.floor((Date.now() + status.msBeforeNext) / 1000),
152
+ )
153
+ ctx.res.setHeader('RateLimit-Policy', `${status.limit};w=${status.duration}`)
154
+ }
155
+
156
+ export const getTightestLimit = (
157
+ resps: (RateLimiterStatus | RateLimitExceededError | null)[],
158
+ ): RateLimiterStatus | RateLimitExceededError | null => {
159
+ let lowest: RateLimiterStatus | null = null
160
+ for (const resp of resps) {
161
+ if (resp === null) continue
162
+ if (resp instanceof RateLimitExceededError) return resp
163
+ if (lowest === null || resp.remainingPoints < lowest.remainingPoints) {
164
+ lowest = resp
165
+ }
166
+ }
167
+ return lowest
168
+ }
169
+
170
+ // when using a proxy, ensure headers are getting forwarded correctly: `app.set('trust proxy', true)`
171
+ // https://expressjs.com/en/guide/behind-proxies.html
172
+ const defaultKey: CalcKeyFn = (ctx: XRPCReqContext) => ctx.req.ip
173
+ const defaultPoints: CalcPointsFn = () => 1
package/src/server.ts CHANGED
@@ -30,6 +30,11 @@ import {
30
30
  XRPCStreamHandler,
31
31
  Params,
32
32
  InternalServerError,
33
+ XRPCReqContext,
34
+ RateLimiterI,
35
+ RateLimiterConsume,
36
+ isShared,
37
+ RateLimitExceededError,
33
38
  } from './types'
34
39
  import {
35
40
  decodeQueryParams,
@@ -38,6 +43,7 @@ import {
38
43
  validateOutput,
39
44
  } from './util'
40
45
  import log from './logger'
46
+ import { consumeMany } from './rate-limiter'
41
47
 
42
48
  export function createServer(lexicons?: unknown[], options?: Options) {
43
49
  return new Server(lexicons, options)
@@ -50,6 +56,9 @@ export class Server {
50
56
  lex = new Lexicons()
51
57
  options: Options
52
58
  middleware: Record<'json' | 'text', RequestHandler>
59
+ globalRateLimiters: RateLimiterI[]
60
+ sharedRateLimiters: Record<string, RateLimiterI>
61
+ routeRateLimiterFns: Record<string, RateLimiterConsume[]>
53
62
 
54
63
  constructor(lexicons?: unknown[], opts?: Options) {
55
64
  if (lexicons) {
@@ -66,6 +75,27 @@ export class Server {
66
75
  json: express.json({ limit: opts?.payload?.jsonLimit }),
67
76
  text: express.text({ limit: opts?.payload?.textLimit }),
68
77
  }
78
+ this.globalRateLimiters = []
79
+ this.sharedRateLimiters = {}
80
+ this.routeRateLimiterFns = {}
81
+ if (opts?.rateLimits?.global) {
82
+ for (const limit of opts.rateLimits.global) {
83
+ const rateLimiter = opts.rateLimits.creator({
84
+ ...limit,
85
+ keyPrefix: `rl-${limit.name}`,
86
+ })
87
+ this.globalRateLimiters.push(rateLimiter)
88
+ }
89
+ }
90
+ if (opts?.rateLimits?.shared) {
91
+ for (const limit of opts.rateLimits.shared) {
92
+ const rateLimiter = opts.rateLimits.creator({
93
+ ...limit,
94
+ keyPrefix: `rl-${limit.name}`,
95
+ })
96
+ this.sharedRateLimiters[limit.name] = rateLimiter
97
+ }
98
+ }
69
99
  }
70
100
 
71
101
  // handlers
@@ -138,6 +168,7 @@ export class Server {
138
168
  middleware.push(this.middleware.json)
139
169
  middleware.push(this.middleware.text)
140
170
  }
171
+ this.setupRouteRateLimits(nsid, config)
141
172
  this.routes[verb](
142
173
  `/xrpc/${nsid}`,
143
174
  ...middleware,
@@ -185,6 +216,10 @@ export class Server {
185
216
  validateOutput(nsid, def, output, this.lex)
186
217
  const assertValidXrpcParams = (params: unknown) =>
187
218
  this.lex.assertValidXrpcParams(nsid, params)
219
+ const rlFns = this.routeRateLimiterFns[nsid] ?? []
220
+ const consumeRateLimit = (reqCtx: XRPCReqContext) =>
221
+ consumeMany(reqCtx, rlFns)
222
+
188
223
  return async function (req, res, next) {
189
224
  try {
190
225
  // validate request
@@ -203,14 +238,24 @@ export class Server {
203
238
 
204
239
  const locals: RequestLocals = req[kRequestLocals]
205
240
 
206
- // run the handler
207
- const outputUnvalidated = await handler({
241
+ const reqCtx: XRPCReqContext = {
208
242
  params,
209
243
  input,
210
244
  auth: locals.auth,
211
245
  req,
212
246
  res,
213
- })
247
+ }
248
+
249
+ // handle rate limits
250
+ if (consumeRateLimit) {
251
+ const result = await consumeRateLimit(reqCtx)
252
+ if (result instanceof RateLimitExceededError) {
253
+ return next(result)
254
+ }
255
+ }
256
+
257
+ // run the handler
258
+ const outputUnvalidated = await handler(reqCtx)
214
259
 
215
260
  if (isHandlerError(outputUnvalidated)) {
216
261
  throw XRPCError.fromError(outputUnvalidated)
@@ -345,6 +390,56 @@ export class Server {
345
390
  return httpServer
346
391
  }
347
392
  }
393
+
394
+ private setupRouteRateLimits(nsid: string, config: XRPCHandlerConfig) {
395
+ this.routeRateLimiterFns[nsid] = []
396
+ for (const limit of this.globalRateLimiters) {
397
+ const consumeFn = async (ctx: XRPCReqContext) => {
398
+ return limit.consume(ctx)
399
+ }
400
+ this.routeRateLimiterFns[nsid].push(consumeFn)
401
+ }
402
+
403
+ if (config.rateLimit) {
404
+ const limits = Array.isArray(config.rateLimit)
405
+ ? config.rateLimit
406
+ : [config.rateLimit]
407
+ this.routeRateLimiterFns[nsid] = []
408
+ for (let i = 0; i < limits.length; i++) {
409
+ const limit = limits[i]
410
+ const { calcKey, calcPoints } = limit
411
+ if (isShared(limit)) {
412
+ const rateLimiter = this.sharedRateLimiters[limit.name]
413
+ if (rateLimiter) {
414
+ const consumeFn = (ctx: XRPCReqContext) =>
415
+ rateLimiter.consume(ctx, {
416
+ calcKey,
417
+ calcPoints,
418
+ })
419
+ this.routeRateLimiterFns[nsid].push(consumeFn)
420
+ }
421
+ } else {
422
+ const { durationMs, points } = limit
423
+ const rateLimiter = this.options.rateLimits?.creator({
424
+ keyPrefix: `nsid-${i}`,
425
+ durationMs,
426
+ points,
427
+ calcKey,
428
+ calcPoints,
429
+ })
430
+ if (rateLimiter) {
431
+ this.sharedRateLimiters[nsid] = rateLimiter
432
+ const consumeFn = (ctx: XRPCReqContext) =>
433
+ rateLimiter.consume(ctx, {
434
+ calcKey,
435
+ calcPoints,
436
+ })
437
+ this.routeRateLimiterFns[nsid].push(consumeFn)
438
+ }
439
+ }
440
+ }
441
+ }
442
+ }
348
443
  }
349
444
 
350
445
  function isHandlerSuccess(v: HandlerOutput): v is HandlerSuccess {
@@ -385,14 +480,23 @@ function createAuthMiddleware(verifier: AuthVerifier): RequestHandler {
385
480
  const errorMiddleware: ErrorRequestHandler = function (err, req, res, next) {
386
481
  const locals: RequestLocals | undefined = req[kRequestLocals]
387
482
  const methodSuffix = locals ? ` method ${locals.nsid}` : ''
388
- if (err instanceof XRPCError) {
389
- log.error(err, `error in xrpc${methodSuffix}`)
390
- } else {
483
+ const xrpcError = XRPCError.fromError(err)
484
+ if (xrpcError instanceof InternalServerError) {
485
+ // log trace for unhandled exceptions
391
486
  log.error(err, `unhandled exception in xrpc${methodSuffix}`)
487
+ } else {
488
+ // do not log trace for known xrpc errors
489
+ log.error(
490
+ {
491
+ status: xrpcError.type,
492
+ message: xrpcError.message,
493
+ name: xrpcError.customErrorName,
494
+ },
495
+ `error in xrpc${methodSuffix}`,
496
+ )
392
497
  }
393
498
  if (res.headersSent) {
394
499
  return next(err)
395
500
  }
396
- const xrpcError = XRPCError.fromError(err)
397
501
  return res.status(xrpcError.type).json(xrpcError.payload)
398
502
  }