@atproto/xrpc-server 0.0.1 → 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/auth.d.ts +15 -0
- package/dist/index.d.ts +2 -0
- package/dist/index.js +40116 -29848
- package/dist/index.js.map +4 -4
- package/dist/server.d.ts +9 -3
- package/dist/src/index.d.ts +2 -0
- package/dist/src/logger.d.ts +2 -0
- package/dist/src/server.d.ts +19 -0
- package/dist/src/types.d.ts +115 -0
- package/dist/src/util.d.ts +10 -0
- package/dist/stream/frames.d.ts +25 -0
- package/dist/stream/index.d.ts +5 -0
- package/dist/stream/logger.d.ts +2 -0
- package/dist/stream/server.d.ts +11 -0
- package/dist/stream/stream.d.ts +5 -0
- package/dist/stream/subscription.d.ts +24 -0
- package/dist/stream/types.d.ts +64 -0
- package/dist/tsconfig.build.tsbuildinfo +1 -0
- package/dist/types.d.ts +16 -0
- package/dist/util.d.ts +3 -2
- package/package.json +14 -2
- package/src/auth.ts +111 -0
- package/src/index.ts +2 -0
- package/src/server.ts +148 -10
- package/src/stream/frames.ts +95 -0
- package/src/stream/index.ts +5 -0
- package/src/stream/logger.ts +5 -0
- package/src/stream/server.ts +65 -0
- package/src/stream/stream.ts +26 -0
- package/src/stream/subscription.ts +175 -0
- package/src/stream/types.ts +43 -0
- package/src/types.ts +27 -2
- package/src/util.ts +38 -7
- package/tests/_util.ts +36 -1
- package/tests/auth.test.ts +15 -36
- package/tests/bodies.test.ts +50 -9
- package/tests/errors.test.ts +38 -11
- package/tests/frames.test.ts +137 -0
- package/tests/ipld.test.ts +96 -0
- package/tests/parameters.test.ts +13 -45
- package/tests/procedures.test.ts +7 -3
- package/tests/queries.test.ts +7 -3
- package/tests/stream.test.ts +169 -0
- package/tests/subscriptions.test.ts +347 -0
- package/tsconfig.build.tsbuildinfo +1 -1
- package/tsconfig.json +1 -0
package/src/auth.ts
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
import * as common from '@atproto/common'
|
|
2
|
+
import { MINUTE } from '@atproto/common'
|
|
3
|
+
import * as crypto from '@atproto/crypto'
|
|
4
|
+
import * as ui8 from 'uint8arrays'
|
|
5
|
+
import { AuthRequiredError } from './types'
|
|
6
|
+
|
|
7
|
+
type ServiceJwtParams = {
|
|
8
|
+
iss: string
|
|
9
|
+
aud: string
|
|
10
|
+
exp?: number
|
|
11
|
+
keypair: crypto.Keypair
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
export const createServiceJwt = async (
|
|
15
|
+
params: ServiceJwtParams,
|
|
16
|
+
): Promise<string> => {
|
|
17
|
+
const { iss, aud, keypair } = params
|
|
18
|
+
const exp = params.exp ?? Math.floor((Date.now() + MINUTE) / 1000)
|
|
19
|
+
const header = {
|
|
20
|
+
typ: 'JWT',
|
|
21
|
+
alg: keypair.jwtAlg,
|
|
22
|
+
}
|
|
23
|
+
const payload = {
|
|
24
|
+
iss,
|
|
25
|
+
aud,
|
|
26
|
+
exp,
|
|
27
|
+
}
|
|
28
|
+
const toSignStr = `${jsonToB64Url(header)}.${jsonToB64Url(payload)}`
|
|
29
|
+
const toSign = ui8.fromString(toSignStr, 'utf8')
|
|
30
|
+
const sig = await keypair.sign(toSign)
|
|
31
|
+
return `${toSignStr}.${ui8.toString(sig, 'base64url')}`
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
export const createServiceAuthHeaders = async (params: ServiceJwtParams) => {
|
|
35
|
+
const jwt = await createServiceJwt(params)
|
|
36
|
+
return {
|
|
37
|
+
headers: { authorization: `Bearer ${jwt}` },
|
|
38
|
+
}
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
const jsonToB64Url = (json: Record<string, unknown>): string => {
|
|
42
|
+
return common.utf8ToB64Url(JSON.stringify(json))
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
export const verifyJwt = async (
|
|
46
|
+
jwtStr: string,
|
|
47
|
+
ownDid: string,
|
|
48
|
+
getSigningKey: (did: string) => Promise<string>,
|
|
49
|
+
): Promise<string> => {
|
|
50
|
+
const parts = jwtStr.split('.')
|
|
51
|
+
if (parts.length !== 3) {
|
|
52
|
+
throw new AuthRequiredError('poorly formatted jwt', 'BadJwt')
|
|
53
|
+
}
|
|
54
|
+
const payload = parsePayload(parts[1])
|
|
55
|
+
const sig = parts[2]
|
|
56
|
+
|
|
57
|
+
if (Date.now() / 1000 > payload.exp) {
|
|
58
|
+
throw new AuthRequiredError('jwt expired', 'JwtExpired')
|
|
59
|
+
}
|
|
60
|
+
if (payload.aud !== ownDid) {
|
|
61
|
+
throw new AuthRequiredError(
|
|
62
|
+
'jwt audience does not match service did',
|
|
63
|
+
'BadJwtAudience',
|
|
64
|
+
)
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
const msgBytes = ui8.fromString(parts.slice(0, 2).join('.'), 'utf8')
|
|
68
|
+
const sigBytes = ui8.fromString(sig, 'base64url')
|
|
69
|
+
|
|
70
|
+
const signingKey = await getSigningKey(payload.iss)
|
|
71
|
+
|
|
72
|
+
let validSig: boolean
|
|
73
|
+
try {
|
|
74
|
+
validSig = await crypto.verifySignature(signingKey, msgBytes, sigBytes)
|
|
75
|
+
} catch (err) {
|
|
76
|
+
throw new AuthRequiredError(
|
|
77
|
+
'could not verify jwt signature',
|
|
78
|
+
'BadJwtSignature',
|
|
79
|
+
)
|
|
80
|
+
}
|
|
81
|
+
if (!validSig) {
|
|
82
|
+
throw new AuthRequiredError(
|
|
83
|
+
'jwt signature does not match jwt issuer',
|
|
84
|
+
'BadJwtSignature',
|
|
85
|
+
)
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
return payload.iss
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
const parseB64UrlToJson = (b64: string) => {
|
|
92
|
+
return JSON.parse(common.b64UrlToUtf8(b64))
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
const parsePayload = (b64: string): JwtPayload => {
|
|
96
|
+
const payload = parseB64UrlToJson(b64)
|
|
97
|
+
if (!payload || typeof payload !== 'object') {
|
|
98
|
+
throw new AuthRequiredError('poorly formatted jwt', 'BadJwt')
|
|
99
|
+
} else if (typeof payload.exp !== 'number') {
|
|
100
|
+
throw new AuthRequiredError('poorly formatted jwt', 'BadJwt')
|
|
101
|
+
} else if (typeof payload.iss !== 'string') {
|
|
102
|
+
throw new AuthRequiredError('poorly formatted jwt', 'BadJwt')
|
|
103
|
+
}
|
|
104
|
+
return payload
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
type JwtPayload = {
|
|
108
|
+
iss: string
|
|
109
|
+
aud: string
|
|
110
|
+
exp: number
|
|
111
|
+
}
|
package/src/index.ts
CHANGED
package/src/server.ts
CHANGED
|
@@ -4,7 +4,15 @@ import express, {
|
|
|
4
4
|
NextFunction,
|
|
5
5
|
RequestHandler,
|
|
6
6
|
} from 'express'
|
|
7
|
-
import {
|
|
7
|
+
import {
|
|
8
|
+
Lexicons,
|
|
9
|
+
lexToJson,
|
|
10
|
+
LexXrpcProcedure,
|
|
11
|
+
LexXrpcQuery,
|
|
12
|
+
LexXrpcSubscription,
|
|
13
|
+
} from '@atproto/lexicon'
|
|
14
|
+
import { check, forwardStreamErrors, schema } from '@atproto/common'
|
|
15
|
+
import { ErrorFrame, Frame, MessageFrame, XrpcStreamServer } from './stream'
|
|
8
16
|
import {
|
|
9
17
|
XRPCHandler,
|
|
10
18
|
XRPCError,
|
|
@@ -18,8 +26,17 @@ import {
|
|
|
18
26
|
AuthVerifier,
|
|
19
27
|
isHandlerError,
|
|
20
28
|
Options,
|
|
29
|
+
XRPCStreamHandlerConfig,
|
|
30
|
+
XRPCStreamHandler,
|
|
31
|
+
Params,
|
|
32
|
+
InternalServerError,
|
|
21
33
|
} from './types'
|
|
22
|
-
import {
|
|
34
|
+
import {
|
|
35
|
+
decodeQueryParams,
|
|
36
|
+
getQueryParams,
|
|
37
|
+
validateInput,
|
|
38
|
+
validateOutput,
|
|
39
|
+
} from './util'
|
|
23
40
|
import log from './logger'
|
|
24
41
|
|
|
25
42
|
export function createServer(lexicons?: unknown[], options?: Options) {
|
|
@@ -27,8 +44,9 @@ export function createServer(lexicons?: unknown[], options?: Options) {
|
|
|
27
44
|
}
|
|
28
45
|
|
|
29
46
|
export class Server {
|
|
30
|
-
router = express
|
|
47
|
+
router = express()
|
|
31
48
|
routes = express.Router()
|
|
49
|
+
subscriptions = new Map<string, XrpcStreamServer>()
|
|
32
50
|
lex = new Lexicons()
|
|
33
51
|
options: Options
|
|
34
52
|
middleware: Record<'json' | 'text', RequestHandler>
|
|
@@ -40,6 +58,9 @@ export class Server {
|
|
|
40
58
|
this.router.use(this.routes)
|
|
41
59
|
this.router.use('/xrpc/:methodId', this.catchall.bind(this))
|
|
42
60
|
this.router.use(errorMiddleware)
|
|
61
|
+
this.router.once('mount', (app: express.Application) => {
|
|
62
|
+
this.enableStreamingOnListen(app)
|
|
63
|
+
})
|
|
43
64
|
this.options = opts ?? {}
|
|
44
65
|
this.middleware = {
|
|
45
66
|
json: express.json({ limit: opts?.payload?.jsonLimit }),
|
|
@@ -58,10 +79,32 @@ export class Server {
|
|
|
58
79
|
const config =
|
|
59
80
|
typeof configOrFn === 'function' ? { handler: configOrFn } : configOrFn
|
|
60
81
|
const def = this.lex.getDef(nsid)
|
|
61
|
-
if (
|
|
82
|
+
if (def?.type === 'query' || def?.type === 'procedure') {
|
|
83
|
+
this.addRoute(nsid, def, config)
|
|
84
|
+
} else {
|
|
62
85
|
throw new Error(`Lex def for ${nsid} is not a query or a procedure`)
|
|
63
86
|
}
|
|
64
|
-
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
streamMethod(
|
|
90
|
+
nsid: string,
|
|
91
|
+
configOrFn: XRPCStreamHandlerConfig | XRPCStreamHandler,
|
|
92
|
+
) {
|
|
93
|
+
this.addStreamMethod(nsid, configOrFn)
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
addStreamMethod(
|
|
97
|
+
nsid: string,
|
|
98
|
+
configOrFn: XRPCStreamHandlerConfig | XRPCStreamHandler,
|
|
99
|
+
) {
|
|
100
|
+
const config =
|
|
101
|
+
typeof configOrFn === 'function' ? { handler: configOrFn } : configOrFn
|
|
102
|
+
const def = this.lex.getDef(nsid)
|
|
103
|
+
if (def?.type === 'subscription') {
|
|
104
|
+
this.addSubscription(nsid, def, config)
|
|
105
|
+
} else {
|
|
106
|
+
throw new Error(`Lex def for ${nsid} is not a subscription`)
|
|
107
|
+
}
|
|
65
108
|
}
|
|
66
109
|
|
|
67
110
|
// schemas
|
|
@@ -145,9 +188,9 @@ export class Server {
|
|
|
145
188
|
return async function (req, res, next) {
|
|
146
189
|
try {
|
|
147
190
|
// validate request
|
|
148
|
-
|
|
191
|
+
let params = decodeQueryParams(def, req.query)
|
|
149
192
|
try {
|
|
150
|
-
assertValidXrpcParams(params)
|
|
193
|
+
params = assertValidXrpcParams(params) as Params
|
|
151
194
|
} catch (e) {
|
|
152
195
|
throw new InvalidRequestError(String(e))
|
|
153
196
|
}
|
|
@@ -181,10 +224,16 @@ export class Server {
|
|
|
181
224
|
output?.encoding === 'application/json' ||
|
|
182
225
|
output?.encoding === 'json'
|
|
183
226
|
) {
|
|
184
|
-
|
|
185
|
-
|
|
227
|
+
const json = lexToJson(output.body)
|
|
228
|
+
res.status(200).json(json)
|
|
229
|
+
} else if (output?.body instanceof Readable) {
|
|
186
230
|
res.header('Content-Type', output.encoding)
|
|
231
|
+
res.status(200)
|
|
232
|
+
forwardStreamErrors(output.body, res)
|
|
233
|
+
output.body.pipe(res)
|
|
234
|
+
} else if (output) {
|
|
187
235
|
res
|
|
236
|
+
.header('Content-Type', output.encoding)
|
|
188
237
|
.status(200)
|
|
189
238
|
.send(
|
|
190
239
|
output.body instanceof Uint8Array
|
|
@@ -196,10 +245,99 @@ export class Server {
|
|
|
196
245
|
}
|
|
197
246
|
}
|
|
198
247
|
} catch (err: unknown) {
|
|
199
|
-
next(
|
|
248
|
+
// Express will not call the next middleware (errorMiddleware in this case)
|
|
249
|
+
// if the value passed to next is falsy (e.g. null, undefined, 0).
|
|
250
|
+
// Hence we replace it with an InternalServerError.
|
|
251
|
+
if (!err) {
|
|
252
|
+
next(new InternalServerError())
|
|
253
|
+
} else {
|
|
254
|
+
next(err)
|
|
255
|
+
}
|
|
200
256
|
}
|
|
201
257
|
}
|
|
202
258
|
}
|
|
259
|
+
|
|
260
|
+
protected async addSubscription(
|
|
261
|
+
nsid: string,
|
|
262
|
+
def: LexXrpcSubscription,
|
|
263
|
+
config: XRPCStreamHandlerConfig,
|
|
264
|
+
) {
|
|
265
|
+
const assertValidXrpcParams = (params: unknown) =>
|
|
266
|
+
this.lex.assertValidXrpcParams(nsid, params)
|
|
267
|
+
this.subscriptions.set(
|
|
268
|
+
nsid,
|
|
269
|
+
new XrpcStreamServer({
|
|
270
|
+
noServer: true,
|
|
271
|
+
handler: async function* (req, signal) {
|
|
272
|
+
try {
|
|
273
|
+
// authenticate request
|
|
274
|
+
const auth = await config.auth?.({ req })
|
|
275
|
+
if (isHandlerError(auth)) {
|
|
276
|
+
throw XRPCError.fromError(auth)
|
|
277
|
+
}
|
|
278
|
+
// validate request
|
|
279
|
+
let params = decodeQueryParams(def, getQueryParams(req.url))
|
|
280
|
+
try {
|
|
281
|
+
params = assertValidXrpcParams(params) as Params
|
|
282
|
+
} catch (e) {
|
|
283
|
+
throw new InvalidRequestError(String(e))
|
|
284
|
+
}
|
|
285
|
+
// stream
|
|
286
|
+
const items = config.handler({ req, params, auth, signal })
|
|
287
|
+
for await (const item of items) {
|
|
288
|
+
if (item instanceof Frame) {
|
|
289
|
+
yield item
|
|
290
|
+
continue
|
|
291
|
+
}
|
|
292
|
+
const type = item?.['$type']
|
|
293
|
+
if (!check.is(item, schema.map) || typeof type !== 'string') {
|
|
294
|
+
yield new MessageFrame(item)
|
|
295
|
+
continue
|
|
296
|
+
}
|
|
297
|
+
const split = type.split('#')
|
|
298
|
+
let t: string
|
|
299
|
+
if (
|
|
300
|
+
split.length === 2 &&
|
|
301
|
+
(split[0] === '' || split[0] === nsid)
|
|
302
|
+
) {
|
|
303
|
+
t = `#${split[1]}`
|
|
304
|
+
} else {
|
|
305
|
+
t = type
|
|
306
|
+
}
|
|
307
|
+
const clone = { ...item }
|
|
308
|
+
delete clone['$type']
|
|
309
|
+
yield new MessageFrame(clone, { type: t })
|
|
310
|
+
}
|
|
311
|
+
} catch (err) {
|
|
312
|
+
const xrpcErrPayload = XRPCError.fromError(err).payload
|
|
313
|
+
yield new ErrorFrame({
|
|
314
|
+
error: xrpcErrPayload.error ?? 'Unknown',
|
|
315
|
+
message: xrpcErrPayload.message,
|
|
316
|
+
})
|
|
317
|
+
}
|
|
318
|
+
},
|
|
319
|
+
}),
|
|
320
|
+
)
|
|
321
|
+
}
|
|
322
|
+
|
|
323
|
+
private enableStreamingOnListen(app: express.Application) {
|
|
324
|
+
const _listen = app.listen
|
|
325
|
+
app.listen = (...args) => {
|
|
326
|
+
// @ts-ignore the args spread
|
|
327
|
+
const httpServer = _listen.call(app, ...args)
|
|
328
|
+
httpServer.on('upgrade', (req, socket, head) => {
|
|
329
|
+
const url = new URL(req.url || '', 'http://x')
|
|
330
|
+
const sub = url.pathname.startsWith('/xrpc/')
|
|
331
|
+
? this.subscriptions.get(url.pathname.replace('/xrpc/', ''))
|
|
332
|
+
: undefined
|
|
333
|
+
if (!sub) return socket.destroy()
|
|
334
|
+
sub.wss.handleUpgrade(req, socket, head, (ws) =>
|
|
335
|
+
sub.wss.emit('connection', ws, req),
|
|
336
|
+
)
|
|
337
|
+
})
|
|
338
|
+
return httpServer
|
|
339
|
+
}
|
|
340
|
+
}
|
|
203
341
|
}
|
|
204
342
|
|
|
205
343
|
function isHandlerSuccess(v: HandlerOutput): v is HandlerSuccess {
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
import * as uint8arrays from 'uint8arrays'
|
|
2
|
+
import { cborEncode, cborDecodeMulti } from '@atproto/common'
|
|
3
|
+
import {
|
|
4
|
+
frameHeader,
|
|
5
|
+
FrameHeader,
|
|
6
|
+
FrameType,
|
|
7
|
+
MessageFrameHeader,
|
|
8
|
+
ErrorFrameHeader,
|
|
9
|
+
ErrorFrameBody,
|
|
10
|
+
errorFrameBody,
|
|
11
|
+
} from './types'
|
|
12
|
+
|
|
13
|
+
export abstract class Frame {
|
|
14
|
+
header: FrameHeader
|
|
15
|
+
body: unknown
|
|
16
|
+
get op(): FrameType {
|
|
17
|
+
return this.header.op
|
|
18
|
+
}
|
|
19
|
+
toBytes(): Uint8Array {
|
|
20
|
+
return uint8arrays.concat([cborEncode(this.header), cborEncode(this.body)])
|
|
21
|
+
}
|
|
22
|
+
isMessage(): this is MessageFrame<unknown> {
|
|
23
|
+
return this.op === FrameType.Message
|
|
24
|
+
}
|
|
25
|
+
isError(): this is ErrorFrame {
|
|
26
|
+
return this.op === FrameType.Error
|
|
27
|
+
}
|
|
28
|
+
static fromBytes(bytes: Uint8Array) {
|
|
29
|
+
const decoded = cborDecodeMulti(bytes)
|
|
30
|
+
if (decoded.length > 2) {
|
|
31
|
+
throw new Error('Too many CBOR data items in frame')
|
|
32
|
+
}
|
|
33
|
+
const header = decoded[0]
|
|
34
|
+
let body: unknown = kUnset
|
|
35
|
+
if (decoded.length > 1) {
|
|
36
|
+
body = decoded[1]
|
|
37
|
+
}
|
|
38
|
+
const parsedHeader = frameHeader.safeParse(header)
|
|
39
|
+
if (!parsedHeader.success) {
|
|
40
|
+
throw new Error(`Invalid frame header: ${parsedHeader.error.message}`)
|
|
41
|
+
}
|
|
42
|
+
if (body === kUnset) {
|
|
43
|
+
throw new Error('Missing frame body')
|
|
44
|
+
}
|
|
45
|
+
const frameOp = parsedHeader.data.op
|
|
46
|
+
if (frameOp === FrameType.Message) {
|
|
47
|
+
return new MessageFrame(body, {
|
|
48
|
+
type: parsedHeader.data.t,
|
|
49
|
+
})
|
|
50
|
+
} else if (frameOp === FrameType.Error) {
|
|
51
|
+
const parsedBody = errorFrameBody.safeParse(body)
|
|
52
|
+
if (!parsedBody.success) {
|
|
53
|
+
throw new Error(`Invalid error frame body: ${parsedBody.error.message}`)
|
|
54
|
+
}
|
|
55
|
+
return new ErrorFrame(parsedBody.data)
|
|
56
|
+
} else {
|
|
57
|
+
const exhaustiveCheck: never = frameOp
|
|
58
|
+
throw new Error(`Unknown frame op: ${exhaustiveCheck}`)
|
|
59
|
+
}
|
|
60
|
+
}
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
export class MessageFrame<T = Record<string, unknown>> extends Frame {
|
|
64
|
+
header: MessageFrameHeader
|
|
65
|
+
body: T
|
|
66
|
+
constructor(body: T, opts?: { type?: string }) {
|
|
67
|
+
super()
|
|
68
|
+
this.header =
|
|
69
|
+
opts?.type !== undefined
|
|
70
|
+
? { op: FrameType.Message, t: opts?.type }
|
|
71
|
+
: { op: FrameType.Message }
|
|
72
|
+
this.body = body
|
|
73
|
+
}
|
|
74
|
+
get type() {
|
|
75
|
+
return this.header.t
|
|
76
|
+
}
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
export class ErrorFrame<T extends string = string> extends Frame {
|
|
80
|
+
header: ErrorFrameHeader
|
|
81
|
+
body: ErrorFrameBody<T>
|
|
82
|
+
constructor(body: ErrorFrameBody<T>) {
|
|
83
|
+
super()
|
|
84
|
+
this.header = { op: FrameType.Error }
|
|
85
|
+
this.body = body
|
|
86
|
+
}
|
|
87
|
+
get code() {
|
|
88
|
+
return this.body.error
|
|
89
|
+
}
|
|
90
|
+
get message() {
|
|
91
|
+
return this.body.message
|
|
92
|
+
}
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
const kUnset = Symbol('unset')
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
import { IncomingMessage } from 'http'
|
|
2
|
+
import { WebSocketServer, ServerOptions, WebSocket } from 'ws'
|
|
3
|
+
import { ErrorFrame, Frame } from './frames'
|
|
4
|
+
import logger from './logger'
|
|
5
|
+
import { CloseCode, DisconnectError } from './types'
|
|
6
|
+
|
|
7
|
+
export class XrpcStreamServer {
|
|
8
|
+
wss: WebSocketServer
|
|
9
|
+
constructor(opts: ServerOptions & { handler: Handler }) {
|
|
10
|
+
const { handler, ...serverOpts } = opts
|
|
11
|
+
this.wss = new WebSocketServer(serverOpts)
|
|
12
|
+
this.wss.on('connection', async (socket, req) => {
|
|
13
|
+
socket.on('error', (err) => logger.error(err, 'websocket error'))
|
|
14
|
+
try {
|
|
15
|
+
const ac = new AbortController()
|
|
16
|
+
const iterator = unwrapIterator(handler(req, ac.signal, socket, this))
|
|
17
|
+
socket.once('close', () => {
|
|
18
|
+
iterator.return?.()
|
|
19
|
+
ac.abort()
|
|
20
|
+
})
|
|
21
|
+
const safeFrames = wrapIterator(iterator)
|
|
22
|
+
for await (const frame of safeFrames) {
|
|
23
|
+
await new Promise((res, rej) => {
|
|
24
|
+
socket.send(frame.toBytes(), { binary: true }, (err) => {
|
|
25
|
+
// @TODO this callback may give more aggressive on backpressure than
|
|
26
|
+
// we ultimately want, but trying it out for the time being.
|
|
27
|
+
if (err) return rej(err)
|
|
28
|
+
res(undefined)
|
|
29
|
+
})
|
|
30
|
+
})
|
|
31
|
+
if (frame instanceof ErrorFrame) {
|
|
32
|
+
throw new DisconnectError(CloseCode.Policy, frame.body.error)
|
|
33
|
+
}
|
|
34
|
+
}
|
|
35
|
+
} catch (err) {
|
|
36
|
+
if (err instanceof DisconnectError) {
|
|
37
|
+
return socket.close(err.wsCode, err.xrpcCode)
|
|
38
|
+
} else {
|
|
39
|
+
logger.error(err, 'websocket server error')
|
|
40
|
+
return socket.terminate()
|
|
41
|
+
}
|
|
42
|
+
}
|
|
43
|
+
socket.close(CloseCode.Normal)
|
|
44
|
+
})
|
|
45
|
+
}
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
export type Handler = (
|
|
49
|
+
req: IncomingMessage,
|
|
50
|
+
signal: AbortSignal,
|
|
51
|
+
socket: WebSocket,
|
|
52
|
+
server: XrpcStreamServer,
|
|
53
|
+
) => AsyncIterable<Frame>
|
|
54
|
+
|
|
55
|
+
function unwrapIterator<T>(iterable: AsyncIterable<T>): AsyncIterator<T> {
|
|
56
|
+
return iterable[Symbol.asyncIterator]()
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
function wrapIterator<T>(iterator: AsyncIterator<T>): AsyncIterable<T> {
|
|
60
|
+
return {
|
|
61
|
+
[Symbol.asyncIterator]() {
|
|
62
|
+
return iterator
|
|
63
|
+
},
|
|
64
|
+
}
|
|
65
|
+
}
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
import { XRPCError, ResponseType } from '@atproto/xrpc'
|
|
2
|
+
import { DuplexOptions } from 'stream'
|
|
3
|
+
import { createWebSocketStream, WebSocket } from 'ws'
|
|
4
|
+
import { Frame } from './frames'
|
|
5
|
+
|
|
6
|
+
export async function* byFrame(ws: WebSocket, options?: DuplexOptions) {
|
|
7
|
+
const wsStream = createWebSocketStream(ws, {
|
|
8
|
+
...options,
|
|
9
|
+
readableObjectMode: true, // Ensures frame bytes don't get buffered/combined together
|
|
10
|
+
})
|
|
11
|
+
for await (const chunk of wsStream) {
|
|
12
|
+
yield Frame.fromBytes(chunk)
|
|
13
|
+
}
|
|
14
|
+
}
|
|
15
|
+
|
|
16
|
+
export async function* byMessage(ws: WebSocket, options?: DuplexOptions) {
|
|
17
|
+
for await (const frame of byFrame(ws, options)) {
|
|
18
|
+
if (frame.isMessage()) {
|
|
19
|
+
yield frame
|
|
20
|
+
} else if (frame.isError()) {
|
|
21
|
+
throw new XRPCError(-1, frame.code, frame.message)
|
|
22
|
+
} else {
|
|
23
|
+
throw new XRPCError(ResponseType.Unknown, undefined, 'Unknown frame type')
|
|
24
|
+
}
|
|
25
|
+
}
|
|
26
|
+
}
|