@atproto/xrpc-server 0.0.1 → 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. package/dist/auth.d.ts +15 -0
  2. package/dist/index.d.ts +2 -0
  3. package/dist/index.js +40116 -29848
  4. package/dist/index.js.map +4 -4
  5. package/dist/server.d.ts +9 -3
  6. package/dist/src/index.d.ts +2 -0
  7. package/dist/src/logger.d.ts +2 -0
  8. package/dist/src/server.d.ts +19 -0
  9. package/dist/src/types.d.ts +115 -0
  10. package/dist/src/util.d.ts +10 -0
  11. package/dist/stream/frames.d.ts +25 -0
  12. package/dist/stream/index.d.ts +5 -0
  13. package/dist/stream/logger.d.ts +2 -0
  14. package/dist/stream/server.d.ts +11 -0
  15. package/dist/stream/stream.d.ts +5 -0
  16. package/dist/stream/subscription.d.ts +24 -0
  17. package/dist/stream/types.d.ts +64 -0
  18. package/dist/tsconfig.build.tsbuildinfo +1 -0
  19. package/dist/types.d.ts +16 -0
  20. package/dist/util.d.ts +3 -2
  21. package/package.json +14 -2
  22. package/src/auth.ts +111 -0
  23. package/src/index.ts +2 -0
  24. package/src/server.ts +148 -10
  25. package/src/stream/frames.ts +95 -0
  26. package/src/stream/index.ts +5 -0
  27. package/src/stream/logger.ts +5 -0
  28. package/src/stream/server.ts +65 -0
  29. package/src/stream/stream.ts +26 -0
  30. package/src/stream/subscription.ts +175 -0
  31. package/src/stream/types.ts +43 -0
  32. package/src/types.ts +27 -2
  33. package/src/util.ts +38 -7
  34. package/tests/_util.ts +36 -1
  35. package/tests/auth.test.ts +15 -36
  36. package/tests/bodies.test.ts +50 -9
  37. package/tests/errors.test.ts +38 -11
  38. package/tests/frames.test.ts +137 -0
  39. package/tests/ipld.test.ts +96 -0
  40. package/tests/parameters.test.ts +13 -45
  41. package/tests/procedures.test.ts +7 -3
  42. package/tests/queries.test.ts +7 -3
  43. package/tests/stream.test.ts +169 -0
  44. package/tests/subscriptions.test.ts +347 -0
  45. package/tsconfig.build.tsbuildinfo +1 -1
  46. package/tsconfig.json +1 -0
package/src/auth.ts ADDED
@@ -0,0 +1,111 @@
1
+ import * as common from '@atproto/common'
2
+ import { MINUTE } from '@atproto/common'
3
+ import * as crypto from '@atproto/crypto'
4
+ import * as ui8 from 'uint8arrays'
5
+ import { AuthRequiredError } from './types'
6
+
7
+ type ServiceJwtParams = {
8
+ iss: string
9
+ aud: string
10
+ exp?: number
11
+ keypair: crypto.Keypair
12
+ }
13
+
14
+ export const createServiceJwt = async (
15
+ params: ServiceJwtParams,
16
+ ): Promise<string> => {
17
+ const { iss, aud, keypair } = params
18
+ const exp = params.exp ?? Math.floor((Date.now() + MINUTE) / 1000)
19
+ const header = {
20
+ typ: 'JWT',
21
+ alg: keypair.jwtAlg,
22
+ }
23
+ const payload = {
24
+ iss,
25
+ aud,
26
+ exp,
27
+ }
28
+ const toSignStr = `${jsonToB64Url(header)}.${jsonToB64Url(payload)}`
29
+ const toSign = ui8.fromString(toSignStr, 'utf8')
30
+ const sig = await keypair.sign(toSign)
31
+ return `${toSignStr}.${ui8.toString(sig, 'base64url')}`
32
+ }
33
+
34
+ export const createServiceAuthHeaders = async (params: ServiceJwtParams) => {
35
+ const jwt = await createServiceJwt(params)
36
+ return {
37
+ headers: { authorization: `Bearer ${jwt}` },
38
+ }
39
+ }
40
+
41
+ const jsonToB64Url = (json: Record<string, unknown>): string => {
42
+ return common.utf8ToB64Url(JSON.stringify(json))
43
+ }
44
+
45
+ export const verifyJwt = async (
46
+ jwtStr: string,
47
+ ownDid: string,
48
+ getSigningKey: (did: string) => Promise<string>,
49
+ ): Promise<string> => {
50
+ const parts = jwtStr.split('.')
51
+ if (parts.length !== 3) {
52
+ throw new AuthRequiredError('poorly formatted jwt', 'BadJwt')
53
+ }
54
+ const payload = parsePayload(parts[1])
55
+ const sig = parts[2]
56
+
57
+ if (Date.now() / 1000 > payload.exp) {
58
+ throw new AuthRequiredError('jwt expired', 'JwtExpired')
59
+ }
60
+ if (payload.aud !== ownDid) {
61
+ throw new AuthRequiredError(
62
+ 'jwt audience does not match service did',
63
+ 'BadJwtAudience',
64
+ )
65
+ }
66
+
67
+ const msgBytes = ui8.fromString(parts.slice(0, 2).join('.'), 'utf8')
68
+ const sigBytes = ui8.fromString(sig, 'base64url')
69
+
70
+ const signingKey = await getSigningKey(payload.iss)
71
+
72
+ let validSig: boolean
73
+ try {
74
+ validSig = await crypto.verifySignature(signingKey, msgBytes, sigBytes)
75
+ } catch (err) {
76
+ throw new AuthRequiredError(
77
+ 'could not verify jwt signature',
78
+ 'BadJwtSignature',
79
+ )
80
+ }
81
+ if (!validSig) {
82
+ throw new AuthRequiredError(
83
+ 'jwt signature does not match jwt issuer',
84
+ 'BadJwtSignature',
85
+ )
86
+ }
87
+
88
+ return payload.iss
89
+ }
90
+
91
+ const parseB64UrlToJson = (b64: string) => {
92
+ return JSON.parse(common.b64UrlToUtf8(b64))
93
+ }
94
+
95
+ const parsePayload = (b64: string): JwtPayload => {
96
+ const payload = parseB64UrlToJson(b64)
97
+ if (!payload || typeof payload !== 'object') {
98
+ throw new AuthRequiredError('poorly formatted jwt', 'BadJwt')
99
+ } else if (typeof payload.exp !== 'number') {
100
+ throw new AuthRequiredError('poorly formatted jwt', 'BadJwt')
101
+ } else if (typeof payload.iss !== 'string') {
102
+ throw new AuthRequiredError('poorly formatted jwt', 'BadJwt')
103
+ }
104
+ return payload
105
+ }
106
+
107
+ type JwtPayload = {
108
+ iss: string
109
+ aud: string
110
+ exp: number
111
+ }
package/src/index.ts CHANGED
@@ -1,2 +1,4 @@
1
1
  export * from './types'
2
+ export * from './auth'
2
3
  export * from './server'
4
+ export * from './stream'
package/src/server.ts CHANGED
@@ -4,7 +4,15 @@ import express, {
4
4
  NextFunction,
5
5
  RequestHandler,
6
6
  } from 'express'
7
- import { Lexicons, LexXrpcProcedure, LexXrpcQuery } from '@atproto/lexicon'
7
+ import {
8
+ Lexicons,
9
+ lexToJson,
10
+ LexXrpcProcedure,
11
+ LexXrpcQuery,
12
+ LexXrpcSubscription,
13
+ } from '@atproto/lexicon'
14
+ import { check, forwardStreamErrors, schema } from '@atproto/common'
15
+ import { ErrorFrame, Frame, MessageFrame, XrpcStreamServer } from './stream'
8
16
  import {
9
17
  XRPCHandler,
10
18
  XRPCError,
@@ -18,8 +26,17 @@ import {
18
26
  AuthVerifier,
19
27
  isHandlerError,
20
28
  Options,
29
+ XRPCStreamHandlerConfig,
30
+ XRPCStreamHandler,
31
+ Params,
32
+ InternalServerError,
21
33
  } from './types'
22
- import { decodeQueryParams, validateInput, validateOutput } from './util'
34
+ import {
35
+ decodeQueryParams,
36
+ getQueryParams,
37
+ validateInput,
38
+ validateOutput,
39
+ } from './util'
23
40
  import log from './logger'
24
41
 
25
42
  export function createServer(lexicons?: unknown[], options?: Options) {
@@ -27,8 +44,9 @@ export function createServer(lexicons?: unknown[], options?: Options) {
27
44
  }
28
45
 
29
46
  export class Server {
30
- router = express.Router()
47
+ router = express()
31
48
  routes = express.Router()
49
+ subscriptions = new Map<string, XrpcStreamServer>()
32
50
  lex = new Lexicons()
33
51
  options: Options
34
52
  middleware: Record<'json' | 'text', RequestHandler>
@@ -40,6 +58,9 @@ export class Server {
40
58
  this.router.use(this.routes)
41
59
  this.router.use('/xrpc/:methodId', this.catchall.bind(this))
42
60
  this.router.use(errorMiddleware)
61
+ this.router.once('mount', (app: express.Application) => {
62
+ this.enableStreamingOnListen(app)
63
+ })
43
64
  this.options = opts ?? {}
44
65
  this.middleware = {
45
66
  json: express.json({ limit: opts?.payload?.jsonLimit }),
@@ -58,10 +79,32 @@ export class Server {
58
79
  const config =
59
80
  typeof configOrFn === 'function' ? { handler: configOrFn } : configOrFn
60
81
  const def = this.lex.getDef(nsid)
61
- if (!def || (def.type !== 'query' && def.type !== 'procedure')) {
82
+ if (def?.type === 'query' || def?.type === 'procedure') {
83
+ this.addRoute(nsid, def, config)
84
+ } else {
62
85
  throw new Error(`Lex def for ${nsid} is not a query or a procedure`)
63
86
  }
64
- this.addRoute(nsid, def, config)
87
+ }
88
+
89
+ streamMethod(
90
+ nsid: string,
91
+ configOrFn: XRPCStreamHandlerConfig | XRPCStreamHandler,
92
+ ) {
93
+ this.addStreamMethod(nsid, configOrFn)
94
+ }
95
+
96
+ addStreamMethod(
97
+ nsid: string,
98
+ configOrFn: XRPCStreamHandlerConfig | XRPCStreamHandler,
99
+ ) {
100
+ const config =
101
+ typeof configOrFn === 'function' ? { handler: configOrFn } : configOrFn
102
+ const def = this.lex.getDef(nsid)
103
+ if (def?.type === 'subscription') {
104
+ this.addSubscription(nsid, def, config)
105
+ } else {
106
+ throw new Error(`Lex def for ${nsid} is not a subscription`)
107
+ }
65
108
  }
66
109
 
67
110
  // schemas
@@ -145,9 +188,9 @@ export class Server {
145
188
  return async function (req, res, next) {
146
189
  try {
147
190
  // validate request
148
- const params = decodeQueryParams(def, req.query)
191
+ let params = decodeQueryParams(def, req.query)
149
192
  try {
150
- assertValidXrpcParams(params)
193
+ params = assertValidXrpcParams(params) as Params
151
194
  } catch (e) {
152
195
  throw new InvalidRequestError(String(e))
153
196
  }
@@ -181,10 +224,16 @@ export class Server {
181
224
  output?.encoding === 'application/json' ||
182
225
  output?.encoding === 'json'
183
226
  ) {
184
- res.status(200).json(output.body)
185
- } else if (output) {
227
+ const json = lexToJson(output.body)
228
+ res.status(200).json(json)
229
+ } else if (output?.body instanceof Readable) {
186
230
  res.header('Content-Type', output.encoding)
231
+ res.status(200)
232
+ forwardStreamErrors(output.body, res)
233
+ output.body.pipe(res)
234
+ } else if (output) {
187
235
  res
236
+ .header('Content-Type', output.encoding)
188
237
  .status(200)
189
238
  .send(
190
239
  output.body instanceof Uint8Array
@@ -196,10 +245,99 @@ export class Server {
196
245
  }
197
246
  }
198
247
  } catch (err: unknown) {
199
- next(err)
248
+ // Express will not call the next middleware (errorMiddleware in this case)
249
+ // if the value passed to next is falsy (e.g. null, undefined, 0).
250
+ // Hence we replace it with an InternalServerError.
251
+ if (!err) {
252
+ next(new InternalServerError())
253
+ } else {
254
+ next(err)
255
+ }
200
256
  }
201
257
  }
202
258
  }
259
+
260
+ protected async addSubscription(
261
+ nsid: string,
262
+ def: LexXrpcSubscription,
263
+ config: XRPCStreamHandlerConfig,
264
+ ) {
265
+ const assertValidXrpcParams = (params: unknown) =>
266
+ this.lex.assertValidXrpcParams(nsid, params)
267
+ this.subscriptions.set(
268
+ nsid,
269
+ new XrpcStreamServer({
270
+ noServer: true,
271
+ handler: async function* (req, signal) {
272
+ try {
273
+ // authenticate request
274
+ const auth = await config.auth?.({ req })
275
+ if (isHandlerError(auth)) {
276
+ throw XRPCError.fromError(auth)
277
+ }
278
+ // validate request
279
+ let params = decodeQueryParams(def, getQueryParams(req.url))
280
+ try {
281
+ params = assertValidXrpcParams(params) as Params
282
+ } catch (e) {
283
+ throw new InvalidRequestError(String(e))
284
+ }
285
+ // stream
286
+ const items = config.handler({ req, params, auth, signal })
287
+ for await (const item of items) {
288
+ if (item instanceof Frame) {
289
+ yield item
290
+ continue
291
+ }
292
+ const type = item?.['$type']
293
+ if (!check.is(item, schema.map) || typeof type !== 'string') {
294
+ yield new MessageFrame(item)
295
+ continue
296
+ }
297
+ const split = type.split('#')
298
+ let t: string
299
+ if (
300
+ split.length === 2 &&
301
+ (split[0] === '' || split[0] === nsid)
302
+ ) {
303
+ t = `#${split[1]}`
304
+ } else {
305
+ t = type
306
+ }
307
+ const clone = { ...item }
308
+ delete clone['$type']
309
+ yield new MessageFrame(clone, { type: t })
310
+ }
311
+ } catch (err) {
312
+ const xrpcErrPayload = XRPCError.fromError(err).payload
313
+ yield new ErrorFrame({
314
+ error: xrpcErrPayload.error ?? 'Unknown',
315
+ message: xrpcErrPayload.message,
316
+ })
317
+ }
318
+ },
319
+ }),
320
+ )
321
+ }
322
+
323
+ private enableStreamingOnListen(app: express.Application) {
324
+ const _listen = app.listen
325
+ app.listen = (...args) => {
326
+ // @ts-ignore the args spread
327
+ const httpServer = _listen.call(app, ...args)
328
+ httpServer.on('upgrade', (req, socket, head) => {
329
+ const url = new URL(req.url || '', 'http://x')
330
+ const sub = url.pathname.startsWith('/xrpc/')
331
+ ? this.subscriptions.get(url.pathname.replace('/xrpc/', ''))
332
+ : undefined
333
+ if (!sub) return socket.destroy()
334
+ sub.wss.handleUpgrade(req, socket, head, (ws) =>
335
+ sub.wss.emit('connection', ws, req),
336
+ )
337
+ })
338
+ return httpServer
339
+ }
340
+ }
203
341
  }
204
342
 
205
343
  function isHandlerSuccess(v: HandlerOutput): v is HandlerSuccess {
@@ -0,0 +1,95 @@
1
+ import * as uint8arrays from 'uint8arrays'
2
+ import { cborEncode, cborDecodeMulti } from '@atproto/common'
3
+ import {
4
+ frameHeader,
5
+ FrameHeader,
6
+ FrameType,
7
+ MessageFrameHeader,
8
+ ErrorFrameHeader,
9
+ ErrorFrameBody,
10
+ errorFrameBody,
11
+ } from './types'
12
+
13
+ export abstract class Frame {
14
+ header: FrameHeader
15
+ body: unknown
16
+ get op(): FrameType {
17
+ return this.header.op
18
+ }
19
+ toBytes(): Uint8Array {
20
+ return uint8arrays.concat([cborEncode(this.header), cborEncode(this.body)])
21
+ }
22
+ isMessage(): this is MessageFrame<unknown> {
23
+ return this.op === FrameType.Message
24
+ }
25
+ isError(): this is ErrorFrame {
26
+ return this.op === FrameType.Error
27
+ }
28
+ static fromBytes(bytes: Uint8Array) {
29
+ const decoded = cborDecodeMulti(bytes)
30
+ if (decoded.length > 2) {
31
+ throw new Error('Too many CBOR data items in frame')
32
+ }
33
+ const header = decoded[0]
34
+ let body: unknown = kUnset
35
+ if (decoded.length > 1) {
36
+ body = decoded[1]
37
+ }
38
+ const parsedHeader = frameHeader.safeParse(header)
39
+ if (!parsedHeader.success) {
40
+ throw new Error(`Invalid frame header: ${parsedHeader.error.message}`)
41
+ }
42
+ if (body === kUnset) {
43
+ throw new Error('Missing frame body')
44
+ }
45
+ const frameOp = parsedHeader.data.op
46
+ if (frameOp === FrameType.Message) {
47
+ return new MessageFrame(body, {
48
+ type: parsedHeader.data.t,
49
+ })
50
+ } else if (frameOp === FrameType.Error) {
51
+ const parsedBody = errorFrameBody.safeParse(body)
52
+ if (!parsedBody.success) {
53
+ throw new Error(`Invalid error frame body: ${parsedBody.error.message}`)
54
+ }
55
+ return new ErrorFrame(parsedBody.data)
56
+ } else {
57
+ const exhaustiveCheck: never = frameOp
58
+ throw new Error(`Unknown frame op: ${exhaustiveCheck}`)
59
+ }
60
+ }
61
+ }
62
+
63
+ export class MessageFrame<T = Record<string, unknown>> extends Frame {
64
+ header: MessageFrameHeader
65
+ body: T
66
+ constructor(body: T, opts?: { type?: string }) {
67
+ super()
68
+ this.header =
69
+ opts?.type !== undefined
70
+ ? { op: FrameType.Message, t: opts?.type }
71
+ : { op: FrameType.Message }
72
+ this.body = body
73
+ }
74
+ get type() {
75
+ return this.header.t
76
+ }
77
+ }
78
+
79
+ export class ErrorFrame<T extends string = string> extends Frame {
80
+ header: ErrorFrameHeader
81
+ body: ErrorFrameBody<T>
82
+ constructor(body: ErrorFrameBody<T>) {
83
+ super()
84
+ this.header = { op: FrameType.Error }
85
+ this.body = body
86
+ }
87
+ get code() {
88
+ return this.body.error
89
+ }
90
+ get message() {
91
+ return this.body.message
92
+ }
93
+ }
94
+
95
+ const kUnset = Symbol('unset')
@@ -0,0 +1,5 @@
1
+ export * from './types'
2
+ export * from './frames'
3
+ export * from './stream'
4
+ export * from './subscription'
5
+ export * from './server'
@@ -0,0 +1,5 @@
1
+ import { subsystemLogger } from '@atproto/common'
2
+
3
+ export const logger = subsystemLogger('xrpc-stream')
4
+
5
+ export default logger
@@ -0,0 +1,65 @@
1
+ import { IncomingMessage } from 'http'
2
+ import { WebSocketServer, ServerOptions, WebSocket } from 'ws'
3
+ import { ErrorFrame, Frame } from './frames'
4
+ import logger from './logger'
5
+ import { CloseCode, DisconnectError } from './types'
6
+
7
+ export class XrpcStreamServer {
8
+ wss: WebSocketServer
9
+ constructor(opts: ServerOptions & { handler: Handler }) {
10
+ const { handler, ...serverOpts } = opts
11
+ this.wss = new WebSocketServer(serverOpts)
12
+ this.wss.on('connection', async (socket, req) => {
13
+ socket.on('error', (err) => logger.error(err, 'websocket error'))
14
+ try {
15
+ const ac = new AbortController()
16
+ const iterator = unwrapIterator(handler(req, ac.signal, socket, this))
17
+ socket.once('close', () => {
18
+ iterator.return?.()
19
+ ac.abort()
20
+ })
21
+ const safeFrames = wrapIterator(iterator)
22
+ for await (const frame of safeFrames) {
23
+ await new Promise((res, rej) => {
24
+ socket.send(frame.toBytes(), { binary: true }, (err) => {
25
+ // @TODO this callback may give more aggressive on backpressure than
26
+ // we ultimately want, but trying it out for the time being.
27
+ if (err) return rej(err)
28
+ res(undefined)
29
+ })
30
+ })
31
+ if (frame instanceof ErrorFrame) {
32
+ throw new DisconnectError(CloseCode.Policy, frame.body.error)
33
+ }
34
+ }
35
+ } catch (err) {
36
+ if (err instanceof DisconnectError) {
37
+ return socket.close(err.wsCode, err.xrpcCode)
38
+ } else {
39
+ logger.error(err, 'websocket server error')
40
+ return socket.terminate()
41
+ }
42
+ }
43
+ socket.close(CloseCode.Normal)
44
+ })
45
+ }
46
+ }
47
+
48
+ export type Handler = (
49
+ req: IncomingMessage,
50
+ signal: AbortSignal,
51
+ socket: WebSocket,
52
+ server: XrpcStreamServer,
53
+ ) => AsyncIterable<Frame>
54
+
55
+ function unwrapIterator<T>(iterable: AsyncIterable<T>): AsyncIterator<T> {
56
+ return iterable[Symbol.asyncIterator]()
57
+ }
58
+
59
+ function wrapIterator<T>(iterator: AsyncIterator<T>): AsyncIterable<T> {
60
+ return {
61
+ [Symbol.asyncIterator]() {
62
+ return iterator
63
+ },
64
+ }
65
+ }
@@ -0,0 +1,26 @@
1
+ import { XRPCError, ResponseType } from '@atproto/xrpc'
2
+ import { DuplexOptions } from 'stream'
3
+ import { createWebSocketStream, WebSocket } from 'ws'
4
+ import { Frame } from './frames'
5
+
6
+ export async function* byFrame(ws: WebSocket, options?: DuplexOptions) {
7
+ const wsStream = createWebSocketStream(ws, {
8
+ ...options,
9
+ readableObjectMode: true, // Ensures frame bytes don't get buffered/combined together
10
+ })
11
+ for await (const chunk of wsStream) {
12
+ yield Frame.fromBytes(chunk)
13
+ }
14
+ }
15
+
16
+ export async function* byMessage(ws: WebSocket, options?: DuplexOptions) {
17
+ for await (const frame of byFrame(ws, options)) {
18
+ if (frame.isMessage()) {
19
+ yield frame
20
+ } else if (frame.isError()) {
21
+ throw new XRPCError(-1, frame.code, frame.message)
22
+ } else {
23
+ throw new XRPCError(ResponseType.Unknown, undefined, 'Unknown frame type')
24
+ }
25
+ }
26
+ }