@linkshell/gateway 0.2.16 → 0.2.18

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/src/tunnel.ts ADDED
@@ -0,0 +1,363 @@
1
+ import type { IncomingMessage, ServerResponse } from "node:http";
2
+ import { randomUUID } from "node:crypto";
3
+ import type WebSocket from "ws";
4
+ import {
5
+ createEnvelope,
6
+ serializeEnvelope,
7
+ } from "@linkshell/protocol";
8
+ import type { SessionManager } from "./sessions.js";
9
+ import type { TokenManager } from "./tokens.js";
10
+
11
+ const TUNNEL_TIMEOUT = 30_000;
12
+ const MAX_TUNNEL_BODY = 10 * 1024 * 1024; // 10MB
13
+
14
+ export interface PendingTunnelRequest {
15
+ res: ServerResponse;
16
+ headersSent: boolean;
17
+ timeout: ReturnType<typeof setTimeout>;
18
+ }
19
+
20
+ export interface PendingTunnelWs {
21
+ ws: WebSocket;
22
+ }
23
+
24
+ // Module-level maps keyed by requestId
25
+ const pendingRequests = new Map<string, PendingTunnelRequest>();
26
+ const pendingWsSockets = new Map<string, PendingTunnelWs>();
27
+
28
+ // Track requestIds per session for cleanup on host disconnect
29
+ const sessionRequests = new Map<string, Set<string>>();
30
+
31
+ function trackRequest(sessionId: string, requestId: string): void {
32
+ let set = sessionRequests.get(sessionId);
33
+ if (!set) {
34
+ set = new Set();
35
+ sessionRequests.set(sessionId, set);
36
+ }
37
+ set.add(requestId);
38
+ }
39
+
40
+ function untrackRequest(sessionId: string, requestId: string): void {
41
+ const set = sessionRequests.get(sessionId);
42
+ if (set) {
43
+ set.delete(requestId);
44
+ if (set.size === 0) sessionRequests.delete(sessionId);
45
+ }
46
+ }
47
+
48
+ function extractToken(req: IncomingMessage, url: URL): string | null {
49
+ // Check Authorization header
50
+ const auth = req.headers.authorization;
51
+ if (auth) {
52
+ const match = auth.match(/^Bearer\s+(.+)$/i);
53
+ if (match?.[1]) return match[1];
54
+ }
55
+ // Check query param
56
+ const qToken = url.searchParams.get("token");
57
+ if (qToken) return qToken;
58
+ // Check cookie (token-only cookie)
59
+ const cookie = req.headers.cookie;
60
+ if (cookie) {
61
+ const match = cookie.match(/lsh_token=([^;]+)/);
62
+ if (match?.[1]) return match[1];
63
+ }
64
+ return null;
65
+ }
66
+
67
+ /** Parse lsh_tunnel cookie: "sessionId:port:token" */
68
+ export function parseTunnelCookie(req: IncomingMessage): { sessionId: string; port: number; token: string } | null {
69
+ const cookie = req.headers.cookie;
70
+ if (!cookie) return null;
71
+ const match = cookie.match(/lsh_tunnel=([^;]+)/);
72
+ if (!match?.[1]) return null;
73
+ const parts = decodeURIComponent(match[1]).split(":");
74
+ if (parts.length < 3) return null;
75
+ const sessionId = parts[0]!;
76
+ const port = Number(parts[1]);
77
+ const token = parts.slice(2).join(":"); // token may contain colons
78
+ if (!sessionId || isNaN(port) || port < 1 || port > 65535 || !token) return null;
79
+ return { sessionId, port, token };
80
+ }
81
+
82
+ function errorResponse(res: ServerResponse, status: number, message: string): void {
83
+ if (res.headersSent) return;
84
+ res.writeHead(status, {
85
+ "content-type": "text/plain",
86
+ "access-control-allow-origin": "*",
87
+ });
88
+ res.end(message);
89
+ }
90
+
91
+ export function parseTunnelPath(pathname: string): { sessionId: string; port: number; path: string } | null {
92
+ const match = pathname.match(/^\/tunnel\/([^/]+)\/(\d+)(\/.*)?$/);
93
+ if (!match) return null;
94
+ const port = Number(match[2]);
95
+ if (port < 1 || port > 65535) return null;
96
+ return {
97
+ sessionId: match[1]!,
98
+ port,
99
+ path: match[3] || "/",
100
+ };
101
+ }
102
+
103
+ export async function handleTunnelRequest(
104
+ req: IncomingMessage,
105
+ res: ServerResponse,
106
+ sessions: SessionManager,
107
+ tokens: TokenManager,
108
+ parsed: { sessionId: string; port: number; path: string },
109
+ url: URL,
110
+ ): Promise<void> {
111
+ const { sessionId, port, path } = parsed;
112
+
113
+ // Auth
114
+ const token = extractToken(req, url);
115
+ if (!token || !tokens.owns(token, sessionId)) {
116
+ errorResponse(res, 401, "Unauthorized");
117
+ return;
118
+ }
119
+
120
+ // Set auth cookie for subsequent sub-resource requests (root path so /_next/... etc. are covered)
121
+ const cookieVal = encodeURIComponent(`${sessionId}:${port}:${token}`);
122
+ res.setHeader("Set-Cookie", `lsh_tunnel=${cookieVal}; Path=/; HttpOnly; SameSite=Lax`);
123
+
124
+ // Validate session & host
125
+ const session = sessions.get(sessionId);
126
+ if (!session || !session.host || session.host.socket.readyState !== session.host.socket.OPEN) {
127
+ errorResponse(res, 502, "Host not connected");
128
+ return;
129
+ }
130
+
131
+ const requestId = randomUUID();
132
+ const method = req.method ?? "GET";
133
+
134
+ // Read request body
135
+ let body: string | null = null;
136
+ if (method !== "GET" && method !== "HEAD") {
137
+ const chunks: Buffer[] = [];
138
+ let size = 0;
139
+ for await (const chunk of req) {
140
+ const buf = Buffer.isBuffer(chunk) ? chunk : Buffer.from(chunk);
141
+ size += buf.length;
142
+ if (size > MAX_TUNNEL_BODY) {
143
+ errorResponse(res, 413, "Request body too large");
144
+ return;
145
+ }
146
+ chunks.push(buf);
147
+ }
148
+ if (chunks.length > 0) {
149
+ body = Buffer.concat(chunks).toString("base64");
150
+ }
151
+ }
152
+
153
+ // Build forwarded headers (strip hop-by-hop)
154
+ const headers: Record<string, string> = {};
155
+ const skipHeaders = new Set(["host", "connection", "upgrade", "transfer-encoding", "keep-alive"]);
156
+ for (const [key, val] of Object.entries(req.headers)) {
157
+ if (!skipHeaders.has(key) && typeof val === "string") {
158
+ headers[key] = val;
159
+ }
160
+ }
161
+
162
+ // Reconstruct URL with query string
163
+ const fullUrl = path + (url.search || "");
164
+
165
+ // Register pending request
166
+ const pending: PendingTunnelRequest = {
167
+ res,
168
+ headersSent: false,
169
+ timeout: setTimeout(() => {
170
+ pendingRequests.delete(requestId);
171
+ untrackRequest(sessionId, requestId);
172
+ errorResponse(res, 504, "Tunnel request timed out");
173
+ }, TUNNEL_TIMEOUT),
174
+ };
175
+ pendingRequests.set(requestId, pending);
176
+ trackRequest(sessionId, requestId);
177
+
178
+ // Send tunnel.request to host
179
+ const envelope = createEnvelope({
180
+ type: "tunnel.request",
181
+ sessionId,
182
+ payload: {
183
+ requestId,
184
+ method,
185
+ url: fullUrl,
186
+ headers,
187
+ body,
188
+ port,
189
+ },
190
+ });
191
+ session.host.socket.send(serializeEnvelope(envelope));
192
+
193
+ // Handle client disconnect
194
+ req.on("close", () => {
195
+ const p = pendingRequests.get(requestId);
196
+ if (p) {
197
+ clearTimeout(p.timeout);
198
+ pendingRequests.delete(requestId);
199
+ untrackRequest(sessionId, requestId);
200
+ }
201
+ });
202
+ }
203
+
204
+ export function handleTunnelResponse(payload: {
205
+ requestId: string;
206
+ statusCode: number;
207
+ headers: Record<string, string>;
208
+ body: string;
209
+ isFinal: boolean;
210
+ }): void {
211
+ const pending = pendingRequests.get(payload.requestId);
212
+ if (!pending) return;
213
+
214
+ if (!pending.headersSent) {
215
+ // Merge CORS headers
216
+ const responseHeaders: Record<string, string> = {
217
+ ...payload.headers,
218
+ "access-control-allow-origin": "*",
219
+ };
220
+ pending.res.writeHead(payload.statusCode, responseHeaders);
221
+ pending.headersSent = true;
222
+ }
223
+
224
+ // Write body chunk
225
+ if (payload.body) {
226
+ pending.res.write(Buffer.from(payload.body, "base64"));
227
+ }
228
+
229
+ if (payload.isFinal) {
230
+ clearTimeout(pending.timeout);
231
+ pendingRequests.delete(payload.requestId);
232
+ pending.res.end();
233
+ }
234
+ }
235
+
236
+ export function handleTunnelWsData(payload: {
237
+ requestId: string;
238
+ data: string;
239
+ isBinary: boolean;
240
+ }): void {
241
+ const pending = pendingWsSockets.get(payload.requestId);
242
+ if (!pending) return;
243
+ const buf = Buffer.from(payload.data, "base64");
244
+ pending.ws.send(payload.isBinary ? buf : buf.toString("utf8"));
245
+ }
246
+
247
+ export function handleTunnelWsClose(payload: {
248
+ requestId: string;
249
+ code?: number;
250
+ reason?: string;
251
+ }): void {
252
+ const pending = pendingWsSockets.get(payload.requestId);
253
+ if (!pending) return;
254
+ pending.ws.close(payload.code ?? 1000, payload.reason ?? "");
255
+ pendingWsSockets.delete(payload.requestId);
256
+ }
257
+
258
+ export function registerTunnelWs(requestId: string, ws: WebSocket): void {
259
+ pendingWsSockets.set(requestId, { ws });
260
+ }
261
+
262
+ export function removeTunnelWs(requestId: string): void {
263
+ pendingWsSockets.delete(requestId);
264
+ }
265
+
266
+ export function cleanupSessionTunnels(sessionId: string): void {
267
+ const requestIds = sessionRequests.get(sessionId);
268
+ if (!requestIds) return;
269
+ for (const rid of requestIds) {
270
+ const pending = pendingRequests.get(rid);
271
+ if (pending) {
272
+ clearTimeout(pending.timeout);
273
+ errorResponse(pending.res, 502, "Host disconnected");
274
+ pendingRequests.delete(rid);
275
+ }
276
+ const ws = pendingWsSockets.get(rid);
277
+ if (ws) {
278
+ ws.ws.close(1001, "Host disconnected");
279
+ pendingWsSockets.delete(rid);
280
+ }
281
+ }
282
+ sessionRequests.delete(sessionId);
283
+ }
284
+
285
+ export function handleTunnelWsUpgrade(
286
+ ws: WebSocket,
287
+ parsed: { sessionId: string; port: number; path: string },
288
+ url: URL,
289
+ sessions: SessionManager,
290
+ tokens: TokenManager,
291
+ ): void {
292
+ const { sessionId, port, path } = parsed;
293
+
294
+ // Auth from query param or cookie in upgrade request
295
+ const token = url.searchParams.get("token");
296
+ if (!token || !tokens.owns(token, sessionId)) {
297
+ ws.close(4001, "Unauthorized");
298
+ return;
299
+ }
300
+
301
+ const session = sessions.get(sessionId);
302
+ if (!session || !session.host || session.host.socket.readyState !== session.host.socket.OPEN) {
303
+ ws.close(4002, "Host not connected");
304
+ return;
305
+ }
306
+
307
+ const requestId = randomUUID();
308
+ const fullUrl = path + (url.search || "");
309
+
310
+ // Register this WS so host responses route here
311
+ registerTunnelWs(requestId, ws);
312
+ trackRequest(sessionId, requestId);
313
+
314
+ // Send tunnel.request with upgrade header to host
315
+ const envelope = createEnvelope({
316
+ type: "tunnel.request",
317
+ sessionId,
318
+ payload: {
319
+ requestId,
320
+ method: "GET",
321
+ url: fullUrl,
322
+ headers: { "upgrade": "websocket" },
323
+ body: null,
324
+ port,
325
+ },
326
+ });
327
+ session.host.socket.send(serializeEnvelope(envelope));
328
+
329
+ // Forward data from browser WS to host
330
+ ws.on("message", (data: Buffer | string) => {
331
+ const s = sessions.get(sessionId);
332
+ if (!s?.host || s.host.socket.readyState !== s.host.socket.OPEN) return;
333
+ const isBinary = typeof data !== "string";
334
+ const buf = typeof data === "string" ? Buffer.from(data) : data;
335
+ const fwd = createEnvelope({
336
+ type: "tunnel.ws.data",
337
+ sessionId,
338
+ payload: {
339
+ requestId,
340
+ data: buf.toString("base64"),
341
+ isBinary,
342
+ },
343
+ });
344
+ s.host.socket.send(serializeEnvelope(fwd));
345
+ });
346
+
347
+ ws.on("close", (code, reason) => {
348
+ removeTunnelWs(requestId);
349
+ untrackRequest(sessionId, requestId);
350
+ const s = sessions.get(sessionId);
351
+ if (!s?.host || s.host.socket.readyState !== s.host.socket.OPEN) return;
352
+ const fwd = createEnvelope({
353
+ type: "tunnel.ws.close",
354
+ sessionId,
355
+ payload: {
356
+ requestId,
357
+ code,
358
+ reason: reason?.toString() || "",
359
+ },
360
+ });
361
+ s.host.socket.send(serializeEnvelope(fwd));
362
+ });
363
+ }