@atproto/xrpc-server 0.3.0 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/logger.d.ts CHANGED
@@ -1,2 +1,3 @@
1
- export declare const logger: import("pino").default.Logger<import("pino").default.LoggerOptions>;
1
+ import { subsystemLogger } from '@atproto/common';
2
+ export declare const logger: ReturnType<typeof subsystemLogger>;
2
3
  export default logger;
@@ -0,0 +1,29 @@
1
+ import { RateLimiterAbstract, RateLimiterRes } from 'rate-limiter-flexible';
2
+ import { CalcKeyFn, CalcPointsFn, RateLimitExceededError, RateLimiterConsume, RateLimiterI, RateLimiterStatus, XRPCReqContext } from './types';
3
+ export declare type RateLimiterOpts = {
4
+ keyPrefix: string;
5
+ durationMs: number;
6
+ points: number;
7
+ bypassSecret?: string;
8
+ calcKey?: CalcKeyFn;
9
+ calcPoints?: CalcPointsFn;
10
+ failClosed?: boolean;
11
+ };
12
+ export declare class RateLimiter implements RateLimiterI {
13
+ limiter: RateLimiterAbstract;
14
+ private byPassSecret?;
15
+ private failClosed?;
16
+ calcKey: CalcKeyFn;
17
+ calcPoints: CalcPointsFn;
18
+ constructor(limiter: RateLimiterAbstract, opts: RateLimiterOpts);
19
+ static memory(opts: RateLimiterOpts): RateLimiter;
20
+ static redis(storeClient: unknown, opts: RateLimiterOpts): RateLimiter;
21
+ consume(ctx: XRPCReqContext, opts?: {
22
+ calcKey?: CalcKeyFn;
23
+ calcPoints?: CalcPointsFn;
24
+ }): Promise<RateLimiterStatus | RateLimitExceededError | null>;
25
+ }
26
+ export declare const formatLimiterStatus: (limiter: RateLimiterAbstract, res: RateLimiterRes) => RateLimiterStatus;
27
+ export declare const consumeMany: (ctx: XRPCReqContext, fns: RateLimiterConsume[]) => Promise<RateLimiterStatus | RateLimitExceededError | null>;
28
+ export declare const setResHeaders: (ctx: XRPCReqContext, status: RateLimiterStatus) => void;
29
+ export declare const getTightestLimit: (resps: (RateLimiterStatus | RateLimitExceededError | null)[]) => RateLimiterStatus | RateLimitExceededError | null;
package/dist/server.d.ts CHANGED
@@ -1,7 +1,7 @@
1
1
  import express, { NextFunction, RequestHandler } from 'express';
2
2
  import { Lexicons, LexXrpcProcedure, LexXrpcQuery, LexXrpcSubscription } from '@atproto/lexicon';
3
3
  import { XrpcStreamServer } from './stream';
4
- import { XRPCHandler, XRPCHandlerConfig, Options, XRPCStreamHandlerConfig, XRPCStreamHandler } from './types';
4
+ import { XRPCHandler, XRPCHandlerConfig, Options, XRPCStreamHandlerConfig, XRPCStreamHandler, RateLimiterI, RateLimiterConsume } from './types';
5
5
  export declare function createServer(lexicons?: unknown[], options?: Options): Server;
6
6
  export declare class Server {
7
7
  router: import("express-serve-static-core").Express;
@@ -10,6 +10,9 @@ export declare class Server {
10
10
  lex: Lexicons;
11
11
  options: Options;
12
12
  middleware: Record<'json' | 'text', RequestHandler>;
13
+ globalRateLimiters: RateLimiterI[];
14
+ sharedRateLimiters: Record<string, RateLimiterI>;
15
+ routeRateLimiterFns: Record<string, RateLimiterConsume[]>;
13
16
  constructor(lexicons?: unknown[], opts?: Options);
14
17
  method(nsid: string, configOrFn: XRPCHandlerConfig | XRPCHandler): void;
15
18
  addMethod(nsid: string, configOrFn: XRPCHandlerConfig | XRPCHandler): void;
@@ -22,4 +25,5 @@ export declare class Server {
22
25
  createHandler(nsid: string, def: LexXrpcQuery | LexXrpcProcedure, handler: XRPCHandler): RequestHandler;
23
26
  protected addSubscription(nsid: string, def: LexXrpcSubscription, config: XRPCStreamHandlerConfig): Promise<void>;
24
27
  private enableStreamingOnListen;
28
+ private setupRouteRateLimits;
25
29
  }
@@ -1,2 +1,3 @@
1
- export declare const logger: import("pino").default.Logger<import("pino").default.LoggerOptions>;
1
+ import { subsystemLogger } from '@atproto/common';
2
+ export declare const logger: ReturnType<typeof subsystemLogger>;
2
3
  export default logger;
@@ -1,5 +1,8 @@
1
1
  /// <reference types="node" />
2
2
  import { DuplexOptions } from 'stream';
3
3
  import { WebSocket } from 'ws';
4
- export declare function byFrame(ws: WebSocket, options?: DuplexOptions): AsyncGenerator<import("./frames").MessageFrame<unknown> | import("./frames").ErrorFrame<string>, void, unknown>;
5
- export declare function byMessage(ws: WebSocket, options?: DuplexOptions): AsyncGenerator<import("./frames").MessageFrame<unknown>, void, unknown>;
4
+ import { MessageFrame } from './frames';
5
+ export declare function streamByteChunks(ws: WebSocket, options?: DuplexOptions): import("stream").Duplex;
6
+ export declare function byFrame(ws: WebSocket, options?: DuplexOptions): AsyncGenerator<MessageFrame<unknown> | import("./frames").ErrorFrame<string>, void, unknown>;
7
+ export declare function byMessage(ws: WebSocket, options?: DuplexOptions): AsyncGenerator<MessageFrame<unknown>, void, unknown>;
8
+ export declare function ensureChunkIsMessage(chunk: Uint8Array): MessageFrame<unknown>;
@@ -4,6 +4,7 @@ export declare class Subscription<T = unknown> {
4
4
  service: string;
5
5
  method: string;
6
6
  maxReconnectSeconds?: number;
7
+ heartbeatIntervalMs?: number;
7
8
  signal?: AbortSignal;
8
9
  validate: (obj: unknown) => T | undefined;
9
10
  onReconnectError?: (error: unknown, n: number, initialSetup: boolean) => void;
@@ -13,12 +14,12 @@ export declare class Subscription<T = unknown> {
13
14
  service: string;
14
15
  method: string;
15
16
  maxReconnectSeconds?: number;
17
+ heartbeatIntervalMs?: number;
16
18
  signal?: AbortSignal;
17
19
  validate: (obj: unknown) => T | undefined;
18
20
  onReconnectError?: (error: unknown, n: number, initialSetup: boolean) => void;
19
21
  getParams?: () => Record<string, unknown> | Promise<Record<string, unknown> | undefined> | undefined;
20
22
  });
21
23
  [Symbol.asyncIterator](): AsyncGenerator<T>;
22
- private getSocket;
23
24
  }
24
25
  export default Subscription;
@@ -0,0 +1,23 @@
1
+ import { WebSocket, ClientOptions } from 'ws';
2
+ export declare class WebSocketKeepAlive {
3
+ opts: ClientOptions & {
4
+ getUrl: () => Promise<string>;
5
+ maxReconnectSeconds?: number;
6
+ signal?: AbortSignal;
7
+ heartbeatIntervalMs?: number;
8
+ onReconnectError?: (error: unknown, n: number, initialSetup: boolean) => void;
9
+ };
10
+ ws: WebSocket | null;
11
+ initialSetup: boolean;
12
+ reconnects: number | null;
13
+ constructor(opts: ClientOptions & {
14
+ getUrl: () => Promise<string>;
15
+ maxReconnectSeconds?: number;
16
+ signal?: AbortSignal;
17
+ heartbeatIntervalMs?: number;
18
+ onReconnectError?: (error: unknown, n: number, initialSetup: boolean) => void;
19
+ });
20
+ [Symbol.asyncIterator](): AsyncGenerator<Uint8Array>;
21
+ startHeartbeat(ws: WebSocket): void;
22
+ }
23
+ export default WebSocketKeepAlive;
package/dist/types.d.ts CHANGED
@@ -10,6 +10,11 @@ export declare type Options = {
10
10
  blobLimit?: number;
11
11
  textLimit?: number;
12
12
  };
13
+ rateLimits?: {
14
+ creator: RateLimiterCreator;
15
+ global?: ServerRateLimitDescription[];
16
+ shared?: ServerRateLimitDescription[];
17
+ };
13
18
  };
14
19
  export declare type UndecodedParams = typeof express.request['query'];
15
20
  export declare type Primitive = string | number | boolean;
@@ -65,13 +70,14 @@ export declare const handlerError: zod.ZodObject<{
65
70
  }>;
66
71
  export declare type HandlerError = zod.infer<typeof handlerError>;
67
72
  export declare type HandlerOutput = HandlerSuccess | HandlerError;
68
- export declare type XRPCHandler = (ctx: {
73
+ export declare type XRPCReqContext = {
69
74
  auth: HandlerAuth | undefined;
70
75
  params: Params;
71
76
  input: HandlerInput | undefined;
72
77
  req: express.Request;
73
78
  res: express.Response;
74
- }) => Promise<HandlerOutput> | HandlerOutput | undefined;
79
+ };
80
+ export declare type XRPCHandler = (ctx: XRPCReqContext) => Promise<HandlerOutput> | HandlerOutput | undefined;
75
81
  export declare type XRPCStreamHandler = (ctx: {
76
82
  auth: HandlerAuth | undefined;
77
83
  params: Params;
@@ -86,7 +92,52 @@ export declare type AuthVerifier = (ctx: {
86
92
  export declare type StreamAuthVerifier = (ctx: {
87
93
  req: IncomingMessage;
88
94
  }) => Promise<AuthOutput> | AuthOutput;
95
+ export declare type CalcKeyFn = (ctx: XRPCReqContext) => string;
96
+ export declare type CalcPointsFn = (ctx: XRPCReqContext) => number;
97
+ export interface RateLimiterI {
98
+ consume: RateLimiterConsume;
99
+ }
100
+ export declare type RateLimiterConsume = (ctx: XRPCReqContext, opts?: {
101
+ calcKey?: CalcKeyFn;
102
+ calcPoints?: CalcPointsFn;
103
+ }) => Promise<RateLimiterStatus | RateLimitExceededError | null>;
104
+ export declare type RateLimiterCreator = (opts: {
105
+ keyPrefix: string;
106
+ durationMs: number;
107
+ points: number;
108
+ calcKey?: (ctx: XRPCReqContext) => string;
109
+ calcPoints?: (ctx: XRPCReqContext) => number;
110
+ }) => RateLimiterI;
111
+ export declare type ServerRateLimitDescription = {
112
+ name: string;
113
+ durationMs: number;
114
+ points: number;
115
+ calcKey?: (ctx: XRPCReqContext) => string;
116
+ calcPoints?: (ctx: XRPCReqContext) => number;
117
+ };
118
+ export declare type SharedRateLimitOpts = {
119
+ name: string;
120
+ calcKey?: (ctx: XRPCReqContext) => string;
121
+ calcPoints?: (ctx: XRPCReqContext) => number;
122
+ };
123
+ export declare type RouteRateLimitOpts = {
124
+ durationMs: number;
125
+ points: number;
126
+ calcKey?: (ctx: XRPCReqContext) => string;
127
+ calcPoints?: (ctx: XRPCReqContext) => number;
128
+ };
129
+ export declare type HandlerRateLimitOpts = SharedRateLimitOpts | RouteRateLimitOpts;
130
+ export declare const isShared: (opts: HandlerRateLimitOpts) => opts is SharedRateLimitOpts;
131
+ export declare type RateLimiterStatus = {
132
+ limit: number;
133
+ duration: number;
134
+ remainingPoints: number;
135
+ msBeforeNext: number;
136
+ consumedPoints: number;
137
+ isFirstInDuration: boolean;
138
+ };
89
139
  export declare type XRPCHandlerConfig = {
140
+ rateLimit?: HandlerRateLimitOpts | HandlerRateLimitOpts[];
90
141
  auth?: AuthVerifier;
91
142
  handler: XRPCHandler;
92
143
  };
@@ -117,13 +168,17 @@ export declare class AuthRequiredError extends XRPCError {
117
168
  export declare class ForbiddenError extends XRPCError {
118
169
  constructor(errorMessage?: string, customErrorName?: string);
119
170
  }
171
+ export declare class RateLimitExceededError extends XRPCError {
172
+ status: RateLimiterStatus;
173
+ constructor(status: RateLimiterStatus, errorMessage?: string, customErrorName?: string);
174
+ }
120
175
  export declare class InternalServerError extends XRPCError {
121
176
  constructor(errorMessage?: string, customErrorName?: string);
122
177
  }
123
178
  export declare class UpstreamFailureError extends XRPCError {
124
179
  constructor(errorMessage?: string, customErrorName?: string);
125
180
  }
126
- export declare class NotEnoughResoucesError extends XRPCError {
181
+ export declare class NotEnoughResourcesError extends XRPCError {
127
182
  constructor(errorMessage?: string, customErrorName?: string);
128
183
  }
129
184
  export declare class UpstreamTimeoutError extends XRPCError {
package/package.json CHANGED
@@ -1,22 +1,7 @@
1
1
  {
2
2
  "name": "@atproto/xrpc-server",
3
- "version": "0.3.0",
3
+ "version": "0.3.1",
4
4
  "main": "dist/index.js",
5
- "scripts": {
6
- "test": "jest",
7
- "prettier": "prettier --check src/ tests/",
8
- "prettier:fix": "prettier --write src/ tests/",
9
- "lint": "eslint . --ext .ts,.tsx",
10
- "lint:fix": "yarn lint --fix",
11
- "verify": "run-p prettier lint",
12
- "verify:fix": "yarn prettier:fix && yarn lint:fix",
13
- "build": "node ./build.js",
14
- "postbuild": "tsc --build tsconfig.build.json",
15
- "update-main-to-dist": "node ./update-pkg.js --update-main-to-dist",
16
- "update-main-to-src": "node ./update-pkg.js --update-main-to-src",
17
- "prepublish": "npm run update-main-to-dist",
18
- "postpublish": "npm run update-main-to-src"
19
- },
20
5
  "license": "MIT",
21
6
  "repository": {
22
7
  "type": "git",
@@ -24,24 +9,33 @@
24
9
  "directory": "packages/xrpc-server"
25
10
  },
26
11
  "dependencies": {
27
- "@atproto/common": "*",
28
- "@atproto/crypto": "*",
29
- "@atproto/lexicon": "*",
30
12
  "cbor-x": "^1.5.1",
31
13
  "express": "^4.17.2",
32
14
  "http-errors": "^2.0.0",
33
15
  "mime-types": "^2.1.35",
16
+ "rate-limiter-flexible": "^2.4.1",
34
17
  "uint8arrays": "3.0.0",
35
18
  "ws": "^8.12.0",
36
- "zod": "^3.21.4"
19
+ "zod": "^3.21.4",
20
+ "@atproto/common": "^0.3.0",
21
+ "@atproto/crypto": "^0.2.2",
22
+ "@atproto/lexicon": "^0.2.1"
37
23
  },
38
24
  "devDependencies": {
39
- "@atproto/crypto": "*",
40
- "@atproto/xrpc": "*",
41
25
  "@types/express": "^4.17.13",
26
+ "@types/express-serve-static-core": "^4.17.36",
42
27
  "@types/http-errors": "^2.0.1",
43
28
  "@types/ws": "^8.5.4",
44
29
  "get-port": "^6.1.2",
45
- "multiformats": "^9.6.4"
46
- }
47
- }
30
+ "multiformats": "^9.9.0",
31
+ "@atproto/crypto": "^0.2.2",
32
+ "@atproto/xrpc": "^0.3.1"
33
+ },
34
+ "scripts": {
35
+ "test": "jest",
36
+ "build": "node ./build.js",
37
+ "postbuild": "tsc --build tsconfig.build.json",
38
+ "update-main-to-dist": "node ../../update-main-to-dist.js packages/xrpc-server"
39
+ },
40
+ "types": "dist/index.d.ts"
41
+ }
package/src/index.ts CHANGED
@@ -2,6 +2,7 @@ export * from './types'
2
2
  export * from './auth'
3
3
  export * from './server'
4
4
  export * from './stream'
5
+ export * from './rate-limiter'
5
6
 
6
7
  export type { ServerTiming } from './util'
7
8
  export { serverTimingHeader, ServerTimer } from './util'
package/src/logger.ts CHANGED
@@ -1,5 +1,6 @@
1
1
  import { subsystemLogger } from '@atproto/common'
2
2
 
3
- export const logger = subsystemLogger('xrpc-server')
3
+ export const logger: ReturnType<typeof subsystemLogger> =
4
+ subsystemLogger('xrpc-server')
4
5
 
5
6
  export default logger
@@ -0,0 +1,167 @@
1
+ import {
2
+ RateLimiterAbstract,
3
+ RateLimiterMemory,
4
+ RateLimiterRedis,
5
+ RateLimiterRes,
6
+ } from 'rate-limiter-flexible'
7
+ import { logger } from './logger'
8
+ import {
9
+ CalcKeyFn,
10
+ CalcPointsFn,
11
+ RateLimitExceededError,
12
+ RateLimiterConsume,
13
+ RateLimiterI,
14
+ RateLimiterStatus,
15
+ XRPCReqContext,
16
+ } from './types'
17
+
18
+ export type RateLimiterOpts = {
19
+ keyPrefix: string
20
+ durationMs: number
21
+ points: number
22
+ bypassSecret?: string
23
+ calcKey?: CalcKeyFn
24
+ calcPoints?: CalcPointsFn
25
+ failClosed?: boolean
26
+ }
27
+
28
+ export class RateLimiter implements RateLimiterI {
29
+ public limiter: RateLimiterAbstract
30
+ private byPassSecret?: string
31
+ private failClosed?: boolean
32
+ public calcKey: CalcKeyFn
33
+ public calcPoints: CalcPointsFn
34
+
35
+ constructor(limiter: RateLimiterAbstract, opts: RateLimiterOpts) {
36
+ this.limiter = limiter
37
+ this.byPassSecret = opts.bypassSecret
38
+ this.calcKey = opts.calcKey ?? defaultKey
39
+ this.calcPoints = opts.calcPoints ?? defaultPoints
40
+ }
41
+
42
+ static memory(opts: RateLimiterOpts): RateLimiter {
43
+ const limiter = new RateLimiterMemory({
44
+ keyPrefix: opts.keyPrefix,
45
+ duration: Math.floor(opts.durationMs / 1000),
46
+ points: opts.points,
47
+ })
48
+ return new RateLimiter(limiter, opts)
49
+ }
50
+
51
+ static redis(storeClient: unknown, opts: RateLimiterOpts): RateLimiter {
52
+ const limiter = new RateLimiterRedis({
53
+ storeClient,
54
+ keyPrefix: opts.keyPrefix,
55
+ duration: Math.floor(opts.durationMs / 1000),
56
+ points: opts.points,
57
+ })
58
+ return new RateLimiter(limiter, opts)
59
+ }
60
+
61
+ async consume(
62
+ ctx: XRPCReqContext,
63
+ opts?: { calcKey?: CalcKeyFn; calcPoints?: CalcPointsFn },
64
+ ): Promise<RateLimiterStatus | RateLimitExceededError | null> {
65
+ if (
66
+ this.byPassSecret &&
67
+ ctx.req.header('x-ratelimit-bypass') === this.byPassSecret
68
+ ) {
69
+ return null
70
+ }
71
+ const key = opts?.calcKey ? opts.calcKey(ctx) : this.calcKey(ctx)
72
+ const points = opts?.calcPoints
73
+ ? opts.calcPoints(ctx)
74
+ : this.calcPoints(ctx)
75
+ if (points < 1) {
76
+ return null
77
+ }
78
+ try {
79
+ const res = await this.limiter.consume(key, points)
80
+ return formatLimiterStatus(this.limiter, res)
81
+ } catch (err) {
82
+ // yes this library rejects with a res not an error
83
+ if (err instanceof RateLimiterRes) {
84
+ const status = formatLimiterStatus(this.limiter, err)
85
+ return new RateLimitExceededError(status)
86
+ } else {
87
+ if (this.failClosed) {
88
+ throw err
89
+ }
90
+ logger.error(
91
+ {
92
+ err,
93
+ keyPrefix: this.limiter.keyPrefix,
94
+ points: this.limiter.points,
95
+ duration: this.limiter.duration,
96
+ },
97
+ 'rate limiter failed to consume points',
98
+ )
99
+ return null
100
+ }
101
+ }
102
+ }
103
+ }
104
+
105
+ export const formatLimiterStatus = (
106
+ limiter: RateLimiterAbstract,
107
+ res: RateLimiterRes,
108
+ ): RateLimiterStatus => {
109
+ return {
110
+ limit: limiter.points,
111
+ duration: limiter.duration,
112
+ remainingPoints: res.remainingPoints,
113
+ msBeforeNext: res.msBeforeNext,
114
+ consumedPoints: res.consumedPoints,
115
+ isFirstInDuration: res.isFirstInDuration,
116
+ }
117
+ }
118
+
119
+ export const consumeMany = async (
120
+ ctx: XRPCReqContext,
121
+ fns: RateLimiterConsume[],
122
+ ): Promise<RateLimiterStatus | RateLimitExceededError | null> => {
123
+ if (fns.length === 0) return null
124
+ const results = await Promise.all(fns.map((fn) => fn(ctx)))
125
+ const tightestLimit = getTightestLimit(results)
126
+ if (tightestLimit === null) {
127
+ return null
128
+ } else if (tightestLimit instanceof RateLimitExceededError) {
129
+ setResHeaders(ctx, tightestLimit.status)
130
+ return tightestLimit
131
+ } else {
132
+ setResHeaders(ctx, tightestLimit)
133
+ return tightestLimit
134
+ }
135
+ }
136
+
137
+ export const setResHeaders = (
138
+ ctx: XRPCReqContext,
139
+ status: RateLimiterStatus,
140
+ ) => {
141
+ ctx.res.setHeader('RateLimit-Limit', status.limit)
142
+ ctx.res.setHeader('RateLimit-Remaining', status.remainingPoints)
143
+ ctx.res.setHeader(
144
+ 'RateLimit-Reset',
145
+ Math.floor((Date.now() + status.msBeforeNext) / 1000),
146
+ )
147
+ ctx.res.setHeader('RateLimit-Policy', `${status.limit};w=${status.duration}`)
148
+ }
149
+
150
+ export const getTightestLimit = (
151
+ resps: (RateLimiterStatus | RateLimitExceededError | null)[],
152
+ ): RateLimiterStatus | RateLimitExceededError | null => {
153
+ let lowest: RateLimiterStatus | null = null
154
+ for (const resp of resps) {
155
+ if (resp === null) continue
156
+ if (resp instanceof RateLimitExceededError) return resp
157
+ if (lowest === null || resp.remainingPoints < lowest.remainingPoints) {
158
+ lowest = resp
159
+ }
160
+ }
161
+ return lowest
162
+ }
163
+
164
+ // when using a proxy, ensure headers are getting forwarded correctly: `app.set('trust proxy', true)`
165
+ // https://expressjs.com/en/guide/behind-proxies.html
166
+ const defaultKey: CalcKeyFn = (ctx: XRPCReqContext) => ctx.req.ip
167
+ const defaultPoints: CalcPointsFn = () => 1
package/src/server.ts CHANGED
@@ -30,6 +30,11 @@ import {
30
30
  XRPCStreamHandler,
31
31
  Params,
32
32
  InternalServerError,
33
+ XRPCReqContext,
34
+ RateLimiterI,
35
+ RateLimiterConsume,
36
+ isShared,
37
+ RateLimitExceededError,
33
38
  } from './types'
34
39
  import {
35
40
  decodeQueryParams,
@@ -38,6 +43,7 @@ import {
38
43
  validateOutput,
39
44
  } from './util'
40
45
  import log from './logger'
46
+ import { consumeMany } from './rate-limiter'
41
47
 
42
48
  export function createServer(lexicons?: unknown[], options?: Options) {
43
49
  return new Server(lexicons, options)
@@ -50,6 +56,9 @@ export class Server {
50
56
  lex = new Lexicons()
51
57
  options: Options
52
58
  middleware: Record<'json' | 'text', RequestHandler>
59
+ globalRateLimiters: RateLimiterI[]
60
+ sharedRateLimiters: Record<string, RateLimiterI>
61
+ routeRateLimiterFns: Record<string, RateLimiterConsume[]>
53
62
 
54
63
  constructor(lexicons?: unknown[], opts?: Options) {
55
64
  if (lexicons) {
@@ -66,6 +75,27 @@ export class Server {
66
75
  json: express.json({ limit: opts?.payload?.jsonLimit }),
67
76
  text: express.text({ limit: opts?.payload?.textLimit }),
68
77
  }
78
+ this.globalRateLimiters = []
79
+ this.sharedRateLimiters = {}
80
+ this.routeRateLimiterFns = {}
81
+ if (opts?.rateLimits?.global) {
82
+ for (const limit of opts.rateLimits.global) {
83
+ const rateLimiter = opts.rateLimits.creator({
84
+ ...limit,
85
+ keyPrefix: `rl-${limit.name}`,
86
+ })
87
+ this.globalRateLimiters.push(rateLimiter)
88
+ }
89
+ }
90
+ if (opts?.rateLimits?.shared) {
91
+ for (const limit of opts.rateLimits.shared) {
92
+ const rateLimiter = opts.rateLimits.creator({
93
+ ...limit,
94
+ keyPrefix: `rl-${limit.name}`,
95
+ })
96
+ this.sharedRateLimiters[limit.name] = rateLimiter
97
+ }
98
+ }
69
99
  }
70
100
 
71
101
  // handlers
@@ -138,6 +168,7 @@ export class Server {
138
168
  middleware.push(this.middleware.json)
139
169
  middleware.push(this.middleware.text)
140
170
  }
171
+ this.setupRouteRateLimits(nsid, config)
141
172
  this.routes[verb](
142
173
  `/xrpc/${nsid}`,
143
174
  ...middleware,
@@ -185,6 +216,10 @@ export class Server {
185
216
  validateOutput(nsid, def, output, this.lex)
186
217
  const assertValidXrpcParams = (params: unknown) =>
187
218
  this.lex.assertValidXrpcParams(nsid, params)
219
+ const rlFns = this.routeRateLimiterFns[nsid] ?? []
220
+ const consumeRateLimit = (reqCtx: XRPCReqContext) =>
221
+ consumeMany(reqCtx, rlFns)
222
+
188
223
  return async function (req, res, next) {
189
224
  try {
190
225
  // validate request
@@ -203,14 +238,24 @@ export class Server {
203
238
 
204
239
  const locals: RequestLocals = req[kRequestLocals]
205
240
 
206
- // run the handler
207
- const outputUnvalidated = await handler({
241
+ const reqCtx: XRPCReqContext = {
208
242
  params,
209
243
  input,
210
244
  auth: locals.auth,
211
245
  req,
212
246
  res,
213
- })
247
+ }
248
+
249
+ // handle rate limits
250
+ if (consumeRateLimit) {
251
+ const result = await consumeRateLimit(reqCtx)
252
+ if (result instanceof RateLimitExceededError) {
253
+ return next(result)
254
+ }
255
+ }
256
+
257
+ // run the handler
258
+ const outputUnvalidated = await handler(reqCtx)
214
259
 
215
260
  if (isHandlerError(outputUnvalidated)) {
216
261
  throw XRPCError.fromError(outputUnvalidated)
@@ -345,6 +390,55 @@ export class Server {
345
390
  return httpServer
346
391
  }
347
392
  }
393
+
394
+ private setupRouteRateLimits(nsid: string, config: XRPCHandlerConfig) {
395
+ this.routeRateLimiterFns[nsid] = []
396
+ for (const limit of this.globalRateLimiters) {
397
+ const consumeFn = async (ctx: XRPCReqContext) => {
398
+ return limit.consume(ctx)
399
+ }
400
+ this.routeRateLimiterFns[nsid].push(consumeFn)
401
+ }
402
+
403
+ if (config.rateLimit) {
404
+ const limits = Array.isArray(config.rateLimit)
405
+ ? config.rateLimit
406
+ : [config.rateLimit]
407
+ this.routeRateLimiterFns[nsid] = []
408
+ for (const limit of limits) {
409
+ const { calcKey, calcPoints } = limit
410
+ if (isShared(limit)) {
411
+ const rateLimiter = this.sharedRateLimiters[limit.name]
412
+ if (rateLimiter) {
413
+ const consumeFn = (ctx: XRPCReqContext) =>
414
+ rateLimiter.consume(ctx, {
415
+ calcKey,
416
+ calcPoints,
417
+ })
418
+ this.routeRateLimiterFns[nsid].push(consumeFn)
419
+ }
420
+ } else {
421
+ const { durationMs, points } = limit
422
+ const rateLimiter = this.options.rateLimits?.creator({
423
+ keyPrefix: nsid,
424
+ durationMs,
425
+ points,
426
+ calcKey,
427
+ calcPoints,
428
+ })
429
+ if (rateLimiter) {
430
+ this.sharedRateLimiters[nsid] = rateLimiter
431
+ const consumeFn = (ctx: XRPCReqContext) =>
432
+ rateLimiter.consume(ctx, {
433
+ calcKey,
434
+ calcPoints,
435
+ })
436
+ this.routeRateLimiterFns[nsid].push(consumeFn)
437
+ }
438
+ }
439
+ }
440
+ }
441
+ }
348
442
  }
349
443
 
350
444
  function isHandlerSuccess(v: HandlerOutput): v is HandlerSuccess {
@@ -385,14 +479,23 @@ function createAuthMiddleware(verifier: AuthVerifier): RequestHandler {
385
479
  const errorMiddleware: ErrorRequestHandler = function (err, req, res, next) {
386
480
  const locals: RequestLocals | undefined = req[kRequestLocals]
387
481
  const methodSuffix = locals ? ` method ${locals.nsid}` : ''
388
- if (err instanceof XRPCError) {
389
- log.error(err, `error in xrpc${methodSuffix}`)
390
- } else {
482
+ const xrpcError = XRPCError.fromError(err)
483
+ if (xrpcError instanceof InternalServerError) {
484
+ // log trace for unhandled exceptions
391
485
  log.error(err, `unhandled exception in xrpc${methodSuffix}`)
486
+ } else {
487
+ // do not log trace for known xrpc errors
488
+ log.error(
489
+ {
490
+ status: xrpcError.type,
491
+ message: xrpcError.message,
492
+ name: xrpcError.customErrorName,
493
+ },
494
+ `error in xrpc${methodSuffix}`,
495
+ )
392
496
  }
393
497
  if (res.headersSent) {
394
498
  return next(err)
395
499
  }
396
- const xrpcError = XRPCError.fromError(err)
397
500
  return res.status(xrpcError.type).json(xrpcError.payload)
398
501
  }
@@ -1,5 +1,6 @@
1
1
  import { subsystemLogger } from '@atproto/common'
2
2
 
3
- export const logger = subsystemLogger('xrpc-stream')
3
+ export const logger: ReturnType<typeof subsystemLogger> =
4
+ subsystemLogger('xrpc-stream')
4
5
 
5
6
  export default logger