@atproto/xrpc-server 0.2.0 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. package/CHANGELOG.md +8 -0
  2. package/LICENSE +21 -0
  3. package/README.md +11 -4
  4. package/build.js +0 -8
  5. package/dist/auth.d.ts +1 -1
  6. package/dist/index.d.ts +3 -0
  7. package/dist/index.js +17210 -8937
  8. package/dist/index.js.map +4 -4
  9. package/dist/logger.d.ts +2 -1
  10. package/dist/rate-limiter.d.ts +29 -0
  11. package/dist/server.d.ts +5 -1
  12. package/dist/stream/logger.d.ts +2 -1
  13. package/dist/stream/stream.d.ts +5 -2
  14. package/dist/stream/subscription.d.ts +2 -1
  15. package/dist/stream/types.d.ts +6 -6
  16. package/dist/stream/websocket-keepalive.d.ts +23 -0
  17. package/dist/types.d.ts +67 -9
  18. package/dist/util.d.ts +15 -0
  19. package/package.json +19 -25
  20. package/src/auth.ts +2 -2
  21. package/src/index.ts +4 -0
  22. package/src/logger.ts +2 -1
  23. package/src/rate-limiter.ts +167 -0
  24. package/src/server.ts +117 -7
  25. package/src/stream/logger.ts +2 -1
  26. package/src/stream/stream.ts +24 -11
  27. package/src/stream/subscription.ts +21 -107
  28. package/src/stream/websocket-keepalive.ts +151 -0
  29. package/src/types.ts +83 -4
  30. package/src/util.ts +33 -0
  31. package/tests/bodies.test.ts +3 -3
  32. package/tests/procedures.test.ts +12 -12
  33. package/tests/queries.test.ts +19 -14
  34. package/tests/rate-limiter.test.ts +249 -0
  35. package/tests/responses.test.ts +77 -0
  36. package/tests/subscriptions.test.ts +71 -15
  37. package/tsconfig.build.json +1 -1
  38. package/tsconfig.json +3 -3
  39. package/dist/src/index.d.ts +0 -2
  40. package/dist/src/logger.d.ts +0 -2
  41. package/dist/src/server.d.ts +0 -19
  42. package/dist/src/types.d.ts +0 -115
  43. package/dist/src/util.d.ts +0 -10
  44. package/dist/tsconfig.build.tsbuildinfo +0 -1
  45. package/tsconfig.build.tsbuildinfo +0 -1
  46. package/update-pkg.js +0 -14
package/dist/logger.d.ts CHANGED
@@ -1,2 +1,3 @@
1
- export declare const logger: import("pino").default.Logger<import("pino").default.LoggerOptions>;
1
+ import { subsystemLogger } from '@atproto/common';
2
+ export declare const logger: ReturnType<typeof subsystemLogger>;
2
3
  export default logger;
@@ -0,0 +1,29 @@
1
+ import { RateLimiterAbstract, RateLimiterRes } from 'rate-limiter-flexible';
2
+ import { CalcKeyFn, CalcPointsFn, RateLimitExceededError, RateLimiterConsume, RateLimiterI, RateLimiterStatus, XRPCReqContext } from './types';
3
+ export declare type RateLimiterOpts = {
4
+ keyPrefix: string;
5
+ durationMs: number;
6
+ points: number;
7
+ bypassSecret?: string;
8
+ calcKey?: CalcKeyFn;
9
+ calcPoints?: CalcPointsFn;
10
+ failClosed?: boolean;
11
+ };
12
+ export declare class RateLimiter implements RateLimiterI {
13
+ limiter: RateLimiterAbstract;
14
+ private byPassSecret?;
15
+ private failClosed?;
16
+ calcKey: CalcKeyFn;
17
+ calcPoints: CalcPointsFn;
18
+ constructor(limiter: RateLimiterAbstract, opts: RateLimiterOpts);
19
+ static memory(opts: RateLimiterOpts): RateLimiter;
20
+ static redis(storeClient: unknown, opts: RateLimiterOpts): RateLimiter;
21
+ consume(ctx: XRPCReqContext, opts?: {
22
+ calcKey?: CalcKeyFn;
23
+ calcPoints?: CalcPointsFn;
24
+ }): Promise<RateLimiterStatus | RateLimitExceededError | null>;
25
+ }
26
+ export declare const formatLimiterStatus: (limiter: RateLimiterAbstract, res: RateLimiterRes) => RateLimiterStatus;
27
+ export declare const consumeMany: (ctx: XRPCReqContext, fns: RateLimiterConsume[]) => Promise<RateLimiterStatus | RateLimitExceededError | null>;
28
+ export declare const setResHeaders: (ctx: XRPCReqContext, status: RateLimiterStatus) => void;
29
+ export declare const getTightestLimit: (resps: (RateLimiterStatus | RateLimitExceededError | null)[]) => RateLimiterStatus | RateLimitExceededError | null;
package/dist/server.d.ts CHANGED
@@ -1,7 +1,7 @@
1
1
  import express, { NextFunction, RequestHandler } from 'express';
2
2
  import { Lexicons, LexXrpcProcedure, LexXrpcQuery, LexXrpcSubscription } from '@atproto/lexicon';
3
3
  import { XrpcStreamServer } from './stream';
4
- import { XRPCHandler, XRPCHandlerConfig, Options, XRPCStreamHandlerConfig, XRPCStreamHandler } from './types';
4
+ import { XRPCHandler, XRPCHandlerConfig, Options, XRPCStreamHandlerConfig, XRPCStreamHandler, RateLimiterI, RateLimiterConsume } from './types';
5
5
  export declare function createServer(lexicons?: unknown[], options?: Options): Server;
6
6
  export declare class Server {
7
7
  router: import("express-serve-static-core").Express;
@@ -10,6 +10,9 @@ export declare class Server {
10
10
  lex: Lexicons;
11
11
  options: Options;
12
12
  middleware: Record<'json' | 'text', RequestHandler>;
13
+ globalRateLimiters: RateLimiterI[];
14
+ sharedRateLimiters: Record<string, RateLimiterI>;
15
+ routeRateLimiterFns: Record<string, RateLimiterConsume[]>;
13
16
  constructor(lexicons?: unknown[], opts?: Options);
14
17
  method(nsid: string, configOrFn: XRPCHandlerConfig | XRPCHandler): void;
15
18
  addMethod(nsid: string, configOrFn: XRPCHandlerConfig | XRPCHandler): void;
@@ -22,4 +25,5 @@ export declare class Server {
22
25
  createHandler(nsid: string, def: LexXrpcQuery | LexXrpcProcedure, handler: XRPCHandler): RequestHandler;
23
26
  protected addSubscription(nsid: string, def: LexXrpcSubscription, config: XRPCStreamHandlerConfig): Promise<void>;
24
27
  private enableStreamingOnListen;
28
+ private setupRouteRateLimits;
25
29
  }
@@ -1,2 +1,3 @@
1
- export declare const logger: import("pino").default.Logger<import("pino").default.LoggerOptions>;
1
+ import { subsystemLogger } from '@atproto/common';
2
+ export declare const logger: ReturnType<typeof subsystemLogger>;
2
3
  export default logger;
@@ -1,5 +1,8 @@
1
1
  /// <reference types="node" />
2
2
  import { DuplexOptions } from 'stream';
3
3
  import { WebSocket } from 'ws';
4
- export declare function byFrame(ws: WebSocket, options?: DuplexOptions): AsyncGenerator<import("./frames").MessageFrame<unknown> | import("./frames").ErrorFrame<string>, void, unknown>;
5
- export declare function byMessage(ws: WebSocket, options?: DuplexOptions): AsyncGenerator<import("./frames").MessageFrame<unknown>, void, unknown>;
4
+ import { MessageFrame } from './frames';
5
+ export declare function streamByteChunks(ws: WebSocket, options?: DuplexOptions): import("stream").Duplex;
6
+ export declare function byFrame(ws: WebSocket, options?: DuplexOptions): AsyncGenerator<MessageFrame<unknown> | import("./frames").ErrorFrame<string>, void, unknown>;
7
+ export declare function byMessage(ws: WebSocket, options?: DuplexOptions): AsyncGenerator<MessageFrame<unknown>, void, unknown>;
8
+ export declare function ensureChunkIsMessage(chunk: Uint8Array): MessageFrame<unknown>;
@@ -4,6 +4,7 @@ export declare class Subscription<T = unknown> {
4
4
  service: string;
5
5
  method: string;
6
6
  maxReconnectSeconds?: number;
7
+ heartbeatIntervalMs?: number;
7
8
  signal?: AbortSignal;
8
9
  validate: (obj: unknown) => T | undefined;
9
10
  onReconnectError?: (error: unknown, n: number, initialSetup: boolean) => void;
@@ -13,12 +14,12 @@ export declare class Subscription<T = unknown> {
13
14
  service: string;
14
15
  method: string;
15
16
  maxReconnectSeconds?: number;
17
+ heartbeatIntervalMs?: number;
16
18
  signal?: AbortSignal;
17
19
  validate: (obj: unknown) => T | undefined;
18
20
  onReconnectError?: (error: unknown, n: number, initialSetup: boolean) => void;
19
21
  getParams?: () => Record<string, unknown> | Promise<Record<string, unknown> | undefined> | undefined;
20
22
  });
21
23
  [Symbol.asyncIterator](): AsyncGenerator<T>;
22
- private getSocket;
23
24
  }
24
25
  export default Subscription;
@@ -7,11 +7,11 @@ export declare const messageFrameHeader: z.ZodObject<{
7
7
  op: z.ZodLiteral<FrameType.Message>;
8
8
  t: z.ZodOptional<z.ZodString>;
9
9
  }, "strip", z.ZodTypeAny, {
10
- t?: string | undefined;
11
10
  op: FrameType.Message;
12
- }, {
13
11
  t?: string | undefined;
12
+ }, {
14
13
  op: FrameType.Message;
14
+ t?: string | undefined;
15
15
  }>;
16
16
  export declare type MessageFrameHeader = z.infer<typeof messageFrameHeader>;
17
17
  export declare const errorFrameHeader: z.ZodObject<{
@@ -25,11 +25,11 @@ export declare const errorFrameBody: z.ZodObject<{
25
25
  error: z.ZodString;
26
26
  message: z.ZodOptional<z.ZodString>;
27
27
  }, "strip", z.ZodTypeAny, {
28
- message?: string | undefined;
29
28
  error: string;
30
- }, {
31
29
  message?: string | undefined;
30
+ }, {
32
31
  error: string;
32
+ message?: string | undefined;
33
33
  }>;
34
34
  export declare type ErrorFrameHeader = z.infer<typeof errorFrameHeader>;
35
35
  export declare type ErrorFrameBody<T extends string = string> = {
@@ -39,11 +39,11 @@ export declare const frameHeader: z.ZodUnion<[z.ZodObject<{
39
39
  op: z.ZodLiteral<FrameType.Message>;
40
40
  t: z.ZodOptional<z.ZodString>;
41
41
  }, "strip", z.ZodTypeAny, {
42
- t?: string | undefined;
43
42
  op: FrameType.Message;
44
- }, {
45
43
  t?: string | undefined;
44
+ }, {
46
45
  op: FrameType.Message;
46
+ t?: string | undefined;
47
47
  }>, z.ZodObject<{
48
48
  op: z.ZodLiteral<FrameType.Error>;
49
49
  }, "strip", z.ZodTypeAny, {
@@ -0,0 +1,23 @@
1
+ import { WebSocket, ClientOptions } from 'ws';
2
+ export declare class WebSocketKeepAlive {
3
+ opts: ClientOptions & {
4
+ getUrl: () => Promise<string>;
5
+ maxReconnectSeconds?: number;
6
+ signal?: AbortSignal;
7
+ heartbeatIntervalMs?: number;
8
+ onReconnectError?: (error: unknown, n: number, initialSetup: boolean) => void;
9
+ };
10
+ ws: WebSocket | null;
11
+ initialSetup: boolean;
12
+ reconnects: number | null;
13
+ constructor(opts: ClientOptions & {
14
+ getUrl: () => Promise<string>;
15
+ maxReconnectSeconds?: number;
16
+ signal?: AbortSignal;
17
+ heartbeatIntervalMs?: number;
18
+ onReconnectError?: (error: unknown, n: number, initialSetup: boolean) => void;
19
+ });
20
+ [Symbol.asyncIterator](): AsyncGenerator<Uint8Array>;
21
+ startHeartbeat(ws: WebSocket): void;
22
+ }
23
+ export default WebSocketKeepAlive;
package/dist/types.d.ts CHANGED
@@ -10,6 +10,11 @@ export declare type Options = {
10
10
  blobLimit?: number;
11
11
  textLimit?: number;
12
12
  };
13
+ rateLimits?: {
14
+ creator: RateLimiterCreator;
15
+ global?: ServerRateLimitDescription[];
16
+ shared?: ServerRateLimitDescription[];
17
+ };
13
18
  };
14
19
  export declare type UndecodedParams = typeof express.request['query'];
15
20
  export declare type Primitive = string | number | boolean;
@@ -18,11 +23,11 @@ export declare const handlerInput: zod.ZodObject<{
18
23
  encoding: zod.ZodString;
19
24
  body: zod.ZodAny;
20
25
  }, "strip", zod.ZodTypeAny, {
21
- body?: any;
22
26
  encoding: string;
23
- }, {
24
27
  body?: any;
28
+ }, {
25
29
  encoding: string;
30
+ body?: any;
26
31
  }>;
27
32
  export declare type HandlerInput = zod.infer<typeof handlerInput>;
28
33
  export declare const handlerAuth: zod.ZodObject<{
@@ -39,12 +44,15 @@ export declare type HandlerAuth = zod.infer<typeof handlerAuth>;
39
44
  export declare const handlerSuccess: zod.ZodObject<{
40
45
  encoding: zod.ZodString;
41
46
  body: zod.ZodAny;
47
+ headers: zod.ZodOptional<zod.ZodRecord<zod.ZodString, zod.ZodString>>;
42
48
  }, "strip", zod.ZodTypeAny, {
43
- body?: any;
44
49
  encoding: string;
45
- }, {
46
50
  body?: any;
51
+ headers?: Record<string, string> | undefined;
52
+ }, {
47
53
  encoding: string;
54
+ body?: any;
55
+ headers?: Record<string, string> | undefined;
48
56
  }>;
49
57
  export declare type HandlerSuccess = zod.infer<typeof handlerSuccess>;
50
58
  export declare const handlerError: zod.ZodObject<{
@@ -52,23 +60,24 @@ export declare const handlerError: zod.ZodObject<{
52
60
  error: zod.ZodOptional<zod.ZodString>;
53
61
  message: zod.ZodOptional<zod.ZodString>;
54
62
  }, "strip", zod.ZodTypeAny, {
63
+ status: number;
55
64
  error?: string | undefined;
56
65
  message?: string | undefined;
57
- status: number;
58
66
  }, {
67
+ status: number;
59
68
  error?: string | undefined;
60
69
  message?: string | undefined;
61
- status: number;
62
70
  }>;
63
71
  export declare type HandlerError = zod.infer<typeof handlerError>;
64
72
  export declare type HandlerOutput = HandlerSuccess | HandlerError;
65
- export declare type XRPCHandler = (ctx: {
73
+ export declare type XRPCReqContext = {
66
74
  auth: HandlerAuth | undefined;
67
75
  params: Params;
68
76
  input: HandlerInput | undefined;
69
77
  req: express.Request;
70
78
  res: express.Response;
71
- }) => Promise<HandlerOutput> | HandlerOutput | undefined;
79
+ };
80
+ export declare type XRPCHandler = (ctx: XRPCReqContext) => Promise<HandlerOutput> | HandlerOutput | undefined;
72
81
  export declare type XRPCStreamHandler = (ctx: {
73
82
  auth: HandlerAuth | undefined;
74
83
  params: Params;
@@ -83,7 +92,52 @@ export declare type AuthVerifier = (ctx: {
83
92
  export declare type StreamAuthVerifier = (ctx: {
84
93
  req: IncomingMessage;
85
94
  }) => Promise<AuthOutput> | AuthOutput;
95
+ export declare type CalcKeyFn = (ctx: XRPCReqContext) => string;
96
+ export declare type CalcPointsFn = (ctx: XRPCReqContext) => number;
97
+ export interface RateLimiterI {
98
+ consume: RateLimiterConsume;
99
+ }
100
+ export declare type RateLimiterConsume = (ctx: XRPCReqContext, opts?: {
101
+ calcKey?: CalcKeyFn;
102
+ calcPoints?: CalcPointsFn;
103
+ }) => Promise<RateLimiterStatus | RateLimitExceededError | null>;
104
+ export declare type RateLimiterCreator = (opts: {
105
+ keyPrefix: string;
106
+ durationMs: number;
107
+ points: number;
108
+ calcKey?: (ctx: XRPCReqContext) => string;
109
+ calcPoints?: (ctx: XRPCReqContext) => number;
110
+ }) => RateLimiterI;
111
+ export declare type ServerRateLimitDescription = {
112
+ name: string;
113
+ durationMs: number;
114
+ points: number;
115
+ calcKey?: (ctx: XRPCReqContext) => string;
116
+ calcPoints?: (ctx: XRPCReqContext) => number;
117
+ };
118
+ export declare type SharedRateLimitOpts = {
119
+ name: string;
120
+ calcKey?: (ctx: XRPCReqContext) => string;
121
+ calcPoints?: (ctx: XRPCReqContext) => number;
122
+ };
123
+ export declare type RouteRateLimitOpts = {
124
+ durationMs: number;
125
+ points: number;
126
+ calcKey?: (ctx: XRPCReqContext) => string;
127
+ calcPoints?: (ctx: XRPCReqContext) => number;
128
+ };
129
+ export declare type HandlerRateLimitOpts = SharedRateLimitOpts | RouteRateLimitOpts;
130
+ export declare const isShared: (opts: HandlerRateLimitOpts) => opts is SharedRateLimitOpts;
131
+ export declare type RateLimiterStatus = {
132
+ limit: number;
133
+ duration: number;
134
+ remainingPoints: number;
135
+ msBeforeNext: number;
136
+ consumedPoints: number;
137
+ isFirstInDuration: boolean;
138
+ };
86
139
  export declare type XRPCHandlerConfig = {
140
+ rateLimit?: HandlerRateLimitOpts | HandlerRateLimitOpts[];
87
141
  auth?: AuthVerifier;
88
142
  handler: XRPCHandler;
89
143
  };
@@ -114,13 +168,17 @@ export declare class AuthRequiredError extends XRPCError {
114
168
  export declare class ForbiddenError extends XRPCError {
115
169
  constructor(errorMessage?: string, customErrorName?: string);
116
170
  }
171
+ export declare class RateLimitExceededError extends XRPCError {
172
+ status: RateLimiterStatus;
173
+ constructor(status: RateLimiterStatus, errorMessage?: string, customErrorName?: string);
174
+ }
117
175
  export declare class InternalServerError extends XRPCError {
118
176
  constructor(errorMessage?: string, customErrorName?: string);
119
177
  }
120
178
  export declare class UpstreamFailureError extends XRPCError {
121
179
  constructor(errorMessage?: string, customErrorName?: string);
122
180
  }
123
- export declare class NotEnoughResoucesError extends XRPCError {
181
+ export declare class NotEnoughResourcesError extends XRPCError {
124
182
  constructor(errorMessage?: string, customErrorName?: string);
125
183
  }
126
184
  export declare class UpstreamTimeoutError extends XRPCError {
package/dist/util.d.ts CHANGED
@@ -9,3 +9,18 @@ export declare function validateOutput(nsid: string, def: LexXrpcProcedure | Lex
9
9
  export declare function normalizeMime(v: string): any;
10
10
  export declare function hasBody(req: express.Request): string | true | undefined;
11
11
  export declare function processBodyAsBytes(req: express.Request): Promise<Uint8Array>;
12
+ export declare function serverTimingHeader(timings: ServerTiming[]): string;
13
+ export declare class ServerTimer implements ServerTiming {
14
+ name: string;
15
+ description?: string | undefined;
16
+ duration?: number;
17
+ private startMs?;
18
+ constructor(name: string, description?: string | undefined);
19
+ start(): this;
20
+ stop(): this;
21
+ }
22
+ export interface ServerTiming {
23
+ name: string;
24
+ duration?: number;
25
+ description?: string;
26
+ }
package/package.json CHANGED
@@ -1,22 +1,7 @@
1
1
  {
2
2
  "name": "@atproto/xrpc-server",
3
- "version": "0.2.0",
3
+ "version": "0.3.1",
4
4
  "main": "dist/index.js",
5
- "scripts": {
6
- "test": "jest",
7
- "prettier": "prettier --check src/",
8
- "prettier:fix": "prettier --write src/",
9
- "lint": "eslint . --ext .ts,.tsx",
10
- "lint:fix": "yarn lint --fix",
11
- "verify": "run-p prettier lint",
12
- "verify:fix": "yarn prettier:fix && yarn lint:fix",
13
- "build": "node ./build.js",
14
- "postbuild": "tsc --build tsconfig.build.json",
15
- "update-main-to-dist": "node ./update-pkg.js --update-main-to-dist",
16
- "update-main-to-src": "node ./update-pkg.js --update-main-to-src",
17
- "prepublish": "npm run update-main-to-dist",
18
- "postpublish": "npm run update-main-to-src"
19
- },
20
5
  "license": "MIT",
21
6
  "repository": {
22
7
  "type": "git",
@@ -24,24 +9,33 @@
24
9
  "directory": "packages/xrpc-server"
25
10
  },
26
11
  "dependencies": {
27
- "@atproto/common": "*",
28
- "@atproto/crypto": "*",
29
- "@atproto/lexicon": "*",
30
12
  "cbor-x": "^1.5.1",
31
13
  "express": "^4.17.2",
32
14
  "http-errors": "^2.0.0",
33
15
  "mime-types": "^2.1.35",
16
+ "rate-limiter-flexible": "^2.4.1",
34
17
  "uint8arrays": "3.0.0",
35
18
  "ws": "^8.12.0",
36
- "zod": "^3.14.2"
19
+ "zod": "^3.21.4",
20
+ "@atproto/common": "^0.3.0",
21
+ "@atproto/crypto": "^0.2.2",
22
+ "@atproto/lexicon": "^0.2.1"
37
23
  },
38
24
  "devDependencies": {
39
- "@atproto/crypto": "*",
40
- "@atproto/xrpc": "*",
41
25
  "@types/express": "^4.17.13",
26
+ "@types/express-serve-static-core": "^4.17.36",
42
27
  "@types/http-errors": "^2.0.1",
43
28
  "@types/ws": "^8.5.4",
44
29
  "get-port": "^6.1.2",
45
- "multiformats": "^9.6.4"
46
- }
47
- }
30
+ "multiformats": "^9.9.0",
31
+ "@atproto/crypto": "^0.2.2",
32
+ "@atproto/xrpc": "^0.3.1"
33
+ },
34
+ "scripts": {
35
+ "test": "jest",
36
+ "build": "node ./build.js",
37
+ "postbuild": "tsc --build tsconfig.build.json",
38
+ "update-main-to-dist": "node ../../update-main-to-dist.js packages/xrpc-server"
39
+ },
40
+ "types": "dist/index.d.ts"
41
+ }
package/src/auth.ts CHANGED
@@ -44,7 +44,7 @@ const jsonToB64Url = (json: Record<string, unknown>): string => {
44
44
 
45
45
  export const verifyJwt = async (
46
46
  jwtStr: string,
47
- ownDid: string,
47
+ ownDid: string | null, // null indicates to skip the audience check
48
48
  getSigningKey: (did: string) => Promise<string>,
49
49
  ): Promise<string> => {
50
50
  const parts = jwtStr.split('.')
@@ -57,7 +57,7 @@ export const verifyJwt = async (
57
57
  if (Date.now() / 1000 > payload.exp) {
58
58
  throw new AuthRequiredError('jwt expired', 'JwtExpired')
59
59
  }
60
- if (payload.aud !== ownDid) {
60
+ if (ownDid !== null && payload.aud !== ownDid) {
61
61
  throw new AuthRequiredError(
62
62
  'jwt audience does not match service did',
63
63
  'BadJwtAudience',
package/src/index.ts CHANGED
@@ -2,3 +2,7 @@ export * from './types'
2
2
  export * from './auth'
3
3
  export * from './server'
4
4
  export * from './stream'
5
+ export * from './rate-limiter'
6
+
7
+ export type { ServerTiming } from './util'
8
+ export { serverTimingHeader, ServerTimer } from './util'
package/src/logger.ts CHANGED
@@ -1,5 +1,6 @@
1
1
  import { subsystemLogger } from '@atproto/common'
2
2
 
3
- export const logger = subsystemLogger('xrpc-server')
3
+ export const logger: ReturnType<typeof subsystemLogger> =
4
+ subsystemLogger('xrpc-server')
4
5
 
5
6
  export default logger
@@ -0,0 +1,167 @@
1
+ import {
2
+ RateLimiterAbstract,
3
+ RateLimiterMemory,
4
+ RateLimiterRedis,
5
+ RateLimiterRes,
6
+ } from 'rate-limiter-flexible'
7
+ import { logger } from './logger'
8
+ import {
9
+ CalcKeyFn,
10
+ CalcPointsFn,
11
+ RateLimitExceededError,
12
+ RateLimiterConsume,
13
+ RateLimiterI,
14
+ RateLimiterStatus,
15
+ XRPCReqContext,
16
+ } from './types'
17
+
18
+ export type RateLimiterOpts = {
19
+ keyPrefix: string
20
+ durationMs: number
21
+ points: number
22
+ bypassSecret?: string
23
+ calcKey?: CalcKeyFn
24
+ calcPoints?: CalcPointsFn
25
+ failClosed?: boolean
26
+ }
27
+
28
+ export class RateLimiter implements RateLimiterI {
29
+ public limiter: RateLimiterAbstract
30
+ private byPassSecret?: string
31
+ private failClosed?: boolean
32
+ public calcKey: CalcKeyFn
33
+ public calcPoints: CalcPointsFn
34
+
35
+ constructor(limiter: RateLimiterAbstract, opts: RateLimiterOpts) {
36
+ this.limiter = limiter
37
+ this.byPassSecret = opts.bypassSecret
38
+ this.calcKey = opts.calcKey ?? defaultKey
39
+ this.calcPoints = opts.calcPoints ?? defaultPoints
40
+ }
41
+
42
+ static memory(opts: RateLimiterOpts): RateLimiter {
43
+ const limiter = new RateLimiterMemory({
44
+ keyPrefix: opts.keyPrefix,
45
+ duration: Math.floor(opts.durationMs / 1000),
46
+ points: opts.points,
47
+ })
48
+ return new RateLimiter(limiter, opts)
49
+ }
50
+
51
+ static redis(storeClient: unknown, opts: RateLimiterOpts): RateLimiter {
52
+ const limiter = new RateLimiterRedis({
53
+ storeClient,
54
+ keyPrefix: opts.keyPrefix,
55
+ duration: Math.floor(opts.durationMs / 1000),
56
+ points: opts.points,
57
+ })
58
+ return new RateLimiter(limiter, opts)
59
+ }
60
+
61
+ async consume(
62
+ ctx: XRPCReqContext,
63
+ opts?: { calcKey?: CalcKeyFn; calcPoints?: CalcPointsFn },
64
+ ): Promise<RateLimiterStatus | RateLimitExceededError | null> {
65
+ if (
66
+ this.byPassSecret &&
67
+ ctx.req.header('x-ratelimit-bypass') === this.byPassSecret
68
+ ) {
69
+ return null
70
+ }
71
+ const key = opts?.calcKey ? opts.calcKey(ctx) : this.calcKey(ctx)
72
+ const points = opts?.calcPoints
73
+ ? opts.calcPoints(ctx)
74
+ : this.calcPoints(ctx)
75
+ if (points < 1) {
76
+ return null
77
+ }
78
+ try {
79
+ const res = await this.limiter.consume(key, points)
80
+ return formatLimiterStatus(this.limiter, res)
81
+ } catch (err) {
82
+ // yes this library rejects with a res not an error
83
+ if (err instanceof RateLimiterRes) {
84
+ const status = formatLimiterStatus(this.limiter, err)
85
+ return new RateLimitExceededError(status)
86
+ } else {
87
+ if (this.failClosed) {
88
+ throw err
89
+ }
90
+ logger.error(
91
+ {
92
+ err,
93
+ keyPrefix: this.limiter.keyPrefix,
94
+ points: this.limiter.points,
95
+ duration: this.limiter.duration,
96
+ },
97
+ 'rate limiter failed to consume points',
98
+ )
99
+ return null
100
+ }
101
+ }
102
+ }
103
+ }
104
+
105
+ export const formatLimiterStatus = (
106
+ limiter: RateLimiterAbstract,
107
+ res: RateLimiterRes,
108
+ ): RateLimiterStatus => {
109
+ return {
110
+ limit: limiter.points,
111
+ duration: limiter.duration,
112
+ remainingPoints: res.remainingPoints,
113
+ msBeforeNext: res.msBeforeNext,
114
+ consumedPoints: res.consumedPoints,
115
+ isFirstInDuration: res.isFirstInDuration,
116
+ }
117
+ }
118
+
119
+ export const consumeMany = async (
120
+ ctx: XRPCReqContext,
121
+ fns: RateLimiterConsume[],
122
+ ): Promise<RateLimiterStatus | RateLimitExceededError | null> => {
123
+ if (fns.length === 0) return null
124
+ const results = await Promise.all(fns.map((fn) => fn(ctx)))
125
+ const tightestLimit = getTightestLimit(results)
126
+ if (tightestLimit === null) {
127
+ return null
128
+ } else if (tightestLimit instanceof RateLimitExceededError) {
129
+ setResHeaders(ctx, tightestLimit.status)
130
+ return tightestLimit
131
+ } else {
132
+ setResHeaders(ctx, tightestLimit)
133
+ return tightestLimit
134
+ }
135
+ }
136
+
137
+ export const setResHeaders = (
138
+ ctx: XRPCReqContext,
139
+ status: RateLimiterStatus,
140
+ ) => {
141
+ ctx.res.setHeader('RateLimit-Limit', status.limit)
142
+ ctx.res.setHeader('RateLimit-Remaining', status.remainingPoints)
143
+ ctx.res.setHeader(
144
+ 'RateLimit-Reset',
145
+ Math.floor((Date.now() + status.msBeforeNext) / 1000),
146
+ )
147
+ ctx.res.setHeader('RateLimit-Policy', `${status.limit};w=${status.duration}`)
148
+ }
149
+
150
+ export const getTightestLimit = (
151
+ resps: (RateLimiterStatus | RateLimitExceededError | null)[],
152
+ ): RateLimiterStatus | RateLimitExceededError | null => {
153
+ let lowest: RateLimiterStatus | null = null
154
+ for (const resp of resps) {
155
+ if (resp === null) continue
156
+ if (resp instanceof RateLimitExceededError) return resp
157
+ if (lowest === null || resp.remainingPoints < lowest.remainingPoints) {
158
+ lowest = resp
159
+ }
160
+ }
161
+ return lowest
162
+ }
163
+
164
+ // when using a proxy, ensure headers are getting forwarded correctly: `app.set('trust proxy', true)`
165
+ // https://expressjs.com/en/guide/behind-proxies.html
166
+ const defaultKey: CalcKeyFn = (ctx: XRPCReqContext) => ctx.req.ip
167
+ const defaultPoints: CalcPointsFn = () => 1