@linkshell/gateway 0.2.16 → 0.2.17

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/src/tunnel.ts ADDED
@@ -0,0 +1,347 @@
1
+ import type { IncomingMessage, ServerResponse } from "node:http";
2
+ import { randomUUID } from "node:crypto";
3
+ import type WebSocket from "ws";
4
+ import {
5
+ createEnvelope,
6
+ serializeEnvelope,
7
+ } from "@linkshell/protocol";
8
+ import type { SessionManager } from "./sessions.js";
9
+ import type { TokenManager } from "./tokens.js";
10
+
11
+ const TUNNEL_TIMEOUT = 30_000;
12
+ const MAX_TUNNEL_BODY = 10 * 1024 * 1024; // 10MB
13
+
14
+ export interface PendingTunnelRequest {
15
+ res: ServerResponse;
16
+ headersSent: boolean;
17
+ timeout: ReturnType<typeof setTimeout>;
18
+ }
19
+
20
+ export interface PendingTunnelWs {
21
+ ws: WebSocket;
22
+ }
23
+
24
+ // Module-level maps keyed by requestId
25
+ const pendingRequests = new Map<string, PendingTunnelRequest>();
26
+ const pendingWsSockets = new Map<string, PendingTunnelWs>();
27
+
28
+ // Track requestIds per session for cleanup on host disconnect
29
+ const sessionRequests = new Map<string, Set<string>>();
30
+
31
+ function trackRequest(sessionId: string, requestId: string): void {
32
+ let set = sessionRequests.get(sessionId);
33
+ if (!set) {
34
+ set = new Set();
35
+ sessionRequests.set(sessionId, set);
36
+ }
37
+ set.add(requestId);
38
+ }
39
+
40
+ function untrackRequest(sessionId: string, requestId: string): void {
41
+ const set = sessionRequests.get(sessionId);
42
+ if (set) {
43
+ set.delete(requestId);
44
+ if (set.size === 0) sessionRequests.delete(sessionId);
45
+ }
46
+ }
47
+
48
+ function extractToken(req: IncomingMessage, url: URL): string | null {
49
+ // Check Authorization header
50
+ const auth = req.headers.authorization;
51
+ if (auth) {
52
+ const match = auth.match(/^Bearer\s+(.+)$/i);
53
+ if (match?.[1]) return match[1];
54
+ }
55
+ // Check query param
56
+ const qToken = url.searchParams.get("token");
57
+ if (qToken) return qToken;
58
+ // Check cookie
59
+ const cookie = req.headers.cookie;
60
+ if (cookie) {
61
+ const match = cookie.match(/lsh_tunnel=([^;]+)/);
62
+ if (match?.[1]) return match[1];
63
+ }
64
+ return null;
65
+ }
66
+
67
+ function errorResponse(res: ServerResponse, status: number, message: string): void {
68
+ if (res.headersSent) return;
69
+ res.writeHead(status, {
70
+ "content-type": "text/plain",
71
+ "access-control-allow-origin": "*",
72
+ });
73
+ res.end(message);
74
+ }
75
+
76
+ export function parseTunnelPath(pathname: string): { sessionId: string; port: number; path: string } | null {
77
+ const match = pathname.match(/^\/tunnel\/([^/]+)\/(\d+)(\/.*)?$/);
78
+ if (!match) return null;
79
+ const port = Number(match[2]);
80
+ if (port < 1 || port > 65535) return null;
81
+ return {
82
+ sessionId: match[1]!,
83
+ port,
84
+ path: match[3] || "/",
85
+ };
86
+ }
87
+
88
+ export async function handleTunnelRequest(
89
+ req: IncomingMessage,
90
+ res: ServerResponse,
91
+ sessions: SessionManager,
92
+ tokens: TokenManager,
93
+ parsed: { sessionId: string; port: number; path: string },
94
+ url: URL,
95
+ ): Promise<void> {
96
+ const { sessionId, port, path } = parsed;
97
+
98
+ // Auth
99
+ const token = extractToken(req, url);
100
+ if (!token || !tokens.owns(token, sessionId)) {
101
+ errorResponse(res, 401, "Unauthorized");
102
+ return;
103
+ }
104
+
105
+ // Set auth cookie for subsequent sub-resource requests
106
+ res.setHeader("Set-Cookie", `lsh_tunnel=${token}; Path=/tunnel/${sessionId}/; HttpOnly; SameSite=Lax`);
107
+
108
+ // Validate session & host
109
+ const session = sessions.get(sessionId);
110
+ if (!session || !session.host || session.host.socket.readyState !== session.host.socket.OPEN) {
111
+ errorResponse(res, 502, "Host not connected");
112
+ return;
113
+ }
114
+
115
+ const requestId = randomUUID();
116
+ const method = req.method ?? "GET";
117
+
118
+ // Read request body
119
+ let body: string | null = null;
120
+ if (method !== "GET" && method !== "HEAD") {
121
+ const chunks: Buffer[] = [];
122
+ let size = 0;
123
+ for await (const chunk of req) {
124
+ const buf = Buffer.isBuffer(chunk) ? chunk : Buffer.from(chunk);
125
+ size += buf.length;
126
+ if (size > MAX_TUNNEL_BODY) {
127
+ errorResponse(res, 413, "Request body too large");
128
+ return;
129
+ }
130
+ chunks.push(buf);
131
+ }
132
+ if (chunks.length > 0) {
133
+ body = Buffer.concat(chunks).toString("base64");
134
+ }
135
+ }
136
+
137
+ // Build forwarded headers (strip hop-by-hop)
138
+ const headers: Record<string, string> = {};
139
+ const skipHeaders = new Set(["host", "connection", "upgrade", "transfer-encoding", "keep-alive"]);
140
+ for (const [key, val] of Object.entries(req.headers)) {
141
+ if (!skipHeaders.has(key) && typeof val === "string") {
142
+ headers[key] = val;
143
+ }
144
+ }
145
+
146
+ // Reconstruct URL with query string
147
+ const fullUrl = path + (url.search || "");
148
+
149
+ // Register pending request
150
+ const pending: PendingTunnelRequest = {
151
+ res,
152
+ headersSent: false,
153
+ timeout: setTimeout(() => {
154
+ pendingRequests.delete(requestId);
155
+ untrackRequest(sessionId, requestId);
156
+ errorResponse(res, 504, "Tunnel request timed out");
157
+ }, TUNNEL_TIMEOUT),
158
+ };
159
+ pendingRequests.set(requestId, pending);
160
+ trackRequest(sessionId, requestId);
161
+
162
+ // Send tunnel.request to host
163
+ const envelope = createEnvelope({
164
+ type: "tunnel.request",
165
+ sessionId,
166
+ payload: {
167
+ requestId,
168
+ method,
169
+ url: fullUrl,
170
+ headers,
171
+ body,
172
+ port,
173
+ },
174
+ });
175
+ session.host.socket.send(serializeEnvelope(envelope));
176
+
177
+ // Handle client disconnect
178
+ req.on("close", () => {
179
+ const p = pendingRequests.get(requestId);
180
+ if (p) {
181
+ clearTimeout(p.timeout);
182
+ pendingRequests.delete(requestId);
183
+ untrackRequest(sessionId, requestId);
184
+ }
185
+ });
186
+ }
187
+
188
+ export function handleTunnelResponse(payload: {
189
+ requestId: string;
190
+ statusCode: number;
191
+ headers: Record<string, string>;
192
+ body: string;
193
+ isFinal: boolean;
194
+ }): void {
195
+ const pending = pendingRequests.get(payload.requestId);
196
+ if (!pending) return;
197
+
198
+ if (!pending.headersSent) {
199
+ // Merge CORS headers
200
+ const responseHeaders: Record<string, string> = {
201
+ ...payload.headers,
202
+ "access-control-allow-origin": "*",
203
+ };
204
+ pending.res.writeHead(payload.statusCode, responseHeaders);
205
+ pending.headersSent = true;
206
+ }
207
+
208
+ // Write body chunk
209
+ if (payload.body) {
210
+ pending.res.write(Buffer.from(payload.body, "base64"));
211
+ }
212
+
213
+ if (payload.isFinal) {
214
+ clearTimeout(pending.timeout);
215
+ pendingRequests.delete(payload.requestId);
216
+ pending.res.end();
217
+ }
218
+ }
219
+
220
+ export function handleTunnelWsData(payload: {
221
+ requestId: string;
222
+ data: string;
223
+ isBinary: boolean;
224
+ }): void {
225
+ const pending = pendingWsSockets.get(payload.requestId);
226
+ if (!pending) return;
227
+ const buf = Buffer.from(payload.data, "base64");
228
+ pending.ws.send(payload.isBinary ? buf : buf.toString("utf8"));
229
+ }
230
+
231
+ export function handleTunnelWsClose(payload: {
232
+ requestId: string;
233
+ code?: number;
234
+ reason?: string;
235
+ }): void {
236
+ const pending = pendingWsSockets.get(payload.requestId);
237
+ if (!pending) return;
238
+ pending.ws.close(payload.code ?? 1000, payload.reason ?? "");
239
+ pendingWsSockets.delete(payload.requestId);
240
+ }
241
+
242
+ export function registerTunnelWs(requestId: string, ws: WebSocket): void {
243
+ pendingWsSockets.set(requestId, { ws });
244
+ }
245
+
246
+ export function removeTunnelWs(requestId: string): void {
247
+ pendingWsSockets.delete(requestId);
248
+ }
249
+
250
+ export function cleanupSessionTunnels(sessionId: string): void {
251
+ const requestIds = sessionRequests.get(sessionId);
252
+ if (!requestIds) return;
253
+ for (const rid of requestIds) {
254
+ const pending = pendingRequests.get(rid);
255
+ if (pending) {
256
+ clearTimeout(pending.timeout);
257
+ errorResponse(pending.res, 502, "Host disconnected");
258
+ pendingRequests.delete(rid);
259
+ }
260
+ const ws = pendingWsSockets.get(rid);
261
+ if (ws) {
262
+ ws.ws.close(1001, "Host disconnected");
263
+ pendingWsSockets.delete(rid);
264
+ }
265
+ }
266
+ sessionRequests.delete(sessionId);
267
+ }
268
+
269
+ export function handleTunnelWsUpgrade(
270
+ ws: WebSocket,
271
+ parsed: { sessionId: string; port: number; path: string },
272
+ url: URL,
273
+ sessions: SessionManager,
274
+ tokens: TokenManager,
275
+ ): void {
276
+ const { sessionId, port, path } = parsed;
277
+
278
+ // Auth from query param or cookie in upgrade request
279
+ const token = url.searchParams.get("token");
280
+ if (!token || !tokens.owns(token, sessionId)) {
281
+ ws.close(4001, "Unauthorized");
282
+ return;
283
+ }
284
+
285
+ const session = sessions.get(sessionId);
286
+ if (!session || !session.host || session.host.socket.readyState !== session.host.socket.OPEN) {
287
+ ws.close(4002, "Host not connected");
288
+ return;
289
+ }
290
+
291
+ const requestId = randomUUID();
292
+ const fullUrl = path + (url.search || "");
293
+
294
+ // Register this WS so host responses route here
295
+ registerTunnelWs(requestId, ws);
296
+ trackRequest(sessionId, requestId);
297
+
298
+ // Send tunnel.request with upgrade header to host
299
+ const envelope = createEnvelope({
300
+ type: "tunnel.request",
301
+ sessionId,
302
+ payload: {
303
+ requestId,
304
+ method: "GET",
305
+ url: fullUrl,
306
+ headers: { "upgrade": "websocket" },
307
+ body: null,
308
+ port,
309
+ },
310
+ });
311
+ session.host.socket.send(serializeEnvelope(envelope));
312
+
313
+ // Forward data from browser WS to host
314
+ ws.on("message", (data: Buffer | string) => {
315
+ const s = sessions.get(sessionId);
316
+ if (!s?.host || s.host.socket.readyState !== s.host.socket.OPEN) return;
317
+ const isBinary = typeof data !== "string";
318
+ const buf = typeof data === "string" ? Buffer.from(data) : data;
319
+ const fwd = createEnvelope({
320
+ type: "tunnel.ws.data",
321
+ sessionId,
322
+ payload: {
323
+ requestId,
324
+ data: buf.toString("base64"),
325
+ isBinary,
326
+ },
327
+ });
328
+ s.host.socket.send(serializeEnvelope(fwd));
329
+ });
330
+
331
+ ws.on("close", (code, reason) => {
332
+ removeTunnelWs(requestId);
333
+ untrackRequest(sessionId, requestId);
334
+ const s = sessions.get(sessionId);
335
+ if (!s?.host || s.host.socket.readyState !== s.host.socket.OPEN) return;
336
+ const fwd = createEnvelope({
337
+ type: "tunnel.ws.close",
338
+ sessionId,
339
+ payload: {
340
+ requestId,
341
+ code,
342
+ reason: reason?.toString() || "",
343
+ },
344
+ });
345
+ s.host.socket.send(serializeEnvelope(fwd));
346
+ });
347
+ }