socket-function 0.8.30 → 0.8.32

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/SocketFunction.ts CHANGED
@@ -1,12 +1,10 @@
1
- import { SocketExposedInterface, CallContextType, SocketFunctionHook, SocketFunctionClientHook, SocketExposedShape, SocketRegistered, NetworkLocation, CallerContext, SocketExposedInterfaceClass, CallType, FullCallType } from "./SocketFunctionTypes";
1
+ import { SocketExposedInterface, CallContextType, SocketFunctionHook, SocketFunctionClientHook, SocketExposedShape, SocketRegistered, CallerContext, FullCallType } from "./SocketFunctionTypes";
2
2
  import { exposeClass, registerClass, registerGlobalClientHook, registerGlobalHook, runClientHooks } from "./src/callManager";
3
3
  import { SocketServerConfig, startSocketServer } from "./src/webSocketServer";
4
- import { getCallFactoryFromNodeId, getCreateCallFactoryLocation, getLocationFromNodeId, getNetworkLocationHash } from "./src/nodeCache";
4
+ import { getCreateCallFactoryLocation, getNodeId, getNodeIdLocation } from "./src/nodeCache";
5
5
  import { getCallProxy } from "./src/nodeProxy";
6
6
  import { Args } from "./src/types";
7
7
  import { setDefaultHTTPCall } from "./src/callHTTPHandler";
8
- import { isNode } from "./src/misc";
9
- import { getOwnNodeId } from "./src/nodeAuthentication";
10
8
 
11
9
  module.allowclient = true;
12
10
 
@@ -31,7 +29,6 @@ export class SocketFunction {
31
29
  };
32
30
  public static httpETagCache = false;
33
31
  public static rejectUnauthorized = true;
34
- public static additionalTrustedRootCAs: string[] = [];
35
32
 
36
33
  public static register<
37
34
  ClassInstance extends object,
@@ -40,12 +37,19 @@ export class SocketFunction {
40
37
  >(
41
38
  classGuid: string,
42
39
  instance: ClassInstance,
43
- shape: Shape
40
+ shape: Shape,
41
+ defaultHooks?: SocketExposedShape[""]
44
42
  ):
45
43
  (
46
44
  SocketRegistered<ExtractShape<ClassInstance, Shape>, CallContext>
47
45
  ) {
48
46
 
47
+ for (let value of Object.values(shape)) {
48
+ if (!value) continue;
49
+ value.clientHooks = [...(defaultHooks?.clientHooks || []), ...(value.clientHooks || [])];
50
+ value.hooks = [...(defaultHooks?.hooks || []), ...(value.hooks || [])];
51
+ value.dataImmutable = defaultHooks?.dataImmutable ?? value.dataImmutable;
52
+ }
49
53
  registerClass(classGuid, instance as SocketExposedInterface, shape as any as SocketExposedShape);
50
54
 
51
55
  let nodeProxy = getCallProxy(classGuid, async (call) => {
@@ -56,10 +60,7 @@ export class SocketFunction {
56
60
  console.log(`START\t\t\t${classGuid}.${functionName}`);
57
61
  }
58
62
  try {
59
- let callFactory = await getCallFactoryFromNodeId(nodeId);
60
- if (!callFactory) {
61
- throw new Error(`Cannot reach node ${nodeId}. It might have been incorrect provided to us via another node, which should have provided us a NetworkLocation instead.`);
62
- }
63
+ let callFactory = await getCreateCallFactoryLocation(nodeId, SocketFunction.mountedNodeId);
63
64
 
64
65
  let shapeObj = shape[functionName];
65
66
  if (!shapeObj) {
@@ -72,6 +73,20 @@ export class SocketFunction {
72
73
  return hookResult.overrideResult;
73
74
  }
74
75
 
76
+ if (hookResult.callTimeout !== undefined) {
77
+ let timeout = hookResult.callTimeout;
78
+ let time = Date.now();
79
+ let timeoutPromise = new Promise((resolve, reject) => {
80
+ setTimeout(() => {
81
+ reject(new Error(`Call timed out after ${Date.now() - time}ms`));
82
+ }, timeout);
83
+ });
84
+ return await Promise.race([
85
+ callFactory.performCall(call),
86
+ timeoutPromise,
87
+ ]);
88
+ }
89
+
75
90
  return await callFactory.performCall(call);
76
91
  } finally {
77
92
  time = Date.now() - time;
@@ -90,14 +105,15 @@ export class SocketFunction {
90
105
  return output as any;
91
106
  }
92
107
 
93
- /** NOTE: Only works if the call has been loaded from a url (we can't convert arbitrary nodeIds into urls,
94
- * as we have no way of knowing how to contain a nodeId). */
108
+ /** NOTE: Only works if the nodeIs used is from SocketFunction.connect (we can't convert arbitrary nodeIds into urls,
109
+ * as we have no way of knowing how to contain a nodeId).
110
+ * */
95
111
  public static getHTTPCallLink(call: FullCallType): string {
96
- let location = getLocationFromNodeId(call.nodeId);
112
+ let location = getNodeIdLocation(call.nodeId);
97
113
  if (!location) {
98
114
  throw new Error(`Cannot find call location for nodeId, and so do not know where call location is. NodeId ${call.nodeId}`);
99
115
  }
100
- let url = new URL(`https://${location.address}:${location.listeningPorts[0]}`);
116
+ let url = new URL(`https://${location.address}:${location.port}`);
101
117
  url.searchParams.set("classGuid", call.classGuid);
102
118
  url.searchParams.set("functionName", call.functionName);
103
119
  url.searchParams.set("args", JSON.stringify(call.args));
@@ -112,12 +128,13 @@ export class SocketFunction {
112
128
  exposeClass(socketRegistered);
113
129
  }
114
130
 
131
+ public static mountedNodeId: string = "NOTMOUNTED";
115
132
  public static async mount(config: SocketServerConfig) {
116
- await startSocketServer(config);
117
- }
118
-
119
- public static async getOwnNodeId() {
120
- return await getOwnNodeId();
133
+ if (this.mountedNodeId !== "NOTMOUNTED") {
134
+ throw new Error("SocketFunction already mounted, mounting twice in one thread is not allowed.");
135
+ }
136
+ this.mountedNodeId = await startSocketServer(config);
137
+ return this.mountedNodeId;
121
138
  }
122
139
 
123
140
  /** Sets the default call when an http request is made, but no classGuid is set. */
@@ -136,28 +153,8 @@ export class SocketFunction {
136
153
  });
137
154
  }
138
155
 
139
- public static async connect(location: NetworkLocation | { address: string; port: number }): Promise<string> {
140
- if (!("listeningPorts" in location)) {
141
- location = {
142
- address: location.address,
143
- listeningPorts: [location.port]
144
- };
145
- }
146
- return await getCreateCallFactoryLocation(location);
147
- }
148
-
149
- public static connectSync(location: NetworkLocation | { address: string; port: number }): string {
150
- if (!("listeningPorts" in location)) {
151
- location = {
152
- address: location.address,
153
- listeningPorts: [location.port]
154
- };
155
- }
156
- let tempNodeId = "syncTempNodeId_" + getNetworkLocationHash(location);
157
-
158
- void getCreateCallFactoryLocation(location, tempNodeId);
159
-
160
- return tempNodeId;
156
+ public static connect(location: { address: string, port: number }): string {
157
+ return getNodeId(location.address, location.port);
161
158
  }
162
159
 
163
160
  public static addGlobalHook<CallContext extends CallContextType>(hook: SocketFunctionHook<SocketExposedInterface, CallContext>) {
@@ -172,6 +169,11 @@ export class SocketFunction {
172
169
  const curSocketContext: SocketRegistered["context"] = {
173
170
  curContext: undefined,
174
171
  caller: undefined,
172
+ getCaller() {
173
+ const caller = curSocketContext.caller;
174
+ if (!caller) throw new Error(`Tried to access caller when not in the synchronous phase of a function call`);
175
+ return caller;
176
+ }
175
177
  };
176
178
  let socketContextSeqNum = 1;
177
179
 
@@ -2,6 +2,10 @@ module.allowclient = true;
2
2
 
3
3
  import debugbreak from "debugbreak";
4
4
  import * as tls from "tls";
5
+ import { SenderInterface } from "./src/CallFactory";
6
+ import { isNode } from "./src/misc";
7
+ import { CertInfo, getNodeIdFromCert } from "./src/nodeAuthentication";
8
+ import { getClientNodeId } from "./src/nodeCache";
5
9
  import { getCallObj } from "./src/nodeProxy";
6
10
  import { Args, MaybePromise } from "./src/types";
7
11
 
@@ -49,14 +53,18 @@ export interface SocketFunctionHook<ExposedType extends SocketExposedInterface =
49
53
  }
50
54
  export type HookContext<ExposedType extends SocketExposedInterface = SocketExposedInterface, CallContext extends CallContextType = CallContextType> = {
51
55
  call: CallType;
52
- context: SocketRegistered["context"];
53
- // If the result is overriden, we continue evaluating hooks BUT NOT perform the final call
56
+ context: SocketRegistered<ExposedType, CallContext>["context"];
57
+ // If the result is overriden, we continue evaluating hooks BUT DO NOT perform the final call
54
58
  overrideResult?: unknown;
55
59
  };
56
60
 
57
61
  export type ClientHookContext<ExposedType extends SocketExposedInterface = SocketExposedInterface, CallContext extends CallContextType = CallContextType> = {
58
62
  call: CallType;
59
- // If the result is overriden, we continue evaluating hooks BUT NOT perform the final call
63
+ /** If the calls takes longer than this (for ANY reason), we return with an error.
64
+ * - Different from reconnectTimeout, which only errors if we lose the connection.
65
+ */
66
+ callTimeout?: number;
67
+ // If the result is overriden, we continue evaluating hooks BUT DO NOT perform the final call
60
68
  overrideResult?: unknown;
61
69
  };
62
70
  export interface SocketFunctionClientHook<ExposedType extends SocketExposedInterface = SocketExposedInterface, CallContext extends CallContextType = CallContextType> {
@@ -81,44 +89,29 @@ export interface SocketRegistered<ExposedType = any, DynamicCallContext extends
81
89
  // If undefined we are not synchronously in a call
82
90
  curContext: DynamicCallContext | undefined;
83
91
  caller: CallerContext | undefined;
92
+ getCaller(): CallerContext;
84
93
  };
85
94
  _classGuid: string;
86
95
  }
87
- export type CallerContext = {
96
+ export type CallerContext = Readonly<CallerContextBase>;
97
+ export type CallerContextBase = {
88
98
  // IMPORTANT! Do not pass nodeId to other nodes with the intention of having
89
99
  // them call functions directly using nodeId. Instead pass location, and have them use connect.
90
- // - nodeId SHOULD be used to identify users though, as it cannot be impersonated
100
+ // - nodeId will be unique per thread, so is only useful for temporary communcation. If you want
101
+ // a more permanent identity, you must derive it from certInfo yourself.
91
102
  nodeId: string;
92
- fromPort: number;
93
- location: NetworkLocation;
94
- // The location of the server. It helps if it is told, due to the fact that one server
95
- // can serve multiple domains.
96
- serverLocation: NetworkLocation;
97
103
 
98
- // NOTE: Only set in NodeJS, as clientside we are not given access to the certificate information.
99
- // TODO: Limit this type to only have the information we need, possible in a slightly different format.
100
- certInfo: tls.DetailedPeerCertificate | undefined;
104
+ /** Gives further info on the node. When we set this, we always make sure it has a verified
105
+ * issuer. It may be set by app code, which should make sure the issuer is verified (not
106
+ * necessarily by the machine, but just in some sense, 'verified', to secure the common name
107
+ * of the cert and prevent anyone from using the same common name as someone else).
108
+ * IF set, is directly used to derive nodeId (by nodeAuthentication.ts)
109
+ */
110
+ certInfo: CertInfo | undefined;
111
+ updateCertInfo?: (certInfo: CertInfo, callbackPort: number | undefined) => void;
101
112
 
102
- // TODO: Add callerBrowserAuthIP, which will allow "Proxy-IP" (or whatever cloudflare uses? It has to be a
103
- // header which the browser is restricted from sending), to override this, allowing the browser to use
104
- // proxies.
105
- // - We have to also ONLY accept this from certain trusted servers, as otherwise it is too easy to spoof.
106
- //callerBrowserAuthIP: string;
107
- };
108
-
109
- export function setCertInfo(socket: tls.TLSSocket | undefined, context: CallerContext) {
110
- if (!socket) return;
111
- let cert = socket.getPeerCertificate(true);
112
- /** Check for a property, because "If the peer does not provide a certificate, an empty object will be
113
- returned. If the socket has been destroyed, `null` will be returned." */
114
- if (cert?.issuer) {
115
- context.certInfo = cert;
116
- }
117
- }
118
-
119
- // IMPORTANT! Nodes at the same network location may vary, so you cannot store NetworkLocation
120
- // in a list of allowed users, otherwise they can be impersonated!
121
- export interface NetworkLocation {
122
- address: string;
123
- listeningPorts: number[];
124
- }
113
+ // The nodeId they contacted. This is useful to determine their intention (otherwise
114
+ // requests can be redirected to us and would accept them, even though they are being
115
+ // blatantly MITMed).
116
+ localNodeId: string;
117
+ };
package/package.json CHANGED
@@ -1,21 +1,27 @@
1
1
  {
2
2
  "name": "socket-function",
3
- "version": "0.8.30",
3
+ "version": "0.8.32",
4
4
  "main": "index.js",
5
5
  "license": "MIT",
6
+ "note1": "note on node-forge fork, see https://github.com/digitalbazaar/forge/issues/744 for details",
6
7
  "dependencies": {
7
8
  "@types/cookie": "^0.5.1",
8
9
  "@types/node": "^18.0.0",
10
+ "@types/node-forge": "^1.3.1",
9
11
  "@types/ws": "^8.5.3",
10
12
  "cookie": "^0.5.0",
11
13
  "debugbreak": "^0.6.5",
12
14
  "mobx": "^6.6.2",
15
+ "node-forge": "https://github.com/sliftist/forge#name",
13
16
  "preact": "^10.10.6",
14
- "typenode": "^4.8.5",
17
+ "typenode": "^4.8.7",
15
18
  "ws": "^8.8.0"
16
19
  },
17
20
  "scripts": {
18
21
  "test": "yarn typenode ./test/server.ts",
19
22
  "type": "yarn tsc --noEmit"
23
+ },
24
+ "devDependencies": {
25
+ "typedev": "^0.1.0"
20
26
  }
21
27
  }
@@ -1,14 +1,12 @@
1
- import { CallerContext, CallType, NetworkLocation, setCertInfo } from "../SocketFunctionTypes";
1
+ import { CallerContext, CallerContextBase, CallType } from "../SocketFunctionTypes";
2
2
  import * as ws from "ws";
3
- import type * as net from "net";
4
3
  import { performLocalCall } from "./callManager";
5
4
  import { convertErrorStackToError, formatNumberSuffixed, isNode } from "./misc";
6
- import { createWebsocketFactory, getNodeId, getTLSSocket } from "./nodeAuthentication";
7
- import debugbreak from "debugbreak";
8
- import http from "http";
5
+ import { createWebsocketFactory, getNodeIdFromCert, getTLSSocket } from "./nodeAuthentication";
9
6
  import { SocketFunction } from "../SocketFunction";
10
7
  import { gzip } from "zlib";
11
8
  import * as tls from "tls";
9
+ import { getClientNodeId, getNodeIdLocation, registerNodeClient } from "./nodeCache";
12
10
 
13
11
  const retryInterval = 2000;
14
12
 
@@ -30,53 +28,10 @@ type InternalReturnType = {
30
28
 
31
29
  export interface CallFactory {
32
30
  nodeId: string;
33
- location: NetworkLocation;
34
31
  // NOTE: May or may not have reconnection or retry logic inside of performCall.
35
32
  // Trigger performLocalCall on the other side of the connection
36
33
  performCall(call: CallType): Promise<unknown>;
37
- }
38
-
39
-
40
- export async function callFactoryFromLocation(
41
- location: NetworkLocation
42
- ): Promise<CallFactory> {
43
- let listeningPort = location.listeningPorts[0];
44
- if (typeof listeningPort !== "number") {
45
- throw new Error(`Expected listeningPorts to be provided, but it was empty`);
46
- }
47
-
48
- // Because we are the client, we don't get to know our NetworkLocation (but we shouldn't
49
- // need to anyway).
50
- let serverLocation: NetworkLocation = {
51
- address: "localhost",
52
- listeningPorts: [],
53
- };
54
-
55
- return await createCallFactory(undefined, location, serverLocation);
56
- }
57
-
58
- export async function callFactoryFromWS(
59
- webSocket: ws.WebSocket & { nodeId?: string },
60
- serverLocation: NetworkLocation,
61
- ): Promise<CallFactory> {
62
- let socket = getTLSSocket(webSocket);
63
- let remoteAddress = socket.remoteAddress;
64
- let remotePort = socket.remotePort;
65
- if (!remoteAddress) {
66
- throw new Error("No remote address?");
67
- }
68
- if (!remotePort) {
69
- throw new Error("No remote port?");
70
- }
71
-
72
- // NOTE: We COULD reconnect to clients, but... chances are... when they go down,
73
- // their process is dead, and is going to stay dead.
74
- let location: NetworkLocation = {
75
- address: remoteAddress,
76
- listeningPorts: [],
77
- };
78
-
79
- return await createCallFactory(webSocket, location, serverLocation);
34
+ closedForever: boolean;
80
35
  }
81
36
 
82
37
  export interface SenderInterface {
@@ -92,27 +47,16 @@ export interface SenderInterface {
92
47
  addEventListener(event: "message", listener: (data: ws.RawData | ws.MessageEvent | string) => void): void;
93
48
  }
94
49
 
95
- async function createCallFactory(
50
+ export async function createCallFactory(
96
51
  webSocketBase: SenderInterface | undefined,
97
- location: NetworkLocation,
98
- serverLocation: NetworkLocation,
52
+ nodeId: string,
53
+ localNodeId: string,
99
54
  ): Promise<CallFactory> {
100
-
101
- let closedForever = false;
102
-
103
- let fromPort = 0;
104
- if (webSocketBase && webSocketBase instanceof ws.WebSocket) {
105
- let socket = getTLSSocket(webSocketBase);
106
- fromPort = socket.remotePort ?? fromPort;
107
- }
108
- let niceConnectionName = `${location.address}:${location.listeningPorts.join("|")}`;
109
- if (fromPort && location.listeningPorts.length === 0) {
110
- niceConnectionName += `(${fromPort})`;
111
- }
55
+ let niceConnectionName = nodeId;
112
56
 
113
57
  const createWebsocket = createWebsocketFactory();
114
58
 
115
- let retriesEnabled = location.listeningPorts.length > 0;
59
+ let retriesEnabled = !!getNodeIdLocation(nodeId);
116
60
 
117
61
  let lastReceivedSeqNum = 0;
118
62
 
@@ -131,8 +75,59 @@ async function createCallFactory(
131
75
  // in return calls.
132
76
  let nextSeqNum = Math.random();
133
77
 
134
- const pendingNodeId = "PENDING";
135
- let callerContext: CallerContext = { location, nodeId: pendingNodeId, serverLocation, fromPort, certInfo: undefined };
78
+ let callerContext: CallerContextBase = {
79
+ nodeId,
80
+ localNodeId,
81
+ certInfo: webSocketBase?.socket?.getPeerCertificate(true),
82
+ updateCertInfo: (certRaw, port) => {
83
+ let nodeId = getNodeIdFromCert(certRaw, port);
84
+ if (!nodeId) {
85
+ return;
86
+ }
87
+ callerContext.nodeId = nodeId;
88
+ callerContext.certInfo = certRaw;
89
+ }
90
+ };
91
+
92
+ let callFactory: CallFactory = {
93
+ nodeId,
94
+ closedForever: false,
95
+ async performCall(call: CallType) {
96
+ if (callFactory.closedForever) {
97
+ throw new Error(`Connection lost to ${niceConnectionName}`);
98
+ }
99
+
100
+ let seqNum = nextSeqNum++;
101
+ let fullCall: InternalCallType = {
102
+ isReturn: false,
103
+ args: call.args,
104
+ classGuid: call.classGuid,
105
+ functionName: call.functionName,
106
+ seqNum,
107
+ compress: !!SocketFunction.compression,
108
+ };
109
+ let data = Buffer.from(JSON.stringify(fullCall));
110
+ let resultPromise = new Promise((resolve, reject) => {
111
+ let callback = (result: InternalReturnType) => {
112
+ if (SocketFunction.logMessages) {
113
+ console.log(`SIZE\t${(formatNumberSuffixed(result.resultSize) + "B").padEnd(4, " ")}\t${call.classGuid}.${call.functionName}`);
114
+ }
115
+ pendingCalls.delete(seqNum);
116
+ if (result.error) {
117
+ reject(convertErrorStackToError(result.error));
118
+ } else {
119
+ resolve(result.result);
120
+ }
121
+ };
122
+ pendingCalls.set(seqNum, { callback, data, call: fullCall, reconnectTimeout: call.reconnectTimeout });
123
+ });
124
+
125
+ await sendWithRetry(call.reconnectTimeout, data);
126
+
127
+ return await resultPromise;
128
+ }
129
+ };
130
+
136
131
  let webSocket!: SenderInterface;
137
132
  if (!webSocketBase) {
138
133
  await tryToReconnect();
@@ -140,11 +135,6 @@ async function createCallFactory(
140
135
  webSocket = webSocketBase;
141
136
  setupWebsocket(webSocketBase);
142
137
  }
143
- if (isNode()) {
144
- callerContext.nodeId = getNodeId(webSocket);
145
- } else {
146
- callerContext.nodeId = location.address + ":" + location.listeningPorts[0];
147
- }
148
138
 
149
139
  niceConnectionName = `${niceConnectionName} (${callerContext.nodeId})`;
150
140
 
@@ -189,11 +179,9 @@ async function createCallFactory(
189
179
  if (reconnectingPromise) return reconnectingPromise;
190
180
  return reconnectingPromise = (async () => {
191
181
  while (true) {
192
- let ports = location.listeningPorts;
193
-
194
- if (ports.length === 0) {
195
- closedForever = true;
196
- console.log(`No ports to reconnect for ${niceConnectionName}, pendingCall count: ${pendingCalls.size}`);
182
+ if (!retriesEnabled) {
183
+ callFactory.closedForever = true;
184
+ console.log(`Cannot reconnect to ${niceConnectionName}, aborting pendingCalls: ${pendingCalls.size}`);
197
185
  for (let call of pendingCalls.values()) {
198
186
  call.callback({
199
187
  isReturn: true,
@@ -207,8 +195,7 @@ async function createCallFactory(
207
195
  return;
208
196
  }
209
197
 
210
- let port = ports[reconnectAttempts % ports.length];
211
- let newWebSocket = createWebsocket(location.address, port);
198
+ let newWebSocket = createWebsocket(nodeId);
212
199
 
213
200
  let connectError = await new Promise<string | undefined>(resolve => {
214
201
  newWebSocket.addEventListener("open", () => {
@@ -225,24 +212,9 @@ async function createCallFactory(
225
212
  setupWebsocket(newWebSocket);
226
213
 
227
214
  if (!connectError) {
228
- console.log(`Reconnected to ${location.address}:${port}`);
229
-
230
- // NOTE: Clientside doesn't have access to peer certificates, so it can't know the nodeId of the server
231
- // that way. However, it can
232
- if (isNode()) {
233
- let newNodeId = getNodeId(newWebSocket);
234
- let prevNodeId = callerContext.nodeId;
235
- if (prevNodeId === pendingNodeId) {
236
- callerContext.nodeId = newNodeId;
237
- } else {
238
- if (newNodeId !== prevNodeId) {
239
- throw new Error(`Connection lost to at ${niceConnectionName} ("${prevNodeId}"), but then re-established, however it is now "${newNodeId}"!`);
240
- }
241
- }
242
- }
215
+ console.log(`Reconnected to ${niceConnectionName}`);
243
216
 
244
- // I'm not sure if we should clear reconnectAttempts? All the ports should be the same, and actually...
245
- // why would there even be a bad port?
217
+ // I'm not sure if we should clear reconnectAttempts? Maybe if we eventually have a max reconnectAttempts?
246
218
  //reconnectAttempts = 0;
247
219
  reconnectingPromise = undefined;
248
220
 
@@ -264,13 +236,15 @@ async function createCallFactory(
264
236
  }
265
237
 
266
238
  reconnectAttempts++;
267
- console.error(`Connection retry to ${location.address}:${port} failed (attempt ${reconnectAttempts}), retrying in ${retryInterval}ms, error: ${JSON.stringify(connectError)}`);
239
+ console.error(`Connection retry to ${niceConnectionName} failed (attempt ${reconnectAttempts}), retrying in ${retryInterval}ms, error: ${JSON.stringify(connectError)}`);
268
240
  await new Promise(resolve => setTimeout(resolve, retryInterval));
269
241
  }
270
242
  })();
271
243
  }
272
244
 
273
245
  function setupWebsocket(webSocket: SenderInterface) {
246
+ registerNodeClient(callFactory);
247
+
274
248
  webSocket.addEventListener("error", e => {
275
249
  console.log(`Websocket error for ${niceConnectionName}`, e);
276
250
  });
@@ -283,8 +257,6 @@ async function createCallFactory(
283
257
  });
284
258
 
285
259
  webSocket.addEventListener("message", onMessage);
286
-
287
- setCertInfo(webSocket.socket || (webSocket as any)._socket, callerContext);
288
260
  }
289
261
 
290
262
 
@@ -374,42 +346,5 @@ async function createCallFactory(
374
346
  }
375
347
  }
376
348
 
377
- return {
378
- nodeId: callerContext.nodeId,
379
- location,
380
- async performCall(call: CallType) {
381
- if (closedForever) {
382
- throw new Error(`Connection lost to ${niceConnectionName}`);
383
- }
384
-
385
- let seqNum = nextSeqNum++;
386
- let fullCall: InternalCallType = {
387
- isReturn: false,
388
- args: call.args,
389
- classGuid: call.classGuid,
390
- functionName: call.functionName,
391
- seqNum,
392
- compress: !!SocketFunction.compression,
393
- };
394
- let data = Buffer.from(JSON.stringify(fullCall));
395
- let resultPromise = new Promise((resolve, reject) => {
396
- let callback = (result: InternalReturnType) => {
397
- if (SocketFunction.logMessages) {
398
- console.log(`SIZE\t${(formatNumberSuffixed(result.resultSize) + "B").padEnd(4, " ")}\t${call.classGuid}.${call.functionName}`);
399
- }
400
- pendingCalls.delete(seqNum);
401
- if (result.error) {
402
- reject(convertErrorStackToError(result.error));
403
- } else {
404
- resolve(result.result);
405
- }
406
- };
407
- pendingCalls.set(seqNum, { callback, data, call: fullCall, reconnectTimeout: call.reconnectTimeout });
408
- });
409
-
410
- await sendWithRetry(call.reconnectTimeout, data);
411
-
412
- return await resultPromise;
413
- }
414
- };
349
+ return callFactory;
415
350
  }
@@ -1,17 +1,11 @@
1
- import https from "https";
2
1
  import http from "http";
3
- import net from "net";
4
2
  import tls from "tls";
5
- import { CallerContext, CallType, NetworkLocation, setCertInfo } from "../SocketFunctionTypes";
3
+ import { CallerContext, CallType } from "../SocketFunctionTypes";
6
4
  import { isDataImmutable, performLocalCall } from "./callManager";
7
- import { getNodeIdRaw } from "./nodeAuthentication";
8
- import debugbreak from "debugbreak";
9
- import * as cookie from "cookie";
10
5
  import { SocketFunction } from "../SocketFunction";
11
6
  import { gzip } from "zlib";
12
7
  import { formatNumberSuffixed, sha256Hash } from "./misc";
13
-
14
- const nodeIdCookie = "node-id4";
8
+ import { getClientNodeId, getNodeId } from "./nodeCache";
15
9
 
16
10
  let defaultHTTPCall: CallType | undefined;
17
11
 
@@ -19,15 +13,7 @@ export function setDefaultHTTPCall(call: CallType) {
19
13
  defaultHTTPCall = call;
20
14
  }
21
15
 
22
- const cookieNodeIdPrefix = "COOKIE_nodeId_";
23
- export function getNodeIdFromRequest(request: http.IncomingMessage): string | undefined {
24
- let cookies = cookie.parse(request.headers.cookie ?? "");
25
- let value = cookies[nodeIdCookie];
26
- if (!value) return value;
27
- if (!value.startsWith(cookieNodeIdPrefix)) return undefined;
28
- return value;
29
- }
30
- export function getServerLocationFromRequest(request: http.IncomingMessage): NetworkLocation {
16
+ export function getServerLocationFromRequest(request: http.IncomingMessage) {
31
17
  let host = request.headers.host;
32
18
  if (!host) {
33
19
  throw new Error(`Missing host in request headers`);
@@ -41,10 +27,27 @@ export function getServerLocationFromRequest(request: http.IncomingMessage): Net
41
27
  address: host,
42
28
  // This is OUR location, so whatever they connected to us... we must be listening on!
43
29
  // (and the localPort doesn't matter in this case)
44
- listeningPorts: [port],
30
+ port,
45
31
  };
46
32
  }
47
33
 
34
+ export function getNodeIdsFromRequest(request: http.IncomingMessage) {
35
+ // TODO: Support passing signed proof of userCertificate via headers in the HTTP request.
36
+ // THAT WAY HTTP can have consistent nodeIds, instead of making them randomly every time!
37
+ // (This isn't needed or possible for websockets, but they stay open, so calling functions
38
+ // after they open to set the nodeId is possible, and preferred).
39
+ let remoteAddress = request.socket.remoteAddress;
40
+ if (!remoteAddress) {
41
+ throw new Error(`Missing remoteAddress`);
42
+ }
43
+ const nodeId = getClientNodeId(remoteAddress);
44
+
45
+ const serverLocation = getServerLocationFromRequest(request);
46
+ // IMPORTANT! Not the actual local id, but is the id the client called
47
+ const localNodeId = getNodeId(serverLocation.address, serverLocation.port);
48
+ return { nodeId, localNodeId };
49
+ }
50
+
48
51
  export async function httpCallHandler(request: http.IncomingMessage, response: http.ServerResponse) {
49
52
  try {
50
53
 
@@ -72,48 +75,13 @@ export async function httpCallHandler(request: http.IncomingMessage, response: h
72
75
  ;
73
76
  });
74
77
 
75
- let socket = request.connection as tls.TLSSocket;
76
-
77
- let address = socket.remoteAddress;
78
- let port = socket.remotePort;
79
- if (!address) {
80
- throw new Error("Missing remote address");
81
- }
82
- if (!port) {
83
- throw new Error("Missing remote port");
84
- }
85
-
86
- let nodeId = getNodeIdRaw(socket);
87
- if (!nodeId) {
88
- let cookieNodeId = getNodeIdFromRequest(request);
89
- if (typeof cookieNodeId === "string") {
90
- nodeId = cookieNodeId;
91
- }
92
- }
93
- if (!nodeId) {
94
- nodeId = cookieNodeIdPrefix + Date.now() + "_" + Math.random();
95
- response.setHeader("Set-Cookie", cookie.serialize(nodeIdCookie, nodeId, {
96
- httpOnly: true,
97
- path: "/",
98
- secure: true,
99
- domain: urlObj.hostname,
100
- sameSite: "none"
101
- }));
102
-
103
- response.setHeader(nodeIdCookie, nodeId);
104
- }
78
+ const { nodeId, localNodeId } = getNodeIdsFromRequest(request);
105
79
 
106
80
  let caller: CallerContext = {
107
81
  nodeId,
108
- fromPort: port,
109
- location: {
110
- address,
111
- listeningPorts: [],
112
- },
113
- serverLocation: getServerLocationFromRequest(request),
114
82
  certInfo: undefined,
83
+ localNodeId,
115
84
  };
116
- setCertInfo(socket, caller);
117
85
 
118
86
  let classGuid = urlObj.searchParams.get("classGuid");
119
87
  let functionName = urlObj.searchParams.get("functionName");
@@ -1,4 +1,4 @@
1
- import { CallContextType, CallerContext, CallType, ClientHookContext, HookContext, NetworkLocation, SocketExposedInterface, SocketExposedInterfaceClass, SocketExposedShape, SocketFunctionClientHook, SocketFunctionHook, SocketRegistered } from "../SocketFunctionTypes";
1
+ import { CallContextType, CallerContext, CallType, ClientHookContext, HookContext, SocketExposedInterface, SocketExposedInterfaceClass, SocketExposedShape, SocketFunctionClientHook, SocketFunctionHook, SocketRegistered } from "../SocketFunctionTypes";
2
2
  import { _setSocketContext } from "../SocketFunction";
3
3
 
4
4
  let classes: {
@@ -40,7 +40,7 @@ export async function performLocalCall(
40
40
  }
41
41
 
42
42
  let curContext: CallContextType = {};
43
- let serverContext = await runServerHooks(call, { caller, curContext }, functionShape);
43
+ let serverContext = await runServerHooks(call, { caller, curContext, getCaller: () => caller }, functionShape);
44
44
  if ("overrideResult" in serverContext) {
45
45
  return serverContext.overrideResult;
46
46
  }
@@ -75,9 +75,21 @@ export function exposeClass(exposedClass: SocketRegistered) {
75
75
  export function registerGlobalHook(hook: SocketFunctionHook) {
76
76
  globalHooks.push(hook);
77
77
  }
78
+ export function unregisterGlobalHook(hook: SocketFunctionHook) {
79
+ let index = globalHooks.indexOf(hook);
80
+ if (index >= 0) {
81
+ globalHooks.splice(index, 1);
82
+ }
83
+ }
78
84
  export function registerGlobalClientHook(hook: SocketFunctionClientHook) {
79
85
  globalClientHooks.push(hook);
80
86
  }
87
+ export function unregisterGlobalClientHook(hook: SocketFunctionClientHook) {
88
+ let index = globalClientHooks.indexOf(hook);
89
+ if (index >= 0) {
90
+ globalClientHooks.splice(index, 1);
91
+ }
92
+ }
81
93
 
82
94
  export async function runClientHooks(
83
95
  callType: CallType,
@@ -0,0 +1,50 @@
1
+ import * as os from "os";
2
+ import * as fs from "fs/promises";
3
+ import * as fsSync from "fs";
4
+ import * as child_process from "child_process";
5
+ import * as tls from "tls";
6
+ import { SocketFunction } from "../SocketFunction";
7
+ import { isNode, isNodeTrue, sha256Hash } from "./misc";
8
+ import { lazy } from "./caching";
9
+
10
+ let trustedCerts = new Set<string>();
11
+ let loadedTrustedCerts = false;
12
+ let watchCallbacks = new Set<(certs: string[]) => void>();
13
+
14
+ let storePath = isNodeTrue() && process.argv[1].replaceAll("\\", "/").split("/").slice(0, -1).join("/") + "/certstore/";
15
+ if (isNode()) {
16
+ if (!fsSync.existsSync(storePath)) {
17
+ fsSync.mkdirSync(storePath);
18
+ }
19
+ }
20
+
21
+ /** Must be populated before the server starts */
22
+ export async function trustUserCertificate(cert: string) {
23
+ if (trustedCerts.has(cert)) return;
24
+ trustedCerts.add(cert);
25
+ await fs.writeFile(storePath + sha256Hash(Buffer.from(cert)) + ".cer", cert);
26
+ let certs = getTrustedUserCertificates();
27
+ for (let callback of watchCallbacks) {
28
+ callback(certs);
29
+ }
30
+ }
31
+ export const loadTrustedUserCertificates = lazy(async () => {
32
+ let files = await fs.readdir(storePath);
33
+ for (let file of files) {
34
+ let cert = await fs.readFile(storePath + file, "utf8");
35
+ trustedCerts.add(cert);
36
+ }
37
+ loadedTrustedCerts = true;
38
+ });
39
+ export function getTrustedUserCertificates(): string[] {
40
+ if (!loadedTrustedCerts) {
41
+ throw new Error("Must call loadTrustedUserCertificates (and await it) before calling getTrustedUserCertificates");
42
+ }
43
+ return Array.from(trustedCerts);
44
+ }
45
+
46
+ export function watchUserCertificates(callback: (certs: string[]) => void) {
47
+ watchCallbacks.add(callback);
48
+ callback(getTrustedUserCertificates());
49
+ return () => watchCallbacks.delete(callback);
50
+ }
package/src/misc.ts CHANGED
@@ -7,7 +7,7 @@ export function convertErrorStackToError(error: string): Error {
7
7
  return errorObj;
8
8
  }
9
9
 
10
- export function sha256Hash(buffer: Buffer) {
10
+ export function sha256Hash(buffer: Buffer | string) {
11
11
  return crypto.createHash("sha256").update(buffer).digest("hex");
12
12
  }
13
13
  /** Async, but works both clientside and serverside. */
@@ -12,6 +12,10 @@ import { isNode, sha256Hash } from "./misc";
12
12
  import { getArgs } from "./args";
13
13
  import { SenderInterface } from "./CallFactory";
14
14
  import { SocketFunction } from "../SocketFunction";
15
+ import { getTrustedUserCertificates } from "./certStore";
16
+ import { getClientNodeId, getNodeId, getNodeIdLocation } from "./nodeCache";
17
+
18
+ export type CertInfo = { raw: Buffer | string; issuerCertificate: { raw: Buffer | string } };
15
19
 
16
20
  let certKeyPairOverride: { key: Buffer; cert: Buffer } | undefined;
17
21
  export function getCertKeyPair(): { key: Buffer; cert: Buffer } {
@@ -19,13 +23,6 @@ export function getCertKeyPair(): { key: Buffer; cert: Buffer } {
19
23
  return getCertKeyPairBase();
20
24
  }
21
25
  const getCertKeyPairBase = lazy((): { key: Buffer; cert: Buffer } => {
22
- // TODO: Also get this working clientside...
23
- // - Use https://developer.mozilla.org/en-US/docs/Web/API/SubtleCrypto/generateKey
24
- // - We might need node-forge for the Certificate Signing Request and x509 stuff
25
- // - Use ECDSA keys
26
- // - ALSO, get our nodeId set in our cookies, so HTTP requests can work as well
27
- // - We will need callHTTPHandler to support this
28
-
29
26
  // https://nodejs.org/en/knowledge/HTTP/servers/how-to-create-a-HTTPS-server/
30
27
 
31
28
  let folder = getAppFolder();
@@ -45,6 +42,9 @@ const getCertKeyPairBase = lazy((): { key: Buffer; cert: Buffer } => {
45
42
  });
46
43
 
47
44
  export function overrideCertKeyPair<T>(certKey: { key: Buffer; cert: Buffer; }, code: () => T): T {
45
+ if (!isNode()) {
46
+ throw new Error(`Cannot override cert/key pair in browser`);
47
+ }
48
48
  let prevOverride = certKeyPairOverride;
49
49
  certKeyPairOverride = certKey;
50
50
  try {
@@ -70,66 +70,47 @@ export async function getOwnNodeId() {
70
70
  throw new Error(`TODO: Implement getOwnNodeId`);
71
71
  }
72
72
 
73
- export const getNodeId = cacheWeak(function (webSocket: SenderInterface | ws.WebSocket & { nodeId?: string }): string {
74
- if (!(webSocket instanceof ws.WebSocket)) {
75
- if (!webSocket.nodeId) {
76
- throw new Error("Sender isn't a WebSocket, and doesn't have a nodeId");
77
- }
78
- return webSocket.nodeId;
79
- }
80
- let socket = getTLSSocket(webSocket);
81
- let nodeId = getNodeIdRaw(socket);
82
- if (!nodeId) {
83
- if (webSocket.nodeId) {
84
- return webSocket.nodeId;
85
- }
86
- throw new Error(`Missing nodeId. If it is from the browser, this likely means your websocket and HTTP request are using different domains (so the cookies are lost). If it is from NodeJs peer certificate must use an RSA key or EC key (which should have a .modulus property)`);
87
- }
88
- return nodeId;
89
- });
90
-
91
- export function getNodeIdFromCert(cert: { modulus: Buffer }) {
92
- // Apparently some implementations strip preceding zeros, which makes sense, as it is a modulus so
93
- // preceding zeros aren't needed.
94
- let startIndex = 0;
95
- while (startIndex < cert.modulus.length && cert.modulus[startIndex] === 0) {
96
- startIndex++;
73
+ export function getNodeIdFromCert(certRaw: { raw: Buffer | string } | undefined, callbackPort: number | undefined) {
74
+ if (!certRaw?.raw) return undefined;
75
+ let cert = new crypto.X509Certificate(certRaw.raw);
76
+ if (!callbackPort) {
77
+ return getClientNodeId(cert.subject);
97
78
  }
98
- return sha256Hash(cert.modulus.slice(startIndex));
99
- }
100
- export function getNodeIdRaw(socket: tls.TLSSocket) {
101
- let peerCert = socket.getPeerCertificate();
102
- if (!peerCert) {
103
- throw new Error("WebSocket connections must provided a peer certificate");
79
+ let subject = cert.subject;
80
+ if (subject.startsWith("CN=")) {
81
+ subject = subject.slice("CN=".length);
104
82
  }
105
-
106
- if (!peerCert.modulus) return undefined;
107
- return getNodeIdFromCert({ modulus: Buffer.from(peerCert.modulus, "hex") });
83
+ return getNodeId(subject, callbackPort);
108
84
  }
109
85
 
110
86
  /** NOTE: We create a factory, which embeds the key/cert information. Otherwise retries might use
111
87
  * a different key/cert context.
112
88
  */
113
- export function createWebsocketFactory(): (address: string, port: number) => SenderInterface {
89
+ export function createWebsocketFactory(): (nodeId: string) => SenderInterface {
114
90
 
115
91
  if (!isNode()) {
116
- // NOTE: We assume an HTTP request has already been made, which will setup a nodeId cookie
117
- // (And as this point we can't even use peer certificates if we wanted to, as this must be done
118
- // directly in the browser)
119
- return (address: string, port: number) => {
92
+ return (nodeId: string) => {
93
+ let location = getNodeIdLocation(nodeId);
94
+ if (!location) throw new Error(`Cannot connect to ${nodeId}, no address known`);
95
+ let { address, port } = location;
96
+
120
97
  console.log(`Connecting to ${address}:${port}`);
121
98
  return new WebSocket(`wss://${address}:${port}`);
122
99
  };
123
100
  } else {
124
101
  let { key, cert } = getCertKeyPair();
125
102
  let rejectUnauthorized = SocketFunction.rejectUnauthorized;
126
- return (address: string, port: number) => {
103
+ return (nodeId: string) => {
104
+ let location = getNodeIdLocation(nodeId);
105
+ if (!location) throw new Error(`Cannot connect to ${nodeId}, no address known`);
106
+ let { address, port } = location;
107
+
127
108
  console.log(`Connecting to ${address}:${port}`);
128
109
  let webSocket = new ws.WebSocket(`wss://${address}:${port}`, {
129
110
  cert,
130
111
  key,
131
112
  rejectUnauthorized,
132
- ca: tls.rootCertificates.concat(SocketFunction.additionalTrustedRootCAs),
113
+ ca: tls.rootCertificates.concat(getTrustedUserCertificates()),
133
114
  });
134
115
  let result = Object.assign(webSocket, { socket: undefined as tls.TLSSocket | undefined });
135
116
  webSocket.once("upgrade", e => {
@@ -138,4 +119,5 @@ export function createWebsocketFactory(): (address: string, port: number) => Sen
138
119
  return result;
139
120
  };
140
121
  }
141
- }
122
+ }
123
+
package/src/nodeCache.ts CHANGED
@@ -1,6 +1,7 @@
1
- import { callFactoryFromLocation, CallFactory } from "./CallFactory";
2
- import { NetworkLocation } from "../SocketFunctionTypes";
1
+ import { CallFactory, createCallFactory } from "./CallFactory";
3
2
  import { MaybePromise } from "./types";
3
+ import { lazy } from "./caching";
4
+ import { SocketFunction } from "../SocketFunction";
4
5
 
5
6
  // TODO: Add CallInstanceFactory.isClosed, so nodeCache can clean up old entries.
6
7
  // This is only needed for memory management, and not for correctness. Entries never
@@ -11,91 +12,66 @@ import { MaybePromise } from "./types";
11
12
  // a value to a new value... then they should be obtained using connect() anyway,
12
13
  // and so whatever way the user got the NetworkLocation to begin with, they should use again.
13
14
 
14
-
15
- // nodeId =>
16
- const nodeCache = new Map<string, {
17
- callFactory: MaybePromise<CallFactory>;
18
- // Just used for getCallFactoryFromNodeId
19
- location: NetworkLocation | undefined;
20
- }>();
21
- const locationLookup = new Map<string, MaybePromise<string>>();
22
-
23
- export function getNetworkLocationHash(location: NetworkLocation): string {
24
- return location.address + ":" + location.listeningPorts.join("|");
15
+ export function getNodeId(domain: string, port: number): string {
16
+ // NOTE: As domains are never reused, this doesn't need any randomness
17
+ return `${domain}:${port}`;
25
18
  }
26
19
 
27
- // NOTE: For client connections, at which point we have the nodeId, location and callFactory.
28
- export function registerNodeClient(callFactory: CallFactory) {
29
- let { nodeId } = callFactory;
30
- // NOTE: We can always clobber the entry, AS, during client connection we give NetworkLocation information,
31
- // so even if we already have this node with NetworkLocation.listeningPorts, this new values should
32
- // be even newer, or the same.
33
- // - AND, clobbering shouldn't happen often, if the other end connected to us they should have given us their
34
- // nodeId. So they'll use the existing websocket when using that nodeId, instead of establishing a new connection,
35
- // except for race conditions cases, in which case we just have an extra connection, which isn't so bad...
36
- // - And of course, we have to use the newer connection, as it might be the case that the NetworkLocation has actually
37
- // updated, and the old connection is now forever closed.
20
+ /** A nodeId not available for reconnecting. */
21
+ export function getClientNodeId(address: string): string {
22
+ return `client_${address}:${Date.now()}:${Math.random()}`;
23
+ }
38
24
 
39
- // Never go from listening ports to no listening ports. Worst case the listening ports are old
40
- // and won't work, in which case... we won't be able to reconnect, which basically what
41
- // we would do if there were no listening ports.
42
- let prevFactory = nodeCache.get(nodeId)?.callFactory;
43
- if (prevFactory && !(prevFactory instanceof Promise)) {
44
- let prevListeningPorts = prevFactory.location.listeningPorts;
45
- if (prevListeningPorts && !callFactory.location.listeningPorts.length) {
46
- callFactory.location.listeningPorts = prevListeningPorts;
47
- }
25
+ export function getNodeIdLocation(nodeId: string): { address: string, port: number; } | undefined {
26
+ if (nodeId.startsWith("client_")) {
27
+ return undefined;
48
28
  }
49
- // TODO: Maybe even preserve the address in some cases, such as if it was a domain, and is now an ip?
50
- nodeCache.set(nodeId, {
51
- callFactory,
52
- location: undefined,
53
- });
29
+ let [address, port] = nodeId.split(":");
30
+ return { address, port: parseInt(port) };
54
31
  }
55
32
 
56
- export function getCreateCallFactoryLocation(location: NetworkLocation, tempNodeId?: string): MaybePromise<string> {
57
- let locationHash = getNetworkLocationHash(location);
58
- let nodeId = locationLookup.get(locationHash);
59
- if (nodeId !== undefined) {
60
- return nodeId;
33
+ export function getNodeIdDomain(nodeId: string): string {
34
+ let location = getNodeIdLocation(nodeId);
35
+ if (!location) {
36
+ throw new Error(`Cannot get domain from nodeId, which is only usable as a client. NodeId: ${JSON.stringify(nodeId)}`);
61
37
  }
38
+ return new URL(location.address).hostname.split(".").slice(-2).join(".");
39
+ }
62
40
 
63
- let callFactoryPromise = callFactoryFromLocation(location);
64
- let nodeIdPromise = callFactoryPromise.then(x => x.nodeId);
65
- locationLookup.set(locationHash, nodeIdPromise);
41
+ // NOTE: CallFactory turns into an actual CallFactory when registerNodeClient is called
42
+ // nodeId =>
43
+ const nodeCache = new Map<string, MaybePromise<CallFactory>>();
66
44
 
67
- if (tempNodeId !== undefined) {
68
- nodeCache.set(tempNodeId, {
69
- callFactory: callFactoryPromise,
70
- location,
71
- });
45
+ // NOTE: Should be called directly inside call factory constructor whenever
46
+ // their nodeId changes (and on construction).
47
+ export function registerNodeClient(callFactory: CallFactory) {
48
+ nodeCache.set(callFactory.nodeId, callFactory);
49
+ startCleanupLoop();
50
+ }
51
+
52
+ export function getCreateCallFactoryLocation(nodeId: string, mountedNodeId: string): MaybePromise<CallFactory> {
53
+ let callFactory = nodeCache.get(nodeId);
54
+ if (callFactory === undefined) {
55
+ callFactory = createCallFactory(undefined, nodeId, mountedNodeId);
56
+ nodeCache.set(nodeId, callFactory);
72
57
  }
58
+ return callFactory;
59
+ }
73
60
 
74
- return callFactoryPromise.then(callFactory => {
75
- let nodeId = callFactory.nodeId;
76
- // TODO: Maybe warn if we just clobbered a nodeId?
77
- let prevEntry = nodeCache.get(nodeId);
78
- if (prevEntry) {
79
- if (prevEntry.callFactory instanceof Promise) {
80
- console.warn(`Clobbering nodeId ${nodeId}, with a new location ${locationHash}, which was still resolving. (This might indiciate multiple locations with the same nodeId, which could cause an issue. If this happens repeatedly it will cause stability issues).`);
81
- } else {
82
- console.warn(`Clobbering nodeId ${nodeId}, with a new location ${locationHash}, was ${getNetworkLocationHash(prevEntry.callFactory.location)}. (This might indiciate multiple locations with the same nodeId, which could cause an issue. If this happens repeatedly it will cause stability issues).`);
61
+ const startCleanupLoop = lazy(() => {
62
+ (async () => {
63
+ while (true) {
64
+ for (let [key, value] of Array.from(nodeCache.entries())) {
65
+ let factory = value;
66
+ if (!(factory instanceof Promise)) {
67
+ if (factory.closedForever) {
68
+ nodeCache.delete(key);
69
+ }
70
+ }
83
71
  }
72
+ await new Promise(resolve => setTimeout(resolve, 1000 * 60 * 5));
84
73
  }
85
- nodeCache.set(nodeId, {
86
- callFactory,
87
- location,
88
- });
89
- return nodeId;
74
+ })().catch(e => {
75
+ console.error(`nodeCache cleanup loop failed, ${e.stack}`);
90
76
  });
91
- }
92
-
93
-
94
- // TODO: Give a special error if the nodeId has been seen, but is only one-way (from HTTP requests).
95
- export async function getCallFactoryFromNodeId(nodeId: string): Promise<CallFactory | undefined> {
96
- return await nodeCache.get(nodeId)?.callFactory;
97
- }
98
- // NOTE: Only works if the nodeId has been loaded with getCreateCallFactoryLocation
99
- export function getLocationFromNodeId(nodeId: string): NetworkLocation | undefined {
100
- return nodeCache.get(nodeId)?.location;
101
- }
77
+ });
@@ -3,18 +3,11 @@ import http from "http";
3
3
  import net from "net";
4
4
  import tls from "tls";
5
5
  import * as ws from "ws";
6
- import { performLocalCall } from "./callManager";
7
- import { CallerContext, CallType, NetworkLocation } from "../SocketFunctionTypes";
8
- import { CallFactory, callFactoryFromWS } from "./CallFactory";
9
- import { registerNodeClient } from "./nodeCache";
10
- import { getCertKeyPair, getNodeId, getNodeIdRaw } from "./nodeAuthentication";
11
- import debugbreak from "debugbreak";
12
- import { cache } from "./caching";
13
- import { getNodeIdFromRequest, getServerLocationFromRequest, httpCallHandler } from "./callHTTPHandler";
6
+ import { getCertKeyPair, getNodeIdFromCert } from "./nodeAuthentication";
7
+ import { getNodeIdsFromRequest, httpCallHandler } from "./callHTTPHandler";
14
8
  import { SocketFunction } from "../SocketFunction";
15
-
16
- // TODO: Support conditional peer certificate requests, as it the certificate prompt
17
- // seems suspicious in the browser (the user can just click cancel though).
9
+ import { getTrustedUserCertificates, loadTrustedUserCertificates, watchUserCertificates } from "./certStore";
10
+ import { createCallFactory } from "./CallFactory";
18
11
 
19
12
  export type SocketServerConfig = (
20
13
  {
@@ -30,7 +23,7 @@ export type SocketServerConfig = (
30
23
 
31
24
  export async function startSocketServer(
32
25
  config: SocketServerConfig
33
- ) {
26
+ ): Promise<string> {
34
27
  let isSecure = "cert" in config || "key" in config || "pfx" in config;
35
28
  if (!isSecure) {
36
29
  let { key, cert } = getCertKeyPair();
@@ -38,19 +31,30 @@ export async function startSocketServer(
38
31
  config.cert = cert;
39
32
  }
40
33
 
34
+ await loadTrustedUserCertificates();
35
+
41
36
  // TODO: Only allow unauthorized for ip certificates, and then for domains use the domain as the nodeId,
42
37
  // so it is easy to read, and consistent.
43
- let httpsServer = https.createServer({
38
+ let options: https.ServerOptions = {
44
39
  ...config,
45
40
  rejectUnauthorized: SocketFunction.rejectUnauthorized,
46
41
  requestCert: true,
47
- ca: tls.rootCertificates.concat(SocketFunction.additionalTrustedRootCAs),
42
+ };
43
+
44
+ let httpsServer = https.createServer(options);
45
+ watchUserCertificates(() => {
46
+ options.ca = tls.rootCertificates.concat(getTrustedUserCertificates());
47
+ httpsServer.setSecureContext(options);
48
48
  });
49
+
49
50
  httpsServer.on("connection", socket => {
50
51
  console.log("Client connection established");
51
52
  socket.on("error", e => {
52
53
  console.log(`Client socket error ${e.message}`);
53
54
  });
55
+ socket.on("close", () => {
56
+ console.log("Client socket closed");
57
+ });
54
58
  });
55
59
  httpsServer.on("error", e => {
56
60
  console.error(`Connection attempt error ${e.message}`);
@@ -83,15 +87,11 @@ export async function startSocketServer(
83
87
  return;
84
88
  }
85
89
  }
86
- webSocketServer.handleUpgrade(request, socket, upgradeHead, async (ws) => {
87
- // NOTE: For the browser, the request will likely have a nodeId, from making an HTTP request.
88
- // We would prefer peer certificates, so this isn't the default (in getNodeId), but it will
89
- // likely be used most of the time.
90
- let requestNodeId = getNodeIdFromRequest(request);
91
- Object.assign(ws, { nodeId: requestNodeId });
92
-
93
- let clientCallFactory = await callFactoryFromWS(ws, getServerLocationFromRequest(request));
94
- registerNodeClient(clientCallFactory);
90
+ webSocketServer.handleUpgrade(request, socket, upgradeHead, (ws) => {
91
+ const { nodeId, localNodeId } = getNodeIdsFromRequest(request);
92
+ createCallFactory(ws, nodeId, localNodeId).catch(e => {
93
+ console.error(`Error in creating call factory, ${e.stack}`);
94
+ });
95
95
  });
96
96
  });
97
97
 
@@ -146,5 +146,13 @@ export async function startSocketServer(
146
146
 
147
147
  await listenPromise;
148
148
 
149
- console.log(`Started Listening on ${host}:${config.port}`);
149
+ let port = (realServer.address() as net.AddressInfo).port;
150
+
151
+ console.log(`Started Listening on ${host}:${port}`);
152
+
153
+ let serverNodeId = getNodeIdFromCert({ raw: config.cert as Buffer | string }, port);
154
+ if (!serverNodeId) {
155
+ throw new Error(`Something is wrong with our cert, we don't have a nodeId?`);
156
+ }
157
+ return serverNodeId;
150
158
  }