wafer-cli 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer/GUIDE.md +118 -0
- wafer/__init__.py +3 -0
- wafer/analytics.py +306 -0
- wafer/api_client.py +195 -0
- wafer/auth.py +432 -0
- wafer/autotuner.py +1080 -0
- wafer/billing.py +233 -0
- wafer/cli.py +7289 -0
- wafer/config.py +105 -0
- wafer/corpus.py +366 -0
- wafer/evaluate.py +4593 -0
- wafer/global_config.py +350 -0
- wafer/gpu_run.py +307 -0
- wafer/inference.py +148 -0
- wafer/kernel_scope.py +552 -0
- wafer/ncu_analyze.py +651 -0
- wafer/nsys_analyze.py +1042 -0
- wafer/nsys_profile.py +510 -0
- wafer/output.py +248 -0
- wafer/problems.py +357 -0
- wafer/rocprof_compute.py +490 -0
- wafer/rocprof_sdk.py +274 -0
- wafer/rocprof_systems.py +520 -0
- wafer/skills/wafer-guide/SKILL.md +129 -0
- wafer/ssh_keys.py +261 -0
- wafer/target_lock.py +270 -0
- wafer/targets.py +842 -0
- wafer/targets_ops.py +717 -0
- wafer/templates/__init__.py +0 -0
- wafer/templates/ask_docs.py +61 -0
- wafer/templates/optimize_kernel.py +71 -0
- wafer/templates/optimize_kernelbench.py +137 -0
- wafer/templates/trace_analyze.py +74 -0
- wafer/tracelens.py +218 -0
- wafer/wevin_cli.py +577 -0
- wafer/workspaces.py +852 -0
- wafer_cli-0.2.14.dist-info/METADATA +16 -0
- wafer_cli-0.2.14.dist-info/RECORD +41 -0
- wafer_cli-0.2.14.dist-info/WHEEL +5 -0
- wafer_cli-0.2.14.dist-info/entry_points.txt +2 -0
- wafer_cli-0.2.14.dist-info/top_level.txt +1 -0
wafer/GUIDE.md
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
# Wafer CLI Guide
|
|
2
|
+
|
|
3
|
+
GPU development primitives for LLM agents.
|
|
4
|
+
|
|
5
|
+
## Quick Start: Cloud GPU (No Setup)
|
|
6
|
+
|
|
7
|
+
Run code on cloud GPUs instantly with workspaces:
|
|
8
|
+
|
|
9
|
+
```bash
|
|
10
|
+
wafer login # One-time auth
|
|
11
|
+
wafer workspaces create dev --gpu B200 # Create workspace (NVIDIA B200)
|
|
12
|
+
wafer workspaces exec dev -- python -c "import torch; print(torch.cuda.get_device_name(0))"
|
|
13
|
+
wafer workspaces sync dev ./my-project # Sync files
|
|
14
|
+
wafer workspaces exec dev -- python train.py
|
|
15
|
+
```
|
|
16
|
+
|
|
17
|
+
**Available GPUs:**
|
|
18
|
+
|
|
19
|
+
- `MI300X` - AMD Instinct MI300X (192GB HBM3, ROCm)
|
|
20
|
+
- `B200` - NVIDIA Blackwell B200 (180GB HBM3e, CUDA) - default
|
|
21
|
+
|
|
22
|
+
## Documentation Lookup
|
|
23
|
+
|
|
24
|
+
Answer GPU programming questions from indexed documentation.
|
|
25
|
+
|
|
26
|
+
```bash
|
|
27
|
+
# Download corpus (one-time)
|
|
28
|
+
wafer corpus download cuda
|
|
29
|
+
wafer corpus download cutlass
|
|
30
|
+
wafer corpus download hip
|
|
31
|
+
|
|
32
|
+
# Query documentation
|
|
33
|
+
wafer agent -t ask-docs --corpus cuda "What is warp divergence?"
|
|
34
|
+
wafer agent -t ask-docs --corpus cutlass "What is a TiledMma?"
|
|
35
|
+
```
|
|
36
|
+
|
|
37
|
+
## Trace Analysis
|
|
38
|
+
|
|
39
|
+
Analyze performance traces from NCU, NSYS, or PyTorch profiler.
|
|
40
|
+
|
|
41
|
+
```bash
|
|
42
|
+
# AI-assisted analysis
|
|
43
|
+
wafer agent -t trace-analyze --args trace=./profile.ncu-rep "Why is this kernel slow?"
|
|
44
|
+
wafer agent -t trace-analyze --args trace=./trace.json "What's the bottleneck?"
|
|
45
|
+
|
|
46
|
+
# Direct trace queries (PyTorch/Perfetto JSON)
|
|
47
|
+
wafer nvidia perfetto tables trace.json
|
|
48
|
+
wafer nvidia perfetto query trace.json \
|
|
49
|
+
"SELECT name, dur/1e6 as ms FROM slice WHERE cat='kernel' ORDER BY dur DESC LIMIT 10"
|
|
50
|
+
|
|
51
|
+
# NCU/NSYS analysis
|
|
52
|
+
wafer nvidia ncu analyze profile.ncu-rep
|
|
53
|
+
wafer nvidia nsys analyze profile.nsys-rep
|
|
54
|
+
```
|
|
55
|
+
|
|
56
|
+
## Kernel Evaluation
|
|
57
|
+
|
|
58
|
+
Test kernel correctness and measure speedup against a reference.
|
|
59
|
+
|
|
60
|
+
```bash
|
|
61
|
+
# Using workspaces (no target setup required):
|
|
62
|
+
wafer workspaces create dev --gpu B200
|
|
63
|
+
wafer workspaces exec --sync ./my-kernel dev -- python test_kernel.py
|
|
64
|
+
|
|
65
|
+
# Or using configured targets (for your own hardware):
|
|
66
|
+
wafer evaluate make-template ./my-kernel
|
|
67
|
+
wafer evaluate \
|
|
68
|
+
--impl ./my-kernel/kernel.py \
|
|
69
|
+
--reference ./my-kernel/reference.py \
|
|
70
|
+
--test-cases ./my-kernel/test_cases.json \
|
|
71
|
+
--target <target-name>
|
|
72
|
+
```
|
|
73
|
+
|
|
74
|
+
For target setup, see `wafer config targets --help`.
|
|
75
|
+
|
|
76
|
+
## Kernel Optimization (AI-assisted)
|
|
77
|
+
|
|
78
|
+
Iteratively optimize a kernel with evaluation feedback.
|
|
79
|
+
|
|
80
|
+
```bash
|
|
81
|
+
wafer agent -t optimize-kernel \
|
|
82
|
+
--args kernel=./my_kernel.cu \
|
|
83
|
+
--args target=H100 \
|
|
84
|
+
"Optimize this GEMM for memory bandwidth"
|
|
85
|
+
```
|
|
86
|
+
|
|
87
|
+
## Workspaces
|
|
88
|
+
|
|
89
|
+
Cloud GPU environments with no setup required.
|
|
90
|
+
|
|
91
|
+
**Available GPUs:**
|
|
92
|
+
|
|
93
|
+
- `MI300X` - AMD Instinct MI300X (192GB HBM3, ROCm)
|
|
94
|
+
- `B200` - NVIDIA Blackwell B200 (180GB HBM3e, CUDA) - default
|
|
95
|
+
|
|
96
|
+
```bash
|
|
97
|
+
wafer workspaces create dev --gpu B200 --wait # NVIDIA B200
|
|
98
|
+
wafer workspaces create amd-dev --gpu MI300X # AMD MI300X
|
|
99
|
+
wafer workspaces list # List all
|
|
100
|
+
wafer workspaces sync dev ./project # Sync files
|
|
101
|
+
wafer workspaces exec dev -- ./run.sh # Run commands
|
|
102
|
+
wafer workspaces ssh dev # Interactive SSH
|
|
103
|
+
wafer workspaces delete dev # Cleanup
|
|
104
|
+
```
|
|
105
|
+
|
|
106
|
+
See `wafer workspaces --help` for details.
|
|
107
|
+
|
|
108
|
+
## Command Reference
|
|
109
|
+
|
|
110
|
+
```bash
|
|
111
|
+
wafer corpus list|download|path # Manage documentation corpora
|
|
112
|
+
wafer workspaces # Cloud GPU environments (no setup)
|
|
113
|
+
wafer evaluate # Test kernel correctness/performance
|
|
114
|
+
wafer nvidia ncu|nsys|perfetto # NVIDIA profiling tools
|
|
115
|
+
wafer amd isa|rocprof-compute # AMD profiling tools
|
|
116
|
+
wafer agent -t <template> # AI-assisted workflows
|
|
117
|
+
wafer config targets # Configure your own GPU targets
|
|
118
|
+
```
|
wafer/__init__.py
ADDED
wafer/analytics.py
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
1
|
+
"""PostHog analytics for Wafer CLI.
|
|
2
|
+
|
|
3
|
+
Tracks CLI command usage and user activity for product analytics.
|
|
4
|
+
Mirrors the analytics implementation in apps/wevin-extension/src/services/analytics.ts.
|
|
5
|
+
|
|
6
|
+
Usage:
|
|
7
|
+
from .analytics import track_command, identify_user, shutdown_analytics
|
|
8
|
+
|
|
9
|
+
# Track a command execution
|
|
10
|
+
track_command("evaluate", {"subcommand": "kernelbench", "outcome": "success"})
|
|
11
|
+
|
|
12
|
+
# Identify user after login
|
|
13
|
+
identify_user("user-id", "user@example.com")
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
import atexit
|
|
17
|
+
import platform
|
|
18
|
+
import uuid
|
|
19
|
+
from pathlib import Path
|
|
20
|
+
from typing import Any
|
|
21
|
+
|
|
22
|
+
# PostHog configuration - same as wevin-extension
|
|
23
|
+
POSTHOG_API_KEY = "phc_9eDjkY72ud9o4l1mA1Gr1dnRT1yx71rP3XY9z66teFh"
|
|
24
|
+
POSTHOG_HOST = "https://us.i.posthog.com"
|
|
25
|
+
|
|
26
|
+
# Anonymous ID storage
|
|
27
|
+
ANONYMOUS_ID_FILE = Path.home() / ".wafer" / ".analytics_id"
|
|
28
|
+
|
|
29
|
+
# Global state
|
|
30
|
+
_posthog_client: Any = None
|
|
31
|
+
_distinct_id: str | None = None
|
|
32
|
+
_initialized: bool = False
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _get_anonymous_id() -> str:
|
|
36
|
+
"""Get or create anonymous ID for users who aren't logged in."""
|
|
37
|
+
if ANONYMOUS_ID_FILE.exists():
|
|
38
|
+
return ANONYMOUS_ID_FILE.read_text().strip()
|
|
39
|
+
|
|
40
|
+
# Generate new anonymous ID
|
|
41
|
+
anonymous_id = f"anon_{uuid.uuid4().hex}"
|
|
42
|
+
ANONYMOUS_ID_FILE.parent.mkdir(parents=True, exist_ok=True)
|
|
43
|
+
ANONYMOUS_ID_FILE.write_text(anonymous_id)
|
|
44
|
+
return anonymous_id
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _get_user_id_from_credentials() -> tuple[str | None, str | None]:
|
|
48
|
+
"""Get user ID and email from stored credentials.
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
Tuple of (user_id, email), both may be None if not logged in.
|
|
52
|
+
"""
|
|
53
|
+
# Import here to avoid circular imports
|
|
54
|
+
from .auth import load_credentials, verify_token
|
|
55
|
+
|
|
56
|
+
creds = load_credentials()
|
|
57
|
+
if not creds:
|
|
58
|
+
return None, None
|
|
59
|
+
|
|
60
|
+
# Try to get user info from token
|
|
61
|
+
try:
|
|
62
|
+
user_info = verify_token(creds.access_token)
|
|
63
|
+
return user_info.user_id, user_info.email or creds.email
|
|
64
|
+
except Exception:
|
|
65
|
+
# Token verification failed, use email from credentials if available
|
|
66
|
+
return None, creds.email
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _is_analytics_enabled() -> bool:
|
|
70
|
+
"""Check if analytics is enabled via preferences.
|
|
71
|
+
|
|
72
|
+
Returns True by default, respects user preference in config.
|
|
73
|
+
"""
|
|
74
|
+
from .global_config import get_preferences
|
|
75
|
+
|
|
76
|
+
try:
|
|
77
|
+
prefs = get_preferences()
|
|
78
|
+
return getattr(prefs, "analytics_enabled", True)
|
|
79
|
+
except Exception:
|
|
80
|
+
# Default to enabled if we can't read preferences
|
|
81
|
+
return True
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def init_analytics() -> bool:
|
|
85
|
+
"""Initialize PostHog client.
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
True if initialization succeeded, False otherwise.
|
|
89
|
+
"""
|
|
90
|
+
global _posthog_client, _distinct_id, _initialized
|
|
91
|
+
|
|
92
|
+
if _initialized:
|
|
93
|
+
return _posthog_client is not None
|
|
94
|
+
|
|
95
|
+
_initialized = True
|
|
96
|
+
|
|
97
|
+
# Check if analytics is enabled
|
|
98
|
+
if not _is_analytics_enabled():
|
|
99
|
+
return False
|
|
100
|
+
|
|
101
|
+
try:
|
|
102
|
+
from posthog import Posthog
|
|
103
|
+
|
|
104
|
+
_posthog_client = Posthog(
|
|
105
|
+
api_key=POSTHOG_API_KEY,
|
|
106
|
+
host=POSTHOG_HOST,
|
|
107
|
+
# Flush immediately for CLI - commands are short-lived
|
|
108
|
+
flush_at=1,
|
|
109
|
+
flush_interval=1,
|
|
110
|
+
# Disable debug logging
|
|
111
|
+
debug=False,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
# Set up distinct ID - prefer authenticated user, fall back to anonymous
|
|
115
|
+
user_id, email = _get_user_id_from_credentials()
|
|
116
|
+
if user_id:
|
|
117
|
+
_distinct_id = user_id
|
|
118
|
+
# Identify the user with their email
|
|
119
|
+
if email:
|
|
120
|
+
_posthog_client.identify(
|
|
121
|
+
distinct_id=user_id,
|
|
122
|
+
properties={
|
|
123
|
+
"email": email,
|
|
124
|
+
"auth_provider": "github",
|
|
125
|
+
},
|
|
126
|
+
)
|
|
127
|
+
else:
|
|
128
|
+
_distinct_id = _get_anonymous_id()
|
|
129
|
+
|
|
130
|
+
# Register shutdown handler to flush events
|
|
131
|
+
atexit.register(shutdown_analytics)
|
|
132
|
+
|
|
133
|
+
return True
|
|
134
|
+
|
|
135
|
+
except ImportError:
|
|
136
|
+
# PostHog not installed - analytics disabled
|
|
137
|
+
return False
|
|
138
|
+
except Exception:
|
|
139
|
+
# Any other error - fail silently, don't break CLI
|
|
140
|
+
return False
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def shutdown_analytics() -> None:
|
|
144
|
+
"""Shutdown PostHog client and flush pending events."""
|
|
145
|
+
global _posthog_client
|
|
146
|
+
|
|
147
|
+
if _posthog_client is not None:
|
|
148
|
+
try:
|
|
149
|
+
_posthog_client.flush()
|
|
150
|
+
_posthog_client.shutdown()
|
|
151
|
+
except Exception:
|
|
152
|
+
pass # Fail silently on shutdown
|
|
153
|
+
_posthog_client = None
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def identify_user(user_id: str, email: str | None = None) -> None:
|
|
157
|
+
"""Identify a user after login.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
user_id: Supabase user ID
|
|
161
|
+
email: User's email address
|
|
162
|
+
"""
|
|
163
|
+
global _distinct_id
|
|
164
|
+
|
|
165
|
+
if not init_analytics():
|
|
166
|
+
return
|
|
167
|
+
|
|
168
|
+
if _posthog_client is None:
|
|
169
|
+
return
|
|
170
|
+
|
|
171
|
+
_distinct_id = user_id
|
|
172
|
+
|
|
173
|
+
try:
|
|
174
|
+
properties: dict[str, Any] = {"auth_provider": "github"}
|
|
175
|
+
if email:
|
|
176
|
+
properties["email"] = email
|
|
177
|
+
|
|
178
|
+
_posthog_client.identify(
|
|
179
|
+
distinct_id=user_id,
|
|
180
|
+
properties=properties,
|
|
181
|
+
)
|
|
182
|
+
_posthog_client.flush()
|
|
183
|
+
except Exception:
|
|
184
|
+
pass # Fail silently
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def reset_user_identity() -> None:
|
|
188
|
+
"""Reset user identity after logout."""
|
|
189
|
+
global _distinct_id
|
|
190
|
+
|
|
191
|
+
_distinct_id = _get_anonymous_id()
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def get_distinct_id() -> str:
|
|
195
|
+
"""Get current distinct ID for tracking."""
|
|
196
|
+
global _distinct_id
|
|
197
|
+
|
|
198
|
+
if _distinct_id is None:
|
|
199
|
+
user_id, _ = _get_user_id_from_credentials()
|
|
200
|
+
_distinct_id = user_id or _get_anonymous_id()
|
|
201
|
+
|
|
202
|
+
return _distinct_id
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def _get_cli_version() -> str:
|
|
206
|
+
"""Get CLI version from package metadata."""
|
|
207
|
+
try:
|
|
208
|
+
from importlib.metadata import version
|
|
209
|
+
|
|
210
|
+
return version("wafer-cli")
|
|
211
|
+
except Exception:
|
|
212
|
+
return "unknown"
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def _get_base_properties() -> dict[str, Any]:
|
|
216
|
+
"""Get base properties included with all events."""
|
|
217
|
+
return {
|
|
218
|
+
"platform": "cli",
|
|
219
|
+
"tool_id": "cli",
|
|
220
|
+
"cli_version": _get_cli_version(),
|
|
221
|
+
"os": platform.system().lower(),
|
|
222
|
+
"os_version": platform.release(),
|
|
223
|
+
"python_version": platform.python_version(),
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def track_event(event_name: str, properties: dict[str, Any] | None = None) -> None:
|
|
228
|
+
"""Track a generic event.
|
|
229
|
+
|
|
230
|
+
Args:
|
|
231
|
+
event_name: Name of the event to track
|
|
232
|
+
properties: Additional properties to include
|
|
233
|
+
"""
|
|
234
|
+
if not init_analytics():
|
|
235
|
+
return
|
|
236
|
+
|
|
237
|
+
if _posthog_client is None:
|
|
238
|
+
return
|
|
239
|
+
|
|
240
|
+
try:
|
|
241
|
+
event_properties = _get_base_properties()
|
|
242
|
+
if properties:
|
|
243
|
+
event_properties.update(properties)
|
|
244
|
+
|
|
245
|
+
_posthog_client.capture(
|
|
246
|
+
distinct_id=get_distinct_id(),
|
|
247
|
+
event=event_name,
|
|
248
|
+
properties=event_properties,
|
|
249
|
+
)
|
|
250
|
+
except Exception:
|
|
251
|
+
pass # Fail silently
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
def track_command(
|
|
255
|
+
command: str,
|
|
256
|
+
subcommand: str | None = None,
|
|
257
|
+
outcome: str = "success",
|
|
258
|
+
duration_ms: int | None = None,
|
|
259
|
+
properties: dict[str, Any] | None = None,
|
|
260
|
+
) -> None:
|
|
261
|
+
"""Track a CLI command execution.
|
|
262
|
+
|
|
263
|
+
This event counts towards DAU in the internal dashboard.
|
|
264
|
+
|
|
265
|
+
Args:
|
|
266
|
+
command: The main command name (e.g., "evaluate", "agent")
|
|
267
|
+
subcommand: Optional subcommand (e.g., "kernelbench")
|
|
268
|
+
outcome: "success" or "error"
|
|
269
|
+
duration_ms: Command execution time in milliseconds
|
|
270
|
+
properties: Additional properties to include
|
|
271
|
+
"""
|
|
272
|
+
event_properties: dict[str, Any] = {
|
|
273
|
+
"command": command,
|
|
274
|
+
"outcome": outcome,
|
|
275
|
+
}
|
|
276
|
+
|
|
277
|
+
if subcommand:
|
|
278
|
+
event_properties["subcommand"] = subcommand
|
|
279
|
+
|
|
280
|
+
if duration_ms is not None:
|
|
281
|
+
event_properties["duration_ms"] = duration_ms
|
|
282
|
+
|
|
283
|
+
if properties:
|
|
284
|
+
event_properties.update(properties)
|
|
285
|
+
|
|
286
|
+
track_event("cli_command_executed", event_properties)
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
def track_login(user_id: str, email: str | None = None) -> None:
|
|
290
|
+
"""Track user login event.
|
|
291
|
+
|
|
292
|
+
Args:
|
|
293
|
+
user_id: Supabase user ID
|
|
294
|
+
email: User's email address
|
|
295
|
+
"""
|
|
296
|
+
# First identify the user
|
|
297
|
+
identify_user(user_id, email)
|
|
298
|
+
|
|
299
|
+
# Then track the login event
|
|
300
|
+
track_event("cli_user_signed_in", {"user_id": user_id})
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
def track_logout() -> None:
|
|
304
|
+
"""Track user logout event."""
|
|
305
|
+
track_event("cli_user_signed_out")
|
|
306
|
+
reset_user_identity()
|
wafer/api_client.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
"""Wafer API client for remote GPU operations.
|
|
2
|
+
|
|
3
|
+
Thin client that calls wafer-api endpoints instead of direct SSH.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import base64
|
|
7
|
+
import sys
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
|
|
11
|
+
import httpx
|
|
12
|
+
|
|
13
|
+
from .global_config import get_api_url # noqa: F401 - re-exported for backwards compat
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass(frozen=True)
|
|
17
|
+
class PushResult:
|
|
18
|
+
"""Result of pushing files to GPU."""
|
|
19
|
+
|
|
20
|
+
workspace_id: str
|
|
21
|
+
workspace_path: str
|
|
22
|
+
files_uploaded: list[str]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass(frozen=True)
|
|
26
|
+
class ApiConfig:
|
|
27
|
+
"""API client configuration."""
|
|
28
|
+
|
|
29
|
+
base_url: str = "http://localhost:8000" # Only used if ApiConfig is instantiated directly
|
|
30
|
+
timeout: float = 60.0
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _get_auth_headers() -> dict[str, str]:
|
|
34
|
+
"""Get auth headers from stored credentials (lazy import to avoid circular)."""
|
|
35
|
+
from .auth import get_auth_headers
|
|
36
|
+
|
|
37
|
+
return get_auth_headers()
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def push_directory(local_path: Path, workspace_name: str | None = None) -> PushResult:
|
|
41
|
+
"""Push local directory to GPU via wafer-api.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
local_path: Local directory to upload
|
|
45
|
+
workspace_name: Optional workspace name (defaults to directory name)
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
PushResult with workspace_id and uploaded files
|
|
49
|
+
|
|
50
|
+
Raises:
|
|
51
|
+
FileNotFoundError: If local_path doesn't exist
|
|
52
|
+
ValueError: If local_path is not a directory
|
|
53
|
+
httpx.HTTPError: If API request fails
|
|
54
|
+
"""
|
|
55
|
+
if not local_path.exists():
|
|
56
|
+
raise FileNotFoundError(f"Path not found: {local_path}")
|
|
57
|
+
if not local_path.is_dir():
|
|
58
|
+
raise ValueError(f"Not a directory: {local_path}")
|
|
59
|
+
|
|
60
|
+
# Collect files and encode as base64
|
|
61
|
+
files = []
|
|
62
|
+
for file_path in local_path.rglob("*"):
|
|
63
|
+
if file_path.is_file():
|
|
64
|
+
relative_path = file_path.relative_to(local_path)
|
|
65
|
+
content = file_path.read_bytes()
|
|
66
|
+
files.append({
|
|
67
|
+
"path": str(relative_path),
|
|
68
|
+
"content": base64.b64encode(content).decode(),
|
|
69
|
+
})
|
|
70
|
+
|
|
71
|
+
# Build request
|
|
72
|
+
request_body = {
|
|
73
|
+
"files": files,
|
|
74
|
+
"workspace_name": workspace_name or local_path.name,
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
# Call API
|
|
78
|
+
api_url = get_api_url()
|
|
79
|
+
headers = _get_auth_headers()
|
|
80
|
+
with httpx.Client(timeout=60.0, headers=headers) as client:
|
|
81
|
+
response = client.post(f"{api_url}/v1/gpu/push", json=request_body)
|
|
82
|
+
response.raise_for_status()
|
|
83
|
+
data = response.json()
|
|
84
|
+
|
|
85
|
+
return PushResult(
|
|
86
|
+
workspace_id=data["workspace_id"],
|
|
87
|
+
workspace_path=data["workspace_path"],
|
|
88
|
+
files_uploaded=data["files_uploaded"],
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def _collect_files(local_path: Path) -> list[dict]:
|
|
93
|
+
"""Collect files from directory as base64-encoded dicts."""
|
|
94
|
+
files = []
|
|
95
|
+
for file_path in local_path.rglob("*"):
|
|
96
|
+
if file_path.is_file():
|
|
97
|
+
relative_path = file_path.relative_to(local_path)
|
|
98
|
+
content = file_path.read_bytes()
|
|
99
|
+
files.append({
|
|
100
|
+
"path": str(relative_path),
|
|
101
|
+
"content": base64.b64encode(content).decode(),
|
|
102
|
+
})
|
|
103
|
+
return files
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def run_command_stream(
|
|
107
|
+
command: str,
|
|
108
|
+
upload_dir: Path | None = None,
|
|
109
|
+
workspace_id: str | None = None,
|
|
110
|
+
gpu_id: int | None = None,
|
|
111
|
+
gpu_count: int = 1,
|
|
112
|
+
docker_image: str | None = None,
|
|
113
|
+
docker_entrypoint: str | None = None,
|
|
114
|
+
pull_image: bool = False,
|
|
115
|
+
require_hardware_counters: bool = False,
|
|
116
|
+
target: str | None = None,
|
|
117
|
+
) -> int:
|
|
118
|
+
"""Run command on GPU via wafer-api, streaming output.
|
|
119
|
+
|
|
120
|
+
Two modes (mutually exclusive):
|
|
121
|
+
- upload_dir: Upload files and run (stateless, high-level)
|
|
122
|
+
- workspace_id: Use existing workspace (low-level)
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
command: Command to execute inside container
|
|
126
|
+
upload_dir: Directory to upload (stateless mode)
|
|
127
|
+
workspace_id: Workspace ID from push (low-level mode)
|
|
128
|
+
gpu_id: GPU ID to use (optional)
|
|
129
|
+
gpu_count: Number of GPUs needed (1-8, default 1)
|
|
130
|
+
docker_image: Docker image override (optional)
|
|
131
|
+
docker_entrypoint: Docker entrypoint override (optional, e.g., "bash")
|
|
132
|
+
pull_image: Pull image if not available (optional, default False)
|
|
133
|
+
require_hardware_counters: Require baremetal for ncu profiling (optional)
|
|
134
|
+
target: Target name to use (optional, defaults to user's default)
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
Exit code (0 = success, non-zero = failure)
|
|
138
|
+
|
|
139
|
+
Raises:
|
|
140
|
+
httpx.HTTPError: If API request fails
|
|
141
|
+
"""
|
|
142
|
+
request_body: dict = {
|
|
143
|
+
"command": command,
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
# Add files or workspace_id (mutually exclusive)
|
|
147
|
+
if upload_dir is not None:
|
|
148
|
+
files = _collect_files(upload_dir)
|
|
149
|
+
request_body["files"] = files
|
|
150
|
+
request_body["workspace_name"] = upload_dir.name
|
|
151
|
+
elif workspace_id is not None:
|
|
152
|
+
request_body["workspace_id"] = workspace_id
|
|
153
|
+
# else: no files, no workspace (run command in temp workspace)
|
|
154
|
+
|
|
155
|
+
if gpu_id is not None:
|
|
156
|
+
request_body["gpu_id"] = gpu_id
|
|
157
|
+
if gpu_count > 1:
|
|
158
|
+
request_body["gpu_count"] = gpu_count
|
|
159
|
+
if docker_image is not None:
|
|
160
|
+
request_body["docker_image"] = docker_image
|
|
161
|
+
if docker_entrypoint is not None:
|
|
162
|
+
request_body["docker_entrypoint"] = docker_entrypoint
|
|
163
|
+
if pull_image:
|
|
164
|
+
request_body["pull_image"] = True
|
|
165
|
+
if require_hardware_counters:
|
|
166
|
+
request_body["require_hardware_counters"] = True
|
|
167
|
+
if target is not None:
|
|
168
|
+
request_body["target"] = target
|
|
169
|
+
|
|
170
|
+
api_url = get_api_url()
|
|
171
|
+
headers = _get_auth_headers()
|
|
172
|
+
exit_code = 0
|
|
173
|
+
|
|
174
|
+
with httpx.Client(timeout=None, headers=headers) as client: # No timeout for streaming
|
|
175
|
+
with client.stream("POST", f"{api_url}/v1/gpu/jobs", json=request_body) as response:
|
|
176
|
+
response.raise_for_status()
|
|
177
|
+
|
|
178
|
+
for line in response.iter_lines():
|
|
179
|
+
if not line:
|
|
180
|
+
continue
|
|
181
|
+
|
|
182
|
+
# Parse SSE format: "data: <content>"
|
|
183
|
+
if line.startswith("data: "):
|
|
184
|
+
content = line[6:] # Strip "data: " prefix
|
|
185
|
+
|
|
186
|
+
if content == "[DONE]":
|
|
187
|
+
break
|
|
188
|
+
elif content.startswith("[ERROR]"):
|
|
189
|
+
print(content[8:], file=sys.stderr) # Strip "[ERROR] "
|
|
190
|
+
exit_code = 1
|
|
191
|
+
break
|
|
192
|
+
else:
|
|
193
|
+
print(content)
|
|
194
|
+
|
|
195
|
+
return exit_code
|