@nmtjs/ws-transport 0.1.7 → 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1 +1 @@
1
- {"version":3,"sources":["../../../lib/utils.ts"],"sourcesContent":["import { ApiError, type Format } from '@nmtjs/application'\nimport {\n ErrorCode,\n type RpcResponse,\n concat,\n encodeNumber,\n} from '@nmtjs/common'\nimport type { HttpRequest } from 'uWebSockets.js'\nimport type { WsTransportSocket } from './types.ts'\n\nexport const sendMessage = (\n ws: WsTransportSocket,\n type: number,\n payload: any,\n) => {\n return send(ws, type, ws.getUserData().format.encoder.encode(payload))\n}\n\nexport const sendRpcMessage = (\n ws: WsTransportSocket,\n type: number,\n rpc: RpcResponse,\n) => {\n const data = ws.getUserData()\n return send(ws, type, data.format.encoder.encodeRpc(rpc, data.context.encode))\n}\n\nexport const send = (\n ws: WsTransportSocket,\n type: number,\n ...buffers: ArrayBuffer[]\n): boolean | null => {\n const data = ws.getUserData()\n try {\n const result = ws.send(\n concat(encodeNumber(type, 'Uint8'), ...buffers),\n true,\n )\n if (result === 0) {\n data.backpressure = Promise.withResolvers()\n return false\n }\n if (result === 2) {\n return null\n }\n return true\n } catch (error) {\n return null\n }\n}\n\nexport const getFormat = ({ headers, query }: RequestData, format: Format) => {\n const contentType = headers.get('content-type') || query.get('content-type')\n const acceptType = headers.get('accept') || query.get('accept')\n\n const encoder = contentType ? format.supportsEncoder(contentType) : undefined\n if (!encoder) throw new Error('Unsupported content-type')\n\n const decoder = acceptType ? format.supportsDecoder(acceptType) : undefined\n if (!decoder) throw new Error('Unsupported accept')\n\n return {\n encoder,\n decoder,\n }\n}\n\nexport const toRecord = (input: {\n forEach: (cb: (value, key) => void) => void\n}) => {\n const obj: Record<string, string> = {}\n input.forEach((value, key) => {\n obj[key] = value\n })\n return obj\n}\n\ntype RequestData = {\n url: string\n origin: URL | null\n method: string\n headers: Map<string, string>\n query: URLSearchParams\n}\n\nexport const getRequestData = (req: HttpRequest): RequestData => {\n const url = req.getUrl()\n const method = req.getMethod()\n const headers = new Map()\n req.forEach((key, value) => headers.set(key, value))\n const query = new URLSearchParams(req.getQuery())\n const origin = headers.has('origin')\n ? new URL(url, headers.get('origin'))\n : null\n\n return {\n url,\n origin,\n method,\n headers,\n query,\n }\n}\n\nexport const InternalError = (message = 'Internal Server Error') =>\n new ApiError(ErrorCode.InternalServerError, message)\n\nexport const NotFoundError = (message = 'Not Found') =>\n new ApiError(ErrorCode.NotFound, message)\n\nexport const ForbiddenError = (message = 'Forbidden') =>\n new ApiError(ErrorCode.Forbidden, message)\n\nexport const RequestTimeoutError = (message = 'Request Timeout') =>\n new ApiError(ErrorCode.RequestTimeout, message)\n"],"names":["ApiError","ErrorCode","concat","encodeNumber","sendMessage","ws","type","payload","send","getUserData","format","encoder","encode","sendRpcMessage","rpc","data","encodeRpc","context","buffers","result","backpressure","Promise","withResolvers","error","getFormat","headers","query","contentType","get","acceptType","supportsEncoder","undefined","Error","decoder","supportsDecoder","toRecord","input","obj","forEach","value","key","getRequestData","req","url","getUrl","method","getMethod","Map","set","URLSearchParams","getQuery","origin","has","URL","InternalError","message","InternalServerError","NotFoundError","NotFound","ForbiddenError","Forbidden","RequestTimeoutError","RequestTimeout"],"mappings":"AAAA,SAASA,QAAQ,QAAqB,qBAAoB;AAC1D,SACEC,SAAS,EAETC,MAAM,EACNC,YAAY,QACP,gBAAe;AAItB,OAAO,MAAMC,cAAc,CACzBC,IACAC,MACAC;IAEA,OAAOC,KAAKH,IAAIC,MAAMD,GAAGI,WAAW,GAAGC,MAAM,CAACC,OAAO,CAACC,MAAM,CAACL;AAC/D,EAAC;AAED,OAAO,MAAMM,iBAAiB,CAC5BR,IACAC,MACAQ;IAEA,MAAMC,OAAOV,GAAGI,WAAW;IAC3B,OAAOD,KAAKH,IAAIC,MAAMS,KAAKL,MAAM,CAACC,OAAO,CAACK,SAAS,CAACF,KAAKC,KAAKE,OAAO,CAACL,MAAM;AAC9E,EAAC;AAED,OAAO,MAAMJ,OAAO,CAClBH,IACAC,MACA,GAAGY;IAEH,MAAMH,OAAOV,GAAGI,WAAW;IAC3B,IAAI;QACF,MAAMU,SAASd,GAAGG,IAAI,CACpBN,OAAOC,aAAaG,MAAM,aAAaY,UACvC;QAEF,IAAIC,WAAW,GAAG;YAChBJ,KAAKK,YAAY,GAAGC,QAAQC,aAAa;YACzC,OAAO;QACT;QACA,IAAIH,WAAW,GAAG;YAChB,OAAO;QACT;QACA,OAAO;IACT,EAAE,OAAOI,OAAO;QACd,OAAO;IACT;AACF,EAAC;AAED,OAAO,MAAMC,YAAY,CAAC,EAAEC,OAAO,EAAEC,KAAK,EAAe,EAAEhB;IACzD,MAAMiB,cAAcF,QAAQG,GAAG,CAAC,mBAAmBF,MAAME,GAAG,CAAC;IAC7D,MAAMC,aAAaJ,QAAQG,GAAG,CAAC,aAAaF,MAAME,GAAG,CAAC;IAEtD,MAAMjB,UAAUgB,cAAcjB,OAAOoB,eAAe,CAACH,eAAeI;IACpE,IAAI,CAACpB,SAAS,MAAM,IAAIqB,MAAM;IAE9B,MAAMC,UAAUJ,aAAanB,OAAOwB,eAAe,CAACL,cAAcE;IAClE,IAAI,CAACE,SAAS,MAAM,IAAID,MAAM;IAE9B,OAAO;QACLrB;QACAsB;IACF;AACF,EAAC;AAED,OAAO,MAAME,WAAW,CAACC;IAGvB,MAAMC,MAA8B,CAAC;IACrCD,MAAME,OAAO,CAAC,CAACC,OAAOC;QACpBH,GAAG,CAACG,IAAI,GAAGD;IACb;IACA,OAAOF;AACT,EAAC;AAUD,OAAO,MAAMI,iBAAiB,CAACC;IAC7B,MAAMC,MAAMD,IAAIE,MAAM;IACtB,MAAMC,SAASH,IAAII,SAAS;IAC5B,MAAMrB,UAAU,IAAIsB;IACpBL,IAAIJ,OAAO,CAAC,CAACE,KAAKD,QAAUd,QAAQuB,GAAG,CAACR,KAAKD;IAC7C,MAAMb,QAAQ,IAAIuB,gBAAgBP,IAAIQ,QAAQ;IAC9C,MAAMC,SAAS1B,QAAQ2B,GAAG,CAAC,YACvB,IAAIC,IAAIV,KAAKlB,QAAQG,GAAG,CAAC,aACzB;IAEJ,OAAO;QACLe;QACAQ;QACAN;QACApB;QACAC;IACF;AACF,EAAC;AAED,OAAO,MAAM4B,gBAAgB,CAACC,UAAU,uBAAuB,GAC7D,IAAIvD,SAASC,UAAUuD,mBAAmB,EAAED,SAAQ;AAEtD,OAAO,MAAME,gBAAgB,CAACF,UAAU,WAAW,GACjD,IAAIvD,SAASC,UAAUyD,QAAQ,EAAEH,SAAQ;AAE3C,OAAO,MAAMI,iBAAiB,CAACJ,UAAU,WAAW,GAClD,IAAIvD,SAASC,UAAU2D,SAAS,EAAEL,SAAQ;AAE5C,OAAO,MAAMM,sBAAsB,CAACN,UAAU,iBAAiB,GAC7D,IAAIvD,SAASC,UAAU6D,cAAc,EAAEP,SAAQ"}
1
+ {"version":3,"sources":["../../../lib/utils.ts"],"sourcesContent":["import { createPromise } from '@nmtjs/common'\nimport { ErrorCode, concat, encodeNumber } from '@nmtjs/protocol/common'\nimport { ProtocolError } from '@nmtjs/protocol/server'\nimport type { HttpRequest } from 'uWebSockets.js'\nimport type { WsTransportSocket } from './types.ts'\n\nexport const send = (\n ws: WsTransportSocket,\n type: number,\n ...buffers: ArrayBuffer[]\n): boolean | null => {\n const data = ws.getUserData()\n try {\n const buffer = concat(encodeNumber(type, 'Uint8'), ...buffers)\n const result = ws.send(buffer, true)\n if (result === 0) {\n data.backpressure = createPromise()\n return false\n }\n if (result === 2) {\n return null\n }\n return true\n } catch (error) {\n return null\n }\n}\n\nexport const toRecord = (input: {\n forEach: (cb: (value, key) => void) => void\n}) => {\n const obj: Record<string, string> = {}\n input.forEach((value, key) => {\n obj[key] = value\n })\n return obj\n}\n\ntype RequestData = {\n url: string\n origin: URL | null\n method: string\n headers: Map<string, string>\n query: URLSearchParams\n}\n\nexport const getRequestData = (req: HttpRequest): RequestData => {\n const url = req.getUrl()\n const method = req.getMethod()\n const headers = new Map()\n req.forEach((key, value) => headers.set(key, value))\n const query = new URLSearchParams(req.getQuery())\n const origin = headers.has('origin')\n ? new URL(url, headers.get('origin'))\n : null\n\n return {\n url,\n origin,\n method,\n headers,\n query,\n }\n}\n\nexport const InternalError = (message = 'Internal Server Error') =>\n new ProtocolError(ErrorCode.InternalServerError, message)\n\nexport const NotFoundError = (message = 'Not Found') =>\n new ProtocolError(ErrorCode.NotFound, message)\n\nexport const ForbiddenError = (message = 'Forbidden') =>\n new ProtocolError(ErrorCode.Forbidden, message)\n\nexport const RequestTimeoutError = (message = 'Request Timeout') =>\n new ProtocolError(ErrorCode.RequestTimeout, message)\n"],"names":["createPromise","ErrorCode","concat","encodeNumber","ProtocolError","send","ws","type","buffers","data","getUserData","buffer","result","backpressure","error","toRecord","input","obj","forEach","value","key","getRequestData","req","url","getUrl","method","getMethod","headers","Map","set","query","URLSearchParams","getQuery","origin","has","URL","get","InternalError","message","InternalServerError","NotFoundError","NotFound","ForbiddenError","Forbidden","RequestTimeoutError","RequestTimeout"],"mappings":"AAAA,SAASA,aAAa,QAAQ,gBAAe;AAC7C,SAASC,SAAS,EAAEC,MAAM,EAAEC,YAAY,QAAQ,yBAAwB;AACxE,SAASC,aAAa,QAAQ,yBAAwB;AAItD,OAAO,MAAMC,OAAO,CAClBC,IACAC,MACA,GAAGC;IAEH,MAAMC,OAAOH,GAAGI,WAAW;IAC3B,IAAI;QACF,MAAMC,SAAST,OAAOC,aAAaI,MAAM,aAAaC;QACtD,MAAMI,SAASN,GAAGD,IAAI,CAACM,QAAQ;QAC/B,IAAIC,WAAW,GAAG;YAChBH,KAAKI,YAAY,GAAGb;YACpB,OAAO;QACT;QACA,IAAIY,WAAW,GAAG;YAChB,OAAO;QACT;QACA,OAAO;IACT,EAAE,OAAOE,OAAO;QACd,OAAO;IACT;AACF,EAAC;AAED,OAAO,MAAMC,WAAW,CAACC;IAGvB,MAAMC,MAA8B,CAAC;IACrCD,MAAME,OAAO,CAAC,CAACC,OAAOC;QACpBH,GAAG,CAACG,IAAI,GAAGD;IACb;IACA,OAAOF;AACT,EAAC;AAUD,OAAO,MAAMI,iBAAiB,CAACC;IAC7B,MAAMC,MAAMD,IAAIE,MAAM;IACtB,MAAMC,SAASH,IAAII,SAAS;IAC5B,MAAMC,UAAU,IAAIC;IACpBN,IAAIJ,OAAO,CAAC,CAACE,KAAKD,QAAUQ,QAAQE,GAAG,CAACT,KAAKD;IAC7C,MAAMW,QAAQ,IAAIC,gBAAgBT,IAAIU,QAAQ;IAC9C,MAAMC,SAASN,QAAQO,GAAG,CAAC,YACvB,IAAIC,IAAIZ,KAAKI,QAAQS,GAAG,CAAC,aACzB;IAEJ,OAAO;QACLb;QACAU;QACAR;QACAE;QACAG;IACF;AACF,EAAC;AAED,OAAO,MAAMO,gBAAgB,CAACC,UAAU,uBAAuB,GAC7D,IAAIlC,cAAcH,UAAUsC,mBAAmB,EAAED,SAAQ;AAE3D,OAAO,MAAME,gBAAgB,CAACF,UAAU,WAAW,GACjD,IAAIlC,cAAcH,UAAUwC,QAAQ,EAAEH,SAAQ;AAEhD,OAAO,MAAMI,iBAAiB,CAACJ,UAAU,WAAW,GAClD,IAAIlC,cAAcH,UAAU0C,SAAS,EAAEL,SAAQ;AAEjD,OAAO,MAAMM,sBAAsB,CAACN,UAAU,iBAAiB,GAC7D,IAAIlC,cAAcH,UAAU4C,cAAc,EAAEP,SAAQ"}
@@ -1,11 +1,13 @@
1
- import {
2
- type LazyInjectable,
3
- type Scope,
4
- injectables,
5
- } from '@nmtjs/application'
6
- import type { WsConnectionData } from './types.ts'
1
+ import type { LazyInjectable, Scope } from '@nmtjs/core'
2
+ import { ProtocolInjectables } from '@nmtjs/protocol/server'
3
+ import type { WsUserData } from './types.ts'
7
4
 
8
- export const connectionData = injectables.connectionData as LazyInjectable<
9
- WsConnectionData,
10
- Scope.Connection
11
- >
5
+ const connectionData =
6
+ ProtocolInjectables.connectionData as unknown as LazyInjectable<
7
+ WsUserData['request'],
8
+ Scope.Connection
9
+ >
10
+
11
+ export const WsTransportInjectables = {
12
+ connectionData,
13
+ } as const
package/lib/server.ts CHANGED
@@ -1,26 +1,4 @@
1
1
  import { randomUUID } from 'node:crypto'
2
- import {
3
- type AnyBaseProcedure,
4
- ApiError,
5
- type ApplicationContext,
6
- type Connection,
7
- type Container,
8
- Scope,
9
- ServerDownStream,
10
- ServerUpStream,
11
- type Service,
12
- SubscriptionResponse,
13
- builtin,
14
- onAbort,
15
- } from '@nmtjs/application'
16
- import {
17
- type ApiBlobMetadata,
18
- type EncodeRpcContext,
19
- MessageType,
20
- TransportType,
21
- decodeNumber,
22
- encodeNumber,
23
- } from '@nmtjs/common'
24
2
  import {
25
3
  App,
26
4
  type HttpRequest,
@@ -29,27 +7,36 @@ import {
29
7
  type TemplatedApp,
30
8
  } from 'uWebSockets.js'
31
9
 
32
- import { connectionData } from './injectables.ts'
10
+ import { createPromise } from '@nmtjs/common'
11
+ import {
12
+ ClientMessageType,
13
+ type ServerMessageType,
14
+ decodeNumber,
15
+ decodeText,
16
+ } from '@nmtjs/protocol/common'
17
+ import {
18
+ type Connection,
19
+ ProtocolInjectables,
20
+ type Transport,
21
+ type TransportPluginContext,
22
+ } from '@nmtjs/protocol/server'
33
23
  import type {
34
24
  WsTransportOptions,
35
25
  WsTransportSocket,
36
26
  WsUserData,
37
27
  } from './types.ts'
38
- import {
39
- InternalError,
40
- getFormat,
41
- getRequestData,
42
- send,
43
- sendMessage,
44
- sendRpcMessage,
45
- } from './utils.ts'
28
+ import { getRequestData, send } from './utils.ts'
46
29
 
47
- export class WsTransportServer {
30
+ export type WsConnectionData = {
31
+ ws: WsTransportSocket
32
+ }
33
+
34
+ export class WsTransportServer implements Transport<WsConnectionData> {
48
35
  protected server!: TemplatedApp
49
- protected readonly transportType = TransportType.WS
36
+ protected clients: Map<string, WsTransportSocket> = new Map()
50
37
 
51
38
  constructor(
52
- protected readonly application: ApplicationContext,
39
+ protected readonly context: TransportPluginContext,
53
40
  protected readonly options: WsTransportOptions,
54
41
  ) {
55
42
  this.server = this.options.tls ? SSLApp(options.tls!) : App()
@@ -70,36 +57,17 @@ export class WsTransportServer {
70
57
  maxPayloadLength: this.options.maxPayloadLength,
71
58
  upgrade: (res, req, context) => {
72
59
  const requestData = getRequestData(req)
73
- const container = this.application.container.createScope(
74
- Scope.Connection,
75
- )
76
- const services = requestData.query.getAll('services')
77
60
 
78
- for (const serviceName of services) {
79
- const service = this.application.registry.services.get(serviceName)
80
- if (!service)
81
- return void res
82
- .writeStatus('400 Bad Request')
83
- .end(`Service ${service} not found`)
84
- if (this.transportType in service.contract.transports === false)
85
- return void res
86
- .writeStatus('400 Bad Request')
87
- .end(`Service ${service} not supported`)
88
- }
61
+ const contentType =
62
+ requestData.headers.get('content-type') ||
63
+ requestData.query.get('content-type')
64
+
65
+ const acceptType =
66
+ requestData.headers.get('accept') || requestData.query.get('accept')
89
67
 
90
68
  const data: WsUserData = {
91
69
  id: randomUUID(),
92
- format: getFormat(requestData, this.application.format),
93
- container,
94
- streams: {
95
- up: new Map(),
96
- down: new Map(),
97
- streamId: 0,
98
- },
99
- abortControllers: new Map(),
100
- subscriptions: new Map(),
101
- services,
102
- data: {
70
+ request: {
103
71
  query: requestData.query,
104
72
  headers: requestData.headers,
105
73
  proxiedRemoteAddress: Buffer.from(
@@ -108,7 +76,10 @@ export class WsTransportServer {
108
76
  remoteAddress: Buffer.from(
109
77
  res.getRemoteAddressAsText(),
110
78
  ).toString(),
79
+ contentType,
80
+ acceptType,
111
81
  },
82
+ opening: createPromise(),
112
83
  backpressure: null,
113
84
  context: {} as any,
114
85
  }
@@ -121,31 +92,38 @@ export class WsTransportServer {
121
92
  context,
122
93
  )
123
94
  },
124
- open: (ws: WsTransportSocket) => {
125
- const data = ws.getUserData()
126
- this.logger.debug('Connection %s opened', data.id)
127
- data.context.decode = this.createDecodeRpcContext(ws)
128
- data.context.encode = this.createEncodeRpcContext(ws)
129
- data.container.provide(connectionData, data.data)
130
-
131
- const connection = this.application.connections.add({
132
- id: data.id,
133
- services: data.services,
134
- type: this.transportType,
135
- subscriptions: data.subscriptions,
136
- sendEvent: (service, event, payload) =>
137
- sendMessage(ws, MessageType.Event, [service, event, payload]),
138
- })
139
-
140
- data.container.provide(builtin.connection, connection)
95
+ open: async (ws: WsTransportSocket) => {
96
+ const { id, context, request, opening } = ws.getUserData()
97
+ this.clients.set(id, ws)
98
+ this.logger.debug('Connection %s opened', id)
99
+ try {
100
+ const { context: _context, connection } =
101
+ await this.context.protocol.connections.add(
102
+ this,
103
+ { id, data: request },
104
+ {
105
+ acceptType: request.acceptType,
106
+ contentType: request.contentType,
107
+ },
108
+ )
109
+ Object.assign(context, _context)
110
+ context.container.provide(
111
+ ProtocolInjectables.connection,
112
+ connection,
113
+ )
114
+ opening.resolve()
115
+ } catch (error) {
116
+ opening.reject(error)
117
+ }
141
118
  },
142
- message: async (ws: WsTransportSocket, event) => {
143
- const buffer = event as unknown as ArrayBuffer
119
+ message: async (ws: WsTransportSocket, buffer) => {
120
+ const { opening } = ws.getUserData()
144
121
  const messageType = decodeNumber(buffer, 'Uint8')
145
122
  if (messageType in this === false) {
146
123
  ws.end(1011, 'Unknown message type')
147
124
  } else {
148
125
  try {
126
+ await opening.promise
149
127
  await this[messageType](
150
128
  ws,
151
129
  buffer.slice(Uint8Array.BYTES_PER_ELEMENT),
@@ -160,44 +138,29 @@ export class WsTransportServer {
160
138
  data.backpressure?.resolve()
161
139
  data.backpressure = null
162
140
  },
163
- close: (ws: WsTransportSocket, code, message) => {
164
- const data = ws.getUserData()
141
+ close: async (ws: WsTransportSocket, code, message) => {
142
+ const { id } = ws.getUserData()
165
143
 
166
144
  this.logger.debug(
167
145
  'Connection %s closed with code %s: %s',
168
- data.id,
146
+ id,
169
147
  code,
170
148
  Buffer.from(message).toString(),
171
149
  )
172
-
173
- this.application.connections.remove(data.id)
174
- const error = new Error('Connection closed')
175
-
176
- for (const ac of data.abortControllers.values()) {
177
- ac.abort(error)
178
- }
179
- data.abortControllers.clear()
180
-
181
- for (const stream of data.streams.down.values()) {
182
- stream.destroy(error)
183
- }
184
- data.streams.down.clear()
185
-
186
- for (const stream of data.streams.up.values()) {
187
- stream.destroy(error)
188
- }
189
- data.streams.up.clear()
190
-
191
- for (const subscription of data.subscriptions.values()) {
192
- subscription.destroy()
193
- }
194
- data.subscriptions.clear()
195
-
196
- this.handleContainerDisposal(data.container)
150
+ this.clients.delete(id)
151
+ await this.protocol.connections.remove(id)
197
152
  },
198
153
  })
199
154
  }
200
155
 
156
+ send(
157
+ connection: Connection<WsConnectionData>,
158
+ messageType: ServerMessageType,
159
+ buffer: ArrayBuffer,
160
+ ) {
161
+ send(connection.data.ws, messageType, buffer)
162
+ }
163
+
201
164
  async start() {
202
165
  return new Promise<void>((resolve, reject) => {
203
166
  const hostname = this.options.hostname ?? '127.0.0.1'
@@ -220,12 +183,12 @@ export class WsTransportServer {
220
183
  this.server.close()
221
184
  }
222
185
 
223
- protected get api() {
224
- return this.application.api
186
+ protected get protocol() {
187
+ return this.context.protocol
225
188
  }
226
189
 
227
190
  protected get logger() {
228
- return this.application.logger
191
+ return this.context.logger
229
192
  }
230
193
 
231
194
  protected async logError(
@@ -245,256 +208,76 @@ export class WsTransportServer {
245
208
  res.writeHeader('Access-Control-Allow-Credentials', 'true')
246
209
  }
247
210
 
248
- protected async handleContainerDisposal(container: Container) {
249
- await container.dispose()
250
- }
251
-
252
- protected async handleRPC(options: {
253
- connection: Connection
254
- service: Service
255
- procedure: AnyBaseProcedure
256
- container: Container
257
- signal: AbortSignal
258
- payload: any
259
- }) {
260
- return await this.api.call({
261
- ...options,
262
- transport: this.transportType,
263
- })
264
- }
265
-
266
- protected createEncodeRpcContext(ws: WsTransportSocket): EncodeRpcContext {
267
- const data = ws.getUserData()
268
- return {
269
- addStream: (blob) => {
270
- const id = ++data.streams.streamId
271
- const downstream = new ServerDownStream(id, blob)
272
- downstream.pause()
273
- downstream.on('error', console.dir)
274
- downstream.once('error', (err) => {
275
- console.log({ err })
276
- if (downstream.errored?.message !== 'Aborted by client')
277
- send(ws, MessageType.DownStreamAbort, encodeNumber(id, 'Uint32'))
278
- })
279
- downstream.on('data', (chunk) => {
280
- downstream.pause()
281
- send(
282
- ws,
283
- MessageType.DownStreamPush,
284
- encodeNumber(id, 'Uint32'),
285
- Buffer.from(chunk).buffer as ArrayBuffer,
286
- )
287
- })
288
- data.streams.down.set(id, downstream)
289
- return { id, metadata: blob.metadata }
290
- },
291
- getStream: (id) => {
292
- return data.streams.down.get(id)!
293
- },
294
- }
295
- }
296
-
297
- protected createDecodeRpcContext(ws: WsTransportSocket) {
298
- const data = ws.getUserData()
299
- return {
300
- addStream(signal: AbortSignal, id: number, metadata: ApiBlobMetadata) {
301
- const upstream = new ServerUpStream(metadata, {
302
- read: (size) => {
303
- send(
304
- ws,
305
- MessageType.UpStreamPull,
306
- encodeNumber(id, 'Uint32'),
307
- encodeNumber(size, 'Uint32'),
308
- )
309
- },
310
- })
311
-
312
- data.streams.up.set(id, upstream)
313
- onAbort(signal, () => {
314
- upstream.destroy(new Error('Call aborted by client'))
315
- })
316
- upstream.once('error', (error) => {
317
- if (error.message !== 'Aborted by server')
318
- send(ws, MessageType.UpStreamAbort, encodeNumber(id, 'Uint32'))
319
- })
320
- upstream.once('close', () => {
321
- if (upstream.errored?.message !== 'Aborted by server')
322
- send(ws, MessageType.UpStreamEnd, encodeNumber(id, 'Uint32'))
323
- })
324
-
325
- return upstream
326
- },
327
- getStream(id) {
328
- return data.streams.down.get(id)
329
- },
330
- }
331
- }
332
-
333
- protected async [MessageType.Rpc](
211
+ protected [ClientMessageType.Rpc](
334
212
  ws: WsTransportSocket,
335
213
  buffer: ArrayBuffer,
336
214
  ) {
337
- const data = ws.getUserData()
338
-
339
- const connection = this.application.connections.get(data.id)
340
- if (!connection) return void ws.end(1011, 'Unknown connection')
341
-
342
- const ac = new AbortController()
343
- const rpc = data.format.decoder.decodeRpc(buffer, {
344
- ...data.context.decode,
345
- addStream: data.context.decode.addStream.bind(null, ac.signal),
346
- })
347
-
348
- data.abortControllers.set(rpc.callId, ac)
349
-
350
- const container = data.container.createScope(Scope.Call)
351
-
352
- try {
353
- const { service, procedure } = this.api.find(
354
- rpc.service,
355
- rpc.procedure,
356
- this.transportType,
357
- )
358
-
359
- const response = await this.handleRPC({
360
- connection,
361
- service,
362
- procedure,
363
- container,
364
- signal: ac.signal,
365
- payload: rpc.payload,
366
- })
367
-
368
- if (response instanceof SubscriptionResponse) {
369
- sendRpcMessage(ws, MessageType.RpcSubscription, {
370
- callId: rpc.callId,
371
- error: null,
372
- payload: [response.subscription.key, response.payload],
373
- })
374
-
375
- response.subscription.on('event', (event, payload) => {
376
- sendMessage(ws, MessageType.ServerSubscriptionEvent, [
377
- response.subscription.key,
378
- event,
379
- payload,
380
- ])
381
- })
382
- response.subscription.once('end', () => {
383
- sendMessage(ws, MessageType.ServerUnsubscribe, [
384
- response.subscription.key,
385
- ])
386
- })
387
- } else {
388
- sendRpcMessage(ws, MessageType.Rpc, {
389
- callId: rpc.callId,
390
- error: null,
391
- payload: response,
392
- })
393
- }
394
- } catch (error) {
395
- if (error instanceof ApiError) {
396
- this.logger.debug(new Error('Api error', { cause: error }))
397
- sendRpcMessage(ws, MessageType.Rpc, {
398
- callId: rpc.callId,
399
- error,
400
- payload: null,
401
- })
402
- } else {
403
- this.logger.error(new Error('Unexpected error', { cause: error }))
404
- sendRpcMessage(ws, MessageType.Rpc, {
405
- callId: rpc.callId,
406
- error: InternalError(),
407
- payload: null,
408
- })
409
- }
410
- } finally {
411
- data.abortControllers.delete(rpc.callId)
412
- this.handleContainerDisposal(container)
413
- }
215
+ const { id } = ws.getUserData()
216
+ this.protocol.rpcRaw(id, buffer)
414
217
  }
415
218
 
416
- async [MessageType.UpStreamPush](ws: WsTransportSocket, buffer: ArrayBuffer) {
417
- const data = ws.getUserData()
418
- const id = decodeNumber(buffer, 'Uint32')
419
- const stream = data.streams.up.get(id)
420
- if (!stream) return ws.end(1011, 'Unknown stream')
421
- stream.push(Buffer.from(buffer.slice(Uint32Array.BYTES_PER_ELEMENT)))
422
- }
423
-
424
- async [MessageType.UpStreamEnd](ws: WsTransportSocket, buffer: ArrayBuffer) {
425
- const data = ws.getUserData()
426
- const id = decodeNumber(buffer, 'Uint32')
427
- const stream = data.streams.up.get(id)
428
- if (!stream) return ws.end(1011, 'Unknown stream')
429
- stream.push(null)
430
- data.streams.up.delete(id)
219
+ protected [ClientMessageType.RpcAbort](
220
+ ws: WsTransportSocket,
221
+ buffer: ArrayBuffer,
222
+ ) {
223
+ const { id } = ws.getUserData()
224
+ this.protocol.rpcAbortRaw(id, buffer)
431
225
  }
432
226
 
433
- async [MessageType.UpStreamAbort](
227
+ protected [ClientMessageType.RpcStreamAbort](
434
228
  ws: WsTransportSocket,
435
229
  buffer: ArrayBuffer,
436
230
  ) {
437
- const data = ws.getUserData()
438
- const id = decodeNumber(buffer, 'Uint32')
439
- const stream = data.streams.up.get(id)
440
- if (!stream) return ws.end(1011, 'Unknown stream')
441
- stream.destroy(new Error('Aborted by client'))
442
- data.streams.up.delete(id)
231
+ const { id } = ws.getUserData()
232
+ this.protocol.rpcStreamAbortRaw(id, buffer)
443
233
  }
444
234
 
445
- async [MessageType.DownStreamPull](
235
+ protected [ClientMessageType.ClientStreamPush](
446
236
  ws: WsTransportSocket,
447
237
  buffer: ArrayBuffer,
448
238
  ) {
449
- const data = ws.getUserData()
450
- const id = decodeNumber(buffer, 'Uint32')
451
- const stream = data.streams.down.get(id)
452
- if (!stream) return ws.end(1011, 'Unknown stream')
453
- await data.backpressure?.promise
454
- if (stream.readableEnded)
455
- send(ws, MessageType.DownStreamEnd, encodeNumber(id, 'Uint32'))
456
- else stream.resume()
239
+ const { id } = ws.getUserData()
240
+ const streamId = decodeNumber(buffer, 'Uint32')
241
+ this.protocol.clientStreams.push(
242
+ id,
243
+ streamId,
244
+ buffer.slice(Uint32Array.BYTES_PER_ELEMENT),
245
+ )
457
246
  }
458
247
 
459
- async [MessageType.DownStreamEnd](
248
+ protected [ClientMessageType.ClientStreamEnd](
460
249
  ws: WsTransportSocket,
461
250
  buffer: ArrayBuffer,
462
251
  ) {
463
- const data = ws.getUserData()
464
- const id = decodeNumber(buffer, 'Uint32')
465
- const stream = data.streams.down.get(id)
466
- if (!stream) return ws.end(1011, 'Unknown stream')
467
- data.streams.down.delete(id)
252
+ const { id } = ws.getUserData()
253
+ const streamId = decodeNumber(buffer, 'Uint32')
254
+ this.protocol.clientStreams.end(id, streamId)
468
255
  }
469
256
 
470
- async [MessageType.DownStreamAbort](
257
+ protected [ClientMessageType.ClientStreamAbort](
471
258
  ws: WsTransportSocket,
472
259
  buffer: ArrayBuffer,
473
260
  ) {
474
- const data = ws.getUserData()
475
- const id = decodeNumber(buffer, 'Uint32')
476
- const stream = data.streams.down.get(id)
477
- if (!stream) return ws.end(1011, 'Unknown stream')
478
- stream.destroy(new Error('Aborted by client'))
479
- data.streams.down.delete(id)
261
+ const { id } = ws.getUserData()
262
+ const streamId = decodeNumber(buffer, 'Uint32')
263
+ this.protocol.clientStreams.abort(id, streamId)
480
264
  }
481
265
 
482
- async [MessageType.ClientUnsubscribe](
266
+ protected [ClientMessageType.ServerStreamPull](
483
267
  ws: WsTransportSocket,
484
268
  buffer: ArrayBuffer,
485
269
  ) {
486
- const data = ws.getUserData()
487
- const [key] = data.format.decoder.decode(buffer)
488
- const subscription = data.subscriptions.get(key)
489
- if (!subscription) return void ws.end()
490
- subscription.destroy()
491
- data.subscriptions.delete(key)
270
+ const { id } = ws.getUserData()
271
+ const streamId = decodeNumber(buffer, 'Uint32')
272
+ this.protocol.serverStreams.pull(id, streamId)
492
273
  }
493
274
 
494
- async [MessageType.RpcAbort](ws: WsTransportSocket, buffer: ArrayBuffer) {
495
- const data = ws.getUserData()
496
- const callId = decodeNumber(buffer, 'Uint32')
497
- const ac = data.abortControllers.get(callId)
498
- if (ac) ac.abort(new Error('Aborted by client'))
275
+ protected [ClientMessageType.ServerStreamAbort](
276
+ ws: WsTransportSocket,
277
+ buffer: ArrayBuffer,
278
+ ) {
279
+ const { id } = ws.getUserData()
280
+ const streamId = decodeNumber(buffer, 'Uint32')
281
+ this.protocol.serverStreams.abort(id, streamId)
499
282
  }
500
283
  }
package/lib/transport.ts CHANGED
@@ -1,16 +1,10 @@
1
- import { Hook, createPlugin } from '@nmtjs/application'
2
- import { WsTransportServer } from './server.ts'
1
+ import { createTransport } from '@nmtjs/protocol/server'
2
+ import { type WsConnectionData, WsTransportServer } from './server.ts'
3
3
  import type { WsTransportOptions } from './types.ts'
4
4
 
5
- export const WsTransport = createPlugin<WsTransportOptions>(
6
- 'WsTransport',
7
- (app, options) => {
8
- const server = new WsTransportServer(app, options)
9
- app.hooks.add(Hook.OnStartup, async () => {
10
- await server.start()
11
- })
12
- app.hooks.add(Hook.OnShutdown, async () => {
13
- await server.stop()
14
- })
15
- },
16
- )
5
+ export const WsTransport = createTransport<
6
+ WsConnectionData,
7
+ WsTransportOptions
8
+ >('WsTransport', (context, options) => {
9
+ return new WsTransportServer(context, options)
10
+ })