wafer-cli 0.2.9__py3-none-any.whl → 0.2.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer/GUIDE.md +18 -7
- wafer/api_client.py +4 -0
- wafer/cli.py +1177 -278
- wafer/corpus.py +158 -32
- wafer/evaluate.py +75 -6
- wafer/kernel_scope.py +132 -31
- wafer/nsys_analyze.py +903 -73
- wafer/nsys_profile.py +511 -0
- wafer/output.py +241 -0
- wafer/skills/wafer-guide/SKILL.md +13 -0
- wafer/ssh_keys.py +261 -0
- wafer/targets_ops.py +718 -0
- wafer/wevin_cli.py +127 -18
- wafer/workspaces.py +232 -184
- {wafer_cli-0.2.9.dist-info → wafer_cli-0.2.11.dist-info}/METADATA +1 -1
- {wafer_cli-0.2.9.dist-info → wafer_cli-0.2.11.dist-info}/RECORD +19 -15
- {wafer_cli-0.2.9.dist-info → wafer_cli-0.2.11.dist-info}/WHEEL +0 -0
- {wafer_cli-0.2.9.dist-info → wafer_cli-0.2.11.dist-info}/entry_points.txt +0 -0
- {wafer_cli-0.2.9.dist-info → wafer_cli-0.2.11.dist-info}/top_level.txt +0 -0
wafer/output.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
1
|
+
"""Structured output formatting for CLI commands.
|
|
2
|
+
|
|
3
|
+
Provides JSON and JSONL output formats for machine-readable CLI output.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
import re
|
|
10
|
+
from dataclasses import asdict, dataclass, field
|
|
11
|
+
from datetime import UTC, datetime
|
|
12
|
+
from enum import StrEnum
|
|
13
|
+
from typing import TYPE_CHECKING, Any, Literal
|
|
14
|
+
|
|
15
|
+
import typer
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from .evaluate import EvaluateResult
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class OutputFormat(StrEnum):
|
|
22
|
+
"""Output format for CLI commands."""
|
|
23
|
+
|
|
24
|
+
TEXT = "text" # Human-readable (default)
|
|
25
|
+
JSON = "json" # Single JSON object at end
|
|
26
|
+
JSONL = "jsonl" # Streaming JSON Lines
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass
|
|
30
|
+
class EvalOutput:
|
|
31
|
+
"""Structured evaluation result for JSON output."""
|
|
32
|
+
|
|
33
|
+
status: Literal["success", "failure", "error"]
|
|
34
|
+
target: str | None = None
|
|
35
|
+
phase: str | None = None
|
|
36
|
+
correctness: dict[str, Any] | None = None
|
|
37
|
+
benchmark: dict[str, Any] | None = None
|
|
38
|
+
profile: dict[str, Any] | None = None
|
|
39
|
+
error: dict[str, Any] | None = None
|
|
40
|
+
raw_compiler_output: str | None = None
|
|
41
|
+
|
|
42
|
+
def to_json(self, indent: int | None = 2) -> str:
|
|
43
|
+
"""Serialize to JSON, excluding None values."""
|
|
44
|
+
data = {k: v for k, v in asdict(self).items() if v is not None}
|
|
45
|
+
return json.dumps(data, indent=indent, default=str)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclass
|
|
49
|
+
class OutputCollector:
|
|
50
|
+
"""Collects output events and formats them according to the output format."""
|
|
51
|
+
|
|
52
|
+
format: OutputFormat
|
|
53
|
+
target: str | None = None
|
|
54
|
+
_result: EvalOutput = field(default_factory=lambda: EvalOutput(status="success"))
|
|
55
|
+
|
|
56
|
+
def emit(self, event: str, **data: Any) -> None:
|
|
57
|
+
"""Emit an event.
|
|
58
|
+
|
|
59
|
+
For JSONL format, prints immediately. For TEXT, prints human-readable.
|
|
60
|
+
For JSON, events are collected and output at the end.
|
|
61
|
+
"""
|
|
62
|
+
if self.format == OutputFormat.JSONL:
|
|
63
|
+
obj = {
|
|
64
|
+
"event": event,
|
|
65
|
+
"timestamp": datetime.now(UTC).isoformat(),
|
|
66
|
+
**data,
|
|
67
|
+
}
|
|
68
|
+
print(json.dumps(obj, default=str), flush=True)
|
|
69
|
+
elif self.format == OutputFormat.TEXT:
|
|
70
|
+
status = data.get("status", "")
|
|
71
|
+
if status:
|
|
72
|
+
typer.echo(f"[wafer] {event}: {status}")
|
|
73
|
+
else:
|
|
74
|
+
typer.echo(f"[wafer] {event}")
|
|
75
|
+
|
|
76
|
+
def set_result(
|
|
77
|
+
self,
|
|
78
|
+
*,
|
|
79
|
+
status: Literal["success", "failure", "error"],
|
|
80
|
+
phase: str | None = None,
|
|
81
|
+
correctness: dict[str, Any] | None = None,
|
|
82
|
+
benchmark: dict[str, Any] | None = None,
|
|
83
|
+
profile: dict[str, Any] | None = None,
|
|
84
|
+
error: dict[str, Any] | None = None,
|
|
85
|
+
raw_compiler_output: str | None = None,
|
|
86
|
+
) -> None:
|
|
87
|
+
"""Set the final result data."""
|
|
88
|
+
self._result = EvalOutput(
|
|
89
|
+
status=status,
|
|
90
|
+
target=self.target,
|
|
91
|
+
phase=phase,
|
|
92
|
+
correctness=correctness,
|
|
93
|
+
benchmark=benchmark,
|
|
94
|
+
profile=profile,
|
|
95
|
+
error=error,
|
|
96
|
+
raw_compiler_output=raw_compiler_output,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
def set_error(self, phase: str, error_type: str, **details: Any) -> None:
|
|
100
|
+
"""Set an error result."""
|
|
101
|
+
self._result.status = "error"
|
|
102
|
+
self._result.phase = phase
|
|
103
|
+
self._result.error = {"type": error_type, **details}
|
|
104
|
+
self._result.target = self.target
|
|
105
|
+
|
|
106
|
+
def finalize(self) -> None:
|
|
107
|
+
"""Print final output based on format."""
|
|
108
|
+
if self.format == OutputFormat.JSON:
|
|
109
|
+
print(self._result.to_json())
|
|
110
|
+
elif self.format == OutputFormat.JSONL:
|
|
111
|
+
print(
|
|
112
|
+
json.dumps(
|
|
113
|
+
{
|
|
114
|
+
"event": "completed",
|
|
115
|
+
"timestamp": datetime.now(UTC).isoformat(),
|
|
116
|
+
"result": {k: v for k, v in asdict(self._result).items() if v is not None},
|
|
117
|
+
},
|
|
118
|
+
default=str,
|
|
119
|
+
)
|
|
120
|
+
)
|
|
121
|
+
# TEXT format already printed incrementally
|
|
122
|
+
|
|
123
|
+
def output_text_result(self, result: EvaluateResult) -> None:
|
|
124
|
+
"""Print human-readable result summary (TEXT format only)."""
|
|
125
|
+
if self.format != OutputFormat.TEXT:
|
|
126
|
+
return
|
|
127
|
+
|
|
128
|
+
typer.echo("")
|
|
129
|
+
typer.echo("=" * 60)
|
|
130
|
+
status = "PASS" if result.all_correct else "FAIL"
|
|
131
|
+
typer.echo(f"Result: {status}")
|
|
132
|
+
score_pct = f"{result.correctness_score:.1%}"
|
|
133
|
+
typer.echo(f"Correctness: {result.passed_tests}/{result.total_tests} ({score_pct})")
|
|
134
|
+
if result.geomean_speedup > 0:
|
|
135
|
+
typer.echo(f"Speedup: {result.geomean_speedup:.2f}x")
|
|
136
|
+
typer.echo("=" * 60)
|
|
137
|
+
|
|
138
|
+
def output_text_error(self, error_message: str) -> None:
|
|
139
|
+
"""Print error message (TEXT format only)."""
|
|
140
|
+
if self.format == OutputFormat.TEXT:
|
|
141
|
+
typer.echo(f"Error: {error_message}", err=True)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def format_evaluate_result(result: EvaluateResult, target: str | None = None) -> EvalOutput:
|
|
145
|
+
"""Convert EvaluateResult to structured EvalOutput."""
|
|
146
|
+
if not result.success:
|
|
147
|
+
# Error case
|
|
148
|
+
error_info = parse_error_message(result.error_message or "Unknown error")
|
|
149
|
+
return EvalOutput(
|
|
150
|
+
status="error",
|
|
151
|
+
target=target,
|
|
152
|
+
phase=error_info.get("phase", "unknown"),
|
|
153
|
+
error=error_info,
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
if not result.all_correct:
|
|
157
|
+
# Correctness failure
|
|
158
|
+
return EvalOutput(
|
|
159
|
+
status="failure",
|
|
160
|
+
target=target,
|
|
161
|
+
phase="correctness",
|
|
162
|
+
correctness={
|
|
163
|
+
"passed": False,
|
|
164
|
+
"tests_run": result.total_tests,
|
|
165
|
+
"tests_passed": result.passed_tests,
|
|
166
|
+
},
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
# Success
|
|
170
|
+
output = EvalOutput(
|
|
171
|
+
status="success",
|
|
172
|
+
target=target,
|
|
173
|
+
correctness={
|
|
174
|
+
"passed": True,
|
|
175
|
+
"tests_run": result.total_tests,
|
|
176
|
+
"tests_passed": result.passed_tests,
|
|
177
|
+
},
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
if result.geomean_speedup > 0:
|
|
181
|
+
output.benchmark = {"speedup": result.geomean_speedup}
|
|
182
|
+
|
|
183
|
+
return output
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def parse_error_message(error_message: str) -> dict[str, Any]:
|
|
187
|
+
"""Parse error message to extract structured information."""
|
|
188
|
+
error_info: dict[str, Any] = {"message": error_message}
|
|
189
|
+
|
|
190
|
+
# Try to identify the phase and type from common patterns
|
|
191
|
+
error_lower = error_message.lower()
|
|
192
|
+
|
|
193
|
+
if "compilation" in error_lower or "compile" in error_lower:
|
|
194
|
+
error_info["phase"] = "compilation"
|
|
195
|
+
error_info["type"] = "CompilationError"
|
|
196
|
+
# Try to parse compiler error format: file:line:col: error: message
|
|
197
|
+
parsed = parse_compilation_error(error_message)
|
|
198
|
+
if parsed:
|
|
199
|
+
error_info.update(parsed)
|
|
200
|
+
elif "hsa_status" in error_lower or "memory" in error_lower or "segfault" in error_lower:
|
|
201
|
+
error_info["phase"] = "runtime"
|
|
202
|
+
error_info["type"] = "MemoryViolation"
|
|
203
|
+
elif "timeout" in error_lower:
|
|
204
|
+
error_info["phase"] = "runtime"
|
|
205
|
+
error_info["type"] = "Timeout"
|
|
206
|
+
elif "correctness" in error_lower:
|
|
207
|
+
error_info["phase"] = "correctness"
|
|
208
|
+
error_info["type"] = "CorrectnessError"
|
|
209
|
+
else:
|
|
210
|
+
error_info["phase"] = "unknown"
|
|
211
|
+
error_info["type"] = "UnknownError"
|
|
212
|
+
|
|
213
|
+
return error_info
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def parse_compilation_error(raw_output: str) -> dict[str, Any] | None:
|
|
217
|
+
"""Extract structured info from compiler error output.
|
|
218
|
+
|
|
219
|
+
Matches patterns like: file.hip:10:14: error: message
|
|
220
|
+
"""
|
|
221
|
+
match = re.search(
|
|
222
|
+
r"(?P<file>[\w./]+):(?P<line>\d+):(?P<col>\d+): error: (?P<message>.+)",
|
|
223
|
+
raw_output,
|
|
224
|
+
)
|
|
225
|
+
if match:
|
|
226
|
+
return {
|
|
227
|
+
"file": match.group("file"),
|
|
228
|
+
"line": int(match.group("line")),
|
|
229
|
+
"column": int(match.group("col")),
|
|
230
|
+
"message": match.group("message").strip(),
|
|
231
|
+
}
|
|
232
|
+
return None
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def get_output_format(json_flag: bool, jsonl_flag: bool) -> OutputFormat:
|
|
236
|
+
"""Determine output format from CLI flags."""
|
|
237
|
+
if jsonl_flag:
|
|
238
|
+
return OutputFormat.JSONL
|
|
239
|
+
if json_flag:
|
|
240
|
+
return OutputFormat.JSON
|
|
241
|
+
return OutputFormat.TEXT
|
|
@@ -7,6 +7,19 @@ description: GPU kernel development with the Wafer CLI. Use when working on CUDA
|
|
|
7
7
|
|
|
8
8
|
GPU development primitives for optimizing CUDA and HIP kernels.
|
|
9
9
|
|
|
10
|
+
## Installation
|
|
11
|
+
|
|
12
|
+
Before using Wafer CLI commands, install the tool:
|
|
13
|
+
|
|
14
|
+
```bash
|
|
15
|
+
# Install wafer-cli using uv (recommended)
|
|
16
|
+
uv tool install wafer-cli
|
|
17
|
+
|
|
18
|
+
# Authenticate (one-time setup)
|
|
19
|
+
wafer login
|
|
20
|
+
|
|
21
|
+
```
|
|
22
|
+
|
|
10
23
|
## When to Use This Skill
|
|
11
24
|
|
|
12
25
|
Activate this skill when:
|
wafer/ssh_keys.py
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
1
|
+
"""SSH Keys CLI - Manage SSH public keys for workspace access.
|
|
2
|
+
|
|
3
|
+
This module provides the implementation for the `wafer ssh-keys` subcommand.
|
|
4
|
+
Users register their SSH public keys here, which are then installed in all
|
|
5
|
+
workspaces they attach to (BYOK - Bring Your Own Key model).
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import json
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
|
|
14
|
+
import httpx
|
|
15
|
+
|
|
16
|
+
from .api_client import get_api_url
|
|
17
|
+
from .auth import get_auth_headers
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass(frozen=True)
|
|
21
|
+
class SshKey:
|
|
22
|
+
"""Registered SSH key info."""
|
|
23
|
+
|
|
24
|
+
id: str
|
|
25
|
+
public_key: str
|
|
26
|
+
name: str | None
|
|
27
|
+
created_at: str
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _get_client() -> tuple[str, dict[str, str]]:
|
|
31
|
+
"""Get API URL and auth headers."""
|
|
32
|
+
api_url = get_api_url()
|
|
33
|
+
headers = get_auth_headers()
|
|
34
|
+
|
|
35
|
+
assert api_url, "API URL must be configured"
|
|
36
|
+
assert api_url.startswith("http"), "API URL must be a valid HTTP(S) URL"
|
|
37
|
+
|
|
38
|
+
return api_url, headers
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _get_key_fingerprint(public_key: str) -> str:
|
|
42
|
+
"""Extract a short fingerprint from a public key for display.
|
|
43
|
+
|
|
44
|
+
Returns the first 12 characters of the base64 data portion.
|
|
45
|
+
"""
|
|
46
|
+
parts = public_key.strip().split()
|
|
47
|
+
if len(parts) >= 2:
|
|
48
|
+
return parts[1][:12] + "..."
|
|
49
|
+
return public_key[:12] + "..."
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _get_key_type(public_key: str) -> str:
|
|
53
|
+
"""Extract the key type from a public key."""
|
|
54
|
+
parts = public_key.strip().split()
|
|
55
|
+
if parts:
|
|
56
|
+
return parts[0]
|
|
57
|
+
return "unknown"
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def _detect_ssh_keys() -> list[Path]:
|
|
61
|
+
"""Detect existing SSH public keys on disk.
|
|
62
|
+
|
|
63
|
+
Returns list of paths to found public key files, in preference order.
|
|
64
|
+
"""
|
|
65
|
+
ssh_dir = Path.home() / ".ssh"
|
|
66
|
+
candidates = [
|
|
67
|
+
"id_ed25519.pub", # Preferred (modern, secure, fast)
|
|
68
|
+
"id_rsa.pub", # Legacy but common
|
|
69
|
+
"id_ecdsa.pub", # Less common
|
|
70
|
+
"id_dsa.pub", # Deprecated
|
|
71
|
+
]
|
|
72
|
+
|
|
73
|
+
found = []
|
|
74
|
+
for filename in candidates:
|
|
75
|
+
key_path = ssh_dir / filename
|
|
76
|
+
if key_path.exists():
|
|
77
|
+
found.append(key_path)
|
|
78
|
+
|
|
79
|
+
return found
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def list_ssh_keys(json_output: bool = False) -> str:
|
|
83
|
+
"""List all registered SSH keys.
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
Formatted output string
|
|
87
|
+
"""
|
|
88
|
+
api_url, headers = _get_client()
|
|
89
|
+
|
|
90
|
+
try:
|
|
91
|
+
with httpx.Client(timeout=30.0, headers=headers) as client:
|
|
92
|
+
response = client.get(f"{api_url}/v1/user/ssh-keys")
|
|
93
|
+
response.raise_for_status()
|
|
94
|
+
keys = response.json()
|
|
95
|
+
except httpx.HTTPStatusError as e:
|
|
96
|
+
if e.response.status_code == 401:
|
|
97
|
+
raise RuntimeError("Not authenticated. Run: wafer login") from e
|
|
98
|
+
raise RuntimeError(f"API error: {e.response.status_code} - {e.response.text}") from e
|
|
99
|
+
except httpx.RequestError as e:
|
|
100
|
+
raise RuntimeError(f"Could not reach API: {e}") from e
|
|
101
|
+
|
|
102
|
+
if json_output:
|
|
103
|
+
return json.dumps(keys, indent=2)
|
|
104
|
+
|
|
105
|
+
if not keys:
|
|
106
|
+
return (
|
|
107
|
+
"No SSH keys registered.\n"
|
|
108
|
+
"\n"
|
|
109
|
+
"Add your SSH key:\n"
|
|
110
|
+
" wafer ssh-keys add\n"
|
|
111
|
+
"\n"
|
|
112
|
+
"This will auto-detect your key from ~/.ssh/"
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
lines = ["SSH Keys:"]
|
|
116
|
+
for key in keys:
|
|
117
|
+
key_type = _get_key_type(key["public_key"])
|
|
118
|
+
fingerprint = _get_key_fingerprint(key["public_key"])
|
|
119
|
+
name = key.get("name") or "(no name)"
|
|
120
|
+
lines.append(f" • {name}: {key_type} {fingerprint}")
|
|
121
|
+
lines.append(f" ID: {key['id']}")
|
|
122
|
+
|
|
123
|
+
return "\n".join(lines)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def add_ssh_key(
|
|
127
|
+
pubkey_path: Path | None = None,
|
|
128
|
+
name: str | None = None,
|
|
129
|
+
json_output: bool = False,
|
|
130
|
+
) -> str:
|
|
131
|
+
"""Add an SSH public key.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
pubkey_path: Path to public key file (auto-detects if None)
|
|
135
|
+
name: Optional friendly name for the key
|
|
136
|
+
json_output: Return JSON instead of formatted output
|
|
137
|
+
|
|
138
|
+
Returns:
|
|
139
|
+
Formatted output string
|
|
140
|
+
"""
|
|
141
|
+
# Auto-detect if no path provided
|
|
142
|
+
if pubkey_path is None:
|
|
143
|
+
detected = _detect_ssh_keys()
|
|
144
|
+
if not detected:
|
|
145
|
+
raise RuntimeError(
|
|
146
|
+
"No SSH key found in ~/.ssh/\n"
|
|
147
|
+
"\n"
|
|
148
|
+
"Generate one with:\n"
|
|
149
|
+
" ssh-keygen -t ed25519\n"
|
|
150
|
+
"\n"
|
|
151
|
+
"Or specify a path:\n"
|
|
152
|
+
" wafer ssh-keys add /path/to/key.pub"
|
|
153
|
+
)
|
|
154
|
+
pubkey_path = detected[0]
|
|
155
|
+
|
|
156
|
+
# Validate path
|
|
157
|
+
if not pubkey_path.exists():
|
|
158
|
+
raise RuntimeError(f"File not found: {pubkey_path}")
|
|
159
|
+
|
|
160
|
+
if not pubkey_path.suffix == ".pub" and "pub" not in pubkey_path.name:
|
|
161
|
+
raise RuntimeError(
|
|
162
|
+
f"Expected a public key file (.pub), got: {pubkey_path}\n"
|
|
163
|
+
"\n"
|
|
164
|
+
"Make sure you're adding the PUBLIC key, not the private key."
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
# Read key content
|
|
168
|
+
try:
|
|
169
|
+
public_key = pubkey_path.read_text().strip()
|
|
170
|
+
except Exception as e:
|
|
171
|
+
raise RuntimeError(f"Could not read key file: {e}") from e
|
|
172
|
+
|
|
173
|
+
# Validate basic format
|
|
174
|
+
if not public_key.startswith(("ssh-", "ecdsa-", "sk-")):
|
|
175
|
+
raise RuntimeError(
|
|
176
|
+
f"Invalid SSH public key format in {pubkey_path}\n"
|
|
177
|
+
"\n"
|
|
178
|
+
"Expected OpenSSH format (e.g., 'ssh-ed25519 AAAAC3... user@host')"
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
# Auto-generate name from key type and filename if not provided
|
|
182
|
+
if name is None:
|
|
183
|
+
key_type = _get_key_type(public_key)
|
|
184
|
+
# Use key type without prefix (e.g., "ed25519" instead of "ssh-ed25519")
|
|
185
|
+
short_type = key_type.replace("ssh-", "").replace("ecdsa-sha2-", "")
|
|
186
|
+
name = short_type
|
|
187
|
+
|
|
188
|
+
# Call API
|
|
189
|
+
api_url, headers = _get_client()
|
|
190
|
+
request_body = {
|
|
191
|
+
"public_key": public_key,
|
|
192
|
+
"name": name,
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
try:
|
|
196
|
+
with httpx.Client(timeout=30.0, headers=headers) as client:
|
|
197
|
+
response = client.post(
|
|
198
|
+
f"{api_url}/v1/user/ssh-keys",
|
|
199
|
+
json=request_body,
|
|
200
|
+
)
|
|
201
|
+
response.raise_for_status()
|
|
202
|
+
key_data = response.json()
|
|
203
|
+
except httpx.HTTPStatusError as e:
|
|
204
|
+
if e.response.status_code == 401:
|
|
205
|
+
raise RuntimeError("Not authenticated. Run: wafer login") from e
|
|
206
|
+
if e.response.status_code == 400:
|
|
207
|
+
# Parse error detail
|
|
208
|
+
try:
|
|
209
|
+
detail = e.response.json().get("detail", e.response.text)
|
|
210
|
+
except Exception:
|
|
211
|
+
detail = e.response.text
|
|
212
|
+
raise RuntimeError(f"Invalid key: {detail}") from e
|
|
213
|
+
raise RuntimeError(f"API error: {e.response.status_code} - {e.response.text}") from e
|
|
214
|
+
except httpx.RequestError as e:
|
|
215
|
+
raise RuntimeError(f"Could not reach API: {e}") from e
|
|
216
|
+
|
|
217
|
+
if json_output:
|
|
218
|
+
return json.dumps(key_data, indent=2)
|
|
219
|
+
|
|
220
|
+
key_type = _get_key_type(public_key)
|
|
221
|
+
fingerprint = _get_key_fingerprint(public_key)
|
|
222
|
+
|
|
223
|
+
return (
|
|
224
|
+
f"✓ SSH key registered: {name}\n"
|
|
225
|
+
f" Type: {key_type}\n"
|
|
226
|
+
f" Fingerprint: {fingerprint}\n"
|
|
227
|
+
f" Source: {pubkey_path}\n"
|
|
228
|
+
f"\n"
|
|
229
|
+
f"Your key will be installed in all workspaces you attach to."
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def remove_ssh_key(key_id: str, json_output: bool = False) -> str:
|
|
234
|
+
"""Remove an SSH key.
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
key_id: UUID of the key to remove
|
|
238
|
+
json_output: Return JSON instead of formatted output
|
|
239
|
+
|
|
240
|
+
Returns:
|
|
241
|
+
Formatted output string
|
|
242
|
+
"""
|
|
243
|
+
api_url, headers = _get_client()
|
|
244
|
+
|
|
245
|
+
try:
|
|
246
|
+
with httpx.Client(timeout=30.0, headers=headers) as client:
|
|
247
|
+
response = client.delete(f"{api_url}/v1/user/ssh-keys/{key_id}")
|
|
248
|
+
response.raise_for_status()
|
|
249
|
+
except httpx.HTTPStatusError as e:
|
|
250
|
+
if e.response.status_code == 401:
|
|
251
|
+
raise RuntimeError("Not authenticated. Run: wafer login") from e
|
|
252
|
+
if e.response.status_code == 404:
|
|
253
|
+
raise RuntimeError(f"SSH key not found: {key_id}") from e
|
|
254
|
+
raise RuntimeError(f"API error: {e.response.status_code} - {e.response.text}") from e
|
|
255
|
+
except httpx.RequestError as e:
|
|
256
|
+
raise RuntimeError(f"Could not reach API: {e}") from e
|
|
257
|
+
|
|
258
|
+
if json_output:
|
|
259
|
+
return json.dumps({"status": "deleted", "key_id": key_id}, indent=2)
|
|
260
|
+
|
|
261
|
+
return f"✓ SSH key removed: {key_id}"
|