hud-python 0.4.21__py3-none-any.whl → 0.4.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of hud-python might be problematic. Click here for more details.
- hud/agents/base.py +37 -37
- hud/agents/claude.py +11 -6
- hud/agents/grounded_openai.py +282 -0
- hud/agents/misc/response_agent.py +3 -2
- hud/agents/openai.py +2 -2
- hud/agents/openai_chat_generic.py +3 -1
- hud/agents/tests/test_client.py +6 -1
- hud/agents/tests/test_grounded_openai_agent.py +155 -0
- hud/cli/__init__.py +34 -24
- hud/cli/analyze.py +27 -26
- hud/cli/build.py +50 -46
- hud/cli/debug.py +7 -7
- hud/cli/dev.py +107 -99
- hud/cli/eval.py +33 -31
- hud/cli/hf.py +53 -53
- hud/cli/init.py +28 -28
- hud/cli/list_func.py +22 -22
- hud/cli/pull.py +36 -36
- hud/cli/push.py +76 -74
- hud/cli/remove.py +42 -40
- hud/cli/rl/__init__.py +2 -2
- hud/cli/rl/init.py +41 -41
- hud/cli/rl/pod.py +97 -91
- hud/cli/rl/ssh.py +42 -40
- hud/cli/rl/train.py +75 -73
- hud/cli/rl/utils.py +10 -10
- hud/cli/tests/test_analyze.py +1 -1
- hud/cli/tests/test_analyze_metadata.py +2 -2
- hud/cli/tests/test_pull.py +45 -45
- hud/cli/tests/test_push.py +31 -29
- hud/cli/tests/test_registry.py +15 -15
- hud/cli/utils/environment.py +11 -11
- hud/cli/utils/interactive.py +18 -18
- hud/cli/utils/logging.py +12 -12
- hud/cli/utils/metadata.py +12 -12
- hud/cli/utils/registry.py +5 -5
- hud/cli/utils/runner.py +23 -23
- hud/cli/utils/server.py +16 -16
- hud/settings.py +6 -0
- hud/shared/hints.py +7 -7
- hud/tools/executors/tests/test_base_executor.py +1 -1
- hud/tools/executors/xdo.py +1 -1
- hud/tools/grounding/__init__.py +13 -0
- hud/tools/grounding/config.py +54 -0
- hud/tools/grounding/grounded_tool.py +314 -0
- hud/tools/grounding/grounder.py +302 -0
- hud/tools/grounding/tests/__init__.py +1 -0
- hud/tools/grounding/tests/test_grounded_tool.py +196 -0
- hud/tools/tests/test_playwright_tool.py +1 -1
- hud/tools/tests/test_tools_init.py +1 -1
- hud/tools/tests/test_utils.py +2 -2
- hud/types.py +4 -4
- hud/utils/__init__.py +3 -3
- hud/utils/agent_factories.py +86 -0
- hud/utils/{design.py → hud_console.py} +39 -33
- hud/utils/pretty_errors.py +6 -6
- hud/utils/tests/test_version.py +1 -1
- hud/version.py +1 -1
- {hud_python-0.4.21.dist-info → hud_python-0.4.23.dist-info}/METADATA +3 -1
- {hud_python-0.4.21.dist-info → hud_python-0.4.23.dist-info}/RECORD +63 -54
- {hud_python-0.4.21.dist-info → hud_python-0.4.23.dist-info}/WHEEL +0 -0
- {hud_python-0.4.21.dist-info → hud_python-0.4.23.dist-info}/entry_points.txt +0 -0
- {hud_python-0.4.21.dist-info → hud_python-0.4.23.dist-info}/licenses/LICENSE +0 -0
hud/cli/rl/pod.py
CHANGED
|
@@ -15,11 +15,11 @@ from rich.live import Live
|
|
|
15
15
|
from rich.spinner import Spinner
|
|
16
16
|
|
|
17
17
|
from hud.settings import settings
|
|
18
|
-
from hud.utils.
|
|
18
|
+
from hud.utils.hud_console import HUDConsole
|
|
19
19
|
|
|
20
20
|
from .ssh import check_and_configure_ssh_key, connect_and_train
|
|
21
21
|
|
|
22
|
-
|
|
22
|
+
hud_console = HUDConsole()
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
def parse_gpu_config(gpus: str) -> tuple[int, str]:
|
|
@@ -29,7 +29,7 @@ def parse_gpu_config(gpus: str) -> tuple[int, str]:
|
|
|
29
29
|
try:
|
|
30
30
|
count = int(count_str)
|
|
31
31
|
except ValueError as e:
|
|
32
|
-
|
|
32
|
+
hud_console.error(f"Invalid GPU count: {count_str}")
|
|
33
33
|
raise typer.Exit(1) from e
|
|
34
34
|
else:
|
|
35
35
|
# Default to 1 GPU if no count specified
|
|
@@ -65,7 +65,7 @@ async def create_and_connect_prime_pod(
|
|
|
65
65
|
is_json_file: bool = False,
|
|
66
66
|
) -> None:
|
|
67
67
|
"""Create a Prime Intellect pod and connect to it for training."""
|
|
68
|
-
|
|
68
|
+
hud_console.section_title("🌐 Creating Prime Intellect Pod")
|
|
69
69
|
|
|
70
70
|
create_cmd = [
|
|
71
71
|
"prime",
|
|
@@ -79,8 +79,8 @@ async def create_and_connect_prime_pod(
|
|
|
79
79
|
pod_name,
|
|
80
80
|
]
|
|
81
81
|
|
|
82
|
-
|
|
83
|
-
|
|
82
|
+
hud_console.info(f"Creating pod: {pod_name}")
|
|
83
|
+
hud_console.info(f"GPU configuration: {gpu_count}x {gpu_type}")
|
|
84
84
|
|
|
85
85
|
# Check for global team config first
|
|
86
86
|
has_global_team = False
|
|
@@ -106,21 +106,21 @@ async def create_and_connect_prime_pod(
|
|
|
106
106
|
break
|
|
107
107
|
|
|
108
108
|
# Display automated selections
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
109
|
+
hud_console.info("")
|
|
110
|
+
hud_console.info("Automated selections:")
|
|
111
|
+
hud_console.info(" Provider: Will select from supported providers")
|
|
112
|
+
hud_console.info(" Disk: Default size")
|
|
113
|
+
hud_console.info(" Image: cuda_12_4_pytorch_2_5")
|
|
114
114
|
if team_id:
|
|
115
|
-
|
|
115
|
+
hud_console.info(f" Team: {team_id}")
|
|
116
116
|
elif has_global_team:
|
|
117
|
-
|
|
117
|
+
hud_console.info(" Team: Using pre-configured team")
|
|
118
118
|
else:
|
|
119
|
-
|
|
120
|
-
|
|
119
|
+
hud_console.info(" Team: Personal Account")
|
|
120
|
+
hud_console.info("")
|
|
121
121
|
|
|
122
122
|
# First, get the provider list by running the command with minimal input
|
|
123
|
-
|
|
123
|
+
hud_console.info("Checking available providers...")
|
|
124
124
|
|
|
125
125
|
# Run command with just a newline to see provider list
|
|
126
126
|
provider_check = subprocess.run( # noqa: S603, ASYNC221
|
|
@@ -156,7 +156,7 @@ async def create_and_connect_prime_pod(
|
|
|
156
156
|
for provider in supported_providers:
|
|
157
157
|
if provider in provider_map:
|
|
158
158
|
provider_choice = provider_map[provider]
|
|
159
|
-
|
|
159
|
+
hud_console.info(f"Selected provider: {provider} (option {provider_choice})")
|
|
160
160
|
break
|
|
161
161
|
|
|
162
162
|
# Build inputs step by step for clarity
|
|
@@ -164,39 +164,39 @@ async def create_and_connect_prime_pod(
|
|
|
164
164
|
image_choice = "7" # cuda_12_4_pytorch_2_5
|
|
165
165
|
|
|
166
166
|
# Log what we're doing
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
167
|
+
hud_console.debug("Pod creation configuration:")
|
|
168
|
+
hud_console.debug(f" Team ID provided: {team_id}")
|
|
169
|
+
hud_console.debug(f" Global team detected: {has_global_team}")
|
|
170
170
|
|
|
171
171
|
if team_id:
|
|
172
172
|
# Explicit team ID provided, select Custom Team ID (option 3)
|
|
173
173
|
team_choice = "3"
|
|
174
174
|
# Fixed: confirmation should be lowercase 'y'
|
|
175
175
|
inputs = f"{provider_choice}\n{disk_size}\n{image_choice}\n{team_choice}\n{team_id}\ny\n"
|
|
176
|
-
|
|
176
|
+
hud_console.debug(f" Using explicit team ID: option {team_choice} with ID {team_id}")
|
|
177
177
|
elif has_global_team:
|
|
178
178
|
# When team is pre-configured, it shows as option 2 - select it
|
|
179
179
|
team_choice = "2"
|
|
180
180
|
# Fixed: confirmation should be lowercase 'y' and come after team selection
|
|
181
181
|
inputs = f"{provider_choice}\n{disk_size}\n{image_choice}\n{team_choice}\ny\n"
|
|
182
|
-
|
|
182
|
+
hud_console.debug(f" Using pre-configured team: option {team_choice}")
|
|
183
183
|
else:
|
|
184
184
|
# Personal account (option 1) - just press enter to accept default [1]
|
|
185
185
|
inputs = (
|
|
186
186
|
f"{provider_choice}\n{disk_size}\n{image_choice}\n\ny\n" # Empty line for default [1]
|
|
187
187
|
)
|
|
188
|
-
|
|
188
|
+
hud_console.debug(" Using personal account: default option [1]")
|
|
189
189
|
|
|
190
|
-
|
|
190
|
+
hud_console.debug(
|
|
191
191
|
f" Input sequence: provider={provider_choice}, disk={disk_size or 'default'}, image={image_choice}, team={team_choice if 'team_choice' in locals() else 'default'}" # noqa: E501
|
|
192
192
|
)
|
|
193
193
|
|
|
194
194
|
# Show found providers
|
|
195
195
|
if provider_lines:
|
|
196
|
-
|
|
197
|
-
|
|
196
|
+
hud_console.info("")
|
|
197
|
+
hud_console.info("Found providers:")
|
|
198
198
|
for pl in provider_lines[:5]: # Show first 5
|
|
199
|
-
|
|
199
|
+
hud_console.info(f" {pl}")
|
|
200
200
|
|
|
201
201
|
try:
|
|
202
202
|
console = Console()
|
|
@@ -214,7 +214,7 @@ async def create_and_connect_prime_pod(
|
|
|
214
214
|
)
|
|
215
215
|
|
|
216
216
|
if result.returncode != 0:
|
|
217
|
-
|
|
217
|
+
hud_console.error("Failed to create pod")
|
|
218
218
|
|
|
219
219
|
# Parse output for better error reporting
|
|
220
220
|
output_lines = result.stdout.strip().split("\n") if result.stdout else []
|
|
@@ -222,60 +222,62 @@ async def create_and_connect_prime_pod(
|
|
|
222
222
|
# Look for provider prices
|
|
223
223
|
for line in output_lines:
|
|
224
224
|
if "$" in line and "/hr" in line:
|
|
225
|
-
|
|
225
|
+
hud_console.info(f"Provider option: {line.strip()}")
|
|
226
226
|
|
|
227
227
|
# Check for team selection error
|
|
228
228
|
if "invalid selection" in result.stdout.lower():
|
|
229
|
-
|
|
229
|
+
hud_console.error("Team selection failed")
|
|
230
230
|
# Find and display the team selection section
|
|
231
231
|
for i, line in enumerate(output_lines):
|
|
232
232
|
if "Select Team:" in line:
|
|
233
|
-
|
|
233
|
+
hud_console.info("Team selection options:")
|
|
234
234
|
# Show next few lines
|
|
235
235
|
for j in range(i, min(i + 6, len(output_lines))):
|
|
236
|
-
|
|
236
|
+
hud_console.info(f" {output_lines[j]}")
|
|
237
237
|
break
|
|
238
238
|
|
|
239
|
-
|
|
240
|
-
|
|
239
|
+
hud_console.info("")
|
|
240
|
+
hud_console.hint(
|
|
241
241
|
"The Prime CLI interface may have changed. Try running the command manually:"
|
|
242
242
|
)
|
|
243
|
-
|
|
243
|
+
hud_console.command_example(
|
|
244
244
|
f"prime pods create --gpu-type {gpu_type} --gpu-count {gpu_count} --name {pod_name}" # noqa: E501
|
|
245
245
|
)
|
|
246
246
|
|
|
247
247
|
# Show error details
|
|
248
248
|
if result.stderr:
|
|
249
|
-
|
|
249
|
+
hud_console.error("Error output:")
|
|
250
250
|
for line in result.stderr.strip().split("\n"):
|
|
251
|
-
|
|
251
|
+
hud_console.error(f" {line}")
|
|
252
252
|
|
|
253
253
|
# Show last part of stdout for context
|
|
254
254
|
if result.stdout:
|
|
255
|
-
|
|
255
|
+
hud_console.info("Command output:")
|
|
256
256
|
# Show last 15 lines for brevity
|
|
257
257
|
for line in output_lines[-15:]:
|
|
258
|
-
|
|
258
|
+
hud_console.info(f" {line}")
|
|
259
259
|
|
|
260
260
|
if "max_price" in str(result.stderr) or "max_price" in str(result.stdout):
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
261
|
+
hud_console.warning("")
|
|
262
|
+
hud_console.warning("The selected provider requires a maximum price limit.")
|
|
263
|
+
hud_console.info("This is a known limitation with some providers.")
|
|
264
|
+
hud_console.info("")
|
|
265
|
+
hud_console.hint("Workarounds:")
|
|
266
|
+
hud_console.info("1. Run the command manually and select a different provider")
|
|
267
|
+
hud_console.info(
|
|
268
|
+
"2. Try again later when datacrunch (usually cheapest) is available"
|
|
269
|
+
)
|
|
270
|
+
hud_console.info("3. Use the Prime web interface: https://app.primeintellect.ai")
|
|
271
|
+
|
|
272
|
+
hud_console.info("")
|
|
273
|
+
hud_console.info("Debug info:")
|
|
274
|
+
hud_console.info(f" Command: {' '.join(create_cmd)}")
|
|
275
|
+
hud_console.info(f" Pod name: {pod_name}")
|
|
276
|
+
hud_console.info(f" Team ID: {'Provided' if team_id else 'Not provided'}")
|
|
277
|
+
hud_console.info(f" Global team detected: {has_global_team}")
|
|
276
278
|
|
|
277
279
|
# Show the exact inputs we sent
|
|
278
|
-
|
|
280
|
+
hud_console.info(" Inputs sent (in order):")
|
|
279
281
|
input_parts = inputs.strip().split("\n")
|
|
280
282
|
input_labels = [
|
|
281
283
|
"Provider selection",
|
|
@@ -287,9 +289,9 @@ async def create_and_connect_prime_pod(
|
|
|
287
289
|
]
|
|
288
290
|
for i, (part, label) in enumerate(zip(input_parts, input_labels, strict=False)):
|
|
289
291
|
if part:
|
|
290
|
-
|
|
292
|
+
hud_console.info(f" {i + 1}. {label}: '{part}'")
|
|
291
293
|
else:
|
|
292
|
-
|
|
294
|
+
hud_console.info(f" {i + 1}. {label}: [Enter/default]")
|
|
293
295
|
|
|
294
296
|
raise typer.Exit(1)
|
|
295
297
|
|
|
@@ -304,18 +306,18 @@ async def create_and_connect_prime_pod(
|
|
|
304
306
|
break
|
|
305
307
|
|
|
306
308
|
if not pod_id:
|
|
307
|
-
|
|
308
|
-
|
|
309
|
+
hud_console.error("Could not extract pod ID from output")
|
|
310
|
+
hud_console.info(f"Output: {result.stdout}")
|
|
309
311
|
raise typer.Exit(1)
|
|
310
312
|
|
|
311
|
-
|
|
313
|
+
hud_console.success(f"Created pod: {pod_id}")
|
|
312
314
|
|
|
313
315
|
# Poll for pod status
|
|
314
316
|
ssh_info = await poll_pod_status(pod_id)
|
|
315
317
|
|
|
316
318
|
if ssh_info:
|
|
317
|
-
|
|
318
|
-
|
|
319
|
+
hud_console.success("Pod is ready!")
|
|
320
|
+
hud_console.info(f"SSH: {ssh_info}")
|
|
319
321
|
|
|
320
322
|
# Check if SSH key is configured globally
|
|
321
323
|
ssh_key_configured = await check_and_configure_ssh_key()
|
|
@@ -335,25 +337,27 @@ async def create_and_connect_prime_pod(
|
|
|
335
337
|
)
|
|
336
338
|
else:
|
|
337
339
|
# Manual fallback
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
340
|
+
hud_console.section_title("📋 Manual Connection Required")
|
|
341
|
+
hud_console.info("SSH key configuration failed. Connect manually:")
|
|
342
|
+
hud_console.info("")
|
|
343
|
+
hud_console.info("1. Download the SSH key from:")
|
|
344
|
+
hud_console.info(" https://app.primeintellect.ai/dashboard/profile")
|
|
345
|
+
hud_console.info("")
|
|
346
|
+
hud_console.info("2. Set permissions:")
|
|
347
|
+
hud_console.command_example("chmod 400 /path/to/prime-key.pem", "")
|
|
348
|
+
hud_console.info("")
|
|
349
|
+
hud_console.info("3. Connect to your instance:")
|
|
350
|
+
hud_console.command_example(f"ssh -i /path/to/prime-key.pem {ssh_info}", "")
|
|
351
|
+
hud_console.info("")
|
|
352
|
+
hud_console.info("4. Run these commands:")
|
|
353
|
+
hud_console.command_example("pip install verifiers hud-vf-gym", "")
|
|
354
|
+
hud_console.command_example(f"prime env install {image}", "")
|
|
353
355
|
|
|
354
356
|
# Build training command with env vars
|
|
355
357
|
if settings.wandb_api_key:
|
|
356
|
-
|
|
358
|
+
hud_console.command_example(
|
|
359
|
+
f"export WANDB_API_KEY={settings.wandb_api_key}", ""
|
|
360
|
+
)
|
|
357
361
|
|
|
358
362
|
train_cmd = f"""vf-train hud-vf-gym \\
|
|
359
363
|
--model {model} \\
|
|
@@ -362,15 +366,17 @@ async def create_and_connect_prime_pod(
|
|
|
362
366
|
--run-name hud-rl-{pod_id[:8]} \\
|
|
363
367
|
--wandb-project hud-rl"""
|
|
364
368
|
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
369
|
+
hud_console.command_example(train_cmd, "")
|
|
370
|
+
hud_console.info("")
|
|
371
|
+
hud_console.warning(
|
|
372
|
+
f"Remember to terminate when done: prime pods terminate {pod_id}"
|
|
373
|
+
)
|
|
368
374
|
else:
|
|
369
|
-
|
|
375
|
+
hud_console.error("Pod failed to become active")
|
|
370
376
|
raise typer.Exit(1)
|
|
371
377
|
|
|
372
378
|
except subprocess.CalledProcessError as e:
|
|
373
|
-
|
|
379
|
+
hud_console.error(f"Failed to create pod: {e}")
|
|
374
380
|
raise typer.Exit(1) from e
|
|
375
381
|
|
|
376
382
|
|
|
@@ -433,7 +439,7 @@ async def poll_pod_status(pod_id: str) -> str | None:
|
|
|
433
439
|
if ssh_value and ssh_value.strip() and ssh_value.strip() != "N/A":
|
|
434
440
|
# Stop the spinner before logging
|
|
435
441
|
live.stop()
|
|
436
|
-
|
|
442
|
+
hud_console.success(f"SSH is available: {ssh_value}")
|
|
437
443
|
return ssh_value
|
|
438
444
|
|
|
439
445
|
time.sleep(10) # Wait 10 seconds # noqa: ASYNC251
|
|
@@ -444,8 +450,8 @@ async def poll_pod_status(pod_id: str) -> str | None:
|
|
|
444
450
|
time.sleep(10) # noqa: ASYNC251
|
|
445
451
|
attempt += 1
|
|
446
452
|
|
|
447
|
-
# Spinner is done, now we can use
|
|
448
|
-
|
|
453
|
+
# Spinner is done, now we can use hud_console.error
|
|
454
|
+
hud_console.error("Timeout: Pod did not become ready within 20 minutes")
|
|
449
455
|
return None
|
|
450
456
|
|
|
451
457
|
|
|
@@ -464,11 +470,11 @@ async def run_prime_training(
|
|
|
464
470
|
"""Run training on Prime Intellect infrastructure."""
|
|
465
471
|
# Check API key
|
|
466
472
|
if not settings.prime_api_key:
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
473
|
+
hud_console.error("Prime API key not found")
|
|
474
|
+
hud_console.info("Set your Prime API key:")
|
|
475
|
+
hud_console.info(" export PRIME_API_KEY='your-api-key'")
|
|
476
|
+
hud_console.info(" # or")
|
|
477
|
+
hud_console.info(" prime auth")
|
|
472
478
|
raise typer.Exit(1)
|
|
473
479
|
|
|
474
480
|
# Parse GPU configuration
|
hud/cli/rl/ssh.py
CHANGED
|
@@ -9,9 +9,9 @@ from pathlib import Path
|
|
|
9
9
|
import typer
|
|
10
10
|
|
|
11
11
|
from hud.settings import settings
|
|
12
|
-
from hud.utils.
|
|
12
|
+
from hud.utils.hud_console import HUDConsole
|
|
13
13
|
|
|
14
|
-
|
|
14
|
+
hud_console = HUDConsole()
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
async def check_and_configure_ssh_key() -> bool:
|
|
@@ -48,32 +48,32 @@ async def check_and_configure_ssh_key() -> bool:
|
|
|
48
48
|
# If SSH key is configured, verify it exists
|
|
49
49
|
if ssh_key_path:
|
|
50
50
|
if Path(ssh_key_path).expanduser().exists():
|
|
51
|
-
|
|
51
|
+
hud_console.info(f"Using configured SSH key: {ssh_key_path}")
|
|
52
52
|
return True
|
|
53
53
|
else:
|
|
54
|
-
|
|
54
|
+
hud_console.warning(f"Configured SSH key not found: {ssh_key_path}")
|
|
55
55
|
|
|
56
56
|
# Prompt for SSH key
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
57
|
+
hud_console.section_title("🔑 SSH Key Configuration")
|
|
58
|
+
hud_console.info("Prime Intellect requires an SSH key for pod access.")
|
|
59
|
+
hud_console.info("")
|
|
60
|
+
hud_console.info("If you don't have a key:")
|
|
61
|
+
hud_console.info("1. Visit https://app.primeintellect.ai/dashboard/profile")
|
|
62
|
+
hud_console.info("2. Generate or upload your SSH key")
|
|
63
|
+
hud_console.info("3. Download the private key file")
|
|
64
|
+
hud_console.info("")
|
|
65
65
|
|
|
66
66
|
key_path = typer.prompt("Enter path to your Prime SSH private key (e.g., ~/.ssh/prime-key.pem)")
|
|
67
67
|
key_path = Path(key_path).expanduser()
|
|
68
68
|
|
|
69
69
|
if not key_path.exists():
|
|
70
|
-
|
|
70
|
+
hud_console.error(f"File not found: {key_path}")
|
|
71
71
|
return False
|
|
72
72
|
|
|
73
73
|
# Set permissions if not Windows
|
|
74
74
|
if os.name != "nt":
|
|
75
75
|
subprocess.run(["chmod", "400", str(key_path)]) # noqa: S603, S607, ASYNC221
|
|
76
|
-
|
|
76
|
+
hud_console.success("Set proper permissions on key file")
|
|
77
77
|
|
|
78
78
|
# Configure the SSH key globally
|
|
79
79
|
result = subprocess.run( # noqa: S603, ASYNC221
|
|
@@ -83,12 +83,12 @@ async def check_and_configure_ssh_key() -> bool:
|
|
|
83
83
|
)
|
|
84
84
|
|
|
85
85
|
if result.returncode == 0:
|
|
86
|
-
|
|
86
|
+
hud_console.success("SSH key configured successfully")
|
|
87
87
|
return True
|
|
88
88
|
else:
|
|
89
|
-
|
|
89
|
+
hud_console.error("Failed to configure SSH key")
|
|
90
90
|
if result.stderr:
|
|
91
|
-
|
|
91
|
+
hud_console.error(f"Error: {result.stderr}")
|
|
92
92
|
return False
|
|
93
93
|
|
|
94
94
|
|
|
@@ -104,7 +104,7 @@ async def connect_and_train(
|
|
|
104
104
|
is_json_file: bool = False,
|
|
105
105
|
) -> None:
|
|
106
106
|
"""Connect to the pod via SSH and run training commands."""
|
|
107
|
-
|
|
107
|
+
hud_console.section_title("🚀 Starting Remote Training")
|
|
108
108
|
|
|
109
109
|
# Parse SSH info to get host and port
|
|
110
110
|
# Format is like "root@65.108.33.78 -p 1234"
|
|
@@ -135,19 +135,19 @@ async def connect_and_train(
|
|
|
135
135
|
break
|
|
136
136
|
|
|
137
137
|
if not ssh_key_path:
|
|
138
|
-
|
|
138
|
+
hud_console.error("SSH key path not configured")
|
|
139
139
|
raise typer.Exit(1)
|
|
140
140
|
|
|
141
141
|
# Verify SSH key exists
|
|
142
142
|
ssh_key_path = Path(ssh_key_path).expanduser()
|
|
143
143
|
if not ssh_key_path.exists():
|
|
144
|
-
|
|
144
|
+
hud_console.error(f"SSH key not found: {ssh_key_path}")
|
|
145
145
|
raise typer.Exit(1)
|
|
146
146
|
|
|
147
|
-
|
|
147
|
+
hud_console.info(f"Using SSH key: {ssh_key_path}")
|
|
148
148
|
|
|
149
149
|
# First, copy the config file to the pod using scp
|
|
150
|
-
|
|
150
|
+
hud_console.info("Copying config file to pod...")
|
|
151
151
|
try:
|
|
152
152
|
# On Windows, we need to ensure proper path formatting
|
|
153
153
|
config_path = str(config).replace("\\", "/")
|
|
@@ -164,22 +164,24 @@ async def connect_and_train(
|
|
|
164
164
|
config_path,
|
|
165
165
|
f"{ssh_user_host}:/root/config.yaml",
|
|
166
166
|
]
|
|
167
|
-
|
|
167
|
+
hud_console.debug(f"Running: {' '.join(scp_cmd)}")
|
|
168
168
|
subprocess.run(scp_cmd, check=True) # noqa: S603, ASYNC221
|
|
169
|
-
|
|
169
|
+
hud_console.success("Config file copied")
|
|
170
170
|
except subprocess.CalledProcessError as e:
|
|
171
|
-
|
|
171
|
+
hud_console.error(f"Failed to copy config file: {e}")
|
|
172
172
|
if os.name == "nt": # Windows
|
|
173
|
-
|
|
174
|
-
|
|
173
|
+
hud_console.info("Make sure OpenSSH is installed. On Windows 10+, it's built-in.")
|
|
174
|
+
hud_console.info(
|
|
175
|
+
"If using older Windows, install Git for Windows which includes SSH/SCP."
|
|
176
|
+
)
|
|
175
177
|
else:
|
|
176
|
-
|
|
178
|
+
hud_console.info("Make sure scp is installed and in your PATH")
|
|
177
179
|
raise typer.Exit(1) from e
|
|
178
180
|
|
|
179
181
|
# If dataset is a JSON file, copy it too
|
|
180
182
|
remote_dataset = dataset # Default to unchanged
|
|
181
183
|
if is_json_file:
|
|
182
|
-
|
|
184
|
+
hud_console.info("Copying task file to pod...")
|
|
183
185
|
try:
|
|
184
186
|
# On Windows, we need to ensure proper path formatting
|
|
185
187
|
dataset_path = str(dataset).replace("\\", "/")
|
|
@@ -200,16 +202,16 @@ async def connect_and_train(
|
|
|
200
202
|
dataset_path,
|
|
201
203
|
f"{ssh_user_host}:{remote_dataset}",
|
|
202
204
|
]
|
|
203
|
-
|
|
205
|
+
hud_console.debug(f"Running: {' '.join(scp_cmd)}")
|
|
204
206
|
subprocess.run(scp_cmd, check=True) # noqa: S603, ASYNC221
|
|
205
|
-
|
|
207
|
+
hud_console.success(f"Task file copied to {remote_dataset}")
|
|
206
208
|
except subprocess.CalledProcessError as e:
|
|
207
|
-
|
|
209
|
+
hud_console.error(f"Failed to copy task file: {e}")
|
|
208
210
|
raise typer.Exit(1) from e
|
|
209
211
|
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
212
|
+
hud_console.info("Setting up environment and starting training...")
|
|
213
|
+
hud_console.info("This will take a few minutes for initial setup, then training will begin.")
|
|
214
|
+
hud_console.info("")
|
|
213
215
|
|
|
214
216
|
# Build environment exports
|
|
215
217
|
env_exports = []
|
|
@@ -311,10 +313,10 @@ async def connect_and_train(
|
|
|
311
313
|
subprocess.run(ssh_cmd, check=True) # noqa: S603, ASYNC221
|
|
312
314
|
|
|
313
315
|
except subprocess.CalledProcessError as e:
|
|
314
|
-
|
|
316
|
+
hud_console.error(f"Training failed: {e}")
|
|
315
317
|
raise typer.Exit(1) from e
|
|
316
318
|
except KeyboardInterrupt:
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
319
|
+
hud_console.warning("Training interrupted by user")
|
|
320
|
+
hud_console.info(f"To reconnect: prime pods ssh {pod_id}")
|
|
321
|
+
hud_console.info(f"To check status: prime pods status {pod_id}")
|
|
322
|
+
hud_console.info(f"To terminate: prime pods terminate {pod_id}")
|