openlattice-ssh 0.0.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,559 @@
1
+ import { readFileSync } from "fs";
2
+ import { Client as SSHClient } from "ssh2";
3
+ import type { ClientChannel, SFTPWrapper } from "ssh2";
4
+ import type {
5
+ ComputeProvider,
6
+ ComputeSpec,
7
+ ExecOpts,
8
+ ExecResult,
9
+ ExtensionMap,
10
+ FileEntry,
11
+ FileExtension,
12
+ HealthStatus,
13
+ NetworkExtension,
14
+ ProviderCapabilities,
15
+ ProviderNode,
16
+ ProviderNodeStatus,
17
+ } from "openlattice";
18
+ import type { SSHProviderConfig, SSHHostConfig } from "./config";
19
+
20
+ /** Internal state for a provisioned SSH node. */
21
+ interface NodeState {
22
+ host: SSHHostConfig;
23
+ workdir: string;
24
+ pid?: number;
25
+ startedAt: Date;
26
+ /** Original command for restart support. */
27
+ command?: string;
28
+ /** Original env prefix for restart support. */
29
+ envPrefix?: string;
30
+ }
31
+
32
+ let nextSessionId = 1;
33
+
34
+ export class SSHProvider implements ComputeProvider {
35
+ readonly name = "ssh";
36
+ readonly capabilities: ProviderCapabilities;
37
+ private readonly config: SSHProviderConfig;
38
+ private readonly connections = new Map<string, SSHClient>();
39
+ private readonly nodes = new Map<string, NodeState>();
40
+ private roundRobinIdx = 0;
41
+
42
+ constructor(config: SSHProviderConfig) {
43
+ if (!config.hosts || config.hosts.length === 0) {
44
+ throw new Error("[ssh] at least one host must be configured");
45
+ }
46
+ // Load private key from file path if defaultKeyPath is set and no key content
47
+ if (config.defaultKeyPath && !config.defaultPrivateKey) {
48
+ config.defaultPrivateKey = readFileSync(config.defaultKeyPath);
49
+ }
50
+ this.config = config;
51
+ const hasGpu = config.hosts.some((h) => h.gpuAvailable);
52
+ this.capabilities = {
53
+ restart: true,
54
+ pause: false, // SIGSTOP unreliable via SSH
55
+ snapshot: false,
56
+ gpu: hasGpu,
57
+ logs: false,
58
+ tailscale: true,
59
+ coldStartMs: 1000,
60
+ maxConcurrent: 0,
61
+ architectures: ["x86_64", "arm64"],
62
+ persistentStorage: true,
63
+ };
64
+ }
65
+
66
+ // ── Required methods ────────────────────────────────────────────
67
+
68
+ async provision(spec: ComputeSpec): Promise<ProviderNode> {
69
+ const host = this.selectHost(spec);
70
+ const conn = await this.getConnection(host);
71
+ const sessionId = `ssh-${nextSessionId++}`;
72
+ const hostKey = hostKeyOf(host);
73
+ const externalId = `${hostKey}/${sessionId}`;
74
+ const workdir = `/tmp/openlattice/${sessionId}`;
75
+
76
+ // Create working directory and validate connectivity
77
+ await this.sshExec(conn, `mkdir -p ${workdir}`);
78
+
79
+ // Store node state
80
+ const state: NodeState = {
81
+ host,
82
+ workdir,
83
+ startedAt: new Date(),
84
+ };
85
+ this.nodes.set(externalId, state);
86
+
87
+ // Ensure Tailscale is installed and running if requested (authKey implies tailscale)
88
+ if (spec.network?.tailscale || spec.network?.tailscaleAuthKey) {
89
+ await this.ensureTailscale(conn);
90
+ if (spec.network?.tailscaleAuthKey) {
91
+ await this.tailscaleUp(conn, spec.network.tailscaleAuthKey);
92
+ }
93
+ }
94
+
95
+ // Run initial command in background if specified
96
+ if (spec.runtime.command && spec.runtime.command.length > 0) {
97
+ const cmd = spec.runtime.command.join(" ");
98
+ const envPrefix = buildEnvPrefix(spec.runtime.env);
99
+ state.command = cmd;
100
+ state.envPrefix = envPrefix;
101
+ const result = await this.sshExec(
102
+ conn,
103
+ `cd ${workdir} && ${envPrefix}nohup ${cmd} > ${workdir}/.stdout 2> ${workdir}/.stderr & echo $!`
104
+ );
105
+ const pid = parseInt(result.stdout.trim(), 10);
106
+ if (!isNaN(pid)) {
107
+ state.pid = pid;
108
+ }
109
+ }
110
+
111
+ return {
112
+ externalId,
113
+ endpoints: [
114
+ {
115
+ type: "ssh",
116
+ host: host.host,
117
+ port: host.port ?? 22,
118
+ },
119
+ ],
120
+ metadata: { hostKey, sessionId, workdir },
121
+ };
122
+ }
123
+
124
+ async exec(
125
+ externalId: string,
126
+ command: string[],
127
+ opts?: ExecOpts
128
+ ): Promise<ExecResult> {
129
+ const state = this.getNodeState(externalId);
130
+ const conn = await this.getConnection(state.host);
131
+ const cmdStr = command.join(" ");
132
+ const cwd = opts?.cwd ?? state.workdir;
133
+ const envPrefix = buildEnvPrefix(opts?.env);
134
+ const fullCmd = `cd ${cwd} && ${envPrefix}${cmdStr}`;
135
+
136
+ return this.sshExec(conn, fullCmd, opts);
137
+ }
138
+
139
+ async destroy(externalId: string): Promise<void> {
140
+ const state = this.nodes.get(externalId);
141
+ if (!state) return; // idempotent
142
+
143
+ try {
144
+ const conn = await this.getConnection(state.host);
145
+
146
+ // Kill managed process if any
147
+ if (state.pid) {
148
+ await this.sshExec(conn, `kill ${state.pid} 2>/dev/null || true`);
149
+ }
150
+
151
+ // Remove working directory
152
+ await this.sshExec(conn, `rm -rf ${state.workdir}`);
153
+ } catch {
154
+ // Best-effort cleanup
155
+ }
156
+
157
+ this.nodes.delete(externalId);
158
+ }
159
+
160
+ async inspect(externalId: string): Promise<ProviderNodeStatus> {
161
+ const state = this.nodes.get(externalId);
162
+ if (!state) {
163
+ return { status: "terminated" };
164
+ }
165
+
166
+ try {
167
+ const conn = await this.getConnection(state.host);
168
+
169
+ if (state.pid) {
170
+ // Check if process is still running
171
+ const result = await this.sshExec(
172
+ conn,
173
+ `kill -0 ${state.pid} 2>/dev/null && echo running || echo stopped`
174
+ );
175
+ const status = result.stdout.trim() === "running" ? "running" : "stopped";
176
+ return {
177
+ status,
178
+ startedAt: state.startedAt,
179
+ };
180
+ }
181
+
182
+ // No PID tracked — just check SSH connectivity
183
+ await this.sshExec(conn, "echo ok");
184
+ return {
185
+ status: "running",
186
+ startedAt: state.startedAt,
187
+ };
188
+ } catch {
189
+ return { status: "unknown" };
190
+ }
191
+ }
192
+
193
+ // ── Optional: stop / start ──────────────────────────────────────
194
+
195
+ async stop(externalId: string): Promise<void> {
196
+ const state = this.getNodeState(externalId);
197
+ if (!state.pid) return;
198
+
199
+ const conn = await this.getConnection(state.host);
200
+ await this.sshExec(conn, `kill ${state.pid} 2>/dev/null || true`);
201
+ state.pid = undefined;
202
+ }
203
+
204
+ async start(externalId: string): Promise<void> {
205
+ const state = this.getNodeState(externalId);
206
+ const conn = await this.getConnection(state.host);
207
+
208
+ // Re-run the original command if one was stored
209
+ if (state.command) {
210
+ const envPrefix = state.envPrefix ?? "";
211
+ const result = await this.sshExec(
212
+ conn,
213
+ `cd ${state.workdir} && ${envPrefix}nohup ${state.command} > ${state.workdir}/.stdout 2> ${state.workdir}/.stderr & echo $!`
214
+ );
215
+ const pid = parseInt(result.stdout.trim(), 10);
216
+ if (!isNaN(pid)) {
217
+ state.pid = pid;
218
+ }
219
+ } else {
220
+ // No command to restart, just verify connectivity
221
+ await this.sshExec(conn, "echo ok");
222
+ }
223
+ }
224
+
225
+ // ── Optional: healthCheck ───────────────────────────────────────
226
+
227
+ async healthCheck(): Promise<HealthStatus> {
228
+ const start = Date.now();
229
+ const results: Array<{ host: string; ok: boolean; error?: string }> = [];
230
+
231
+ for (const host of this.config.hosts) {
232
+ try {
233
+ const conn = await this.getConnection(host);
234
+ await this.sshExec(conn, "echo ok");
235
+ results.push({ host: host.host, ok: true });
236
+ } catch (err: unknown) {
237
+ results.push({
238
+ host: host.host,
239
+ ok: false,
240
+ error: err instanceof Error ? err.message : String(err),
241
+ });
242
+ }
243
+ }
244
+
245
+ const allHealthy = results.every((r) => r.ok);
246
+ const unhealthy = results.filter((r) => !r.ok);
247
+
248
+ return {
249
+ healthy: allHealthy,
250
+ latencyMs: Date.now() - start,
251
+ message: allHealthy
252
+ ? undefined
253
+ : `Unhealthy hosts: ${unhealthy.map((r) => `${r.host} (${r.error})`).join(", ")}`,
254
+ };
255
+ }
256
+
257
+ // ── Optional: extensions ────────────────────────────────────────
258
+
259
+ getExtension<K extends keyof ExtensionMap>(
260
+ externalId: string,
261
+ extension: K
262
+ ): ExtensionMap[K] | undefined {
263
+ if (extension === "files") {
264
+ return this.createFileExtension(externalId) as ExtensionMap[K];
265
+ }
266
+ if (extension === "network") {
267
+ return this.createNetworkExtension(externalId) as ExtensionMap[K];
268
+ }
269
+ return undefined;
270
+ }
271
+
272
+ /** Close all SSH connections. Call when shutting down the provider. */
273
+ async close(): Promise<void> {
274
+ for (const conn of this.connections.values()) {
275
+ conn.end();
276
+ }
277
+ this.connections.clear();
278
+ this.nodes.clear();
279
+ }
280
+
281
+ // ── Private helpers ─────────────────────────────────────────────
282
+
283
+ private createNetworkExtension(externalId: string): NetworkExtension {
284
+ const state = this.getNodeState(externalId);
285
+ const host = state.host.host;
286
+ return {
287
+ async getUrl(port: number): Promise<string> {
288
+ return `http://${host}:${port}`;
289
+ },
290
+ };
291
+ }
292
+
293
+ private selectHost(spec: ComputeSpec): SSHHostConfig {
294
+ let candidates = [...this.config.hosts];
295
+
296
+ // Filter by GPU if requested
297
+ if (spec.gpu && spec.gpu.count > 0) {
298
+ candidates = candidates.filter((h) => h.gpuAvailable);
299
+ if (candidates.length === 0) {
300
+ throw new Error("[ssh] no hosts with GPU available");
301
+ }
302
+ }
303
+
304
+ // Filter by labels if specified
305
+ if (spec.labels) {
306
+ candidates = candidates.filter((h) => {
307
+ if (!h.labels) return false;
308
+ return Object.entries(spec.labels!).every(
309
+ ([k, v]) => h.labels![k] === v
310
+ );
311
+ });
312
+ if (candidates.length === 0) {
313
+ throw new Error("[ssh] no hosts matching labels");
314
+ }
315
+ }
316
+
317
+ if (candidates.length === 0) {
318
+ throw new Error("[ssh] no hosts available");
319
+ }
320
+
321
+ // Round-robin selection
322
+ const host = candidates[this.roundRobinIdx % candidates.length];
323
+ this.roundRobinIdx++;
324
+ return host;
325
+ }
326
+
327
+ private getNodeState(externalId: string): NodeState {
328
+ const state = this.nodes.get(externalId);
329
+ if (!state) {
330
+ throw new Error(`[ssh] node not found: ${externalId}`);
331
+ }
332
+ return state;
333
+ }
334
+
335
+ private async getConnection(host: SSHHostConfig): Promise<SSHClient> {
336
+ const key = hostKeyOf(host);
337
+ const existing = this.connections.get(key);
338
+ if (existing) return existing;
339
+
340
+ const conn = new SSHClient();
341
+
342
+ return new Promise<SSHClient>((resolve, reject) => {
343
+ conn.on("ready", () => {
344
+ this.connections.set(key, conn);
345
+ resolve(conn);
346
+ });
347
+ conn.on("error", (err) => {
348
+ this.connections.delete(key);
349
+ reject(new Error(`[ssh] connection error (${key}): ${err.message}`));
350
+ });
351
+ conn.on("close", () => {
352
+ this.connections.delete(key);
353
+ });
354
+
355
+ conn.connect({
356
+ host: host.host,
357
+ port: host.port ?? 22,
358
+ username: host.username ?? this.config.defaultUser,
359
+ privateKey: host.privateKey ?? this.config.defaultPrivateKey,
360
+ password: host.password,
361
+ readyTimeout: this.config.connectTimeoutMs ?? 20_000,
362
+ keepaliveInterval: this.config.keepaliveIntervalMs ?? 10_000,
363
+ keepaliveCountMax: 3,
364
+ });
365
+ });
366
+ }
367
+
368
+ private async ensureTailscale(conn: SSHClient): Promise<void> {
369
+ // Check if tailscale is installed
370
+ const check = await this.sshExec(conn, "which tailscale 2>/dev/null");
371
+ if (check.exitCode !== 0) {
372
+ // Install Tailscale via official install script
373
+ const install = await this.sshExec(
374
+ conn,
375
+ "curl -fsSL https://tailscale.com/install.sh | sh",
376
+ { timeoutMs: 120_000 }
377
+ );
378
+ if (install.exitCode !== 0) {
379
+ throw new Error(
380
+ `[ssh] failed to install tailscale: ${install.stderr.trim()}`
381
+ );
382
+ }
383
+ }
384
+
385
+ // Ensure tailscaled is running
386
+ const daemonCheck = await this.sshExec(
387
+ conn,
388
+ "pgrep tailscaled >/dev/null 2>&1 || sudo tailscaled --state=/var/lib/tailscale/tailscaled.state &",
389
+ { timeoutMs: 10_000 }
390
+ );
391
+ if (daemonCheck.exitCode !== 0) {
392
+ // Try systemd as fallback
393
+ await this.sshExec(conn, "sudo systemctl start tailscaled 2>/dev/null || true", {
394
+ timeoutMs: 10_000,
395
+ });
396
+ }
397
+ }
398
+
399
+ private async tailscaleUp(conn: SSHClient, authKey: string): Promise<void> {
400
+ const result = await this.sshExec(
401
+ conn,
402
+ `sudo tailscale up --authkey=${authKey}`,
403
+ { timeoutMs: 30_000 }
404
+ );
405
+ if (result.exitCode !== 0) {
406
+ throw new Error(`[ssh] tailscale up failed: ${result.stderr.trim()}`);
407
+ }
408
+ }
409
+
410
+ private sshExec(
411
+ conn: SSHClient,
412
+ command: string,
413
+ opts?: ExecOpts
414
+ ): Promise<ExecResult> {
415
+ return new Promise<ExecResult>((resolve, reject) => {
416
+ let settled = false;
417
+ let timer: ReturnType<typeof setTimeout> | undefined;
418
+
419
+ const settle = (fn: () => void) => {
420
+ if (settled) return;
421
+ settled = true;
422
+ if (timer) clearTimeout(timer);
423
+ fn();
424
+ };
425
+
426
+ conn.exec(command, (err: Error | undefined, stream: ClientChannel) => {
427
+ if (err) {
428
+ return settle(() => reject(new Error(`[ssh] exec failed: ${err.message}`)));
429
+ }
430
+
431
+ let stdout = "";
432
+ let stderr = "";
433
+
434
+ // Enforce timeout
435
+ if (opts?.timeoutMs && opts.timeoutMs > 0) {
436
+ timer = setTimeout(() => {
437
+ stream.close();
438
+ settle(() =>
439
+ resolve({
440
+ exitCode: 124, // Conventional timeout exit code
441
+ stdout,
442
+ stderr: stderr + `\n[ssh] command timed out after ${opts.timeoutMs}ms`,
443
+ })
444
+ );
445
+ }, opts.timeoutMs);
446
+ }
447
+
448
+ stream.on("data", (data: Buffer) => {
449
+ const str = data.toString();
450
+ stdout += str;
451
+ opts?.onStdout?.(str);
452
+ });
453
+
454
+ stream.stderr.on("data", (data: Buffer) => {
455
+ const str = data.toString();
456
+ stderr += str;
457
+ opts?.onStderr?.(str);
458
+ });
459
+
460
+ stream.on("close", (code: number | null) => {
461
+ settle(() => resolve({ exitCode: code ?? 1, stdout, stderr }));
462
+ });
463
+ });
464
+ });
465
+ }
466
+
467
+ private getSftp(conn: SSHClient): Promise<SFTPWrapper> {
468
+ return new Promise((resolve, reject) => {
469
+ conn.sftp((err, sftp) => {
470
+ if (err) return reject(new Error(`[ssh] SFTP failed: ${err.message}`));
471
+ resolve(sftp);
472
+ });
473
+ });
474
+ }
475
+
476
+ private createFileExtension(externalId: string): FileExtension {
477
+ const provider = this;
478
+ return {
479
+ async read(path: string): Promise<string | Buffer> {
480
+ const state = provider.getNodeState(externalId);
481
+ const conn = await provider.getConnection(state.host);
482
+ const sftp = await provider.getSftp(conn);
483
+ return new Promise((resolve, reject) => {
484
+ sftp.readFile(path, (err, data) => {
485
+ if (err) return reject(err);
486
+ resolve(data.toString());
487
+ });
488
+ });
489
+ },
490
+ async write(path: string, content: string | Buffer): Promise<void> {
491
+ const state = provider.getNodeState(externalId);
492
+ const conn = await provider.getConnection(state.host);
493
+ const sftp = await provider.getSftp(conn);
494
+ return new Promise((resolve, reject) => {
495
+ sftp.writeFile(
496
+ path,
497
+ typeof content === "string" ? content : content,
498
+ (err) => {
499
+ if (err) return reject(err);
500
+ resolve();
501
+ }
502
+ );
503
+ });
504
+ },
505
+ async list(dirPath: string): Promise<FileEntry[]> {
506
+ const state = provider.getNodeState(externalId);
507
+ const conn = await provider.getConnection(state.host);
508
+ const sftp = await provider.getSftp(conn);
509
+ return new Promise((resolve, reject) => {
510
+ sftp.readdir(dirPath, (err, list) => {
511
+ if (err) return reject(err);
512
+ resolve(
513
+ list.map((entry) => ({
514
+ name: entry.filename,
515
+ path: `${dirPath}/${entry.filename}`,
516
+ type: (entry.longname.startsWith("d")
517
+ ? "directory"
518
+ : "file") as "file" | "directory",
519
+ size: entry.attrs.size,
520
+ }))
521
+ );
522
+ });
523
+ });
524
+ },
525
+ async remove(path: string): Promise<void> {
526
+ const state = provider.getNodeState(externalId);
527
+ const conn = await provider.getConnection(state.host);
528
+ const sftp = await provider.getSftp(conn);
529
+ return new Promise((resolve, reject) => {
530
+ sftp.unlink(path, (err) => {
531
+ if (err) return reject(err);
532
+ resolve();
533
+ });
534
+ });
535
+ },
536
+ async mkdir(dirPath: string): Promise<void> {
537
+ const state = provider.getNodeState(externalId);
538
+ const conn = await provider.getConnection(state.host);
539
+ // Use exec with mkdir -p for recursive directory creation
540
+ await provider.sshExec(conn, `mkdir -p '${dirPath.replace(/'/g, "'\\''")}'`);
541
+ },
542
+ };
543
+ }
544
+ }
545
+
546
+ // ── Utility functions ───────────────────────────────────────────────
547
+
548
+ function hostKeyOf(host: SSHHostConfig): string {
549
+ return `${host.host}:${host.port ?? 22}`;
550
+ }
551
+
552
+ function buildEnvPrefix(env?: Record<string, string>): string {
553
+ if (!env) return "";
554
+ return (
555
+ Object.entries(env)
556
+ .map(([k, v]) => `${k}='${v.replace(/'/g, "'\\''")}'`)
557
+ .join(" ") + " "
558
+ );
559
+ }
@@ -0,0 +1,28 @@
1
+ import { describe, it, expect, beforeEach, afterEach } from "vitest";
2
+ import { runProviderConformanceTests } from "openlattice/testing";
3
+ import { SSHProvider } from "../src/ssh-provider";
4
+
5
+ const HAS_SSH = process.env.TEST_SSH === "1";
6
+
7
+ describe.skipIf(!HAS_SSH)("SSHProvider conformance", () => {
8
+ runProviderConformanceTests(
9
+ {
10
+ createProvider: () =>
11
+ new SSHProvider({
12
+ hosts: [
13
+ {
14
+ host: process.env.SSH_HOST ?? "localhost",
15
+ port: parseInt(process.env.SSH_PORT ?? "22", 10),
16
+ username: process.env.SSH_USER ?? "root",
17
+ privateKey: process.env.SSH_KEY,
18
+ },
19
+ ],
20
+ }),
21
+ createSpec: () => ({
22
+ runtime: { image: "any" },
23
+ }),
24
+ timeoutMs: 30_000,
25
+ },
26
+ { describe, it, expect, beforeEach, afterEach }
27
+ );
28
+ });