@rivetkit/engine-runner 2.0.21

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/src/tunnel.ts ADDED
@@ -0,0 +1,613 @@
1
+ import type * as protocol from "@rivetkit/engine-runner-protocol";
2
+ import type { MessageId, RequestId } from "@rivetkit/engine-runner-protocol";
3
+ import { v4 as uuidv4 } from "uuid";
4
+ import { logger } from "./log";
5
+ import type { ActorInstance, Runner } from "./mod";
6
+ import { unreachable } from "./utils";
7
+ import { WebSocketTunnelAdapter } from "./websocket-tunnel-adapter";
8
+
9
+ const GC_INTERVAL = 60000; // 60 seconds
10
+ const MESSAGE_ACK_TIMEOUT = 5000; // 5 seconds
11
+
12
+ interface PendingRequest {
13
+ resolve: (response: Response) => void;
14
+ reject: (error: Error) => void;
15
+ streamController?: ReadableStreamDefaultController<Uint8Array>;
16
+ actorId?: string;
17
+ }
18
+
19
+ interface PendingTunnelMessage {
20
+ sentAt: number;
21
+ requestIdStr: string;
22
+ }
23
+
24
+ export class Tunnel {
25
+ #runner: Runner;
26
+
27
+ /** Requests over the tunnel to the actor that are in flight. */
28
+ #actorPendingRequests: Map<string, PendingRequest> = new Map();
29
+ /** WebSockets over the tunnel to the actor that are in flight. */
30
+ #actorWebSockets: Map<string, WebSocketTunnelAdapter> = new Map();
31
+
32
+ /** Messages sent from the actor over the tunnel that have not been acked by the gateway. */
33
+ #pendingTunnelMessages: Map<string, PendingTunnelMessage> = new Map();
34
+
35
+ #gcInterval?: NodeJS.Timeout;
36
+
37
+ constructor(runner: Runner) {
38
+ this.#runner = runner;
39
+ }
40
+
41
+ start(): void {
42
+ this.#startGarbageCollector();
43
+ }
44
+
45
+ shutdown() {
46
+ if (this.#gcInterval) {
47
+ clearInterval(this.#gcInterval);
48
+ this.#gcInterval = undefined;
49
+ }
50
+
51
+ // Reject all pending requests
52
+ for (const [_, request] of this.#actorPendingRequests) {
53
+ request.reject(new Error("Tunnel shutting down"));
54
+ }
55
+ this.#actorPendingRequests.clear();
56
+
57
+ // Close all WebSockets
58
+ for (const [_, ws] of this.#actorWebSockets) {
59
+ ws.close();
60
+ }
61
+ this.#actorWebSockets.clear();
62
+ }
63
+
64
+ #sendMessage(
65
+ requestId: RequestId,
66
+ messageKind: protocol.ToServerTunnelMessageKind,
67
+ ) {
68
+ // TODO: Switch this with runner WS
69
+ if (!this.#runner.__webSocketReady()) {
70
+ logger()?.warn(
71
+ "cannot send tunnel message, socket not connected to engine",
72
+ );
73
+ return;
74
+ }
75
+
76
+ // Build message
77
+ const messageId = generateUuidBuffer();
78
+
79
+ const requestIdStr = bufferToString(requestId);
80
+ this.#pendingTunnelMessages.set(bufferToString(messageId), {
81
+ sentAt: Date.now(),
82
+ requestIdStr,
83
+ });
84
+
85
+ // Send message
86
+ const message: protocol.ToServer = {
87
+ tag: "ToServerTunnelMessage",
88
+ val: {
89
+ requestId,
90
+ messageId,
91
+ messageKind,
92
+ },
93
+ };
94
+ this.#runner.__sendToServer(message);
95
+ }
96
+
97
+ #sendAck(requestId: RequestId, messageId: MessageId) {
98
+ if (!this.#runner.__webSocketReady()) {
99
+ return;
100
+ }
101
+
102
+ const message: protocol.ToServer = {
103
+ tag: "ToServerTunnelMessage",
104
+ val: {
105
+ requestId,
106
+ messageId,
107
+ messageKind: { tag: "TunnelAck", val: null },
108
+ },
109
+ };
110
+
111
+ this.#runner.__sendToServer(message);
112
+ }
113
+
114
+ #startGarbageCollector() {
115
+ if (this.#gcInterval) {
116
+ clearInterval(this.#gcInterval);
117
+ }
118
+
119
+ this.#gcInterval = setInterval(() => {
120
+ this.#gc();
121
+ }, GC_INTERVAL);
122
+ }
123
+
124
+ #gc() {
125
+ const now = Date.now();
126
+ const messagesToDelete: string[] = [];
127
+
128
+ for (const [messageId, pendingMessage] of this.#pendingTunnelMessages) {
129
+ // Check if message is older than timeout
130
+ if (now - pendingMessage.sentAt > MESSAGE_ACK_TIMEOUT) {
131
+ messagesToDelete.push(messageId);
132
+
133
+ const requestIdStr = pendingMessage.requestIdStr;
134
+
135
+ // Check if this is an HTTP request
136
+ const pendingRequest =
137
+ this.#actorPendingRequests.get(requestIdStr);
138
+ if (pendingRequest) {
139
+ // Reject the pending HTTP request
140
+ pendingRequest.reject(
141
+ new Error("Message acknowledgment timeout"),
142
+ );
143
+
144
+ // Close stream controller if it exists
145
+ if (pendingRequest.streamController) {
146
+ pendingRequest.streamController.error(
147
+ new Error("Message acknowledgment timeout"),
148
+ );
149
+ }
150
+
151
+ // Clean up from actorPendingRequests map
152
+ this.#actorPendingRequests.delete(requestIdStr);
153
+ }
154
+
155
+ // Check if this is a WebSocket
156
+ const webSocket = this.#actorWebSockets.get(requestIdStr);
157
+ if (webSocket) {
158
+ // Close the WebSocket connection
159
+ webSocket.close(1000, "Message acknowledgment timeout");
160
+
161
+ // Clean up from actorWebSockets map
162
+ this.#actorWebSockets.delete(requestIdStr);
163
+ }
164
+ }
165
+ }
166
+
167
+ // Remove timed out messages
168
+ if (messagesToDelete.length > 0) {
169
+ logger()?.warn({
170
+ msg: "purging unacked tunnel messages, this indicates that the Rivet Engine is disconnected or not responding",
171
+ count: messagesToDelete.length,
172
+ });
173
+ for (const messageId of messagesToDelete) {
174
+ this.#pendingTunnelMessages.delete(messageId);
175
+ }
176
+ }
177
+ }
178
+
179
+ unregisterActor(actor: ActorInstance) {
180
+ const actorId = actor.actorId;
181
+
182
+ // Terminate all requests for this actor
183
+ for (const requestId of actor.requests) {
184
+ const pending = this.#actorPendingRequests.get(requestId);
185
+ if (pending) {
186
+ pending.reject(new Error(`Actor ${actorId} stopped`));
187
+ this.#actorPendingRequests.delete(requestId);
188
+ }
189
+ }
190
+ actor.requests.clear();
191
+
192
+ // Close all WebSockets for this actor
193
+ for (const webSocketId of actor.webSockets) {
194
+ const ws = this.#actorWebSockets.get(webSocketId);
195
+ if (ws) {
196
+ ws.close(1000, "Actor stopped");
197
+ this.#actorWebSockets.delete(webSocketId);
198
+ }
199
+ }
200
+ actor.webSockets.clear();
201
+ }
202
+
203
+ async #fetch(actorId: string, request: Request): Promise<Response> {
204
+ // Validate actor exists
205
+ if (!this.#runner.hasActor(actorId)) {
206
+ logger()?.warn({
207
+ msg: "ignoring request for unknown actor",
208
+ actorId,
209
+ });
210
+ return new Response("Actor not found", { status: 404 });
211
+ }
212
+
213
+ const fetchHandler = this.#runner.config.fetch(
214
+ this.#runner,
215
+ actorId,
216
+ request,
217
+ );
218
+
219
+ if (!fetchHandler) {
220
+ return new Response("Not Implemented", { status: 501 });
221
+ }
222
+
223
+ return fetchHandler;
224
+ }
225
+
226
+ async handleTunnelMessage(message: protocol.ToClientTunnelMessage) {
227
+ if (message.messageKind.tag === "TunnelAck") {
228
+ // Mark pending message as acknowledged and remove it
229
+ const msgIdStr = bufferToString(message.messageId);
230
+ const pending = this.#pendingTunnelMessages.get(msgIdStr);
231
+ if (pending) {
232
+ this.#pendingTunnelMessages.delete(msgIdStr);
233
+ }
234
+ } else {
235
+ this.#sendAck(message.requestId, message.messageId);
236
+ switch (message.messageKind.tag) {
237
+ case "ToClientRequestStart":
238
+ await this.#handleRequestStart(
239
+ message.requestId,
240
+ message.messageKind.val,
241
+ );
242
+ break;
243
+ case "ToClientRequestChunk":
244
+ await this.#handleRequestChunk(
245
+ message.requestId,
246
+ message.messageKind.val,
247
+ );
248
+ break;
249
+ case "ToClientRequestAbort":
250
+ await this.#handleRequestAbort(message.requestId);
251
+ break;
252
+ case "ToClientWebSocketOpen":
253
+ await this.#handleWebSocketOpen(
254
+ message.requestId,
255
+ message.messageKind.val,
256
+ );
257
+ break;
258
+ case "ToClientWebSocketMessage":
259
+ await this.#handleWebSocketMessage(
260
+ message.requestId,
261
+ message.messageKind.val,
262
+ );
263
+ break;
264
+ case "ToClientWebSocketClose":
265
+ await this.#handleWebSocketClose(
266
+ message.requestId,
267
+ message.messageKind.val,
268
+ );
269
+ break;
270
+ default:
271
+ unreachable(message.messageKind);
272
+ }
273
+ }
274
+ }
275
+
276
+ async #handleRequestStart(
277
+ requestId: ArrayBuffer,
278
+ req: protocol.ToClientRequestStart,
279
+ ) {
280
+ // Track this request for the actor
281
+ const requestIdStr = bufferToString(requestId);
282
+ const actor = this.#runner.getActor(req.actorId);
283
+ if (actor) {
284
+ actor.requests.add(requestIdStr);
285
+ }
286
+
287
+ try {
288
+ // Convert headers map to Headers object
289
+ const headers = new Headers();
290
+ for (const [key, value] of req.headers) {
291
+ headers.append(key, value);
292
+ }
293
+
294
+ // Create Request object
295
+ const request = new Request(`http://localhost${req.path}`, {
296
+ method: req.method,
297
+ headers,
298
+ body: req.body ? new Uint8Array(req.body) : undefined,
299
+ });
300
+
301
+ // Handle streaming request
302
+ if (req.stream) {
303
+ // Create a stream for the request body
304
+ const stream = new ReadableStream<Uint8Array>({
305
+ start: (controller) => {
306
+ // Store controller for chunks
307
+ const existing =
308
+ this.#actorPendingRequests.get(requestIdStr);
309
+ if (existing) {
310
+ existing.streamController = controller;
311
+ existing.actorId = req.actorId;
312
+ } else {
313
+ this.#actorPendingRequests.set(requestIdStr, {
314
+ resolve: () => {},
315
+ reject: () => {},
316
+ streamController: controller,
317
+ actorId: req.actorId,
318
+ });
319
+ }
320
+ },
321
+ });
322
+
323
+ // Create request with streaming body
324
+ const streamingRequest = new Request(request, {
325
+ body: stream,
326
+ duplex: "half",
327
+ } as any);
328
+
329
+ // Call fetch handler with validation
330
+ const response = await this.#fetch(
331
+ req.actorId,
332
+ streamingRequest,
333
+ );
334
+ await this.#sendResponse(requestId, response);
335
+ } else {
336
+ // Non-streaming request
337
+ const response = await this.#fetch(req.actorId, request);
338
+ await this.#sendResponse(requestId, response);
339
+ }
340
+ } catch (error) {
341
+ logger()?.error({ msg: "error handling request", error });
342
+ this.#sendResponseError(requestId, 500, "Internal Server Error");
343
+ } finally {
344
+ // Clean up request tracking
345
+ const actor = this.#runner.getActor(req.actorId);
346
+ if (actor) {
347
+ actor.requests.delete(requestIdStr);
348
+ }
349
+ }
350
+ }
351
+
352
+ async #handleRequestChunk(
353
+ requestId: ArrayBuffer,
354
+ chunk: protocol.ToClientRequestChunk,
355
+ ) {
356
+ const requestIdStr = bufferToString(requestId);
357
+ const pending = this.#actorPendingRequests.get(requestIdStr);
358
+ if (pending?.streamController) {
359
+ pending.streamController.enqueue(new Uint8Array(chunk.body));
360
+ if (chunk.finish) {
361
+ pending.streamController.close();
362
+ this.#actorPendingRequests.delete(requestIdStr);
363
+ }
364
+ }
365
+ }
366
+
367
+ async #handleRequestAbort(requestId: ArrayBuffer) {
368
+ const requestIdStr = bufferToString(requestId);
369
+ const pending = this.#actorPendingRequests.get(requestIdStr);
370
+ if (pending?.streamController) {
371
+ pending.streamController.error(new Error("Request aborted"));
372
+ }
373
+ this.#actorPendingRequests.delete(requestIdStr);
374
+ }
375
+
376
+ async #sendResponse(requestId: ArrayBuffer, response: Response) {
377
+ // Always treat responses as non-streaming for now
378
+ // In the future, we could detect streaming responses based on:
379
+ // - Transfer-Encoding: chunked
380
+ // - Content-Type: text/event-stream
381
+ // - Explicit stream flag from the handler
382
+
383
+ // Read the body first to get the actual content
384
+ const body = response.body ? await response.arrayBuffer() : null;
385
+
386
+ // Convert headers to map and add Content-Length if not present
387
+ const headers = new Map<string, string>();
388
+ response.headers.forEach((value, key) => {
389
+ headers.set(key, value);
390
+ });
391
+
392
+ // Add Content-Length header if we have a body and it's not already set
393
+ if (body && !headers.has("content-length")) {
394
+ headers.set("content-length", String(body.byteLength));
395
+ }
396
+
397
+ // Send as non-streaming response
398
+ this.#sendMessage(requestId, {
399
+ tag: "ToServerResponseStart",
400
+ val: {
401
+ status: response.status as protocol.u16,
402
+ headers,
403
+ body: body || null,
404
+ stream: false,
405
+ },
406
+ });
407
+ }
408
+
409
+ #sendResponseError(
410
+ requestId: ArrayBuffer,
411
+ status: number,
412
+ message: string,
413
+ ) {
414
+ const headers = new Map<string, string>();
415
+ headers.set("content-type", "text/plain");
416
+
417
+ this.#sendMessage(requestId, {
418
+ tag: "ToServerResponseStart",
419
+ val: {
420
+ status: status as protocol.u16,
421
+ headers,
422
+ body: new TextEncoder().encode(message).buffer as ArrayBuffer,
423
+ stream: false,
424
+ },
425
+ });
426
+ }
427
+
428
+ async #handleWebSocketOpen(
429
+ requestId: ArrayBuffer,
430
+ open: protocol.ToClientWebSocketOpen,
431
+ ) {
432
+ const webSocketId = bufferToString(requestId);
433
+ // Validate actor exists
434
+ const actor = this.#runner.getActor(open.actorId);
435
+ if (!actor) {
436
+ logger()?.warn({
437
+ msg: "ignoring websocket for unknown actor",
438
+ actorId: open.actorId,
439
+ });
440
+ // Send close immediately
441
+ this.#sendMessage(requestId, {
442
+ tag: "ToServerWebSocketClose",
443
+ val: {
444
+ code: 1011,
445
+ reason: "Actor not found",
446
+ },
447
+ });
448
+ return;
449
+ }
450
+
451
+ const websocketHandler = this.#runner.config.websocket;
452
+
453
+ if (!websocketHandler) {
454
+ logger()?.error({
455
+ msg: "no websocket handler configured for tunnel",
456
+ });
457
+ // Send close immediately
458
+ this.#sendMessage(requestId, {
459
+ tag: "ToServerWebSocketClose",
460
+ val: {
461
+ code: 1011,
462
+ reason: "Not Implemented",
463
+ },
464
+ });
465
+ return;
466
+ }
467
+
468
+ // Track this WebSocket for the actor
469
+ if (actor) {
470
+ actor.webSockets.add(webSocketId);
471
+ }
472
+
473
+ try {
474
+ // Create WebSocket adapter
475
+ const adapter = new WebSocketTunnelAdapter(
476
+ webSocketId,
477
+ (data: ArrayBuffer | string, isBinary: boolean) => {
478
+ // Send message through tunnel
479
+ const dataBuffer =
480
+ typeof data === "string"
481
+ ? (new TextEncoder().encode(data)
482
+ .buffer as ArrayBuffer)
483
+ : data;
484
+
485
+ this.#sendMessage(requestId, {
486
+ tag: "ToServerWebSocketMessage",
487
+ val: {
488
+ data: dataBuffer,
489
+ binary: isBinary,
490
+ },
491
+ });
492
+ },
493
+ (code?: number, reason?: string) => {
494
+ // Send close through tunnel
495
+ this.#sendMessage(requestId, {
496
+ tag: "ToServerWebSocketClose",
497
+ val: {
498
+ code: code || null,
499
+ reason: reason || null,
500
+ },
501
+ });
502
+
503
+ // Remove from map
504
+ this.#actorWebSockets.delete(webSocketId);
505
+
506
+ // Clean up actor tracking
507
+ if (actor) {
508
+ actor.webSockets.delete(webSocketId);
509
+ }
510
+ },
511
+ );
512
+
513
+ // Store adapter
514
+ this.#actorWebSockets.set(webSocketId, adapter);
515
+
516
+ // Send open confirmation
517
+ this.#sendMessage(requestId, {
518
+ tag: "ToServerWebSocketOpen",
519
+ val: null,
520
+ });
521
+
522
+ // Notify adapter that connection is open
523
+ adapter._handleOpen();
524
+
525
+ // Create a minimal request object for the websocket handler
526
+ // Include original headers from the open message
527
+ const headerInit: Record<string, string> = {};
528
+ if (open.headers) {
529
+ for (const [k, v] of open.headers as ReadonlyMap<
530
+ string,
531
+ string
532
+ >) {
533
+ headerInit[k] = v;
534
+ }
535
+ }
536
+ // Ensure websocket upgrade headers are present
537
+ headerInit["Upgrade"] = "websocket";
538
+ headerInit["Connection"] = "Upgrade";
539
+
540
+ const request = new Request(`http://localhost${open.path}`, {
541
+ method: "GET",
542
+ headers: headerInit,
543
+ });
544
+
545
+ // Call websocket handler
546
+ await websocketHandler(
547
+ this.#runner,
548
+ open.actorId,
549
+ adapter,
550
+ request,
551
+ );
552
+ } catch (error) {
553
+ logger()?.error({ msg: "error handling websocket open", error });
554
+ // Send close on error
555
+ this.#sendMessage(requestId, {
556
+ tag: "ToServerWebSocketClose",
557
+ val: {
558
+ code: 1011,
559
+ reason: "Server Error",
560
+ },
561
+ });
562
+
563
+ this.#actorWebSockets.delete(webSocketId);
564
+
565
+ // Clean up actor tracking
566
+ if (actor) {
567
+ actor.webSockets.delete(webSocketId);
568
+ }
569
+ }
570
+ }
571
+
572
+ async #handleWebSocketMessage(
573
+ requestId: ArrayBuffer,
574
+ msg: protocol.ToServerWebSocketMessage,
575
+ ) {
576
+ const webSocketId = bufferToString(requestId);
577
+ const adapter = this.#actorWebSockets.get(webSocketId);
578
+ if (adapter) {
579
+ const data = msg.binary
580
+ ? new Uint8Array(msg.data)
581
+ : new TextDecoder().decode(new Uint8Array(msg.data));
582
+
583
+ adapter._handleMessage(data, msg.binary);
584
+ }
585
+ }
586
+
587
+ async #handleWebSocketClose(
588
+ requestId: ArrayBuffer,
589
+ close: protocol.ToServerWebSocketClose,
590
+ ) {
591
+ const webSocketId = bufferToString(requestId);
592
+ const adapter = this.#actorWebSockets.get(webSocketId);
593
+ if (adapter) {
594
+ adapter._handleClose(
595
+ close.code || undefined,
596
+ close.reason || undefined,
597
+ );
598
+ this.#actorWebSockets.delete(webSocketId);
599
+ }
600
+ }
601
+ }
602
+
603
+ /** Converts a buffer to a string. Used for storing strings in a lookup map. */
604
+ function bufferToString(buffer: ArrayBuffer): string {
605
+ return Buffer.from(buffer).toString("base64");
606
+ }
607
+
608
+ /** Generates a UUID as bytes. */
609
+ function generateUuidBuffer(): ArrayBuffer {
610
+ const buffer = new Uint8Array(16);
611
+ uuidv4(undefined, buffer);
612
+ return buffer.buffer;
613
+ }
package/src/utils.ts ADDED
@@ -0,0 +1,31 @@
1
+ export function unreachable(x: never): never {
2
+ throw `Unreachable: ${x}`;
3
+ }
4
+
5
+ export interface BackoffOptions {
6
+ initialDelay?: number;
7
+ maxDelay?: number;
8
+ multiplier?: number;
9
+ jitter?: boolean;
10
+ }
11
+
12
+ export function calculateBackoff(
13
+ attempt: number,
14
+ options: BackoffOptions = {},
15
+ ): number {
16
+ const {
17
+ initialDelay = 1000,
18
+ maxDelay = 30000,
19
+ multiplier = 2,
20
+ jitter = true,
21
+ } = options;
22
+
23
+ let delay = Math.min(initialDelay * multiplier ** attempt, maxDelay);
24
+
25
+ if (jitter) {
26
+ // Add random jitter between 0% and 25% of the delay
27
+ delay = delay * (1 + Math.random() * 0.25);
28
+ }
29
+
30
+ return Math.floor(delay);
31
+ }