@nmtjs/gateway 0.15.0-beta.3 → 0.15.0-beta.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/src/gateway.ts ADDED
@@ -0,0 +1,698 @@
1
+ import { randomUUID } from 'node:crypto'
2
+ import { isTypedArray } from 'node:util/types'
3
+
4
+ import type {
5
+ Container,
6
+ Hooks,
7
+ Logger,
8
+ LoggerChildOptions,
9
+ ResolveInjectableType,
10
+ } from '@nmtjs/core'
11
+ import type {
12
+ ClientStreamConsumer,
13
+ ProtocolFormats,
14
+ MessageContext as ProtocolMessageContext,
15
+ } from '@nmtjs/protocol/server'
16
+ import { anyAbortSignal, isAbortError } from '@nmtjs/common'
17
+ import { createFactoryInjectable, provide, Scope } from '@nmtjs/core'
18
+ import {
19
+ ClientMessageType,
20
+ isBlobInterface,
21
+ kBlobKey,
22
+ ProtocolBlob,
23
+ ServerMessageType,
24
+ } from '@nmtjs/protocol'
25
+ import { getFormat, ProtocolError, versions } from '@nmtjs/protocol/server'
26
+
27
+ import type { GatewayApi } from './api.ts'
28
+ import type { GatewayConnection } from './connections.ts'
29
+ import type { ProxyableTransportType } from './enums.ts'
30
+ import type { StreamConfig } from './streams.ts'
31
+ import type { TransportWorker, TransportWorkerParams } from './transport.ts'
32
+ import type {
33
+ ConnectionIdentity,
34
+ GatewayRpc,
35
+ GatewayRpcContext,
36
+ } from './types.ts'
37
+ import { ConnectionManager } from './connections.ts'
38
+ import { StreamTimeout } from './enums.ts'
39
+ import * as injectables from './injectables.ts'
40
+ import { RpcManager } from './rpcs.ts'
41
+ import { BlobStreamsManager } from './streams.ts'
42
+
43
+ export interface GatewayOptions {
44
+ logger: Logger
45
+ container: Container
46
+ hooks: Hooks
47
+ formats: ProtocolFormats
48
+ api: GatewayApi
49
+ transports: {
50
+ [key: string]: {
51
+ transport: TransportWorker
52
+ proxyable?: ProxyableTransportType
53
+ }
54
+ }
55
+ identity?: ConnectionIdentity
56
+ rpcStreamConsumeTimeout?: number
57
+ streamTimeouts?: Partial<StreamConfig['timeouts']>
58
+ }
59
+
60
+ export class Gateway {
61
+ readonly logger: Logger
62
+ readonly connections: ConnectionManager
63
+ readonly rpcs: RpcManager
64
+ readonly blobStreams: BlobStreamsManager
65
+ public options: Required<
66
+ Omit<GatewayOptions, 'streamTimeouts'> & {
67
+ streamTimeouts: Required<
68
+ Exclude<GatewayOptions['streamTimeouts'], undefined>
69
+ >
70
+ }
71
+ >
72
+
73
+ constructor(options: GatewayOptions) {
74
+ this.options = {
75
+ rpcStreamConsumeTimeout: 5000,
76
+ streamTimeouts: {
77
+ //@ts-expect-error
78
+ [StreamTimeout.Pull]:
79
+ options.streamTimeouts?.[StreamTimeout.Pull] ?? 5000,
80
+ //@ts-expect-error
81
+ [StreamTimeout.Consume]:
82
+ options.streamTimeouts?.[StreamTimeout.Consume] ?? 5000,
83
+ //@ts-expect-error
84
+ [StreamTimeout.Finish]:
85
+ options.streamTimeouts?.[StreamTimeout.Finish] ?? 10000,
86
+ },
87
+ ...options,
88
+ identity:
89
+ options.identity ??
90
+ createFactoryInjectable({
91
+ dependencies: { connectionId: injectables.connectionId },
92
+ factory: ({ connectionId }) => connectionId,
93
+ }),
94
+ }
95
+ this.logger = options.logger.child({}, gatewayLoggerOptions)
96
+ this.connections = new ConnectionManager()
97
+ this.rpcs = new RpcManager()
98
+ this.blobStreams = new BlobStreamsManager({
99
+ timeouts: this.options.streamTimeouts,
100
+ })
101
+ }
102
+
103
+ async start() {
104
+ const hosts: { url: string; type: ProxyableTransportType }[] = []
105
+ for (const key in this.options.transports) {
106
+ const { transport, proxyable } = this.options.transports[key]
107
+ const url = await transport.start({
108
+ formats: this.options.formats,
109
+ onConnect: this.onConnect(key),
110
+ onDisconnect: this.onDisconnect(key),
111
+ onMessage: this.onMessage(key),
112
+ onRpc: this.onRpc(key),
113
+ })
114
+ this.logger.info(`Transport [${key}] started on [${url}]`)
115
+ if (proxyable) hosts.push({ url, type: proxyable })
116
+ }
117
+ return hosts
118
+ }
119
+
120
+ async stop() {
121
+ // Close all connections
122
+ for (const connection of this.connections.getAll()) {
123
+ await this.closeConnection(connection.id)
124
+ }
125
+
126
+ for (const key in this.options.transports) {
127
+ const { transport } = this.options.transports[key]
128
+ await transport.stop({ formats: this.options.formats })
129
+ this.logger.debug(`Transport [${key}] stopped`)
130
+ }
131
+ }
132
+
133
+ send(transport: string, connectionId: string, data: ArrayBufferView) {
134
+ if (transport in this.options.transports) {
135
+ const transportInstance = this.options.transports[transport].transport
136
+ if (transportInstance.send) {
137
+ return transportInstance.send(connectionId, data)
138
+ }
139
+ }
140
+ }
141
+
142
+ async reload() {
143
+ for (const connections of this.connections.connections.values()) {
144
+ await connections.container.dispose()
145
+ }
146
+ }
147
+
148
+ protected createRpcContext(
149
+ connection: GatewayConnection,
150
+ messageContext: ReturnType<typeof this.createMessageContext>,
151
+ logger: Logger,
152
+ gatewayRpc: GatewayRpc,
153
+ signal?: AbortSignal,
154
+ ): GatewayRpcContext {
155
+ const { callId, payload, procedure, metadata } = gatewayRpc
156
+ const controller = new AbortController()
157
+ this.rpcs.set(connection.id, callId, controller)
158
+
159
+ signal = signal
160
+ ? anyAbortSignal(signal, controller.signal)
161
+ : controller.signal
162
+
163
+ const container = connection.container.fork(Scope.Call)
164
+
165
+ const dispose = async () => {
166
+ const streamAbortReason = 'Stream is not consumed by a user'
167
+
168
+ // Abort streams related to this call
169
+ this.blobStreams.abortClientCallStreams(
170
+ connection.id,
171
+ callId,
172
+ streamAbortReason,
173
+ )
174
+
175
+ this.rpcs.delete(connection.id, callId)
176
+ this.rpcs.releasePull(connection.id, callId)
177
+
178
+ await container.dispose()
179
+ }
180
+
181
+ return {
182
+ ...messageContext,
183
+ callId,
184
+ payload,
185
+ procedure,
186
+ metadata,
187
+ container,
188
+ signal,
189
+ logger: logger.child({ callId, procedure }),
190
+ [Symbol.asyncDispose]: dispose,
191
+ }
192
+ }
193
+
194
+ protected createMessageContext(
195
+ connection: GatewayConnection,
196
+ transportKey: string,
197
+ ) {
198
+ const transport = this.options.transports[transportKey].transport
199
+ const { id: connectionId, protocol, decoder, encoder } = connection
200
+
201
+ const streamId = this.connections.getStreamId.bind(
202
+ this.connections,
203
+ connectionId,
204
+ )
205
+
206
+ return {
207
+ connectionId,
208
+ protocol,
209
+ encoder,
210
+ decoder,
211
+ transport,
212
+ streamId,
213
+ addClientStream: ({ streamId, callId, metadata }) => {
214
+ const stream = this.blobStreams.createClientStream(
215
+ connectionId,
216
+ callId,
217
+ streamId,
218
+ metadata,
219
+ {
220
+ read: (size) => {
221
+ transport.send!(
222
+ connectionId,
223
+ protocol.encodeMessage(
224
+ this.createMessageContext(connection, transportKey),
225
+ ServerMessageType.ClientStreamPull,
226
+ { streamId, size: size || 65535 },
227
+ ),
228
+ )
229
+ },
230
+ },
231
+ )
232
+
233
+ stream.once('error', () => {
234
+ this.send(
235
+ transportKey,
236
+ connectionId,
237
+ protocol.encodeMessage(
238
+ this.createMessageContext(connection, transportKey),
239
+ ServerMessageType.ClientStreamAbort,
240
+ { streamId },
241
+ ),
242
+ )
243
+ })
244
+
245
+ const consume = () => {
246
+ this.blobStreams.consumeClientStream(connectionId, callId, streamId)
247
+ return stream
248
+ }
249
+
250
+ const consumer = Object.defineProperties(consume, {
251
+ [kBlobKey]: {
252
+ enumerable: false,
253
+ configurable: false,
254
+ writable: false,
255
+ value: true,
256
+ },
257
+ metadata: {
258
+ value: metadata,
259
+ enumerable: true,
260
+ configurable: false,
261
+ writable: false,
262
+ },
263
+ }) as ClientStreamConsumer
264
+
265
+ return consumer
266
+ },
267
+ } satisfies ProtocolMessageContext & { [key: string]: unknown }
268
+ }
269
+
270
+ protected onConnect(transport: string): TransportWorkerParams['onConnect'] {
271
+ const logger = this.logger.child({ transport })
272
+ return async (options, ...injections) => {
273
+ logger.debug('Initiating new connection')
274
+
275
+ const protocol = versions[options.protocolVersion]
276
+ if (!protocol) throw new Error('Unsupported protocol version')
277
+
278
+ const id = randomUUID()
279
+ const container = this.options.container.fork(Scope.Connection)
280
+
281
+ try {
282
+ await container.provide([
283
+ provide(injectables.connectionData, options.data),
284
+ provide(injectables.connectionId, id),
285
+ ])
286
+ await container.provide(injections)
287
+
288
+ const identity = await container.resolve(this.options.identity)
289
+
290
+ const { accept, contentType, type } = options
291
+ const { decoder, encoder } = getFormat(this.options.formats, {
292
+ accept,
293
+ contentType,
294
+ })
295
+
296
+ const connection: GatewayConnection = {
297
+ id,
298
+ type,
299
+ identity,
300
+ transport,
301
+ protocol,
302
+ container,
303
+ encoder,
304
+ decoder,
305
+ abortController: new AbortController(),
306
+ }
307
+
308
+ this.connections.add(connection)
309
+
310
+ await container.provide(
311
+ injectables.connectionAbortSignal,
312
+ connection.abortController.signal,
313
+ )
314
+
315
+ logger.debug(
316
+ {
317
+ id,
318
+ protocol: options.protocolVersion,
319
+ type,
320
+ accept,
321
+ contentType,
322
+ identity,
323
+ transportData: options.data,
324
+ },
325
+ 'Connection established',
326
+ )
327
+
328
+ return Object.assign(connection, {
329
+ [Symbol.asyncDispose]: async () => {
330
+ await this.onDisconnect(transport)(connection.id)
331
+ },
332
+ })
333
+ } catch (error) {
334
+ logger.debug({ error }, 'Error establishing connection')
335
+ container.dispose()
336
+ throw error
337
+ }
338
+ }
339
+ }
340
+
341
+ protected onDisconnect(
342
+ transport: string,
343
+ ): TransportWorkerParams['onDisconnect'] {
344
+ const logger = this.logger.child({ transport })
345
+ return async (connectionId) => {
346
+ logger.debug({ connectionId }, 'Disconnecting connection')
347
+ await this.closeConnection(connectionId)
348
+ }
349
+ }
350
+
351
+ protected onMessage(transport: string): TransportWorkerParams['onMessage'] {
352
+ const _logger = this.logger.child({ transport })
353
+
354
+ return async ({ connectionId, data }, ...injections) => {
355
+ const logger = _logger.child({ connectionId })
356
+ try {
357
+ const connection = this.connections.get(connectionId)
358
+ const messageContext = this.createMessageContext(connection, transport)
359
+
360
+ const message = connection.protocol.decodeMessage(
361
+ messageContext,
362
+ Buffer.from(data),
363
+ )
364
+
365
+ logger.trace(message, 'Received message')
366
+
367
+ switch (message.type) {
368
+ case ClientMessageType.Rpc: {
369
+ const rpcContext = this.createRpcContext(
370
+ connection,
371
+ messageContext,
372
+ logger,
373
+ message.rpc,
374
+ )
375
+ try {
376
+ await rpcContext.container.provide([
377
+ ...injections,
378
+ provide(
379
+ injectables.createBlob,
380
+ this.createBlobFunction(rpcContext),
381
+ ),
382
+ ])
383
+ await this.handleRpcMessage(connection, rpcContext)
384
+ } finally {
385
+ await rpcContext[Symbol.asyncDispose]()
386
+ }
387
+ break
388
+ }
389
+ case ClientMessageType.RpcPull: {
390
+ this.rpcs.releasePull(connectionId, message.callId)
391
+ break
392
+ }
393
+ case ClientMessageType.RpcAbort: {
394
+ this.rpcs.abort(connectionId, message.callId)
395
+ break
396
+ }
397
+ case ClientMessageType.ClientStreamAbort: {
398
+ this.blobStreams.abortClientStream(
399
+ connectionId,
400
+ message.streamId,
401
+ message.reason,
402
+ )
403
+ break
404
+ }
405
+ case ClientMessageType.ClientStreamPush: {
406
+ this.blobStreams.pushToClientStream(
407
+ connectionId,
408
+ message.streamId,
409
+ message.chunk,
410
+ )
411
+ break
412
+ }
413
+ case ClientMessageType.ClientStreamEnd: {
414
+ this.blobStreams.endClientStream(connectionId, message.streamId)
415
+ break
416
+ }
417
+ case ClientMessageType.ServerStreamAbort: {
418
+ this.blobStreams.abortServerStream(
419
+ connectionId,
420
+ message.streamId,
421
+ message.reason,
422
+ )
423
+ break
424
+ }
425
+ case ClientMessageType.ServerStreamPull: {
426
+ this.blobStreams.pullServerStream(connectionId, message.streamId)
427
+ break
428
+ }
429
+ default:
430
+ throw new Error('Unknown message type')
431
+ }
432
+ } catch (error) {
433
+ logger.trace({ error }, 'Error handling message')
434
+ throw error
435
+ }
436
+ }
437
+ }
438
+
439
+ protected onRpc(transport: string): TransportWorkerParams['onRpc'] {
440
+ const _logger = this.logger.child({ transport })
441
+ return async (connection, rpc, signal, ...injections) => {
442
+ const logger = _logger.child({ connectionId: connection.id })
443
+ const messageContext = this.createMessageContext(
444
+ connection,
445
+ connection.transport,
446
+ )
447
+ const rpcContext = this.createRpcContext(
448
+ connection,
449
+ messageContext,
450
+ logger,
451
+ rpc,
452
+ signal,
453
+ )
454
+ try {
455
+ await rpcContext.container.provide([
456
+ ...injections,
457
+ provide(injectables.rpcAbortSignal, signal),
458
+ provide(injectables.createBlob, this.createBlobFunction(rpcContext)),
459
+ ])
460
+
461
+ const result = await this.options.api.call({
462
+ connection,
463
+ payload: rpc.payload,
464
+ procedure: rpc.procedure,
465
+ metadata: rpc.metadata,
466
+ container: rpcContext.container,
467
+ signal: rpcContext.signal,
468
+ })
469
+
470
+ if (typeof result === 'function') {
471
+ return result(async () => {
472
+ await rpcContext[Symbol.asyncDispose]()
473
+ })
474
+ } else {
475
+ await rpcContext[Symbol.asyncDispose]()
476
+ return result
477
+ }
478
+ } catch (error) {
479
+ await rpcContext[Symbol.asyncDispose]()
480
+ throw error
481
+ }
482
+ }
483
+ }
484
+
485
+ protected async handleRpcMessage(
486
+ connection: GatewayConnection,
487
+ context: GatewayRpcContext,
488
+ ): Promise<void> {
489
+ const {
490
+ container,
491
+ connectionId,
492
+ transport,
493
+ protocol,
494
+ signal,
495
+ callId,
496
+ procedure,
497
+ payload,
498
+ encoder,
499
+ } = context
500
+ try {
501
+ await container.provide(injectables.rpcAbortSignal, signal)
502
+ const response = await this.options.api.call({
503
+ connection: connection as any,
504
+ container,
505
+ payload,
506
+ procedure,
507
+ signal,
508
+ })
509
+
510
+ if (typeof response === 'function') {
511
+ transport.send!(
512
+ connectionId,
513
+ protocol.encodeMessage(context, ServerMessageType.RpcStreamResponse, {
514
+ callId,
515
+ }),
516
+ )
517
+
518
+ try {
519
+ const consumeTimeoutSignal = this.options.rpcStreamConsumeTimeout
520
+ ? AbortSignal.timeout(this.options.rpcStreamConsumeTimeout)
521
+ : undefined
522
+
523
+ const streamSignal = consumeTimeoutSignal
524
+ ? anyAbortSignal(signal, consumeTimeoutSignal)
525
+ : signal
526
+
527
+ await this.rpcs.awaitPull(connectionId, callId, streamSignal)
528
+
529
+ for await (const chunk of response()) {
530
+ signal.throwIfAborted()
531
+ const chunkEncoded = encoder.encode(chunk)
532
+ transport.send!(
533
+ connectionId,
534
+ protocol.encodeMessage(
535
+ context,
536
+ ServerMessageType.RpcStreamChunk,
537
+ { callId, chunk: chunkEncoded },
538
+ ),
539
+ )
540
+ await this.rpcs.awaitPull(connectionId, callId)
541
+ }
542
+
543
+ transport.send!(
544
+ connectionId,
545
+ protocol.encodeMessage(context, ServerMessageType.RpcStreamEnd, {
546
+ callId,
547
+ }),
548
+ )
549
+ } catch (error) {
550
+ if (!isAbortError(error)) {
551
+ this.logger.error(error)
552
+ }
553
+ transport.send!(
554
+ connectionId,
555
+ protocol.encodeMessage(context, ServerMessageType.RpcStreamAbort, {
556
+ callId,
557
+ }),
558
+ )
559
+ }
560
+ } else {
561
+ const streams = this.blobStreams.getServerStreamsMetadata(
562
+ connectionId,
563
+ callId,
564
+ )
565
+ transport.send!(
566
+ connectionId,
567
+ protocol.encodeMessage(context, ServerMessageType.RpcResponse, {
568
+ callId,
569
+ result: response,
570
+ streams,
571
+ error: null,
572
+ }),
573
+ )
574
+ }
575
+ } catch (error) {
576
+ transport.send!(
577
+ connectionId,
578
+ protocol.encodeMessage(context, ServerMessageType.RpcResponse, {
579
+ callId,
580
+ result: null,
581
+ streams: {},
582
+ error,
583
+ }),
584
+ )
585
+ const level = error instanceof ProtocolError ? 'trace' : 'error'
586
+ this.logger[level](error)
587
+ }
588
+ }
589
+
590
+ protected async closeConnection(connectionId: string) {
591
+ if (this.connections.has(connectionId)) {
592
+ const connection = this.connections.get(connectionId)
593
+ connection.abortController.abort()
594
+ connection.container.dispose()
595
+ }
596
+
597
+ this.rpcs.close(connectionId)
598
+ this.blobStreams.cleanupConnection(connectionId)
599
+ this.connections.remove(connectionId)
600
+ }
601
+
602
+ protected createBlobFunction(
603
+ context: GatewayRpcContext,
604
+ ): ResolveInjectableType<typeof injectables.createBlob> {
605
+ const {
606
+ streamId: getStreamId,
607
+ transport,
608
+ protocol,
609
+ connectionId,
610
+ callId,
611
+ encoder,
612
+ } = context
613
+
614
+ return (source, metadata) => {
615
+ const streamId = getStreamId()
616
+ const blob = ProtocolBlob.from(source, metadata, () => {
617
+ return encoder.encodeBlob(streamId)
618
+ })
619
+ const stream = this.blobStreams.createServerStream(
620
+ connectionId,
621
+ callId,
622
+ streamId,
623
+ blob,
624
+ )
625
+
626
+ stream.on('data', (chunk) => {
627
+ transport.send!(
628
+ connectionId,
629
+ protocol.encodeMessage(context, ServerMessageType.ServerStreamPush, {
630
+ streamId: streamId,
631
+ chunk: Buffer.from(chunk),
632
+ }),
633
+ )
634
+ })
635
+
636
+ stream.on('error', (error) => {
637
+ transport.send!(
638
+ connectionId,
639
+ protocol.encodeMessage(context, ServerMessageType.ServerStreamAbort, {
640
+ streamId: streamId,
641
+ reason: error.message,
642
+ }),
643
+ )
644
+ })
645
+
646
+ stream.once('finish', () => {
647
+ transport.send!(
648
+ connectionId,
649
+ protocol.encodeMessage(context, ServerMessageType.ServerStreamEnd, {
650
+ streamId: streamId,
651
+ }),
652
+ )
653
+ })
654
+
655
+ stream.once('close', () => {
656
+ this.blobStreams.removeServerStream(connectionId, streamId)
657
+ })
658
+
659
+ return blob
660
+ }
661
+ }
662
+ }
663
+
664
+ const gatewayLoggerOptions: LoggerChildOptions = {
665
+ serializers: {
666
+ chunk: (chunk) =>
667
+ isTypedArray(chunk) ? `<Buffer length=${chunk.byteLength}>` : chunk,
668
+ payload: (payload) => {
669
+ function traverseObject(obj: any): any {
670
+ if (Array.isArray(obj)) {
671
+ return obj.map(traverseObject)
672
+ } else if (isTypedArray(obj)) {
673
+ return `<${obj.constructor.name} length=${obj.byteLength}>`
674
+ } else if (typeof obj === 'object' && obj !== null) {
675
+ const result: Record<string, any> = {}
676
+ for (const [key, value] of Object.entries(obj)) {
677
+ result[key] = traverseObject(value)
678
+ }
679
+ return result
680
+ } else if (isBlobInterface(obj)) {
681
+ return `<ClientBlobStream metadata=${JSON.stringify(obj.metadata)}>`
682
+ }
683
+ return obj
684
+ }
685
+ return traverseObject(payload)
686
+ },
687
+ headers: (value) => {
688
+ if (value instanceof Headers) {
689
+ const obj: Record<string, any> = {}
690
+ value.forEach((v, k) => {
691
+ obj[k] = v
692
+ })
693
+ return obj
694
+ }
695
+ return value
696
+ },
697
+ },
698
+ }
package/src/index.ts ADDED
@@ -0,0 +1,9 @@
1
+ export * from './api.ts'
2
+ export * from './connections.ts'
3
+ export * from './enums.ts'
4
+ export * from './gateway.ts'
5
+ export * from './injectables.ts'
6
+ export * from './rpcs.ts'
7
+ export * from './streams.ts'
8
+ export * from './transport.ts'
9
+ export * from './types.ts'