@atproto/xrpc-server 0.0.1 → 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/server.d.ts CHANGED
@@ -1,19 +1,25 @@
1
1
  import express, { NextFunction, RequestHandler } from 'express';
2
- import { Lexicons, LexXrpcProcedure, LexXrpcQuery } from '@atproto/lexicon';
3
- import { XRPCHandler, XRPCHandlerConfig, Options } from './types';
2
+ import { Lexicons, LexXrpcProcedure, LexXrpcQuery, LexXrpcSubscription } from '@atproto/lexicon';
3
+ import { XrpcStreamServer } from './stream';
4
+ import { XRPCHandler, XRPCHandlerConfig, Options, XRPCStreamHandlerConfig, XRPCStreamHandler } from './types';
4
5
  export declare function createServer(lexicons?: unknown[], options?: Options): Server;
5
6
  export declare class Server {
6
- router: import("express-serve-static-core").Router;
7
+ router: import("express-serve-static-core").Express;
7
8
  routes: import("express-serve-static-core").Router;
9
+ subscriptions: Map<string, XrpcStreamServer>;
8
10
  lex: Lexicons;
9
11
  options: Options;
10
12
  middleware: Record<'json' | 'text', RequestHandler>;
11
13
  constructor(lexicons?: unknown[], opts?: Options);
12
14
  method(nsid: string, configOrFn: XRPCHandlerConfig | XRPCHandler): void;
13
15
  addMethod(nsid: string, configOrFn: XRPCHandlerConfig | XRPCHandler): void;
16
+ streamMethod(nsid: string, configOrFn: XRPCStreamHandlerConfig | XRPCStreamHandler): void;
17
+ addStreamMethod(nsid: string, configOrFn: XRPCStreamHandlerConfig | XRPCStreamHandler): void;
14
18
  addLexicon(doc: unknown): void;
15
19
  addLexicons(docs: unknown[]): void;
16
20
  protected addRoute(nsid: string, def: LexXrpcQuery | LexXrpcProcedure, config: XRPCHandlerConfig): Promise<void>;
17
21
  catchall(req: express.Request, _res: express.Response, next: NextFunction): Promise<void>;
18
22
  createHandler(nsid: string, def: LexXrpcQuery | LexXrpcProcedure, handler: XRPCHandler): RequestHandler;
23
+ protected addSubscription(nsid: string, def: LexXrpcSubscription, config: XRPCStreamHandlerConfig): Promise<void>;
24
+ private enableStreamingOnListen;
19
25
  }
@@ -0,0 +1,25 @@
1
+ import { FrameHeader, FrameType, MessageFrameHeader, ErrorFrameHeader, ErrorFrameBody } from './types';
2
+ export declare abstract class Frame {
3
+ header: FrameHeader;
4
+ body: unknown;
5
+ get op(): FrameType;
6
+ toBytes(): Uint8Array;
7
+ isMessage(): this is MessageFrame<unknown>;
8
+ isError(): this is ErrorFrame;
9
+ static fromBytes(bytes: Uint8Array): MessageFrame<unknown> | ErrorFrame<string>;
10
+ }
11
+ export declare class MessageFrame<T = Record<string, unknown>> extends Frame {
12
+ header: MessageFrameHeader;
13
+ body: T;
14
+ constructor(body: T, opts?: {
15
+ type?: string;
16
+ });
17
+ get type(): string | undefined;
18
+ }
19
+ export declare class ErrorFrame<T extends string = string> extends Frame {
20
+ header: ErrorFrameHeader;
21
+ body: ErrorFrameBody<T>;
22
+ constructor(body: ErrorFrameBody<T>);
23
+ get code(): T & string;
24
+ get message(): string | undefined;
25
+ }
@@ -0,0 +1,5 @@
1
+ export * from './types';
2
+ export * from './frames';
3
+ export * from './stream';
4
+ export * from './subscription';
5
+ export * from './server';
@@ -0,0 +1,2 @@
1
+ export declare const logger: import("pino").default.Logger<import("pino").default.LoggerOptions>;
2
+ export default logger;
@@ -0,0 +1,11 @@
1
+ /// <reference types="node" />
2
+ import { IncomingMessage } from 'http';
3
+ import { WebSocketServer, ServerOptions, WebSocket } from 'ws';
4
+ import { Frame } from './frames';
5
+ export declare class XrpcStreamServer {
6
+ wss: WebSocketServer;
7
+ constructor(opts: ServerOptions & {
8
+ handler: Handler;
9
+ });
10
+ }
11
+ export declare type Handler = (req: IncomingMessage, socket: WebSocket, server: XrpcStreamServer) => AsyncIterable<Frame>;
@@ -0,0 +1,5 @@
1
+ /// <reference types="node" />
2
+ import { DuplexOptions } from 'stream';
3
+ import { WebSocket } from 'ws';
4
+ export declare function byFrame(ws: WebSocket, options?: DuplexOptions): AsyncGenerator<import("./frames").MessageFrame<unknown> | import("./frames").ErrorFrame<string>, void, unknown>;
5
+ export declare function byMessage(ws: WebSocket, options?: DuplexOptions): AsyncGenerator<import("./frames").MessageFrame<unknown>, void, unknown>;
@@ -0,0 +1,24 @@
1
+ import { ClientOptions } from 'ws';
2
+ export declare class Subscription<T = unknown> {
3
+ opts: ClientOptions & {
4
+ service: string;
5
+ method: string;
6
+ maxReconnectSeconds?: number;
7
+ signal?: AbortSignal;
8
+ validate: (obj: unknown) => T | undefined;
9
+ onReconnectError?: (error: unknown, n: number, initialSetup: boolean) => void;
10
+ getParams?: () => Record<string, unknown> | Promise<Record<string, unknown> | undefined> | undefined;
11
+ };
12
+ constructor(opts: ClientOptions & {
13
+ service: string;
14
+ method: string;
15
+ maxReconnectSeconds?: number;
16
+ signal?: AbortSignal;
17
+ validate: (obj: unknown) => T | undefined;
18
+ onReconnectError?: (error: unknown, n: number, initialSetup: boolean) => void;
19
+ getParams?: () => Record<string, unknown> | Promise<Record<string, unknown> | undefined> | undefined;
20
+ });
21
+ [Symbol.asyncIterator](): AsyncGenerator<T>;
22
+ private getSocket;
23
+ }
24
+ export default Subscription;
@@ -0,0 +1,64 @@
1
+ import { z } from 'zod';
2
+ export declare enum FrameType {
3
+ Message = 1,
4
+ Error = -1
5
+ }
6
+ export declare const messageFrameHeader: z.ZodObject<{
7
+ op: z.ZodLiteral<FrameType.Message>;
8
+ t: z.ZodOptional<z.ZodString>;
9
+ }, "strip", z.ZodTypeAny, {
10
+ t?: string | undefined;
11
+ op: FrameType.Message;
12
+ }, {
13
+ t?: string | undefined;
14
+ op: FrameType.Message;
15
+ }>;
16
+ export declare type MessageFrameHeader = z.infer<typeof messageFrameHeader>;
17
+ export declare const errorFrameHeader: z.ZodObject<{
18
+ op: z.ZodLiteral<FrameType.Error>;
19
+ }, "strip", z.ZodTypeAny, {
20
+ op: FrameType.Error;
21
+ }, {
22
+ op: FrameType.Error;
23
+ }>;
24
+ export declare const errorFrameBody: z.ZodObject<{
25
+ error: z.ZodString;
26
+ message: z.ZodOptional<z.ZodString>;
27
+ }, "strip", z.ZodTypeAny, {
28
+ message?: string | undefined;
29
+ error: string;
30
+ }, {
31
+ message?: string | undefined;
32
+ error: string;
33
+ }>;
34
+ export declare type ErrorFrameHeader = z.infer<typeof errorFrameHeader>;
35
+ export declare type ErrorFrameBody<T extends string = string> = {
36
+ error: T;
37
+ } & z.infer<typeof errorFrameBody>;
38
+ export declare const frameHeader: z.ZodUnion<[z.ZodObject<{
39
+ op: z.ZodLiteral<FrameType.Message>;
40
+ t: z.ZodOptional<z.ZodString>;
41
+ }, "strip", z.ZodTypeAny, {
42
+ t?: string | undefined;
43
+ op: FrameType.Message;
44
+ }, {
45
+ t?: string | undefined;
46
+ op: FrameType.Message;
47
+ }>, z.ZodObject<{
48
+ op: z.ZodLiteral<FrameType.Error>;
49
+ }, "strip", z.ZodTypeAny, {
50
+ op: FrameType.Error;
51
+ }, {
52
+ op: FrameType.Error;
53
+ }>]>;
54
+ export declare type FrameHeader = z.infer<typeof frameHeader>;
55
+ export declare class DisconnectError extends Error {
56
+ wsCode: CloseCode;
57
+ xrpcCode?: string | undefined;
58
+ constructor(wsCode?: CloseCode, xrpcCode?: string | undefined);
59
+ }
60
+ export declare enum CloseCode {
61
+ Normal = 1000,
62
+ Abnormal = 1006,
63
+ Policy = 1008
64
+ }
package/dist/types.d.ts CHANGED
@@ -1,3 +1,5 @@
1
+ /// <reference types="node" />
2
+ import { IncomingMessage } from 'http';
1
3
  import express from 'express';
2
4
  import zod from 'zod';
3
5
  import { ResponseType } from '@atproto/xrpc';
@@ -67,15 +69,27 @@ export declare type XRPCHandler = (ctx: {
67
69
  req: express.Request;
68
70
  res: express.Response;
69
71
  }) => Promise<HandlerOutput> | HandlerOutput | undefined;
72
+ export declare type XRPCStreamHandler = (ctx: {
73
+ auth: HandlerAuth | undefined;
74
+ params: Params;
75
+ req: IncomingMessage;
76
+ }) => AsyncIterable<unknown>;
70
77
  export declare type AuthOutput = HandlerAuth | HandlerError;
71
78
  export declare type AuthVerifier = (ctx: {
72
79
  req: express.Request;
73
80
  res: express.Response;
74
81
  }) => Promise<AuthOutput> | AuthOutput;
82
+ export declare type StreamAuthVerifier = (ctx: {
83
+ req: IncomingMessage;
84
+ }) => Promise<AuthOutput> | AuthOutput;
75
85
  export declare type XRPCHandlerConfig = {
76
86
  auth?: AuthVerifier;
77
87
  handler: XRPCHandler;
78
88
  };
89
+ export declare type XRPCStreamHandlerConfig = {
90
+ auth?: StreamAuthVerifier;
91
+ handler: XRPCStreamHandler;
92
+ };
79
93
  export declare class XRPCError extends Error {
80
94
  type: ResponseType;
81
95
  errorMessage?: string | undefined;
@@ -85,6 +99,7 @@ export declare class XRPCError extends Error {
85
99
  error: string | undefined;
86
100
  message: string | undefined;
87
101
  };
102
+ get typeName(): string | undefined;
88
103
  get typeStr(): string | undefined;
89
104
  static fromError(error: unknown): XRPCError;
90
105
  }
package/dist/util.d.ts CHANGED
@@ -1,8 +1,9 @@
1
1
  import express from 'express';
2
- import { Lexicons, LexXrpcProcedure, LexXrpcQuery } from '@atproto/lexicon';
2
+ import { Lexicons, LexXrpcProcedure, LexXrpcQuery, LexXrpcSubscription } from '@atproto/lexicon';
3
3
  import { UndecodedParams, Params, HandlerInput, HandlerSuccess, Options } from './types';
4
- export declare function decodeQueryParams(def: LexXrpcProcedure | LexXrpcQuery, params: UndecodedParams): Params;
4
+ export declare function decodeQueryParams(def: LexXrpcProcedure | LexXrpcQuery | LexXrpcSubscription, params: UndecodedParams): Params;
5
5
  export declare function decodeQueryParam(type: string, value: unknown): string | number | boolean | undefined;
6
+ export declare function getQueryParams(url?: string): Record<string, string | string[]>;
6
7
  export declare function validateInput(nsid: string, def: LexXrpcProcedure | LexXrpcQuery, req: express.Request, opts: Options, lexicons: Lexicons): HandlerInput | undefined;
7
8
  export declare function validateOutput(nsid: string, def: LexXrpcProcedure | LexXrpcQuery, output: HandlerSuccess | undefined, lexicons: Lexicons): HandlerSuccess | undefined;
8
9
  export declare function normalizeMime(v: string): any;
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@atproto/xrpc-server",
3
- "version": "0.0.1",
3
+ "version": "0.1.0",
4
4
  "main": "dist/index.js",
5
5
  "scripts": {
6
6
  "test": "jest",
@@ -21,15 +21,21 @@
21
21
  "dependencies": {
22
22
  "@atproto/common": "*",
23
23
  "@atproto/lexicon": "*",
24
+ "cbor-x": "^1.5.1",
24
25
  "express": "^4.17.2",
25
26
  "http-errors": "^2.0.0",
26
27
  "mime-types": "^2.1.35",
28
+ "uint8arrays": "3.0.0",
29
+ "ws": "^8.12.0",
27
30
  "zod": "^3.14.2"
28
31
  },
29
32
  "devDependencies": {
30
33
  "@atproto/crypto": "*",
31
34
  "@atproto/xrpc": "*",
32
35
  "@types/express": "^4.17.13",
33
- "@types/http-errors": "^2.0.1"
36
+ "@types/http-errors": "^2.0.1",
37
+ "@types/ws": "^8.5.4",
38
+ "get-port": "^6.1.2",
39
+ "multiformats": "^9.6.4"
34
40
  }
35
41
  }
package/src/index.ts CHANGED
@@ -1,2 +1,3 @@
1
1
  export * from './types'
2
2
  export * from './server'
3
+ export * from './stream'
package/src/server.ts CHANGED
@@ -4,7 +4,15 @@ import express, {
4
4
  NextFunction,
5
5
  RequestHandler,
6
6
  } from 'express'
7
- import { Lexicons, LexXrpcProcedure, LexXrpcQuery } from '@atproto/lexicon'
7
+ import {
8
+ Lexicons,
9
+ lexToJson,
10
+ LexXrpcProcedure,
11
+ LexXrpcQuery,
12
+ LexXrpcSubscription,
13
+ } from '@atproto/lexicon'
14
+ import { check, forwardStreamErrors, schema } from '@atproto/common'
15
+ import { ErrorFrame, Frame, MessageFrame, XrpcStreamServer } from './stream'
8
16
  import {
9
17
  XRPCHandler,
10
18
  XRPCError,
@@ -18,8 +26,16 @@ import {
18
26
  AuthVerifier,
19
27
  isHandlerError,
20
28
  Options,
29
+ XRPCStreamHandlerConfig,
30
+ XRPCStreamHandler,
31
+ Params,
21
32
  } from './types'
22
- import { decodeQueryParams, validateInput, validateOutput } from './util'
33
+ import {
34
+ decodeQueryParams,
35
+ getQueryParams,
36
+ validateInput,
37
+ validateOutput,
38
+ } from './util'
23
39
  import log from './logger'
24
40
 
25
41
  export function createServer(lexicons?: unknown[], options?: Options) {
@@ -27,8 +43,9 @@ export function createServer(lexicons?: unknown[], options?: Options) {
27
43
  }
28
44
 
29
45
  export class Server {
30
- router = express.Router()
46
+ router = express()
31
47
  routes = express.Router()
48
+ subscriptions = new Map<string, XrpcStreamServer>()
32
49
  lex = new Lexicons()
33
50
  options: Options
34
51
  middleware: Record<'json' | 'text', RequestHandler>
@@ -40,6 +57,9 @@ export class Server {
40
57
  this.router.use(this.routes)
41
58
  this.router.use('/xrpc/:methodId', this.catchall.bind(this))
42
59
  this.router.use(errorMiddleware)
60
+ this.router.once('mount', (app: express.Application) => {
61
+ this.enableStreamingOnListen(app)
62
+ })
43
63
  this.options = opts ?? {}
44
64
  this.middleware = {
45
65
  json: express.json({ limit: opts?.payload?.jsonLimit }),
@@ -58,10 +78,32 @@ export class Server {
58
78
  const config =
59
79
  typeof configOrFn === 'function' ? { handler: configOrFn } : configOrFn
60
80
  const def = this.lex.getDef(nsid)
61
- if (!def || (def.type !== 'query' && def.type !== 'procedure')) {
81
+ if (def?.type === 'query' || def?.type === 'procedure') {
82
+ this.addRoute(nsid, def, config)
83
+ } else {
62
84
  throw new Error(`Lex def for ${nsid} is not a query or a procedure`)
63
85
  }
64
- this.addRoute(nsid, def, config)
86
+ }
87
+
88
+ streamMethod(
89
+ nsid: string,
90
+ configOrFn: XRPCStreamHandlerConfig | XRPCStreamHandler,
91
+ ) {
92
+ this.addStreamMethod(nsid, configOrFn)
93
+ }
94
+
95
+ addStreamMethod(
96
+ nsid: string,
97
+ configOrFn: XRPCStreamHandlerConfig | XRPCStreamHandler,
98
+ ) {
99
+ const config =
100
+ typeof configOrFn === 'function' ? { handler: configOrFn } : configOrFn
101
+ const def = this.lex.getDef(nsid)
102
+ if (def?.type === 'subscription') {
103
+ this.addSubscription(nsid, def, config)
104
+ } else {
105
+ throw new Error(`Lex def for ${nsid} is not a subscription`)
106
+ }
65
107
  }
66
108
 
67
109
  // schemas
@@ -145,9 +187,9 @@ export class Server {
145
187
  return async function (req, res, next) {
146
188
  try {
147
189
  // validate request
148
- const params = decodeQueryParams(def, req.query)
190
+ let params = decodeQueryParams(def, req.query)
149
191
  try {
150
- assertValidXrpcParams(params)
192
+ params = assertValidXrpcParams(params) as Params
151
193
  } catch (e) {
152
194
  throw new InvalidRequestError(String(e))
153
195
  }
@@ -181,10 +223,16 @@ export class Server {
181
223
  output?.encoding === 'application/json' ||
182
224
  output?.encoding === 'json'
183
225
  ) {
184
- res.status(200).json(output.body)
185
- } else if (output) {
226
+ const json = lexToJson(output.body)
227
+ res.status(200).json(json)
228
+ } else if (output?.body instanceof Readable) {
186
229
  res.header('Content-Type', output.encoding)
230
+ res.status(200)
231
+ forwardStreamErrors(output.body, res)
232
+ output.body.pipe(res)
233
+ } else if (output) {
187
234
  res
235
+ .header('Content-Type', output.encoding)
188
236
  .status(200)
189
237
  .send(
190
238
  output.body instanceof Uint8Array
@@ -200,6 +248,88 @@ export class Server {
200
248
  }
201
249
  }
202
250
  }
251
+
252
+ protected async addSubscription(
253
+ nsid: string,
254
+ def: LexXrpcSubscription,
255
+ config: XRPCStreamHandlerConfig,
256
+ ) {
257
+ const assertValidXrpcParams = (params: unknown) =>
258
+ this.lex.assertValidXrpcParams(nsid, params)
259
+ this.subscriptions.set(
260
+ nsid,
261
+ new XrpcStreamServer({
262
+ noServer: true,
263
+ handler: async function* (req) {
264
+ try {
265
+ // authenticate request
266
+ const auth = await config.auth?.({ req })
267
+ if (isHandlerError(auth)) {
268
+ throw XRPCError.fromError(auth)
269
+ }
270
+ // validate request
271
+ let params = decodeQueryParams(def, getQueryParams(req.url))
272
+ try {
273
+ params = assertValidXrpcParams(params) as Params
274
+ } catch (e) {
275
+ throw new InvalidRequestError(String(e))
276
+ }
277
+ // stream
278
+ const items = config.handler({ req, params, auth })
279
+ for await (const item of items) {
280
+ if (item instanceof Frame) {
281
+ yield item
282
+ continue
283
+ }
284
+ const type = item?.['$type']
285
+ if (!check.is(item, schema.map) || typeof type !== 'string') {
286
+ yield new MessageFrame(item)
287
+ continue
288
+ }
289
+ const split = type.split('#')
290
+ let t: string
291
+ if (
292
+ split.length === 2 &&
293
+ (split[0] === '' || split[0] === nsid)
294
+ ) {
295
+ t = `#${split[1]}`
296
+ } else {
297
+ t = type
298
+ }
299
+ const clone = { ...item }
300
+ delete clone['$type']
301
+ yield new MessageFrame(clone, { type: t })
302
+ }
303
+ } catch (err) {
304
+ const xrpcErrPayload = XRPCError.fromError(err).payload
305
+ yield new ErrorFrame({
306
+ error: xrpcErrPayload.error ?? 'Unknown',
307
+ message: xrpcErrPayload.message,
308
+ })
309
+ }
310
+ },
311
+ }),
312
+ )
313
+ }
314
+
315
+ private enableStreamingOnListen(app: express.Application) {
316
+ const _listen = app.listen
317
+ app.listen = (...args) => {
318
+ // @ts-ignore the args spread
319
+ const httpServer = _listen.call(app, ...args)
320
+ httpServer.on('upgrade', (req, socket, head) => {
321
+ const url = new URL(req.url || '', 'http://x')
322
+ const sub = url.pathname.startsWith('/xrpc/')
323
+ ? this.subscriptions.get(url.pathname.replace('/xrpc/', ''))
324
+ : undefined
325
+ if (!sub) return socket.destroy()
326
+ sub.wss.handleUpgrade(req, socket, head, (ws) =>
327
+ sub.wss.emit('connection', ws, req),
328
+ )
329
+ })
330
+ return httpServer
331
+ }
332
+ }
203
333
  }
204
334
 
205
335
  function isHandlerSuccess(v: HandlerOutput): v is HandlerSuccess {
@@ -0,0 +1,95 @@
1
+ import * as uint8arrays from 'uint8arrays'
2
+ import { cborEncode, cborDecodeMulti } from '@atproto/common'
3
+ import {
4
+ frameHeader,
5
+ FrameHeader,
6
+ FrameType,
7
+ MessageFrameHeader,
8
+ ErrorFrameHeader,
9
+ ErrorFrameBody,
10
+ errorFrameBody,
11
+ } from './types'
12
+
13
+ export abstract class Frame {
14
+ header: FrameHeader
15
+ body: unknown
16
+ get op(): FrameType {
17
+ return this.header.op
18
+ }
19
+ toBytes(): Uint8Array {
20
+ return uint8arrays.concat([cborEncode(this.header), cborEncode(this.body)])
21
+ }
22
+ isMessage(): this is MessageFrame<unknown> {
23
+ return this.op === FrameType.Message
24
+ }
25
+ isError(): this is ErrorFrame {
26
+ return this.op === FrameType.Error
27
+ }
28
+ static fromBytes(bytes: Uint8Array) {
29
+ const decoded = cborDecodeMulti(bytes)
30
+ if (decoded.length > 2) {
31
+ throw new Error('Too many CBOR data items in frame')
32
+ }
33
+ const header = decoded[0]
34
+ let body: unknown = kUnset
35
+ if (decoded.length > 1) {
36
+ body = decoded[1]
37
+ }
38
+ const parsedHeader = frameHeader.safeParse(header)
39
+ if (!parsedHeader.success) {
40
+ throw new Error(`Invalid frame header: ${parsedHeader.error.message}`)
41
+ }
42
+ if (body === kUnset) {
43
+ throw new Error('Missing frame body')
44
+ }
45
+ const frameOp = parsedHeader.data.op
46
+ if (frameOp === FrameType.Message) {
47
+ return new MessageFrame(body, {
48
+ type: parsedHeader.data.t,
49
+ })
50
+ } else if (frameOp === FrameType.Error) {
51
+ const parsedBody = errorFrameBody.safeParse(body)
52
+ if (!parsedBody.success) {
53
+ throw new Error(`Invalid error frame body: ${parsedBody.error.message}`)
54
+ }
55
+ return new ErrorFrame(parsedBody.data)
56
+ } else {
57
+ const exhaustiveCheck: never = frameOp
58
+ throw new Error(`Unknown frame op: ${exhaustiveCheck}`)
59
+ }
60
+ }
61
+ }
62
+
63
+ export class MessageFrame<T = Record<string, unknown>> extends Frame {
64
+ header: MessageFrameHeader
65
+ body: T
66
+ constructor(body: T, opts?: { type?: string }) {
67
+ super()
68
+ this.header =
69
+ opts?.type !== undefined
70
+ ? { op: FrameType.Message, t: opts?.type }
71
+ : { op: FrameType.Message }
72
+ this.body = body
73
+ }
74
+ get type() {
75
+ return this.header.t
76
+ }
77
+ }
78
+
79
+ export class ErrorFrame<T extends string = string> extends Frame {
80
+ header: ErrorFrameHeader
81
+ body: ErrorFrameBody<T>
82
+ constructor(body: ErrorFrameBody<T>) {
83
+ super()
84
+ this.header = { op: FrameType.Error }
85
+ this.body = body
86
+ }
87
+ get code() {
88
+ return this.body.error
89
+ }
90
+ get message() {
91
+ return this.body.message
92
+ }
93
+ }
94
+
95
+ const kUnset = Symbol('unset')
@@ -0,0 +1,5 @@
1
+ export * from './types'
2
+ export * from './frames'
3
+ export * from './stream'
4
+ export * from './subscription'
5
+ export * from './server'
@@ -0,0 +1,5 @@
1
+ import { subsystemLogger } from '@atproto/common'
2
+
3
+ export const logger = subsystemLogger('xrpc-stream')
4
+
5
+ export default logger
@@ -0,0 +1,60 @@
1
+ import { IncomingMessage } from 'http'
2
+ import { WebSocketServer, ServerOptions, WebSocket } from 'ws'
3
+ import { ErrorFrame, Frame } from './frames'
4
+ import logger from './logger'
5
+ import { CloseCode, DisconnectError } from './types'
6
+
7
+ export class XrpcStreamServer {
8
+ wss: WebSocketServer
9
+ constructor(opts: ServerOptions & { handler: Handler }) {
10
+ const { handler, ...serverOpts } = opts
11
+ this.wss = new WebSocketServer(serverOpts)
12
+ this.wss.on('connection', async (socket, req) => {
13
+ socket.on('error', (err) => logger.error(err, 'websocket error'))
14
+ try {
15
+ const iterator = unwrapIterator(handler(req, socket, this))
16
+ socket.once('close', () => iterator.return?.())
17
+ const safeFrames = wrapIterator(iterator)
18
+ for await (const frame of safeFrames) {
19
+ if (frame instanceof ErrorFrame) {
20
+ await new Promise((res, rej) => {
21
+ socket.send(frame.toBytes(), { binary: true }, (err) => {
22
+ if (err) return rej(err)
23
+ res(undefined)
24
+ })
25
+ })
26
+ throw new DisconnectError(CloseCode.Policy, frame.body.error)
27
+ } else {
28
+ socket.send(frame.toBytes(), { binary: true })
29
+ }
30
+ }
31
+ } catch (err) {
32
+ if (err instanceof DisconnectError) {
33
+ return socket.close(err.wsCode, err.xrpcCode)
34
+ } else {
35
+ logger.error(err, 'websocket server error')
36
+ return socket.terminate()
37
+ }
38
+ }
39
+ socket.close(CloseCode.Normal)
40
+ })
41
+ }
42
+ }
43
+
44
+ export type Handler = (
45
+ req: IncomingMessage,
46
+ socket: WebSocket,
47
+ server: XrpcStreamServer,
48
+ ) => AsyncIterable<Frame>
49
+
50
+ function unwrapIterator<T>(iterable: AsyncIterable<T>): AsyncIterator<T> {
51
+ return iterable[Symbol.asyncIterator]()
52
+ }
53
+
54
+ function wrapIterator<T>(iterator: AsyncIterator<T>): AsyncIterable<T> {
55
+ return {
56
+ [Symbol.asyncIterator]() {
57
+ return iterator
58
+ },
59
+ }
60
+ }
@@ -0,0 +1,26 @@
1
+ import { XRPCError, ResponseType } from '@atproto/xrpc'
2
+ import { DuplexOptions } from 'stream'
3
+ import { createWebSocketStream, WebSocket } from 'ws'
4
+ import { Frame } from './frames'
5
+
6
+ export async function* byFrame(ws: WebSocket, options?: DuplexOptions) {
7
+ const wsStream = createWebSocketStream(ws, {
8
+ ...options,
9
+ readableObjectMode: true, // Ensures frame bytes don't get buffered/combined together
10
+ })
11
+ for await (const chunk of wsStream) {
12
+ yield Frame.fromBytes(chunk)
13
+ }
14
+ }
15
+
16
+ export async function* byMessage(ws: WebSocket, options?: DuplexOptions) {
17
+ for await (const frame of byFrame(ws, options)) {
18
+ if (frame.isMessage()) {
19
+ yield frame
20
+ } else if (frame.isError()) {
21
+ throw new XRPCError(-1, frame.code, frame.message)
22
+ } else {
23
+ throw new XRPCError(ResponseType.Unknown, undefined, 'Unknown frame type')
24
+ }
25
+ }
26
+ }