hud-python 0.4.35__py3-none-any.whl → 0.4.37__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of hud-python might be problematic. Click here for more details.
- hud/agents/__init__.py +2 -0
- hud/agents/lite_llm.py +72 -0
- hud/agents/openai_chat_generic.py +21 -7
- hud/agents/tests/test_claude.py +32 -7
- hud/agents/tests/test_openai.py +29 -6
- hud/cli/__init__.py +228 -79
- hud/cli/build.py +26 -6
- hud/cli/dev.py +21 -40
- hud/cli/eval.py +96 -15
- hud/cli/flows/tasks.py +198 -65
- hud/cli/init.py +222 -629
- hud/cli/pull.py +6 -0
- hud/cli/push.py +11 -1
- hud/cli/rl/__init__.py +14 -4
- hud/cli/rl/celebrate.py +187 -0
- hud/cli/rl/config.py +15 -8
- hud/cli/rl/local_runner.py +44 -20
- hud/cli/rl/remote_runner.py +166 -87
- hud/cli/rl/viewer.py +141 -0
- hud/cli/rl/wait_utils.py +89 -0
- hud/cli/tests/test_build.py +3 -27
- hud/cli/tests/test_mcp_server.py +1 -12
- hud/cli/utils/config.py +85 -0
- hud/cli/utils/docker.py +21 -39
- hud/cli/utils/env_check.py +196 -0
- hud/cli/utils/environment.py +4 -3
- hud/cli/utils/interactive.py +2 -1
- hud/cli/utils/local_runner.py +204 -0
- hud/cli/utils/metadata.py +3 -1
- hud/cli/utils/package_runner.py +292 -0
- hud/cli/utils/remote_runner.py +4 -1
- hud/cli/utils/source_hash.py +108 -0
- hud/clients/base.py +1 -1
- hud/clients/fastmcp.py +1 -1
- hud/clients/mcp_use.py +30 -7
- hud/datasets/parallel.py +3 -1
- hud/datasets/runner.py +4 -1
- hud/otel/config.py +1 -1
- hud/otel/context.py +40 -6
- hud/rl/buffer.py +3 -0
- hud/rl/tests/test_learner.py +1 -1
- hud/rl/vllm_adapter.py +1 -1
- hud/server/server.py +234 -7
- hud/server/tests/test_add_tool.py +60 -0
- hud/server/tests/test_context.py +128 -0
- hud/server/tests/test_mcp_server_handlers.py +44 -0
- hud/server/tests/test_mcp_server_integration.py +405 -0
- hud/server/tests/test_mcp_server_more.py +247 -0
- hud/server/tests/test_run_wrapper.py +53 -0
- hud/server/tests/test_server_extra.py +166 -0
- hud/server/tests/test_sigterm_runner.py +78 -0
- hud/settings.py +38 -0
- hud/shared/hints.py +2 -2
- hud/telemetry/job.py +2 -2
- hud/types.py +9 -2
- hud/utils/tasks.py +32 -24
- hud/utils/tests/test_version.py +1 -1
- hud/version.py +1 -1
- {hud_python-0.4.35.dist-info → hud_python-0.4.37.dist-info}/METADATA +43 -23
- {hud_python-0.4.35.dist-info → hud_python-0.4.37.dist-info}/RECORD +63 -46
- {hud_python-0.4.35.dist-info → hud_python-0.4.37.dist-info}/WHEEL +0 -0
- {hud_python-0.4.35.dist-info → hud_python-0.4.37.dist-info}/entry_points.txt +0 -0
- {hud_python-0.4.35.dist-info → hud_python-0.4.37.dist-info}/licenses/LICENSE +0 -0
hud/cli/pull.py
CHANGED
|
@@ -154,6 +154,9 @@ def pull_environment(
|
|
|
154
154
|
# Check for API key (not required for pulling, but good to inform)
|
|
155
155
|
if not settings.api_key:
|
|
156
156
|
hud_console.info("No HUD API key set (pulling from public registry)")
|
|
157
|
+
hud_console.info(
|
|
158
|
+
"Set it in your environment or run: hud set HUD_API_KEY=your-key-here"
|
|
159
|
+
)
|
|
157
160
|
|
|
158
161
|
lock_data = fetch_lock_from_registry(target)
|
|
159
162
|
|
|
@@ -166,6 +169,9 @@ def pull_environment(
|
|
|
166
169
|
hud_console.info(
|
|
167
170
|
"Not found in HUD registry (try setting HUD_API_KEY for private environments)" # noqa: E501
|
|
168
171
|
)
|
|
172
|
+
hud_console.info(
|
|
173
|
+
"Set it in your environment or run: hud set HUD_API_KEY=your-key-here"
|
|
174
|
+
)
|
|
169
175
|
else:
|
|
170
176
|
hud_console.info("Not found in HUD registry, treating as Docker image")
|
|
171
177
|
|
hud/cli/push.py
CHANGED
|
@@ -11,6 +11,7 @@ import requests
|
|
|
11
11
|
import typer
|
|
12
12
|
import yaml
|
|
13
13
|
|
|
14
|
+
from hud.cli.utils.env_check import ensure_built
|
|
14
15
|
from hud.utils.hud_console import HUDConsole
|
|
15
16
|
|
|
16
17
|
|
|
@@ -131,6 +132,14 @@ def push_environment(
|
|
|
131
132
|
|
|
132
133
|
# Find hud.lock.yaml in specified directory
|
|
133
134
|
env_dir = Path(directory)
|
|
135
|
+
|
|
136
|
+
# Ensure environment is built and up-to-date (hash-based); interactive prompt
|
|
137
|
+
try:
|
|
138
|
+
ensure_built(env_dir, interactive=True)
|
|
139
|
+
except typer.Exit:
|
|
140
|
+
raise
|
|
141
|
+
except Exception as e:
|
|
142
|
+
HUDConsole().debug(f"Skipping pre-push build check: {e}")
|
|
134
143
|
lock_path = env_dir / "hud.lock.yaml"
|
|
135
144
|
|
|
136
145
|
if not lock_path.exists():
|
|
@@ -144,7 +153,7 @@ def push_environment(
|
|
|
144
153
|
hud_console.warning("A HUD API key is required to push environments.")
|
|
145
154
|
hud_console.info("\nTo get started:")
|
|
146
155
|
hud_console.info("1. Get your API key at: https://hud.so/settings")
|
|
147
|
-
hud_console.
|
|
156
|
+
hud_console.info("Set it in your environment or run: hud set HUD_API_KEY=your-key-here")
|
|
148
157
|
hud_console.command_example("hud push", "Try again")
|
|
149
158
|
hud_console.info("")
|
|
150
159
|
raise typer.Exit(1)
|
|
@@ -414,6 +423,7 @@ def push_environment(
|
|
|
414
423
|
hud_console.error("Authentication failed")
|
|
415
424
|
hud_console.info("Check your HUD_API_KEY is valid")
|
|
416
425
|
hud_console.info("Get a new key at: https://hud.so/settings")
|
|
426
|
+
hud_console.info("Set it in your environment or run: hud set HUD_API_KEY=your-key-here")
|
|
417
427
|
elif response.status_code == 403:
|
|
418
428
|
hud_console.error("Permission denied")
|
|
419
429
|
hud_console.info("You may not have access to push to this namespace")
|
hud/cli/rl/__init__.py
CHANGED
|
@@ -25,7 +25,7 @@ def rl_command(
|
|
|
25
25
|
),
|
|
26
26
|
model: str | None = typer.Argument(
|
|
27
27
|
None,
|
|
28
|
-
help="Model to train (default: interactive selection)",
|
|
28
|
+
help="Model to train from https://hud.so/models (default: interactive selection)",
|
|
29
29
|
),
|
|
30
30
|
config_file: Path | None = typer.Option( # noqa: B008
|
|
31
31
|
None,
|
|
@@ -72,6 +72,12 @@ def rl_command(
|
|
|
72
72
|
"--local",
|
|
73
73
|
help="Run training locally instead of using remote API server",
|
|
74
74
|
),
|
|
75
|
+
yes: bool = typer.Option(
|
|
76
|
+
False,
|
|
77
|
+
"--yes",
|
|
78
|
+
"-y",
|
|
79
|
+
help="Auto-accept all prompts and use defaults (lazy mode)",
|
|
80
|
+
),
|
|
75
81
|
# Internal flag
|
|
76
82
|
skip_vllm_startup: bool = typer.Option(
|
|
77
83
|
False,
|
|
@@ -122,8 +128,7 @@ def rl_command(
|
|
|
122
128
|
try:
|
|
123
129
|
from hud.cli.flows.tasks import convert_tasks_to_remote
|
|
124
130
|
|
|
125
|
-
console.print("
|
|
126
|
-
console.print("[cyan](build/push if needed)[/cyan]")
|
|
131
|
+
console.print("[cyan]Preparing remote training tasks...[/cyan]")
|
|
127
132
|
tasks_file = convert_tasks_to_remote(tasks_file)
|
|
128
133
|
except typer.Exit:
|
|
129
134
|
raise
|
|
@@ -137,7 +142,11 @@ def rl_command(
|
|
|
137
142
|
from .remote_runner import run_remote_training
|
|
138
143
|
|
|
139
144
|
run_remote_training(
|
|
140
|
-
tasks_file=tasks_file,
|
|
145
|
+
tasks_file=tasks_file,
|
|
146
|
+
model=model,
|
|
147
|
+
config_file=config_file,
|
|
148
|
+
output_dir=output_dir,
|
|
149
|
+
yes=yes,
|
|
141
150
|
)
|
|
142
151
|
return
|
|
143
152
|
except Exception as e:
|
|
@@ -152,6 +161,7 @@ def rl_command(
|
|
|
152
161
|
model=model,
|
|
153
162
|
config_file=config_file,
|
|
154
163
|
output_dir=output_dir,
|
|
164
|
+
yes=yes,
|
|
155
165
|
restart=restart,
|
|
156
166
|
verbose=verbose,
|
|
157
167
|
no_ddp=no_ddp,
|
hud/cli/rl/celebrate.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
# ruff: noqa: S311
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import random
|
|
5
|
+
import time
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from typing import TYPE_CHECKING, ClassVar
|
|
8
|
+
|
|
9
|
+
from rich.live import Live
|
|
10
|
+
from rich.text import Text
|
|
11
|
+
|
|
12
|
+
from hud.utils.hud_console import hud_console
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from rich.console import Console
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class Particle:
|
|
20
|
+
"""A confetti particle with physics."""
|
|
21
|
+
|
|
22
|
+
x: float
|
|
23
|
+
y: float
|
|
24
|
+
vx: float # velocity x
|
|
25
|
+
vy: float # velocity y
|
|
26
|
+
char: str
|
|
27
|
+
color: str
|
|
28
|
+
|
|
29
|
+
def update(self, gravity: float = 0.5, fps: float = 30.0) -> None:
|
|
30
|
+
"""Update particle position and velocity."""
|
|
31
|
+
dt = 1.0 / fps
|
|
32
|
+
self.x += self.vx * dt
|
|
33
|
+
self.vy += gravity # Apply gravity
|
|
34
|
+
self.y += self.vy * dt
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class ConfettiSystem:
|
|
38
|
+
"""Minimal confetti system inspired by confetty."""
|
|
39
|
+
|
|
40
|
+
# Confetty-style colors
|
|
41
|
+
COLORS: ClassVar[list[str]] = ["#a864fd", "#29cdff", "#78ff44", "#ff718d", "#fdff6a"]
|
|
42
|
+
# Confetty-style characters
|
|
43
|
+
CHARS: ClassVar[list[str]] = ["█", "▓", "▒", "░", "▄", "▀"]
|
|
44
|
+
|
|
45
|
+
def __init__(self, width: int, height: int) -> None:
|
|
46
|
+
self.width = width
|
|
47
|
+
self.height = height
|
|
48
|
+
self.particles: list[Particle] = []
|
|
49
|
+
|
|
50
|
+
def spawn_burst(self, num_particles: int = 75) -> None:
|
|
51
|
+
"""Spawn a burst of confetti particles from the top center."""
|
|
52
|
+
center_x = self.width / 2
|
|
53
|
+
|
|
54
|
+
for _ in range(num_particles):
|
|
55
|
+
# Start from top center with some horizontal spread
|
|
56
|
+
x = center_x + (self.width / 4) * (random.random() - 0.5)
|
|
57
|
+
y = 0
|
|
58
|
+
|
|
59
|
+
# Random velocities - horizontal spread and upward/slight downward initial velocity
|
|
60
|
+
vx = (random.random() - 0.5) * 100
|
|
61
|
+
vy = random.random() * 50 - 25 # Some go up first
|
|
62
|
+
|
|
63
|
+
particle = Particle(
|
|
64
|
+
x=x,
|
|
65
|
+
y=y,
|
|
66
|
+
vx=vx,
|
|
67
|
+
vy=vy,
|
|
68
|
+
char=random.choice(self.CHARS),
|
|
69
|
+
color=random.choice(self.COLORS),
|
|
70
|
+
)
|
|
71
|
+
self.particles.append(particle)
|
|
72
|
+
|
|
73
|
+
def update(self) -> None:
|
|
74
|
+
"""Update all particles and remove off-screen ones."""
|
|
75
|
+
# Update physics
|
|
76
|
+
for particle in self.particles:
|
|
77
|
+
particle.update()
|
|
78
|
+
|
|
79
|
+
# Remove particles that are off-screen
|
|
80
|
+
self.particles = [p for p in self.particles if 0 <= p.x < self.width and p.y < self.height]
|
|
81
|
+
|
|
82
|
+
def render(self) -> str:
|
|
83
|
+
"""Render the particle system to a string."""
|
|
84
|
+
# Create empty grid
|
|
85
|
+
grid = [[" " for _ in range(self.width)] for _ in range(self.height)]
|
|
86
|
+
|
|
87
|
+
# Place particles
|
|
88
|
+
for particle in self.particles:
|
|
89
|
+
x, y = int(particle.x), int(particle.y)
|
|
90
|
+
if 0 <= x < self.width and 0 <= y < self.height:
|
|
91
|
+
grid[y][x] = particle.char
|
|
92
|
+
|
|
93
|
+
# Convert to string
|
|
94
|
+
return "\n".join("".join(row) for row in grid)
|
|
95
|
+
|
|
96
|
+
def render_with_colors(self) -> Text:
|
|
97
|
+
"""Render the particle system with colors for Rich."""
|
|
98
|
+
text = Text()
|
|
99
|
+
|
|
100
|
+
# Create empty grid with color info
|
|
101
|
+
grid: list[list[tuple[str, str] | None]] = [
|
|
102
|
+
[None for _ in range(self.width)] for _ in range(self.height)
|
|
103
|
+
]
|
|
104
|
+
|
|
105
|
+
# Place particles with their colors
|
|
106
|
+
for particle in self.particles:
|
|
107
|
+
x, y = int(particle.x), int(particle.y)
|
|
108
|
+
if 0 <= x < self.width and 0 <= y < self.height:
|
|
109
|
+
grid[y][x] = (particle.char, particle.color)
|
|
110
|
+
|
|
111
|
+
# Build colored text
|
|
112
|
+
for row in grid:
|
|
113
|
+
for cell in row:
|
|
114
|
+
if cell:
|
|
115
|
+
char, color = cell
|
|
116
|
+
text.append(char, style=color)
|
|
117
|
+
else:
|
|
118
|
+
text.append(" ")
|
|
119
|
+
text.append("\n")
|
|
120
|
+
|
|
121
|
+
return text
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def show_confetti(console: Console, seconds: float = 2.5) -> None:
|
|
125
|
+
"""Display celebratory confetti animation inspired by confetty.
|
|
126
|
+
|
|
127
|
+
Shows "Starting training!" message first, then creates two bursts of
|
|
128
|
+
falling confetti particles that fall away completely.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
console: Rich console instance
|
|
132
|
+
seconds: Duration to show confetti
|
|
133
|
+
"""
|
|
134
|
+
# Show celebratory message first
|
|
135
|
+
console.print(
|
|
136
|
+
"[bold green]🎉 Starting training! See your model on https://hud.so/models[/bold green]"
|
|
137
|
+
)
|
|
138
|
+
time.sleep(0.3) # Brief pause to see the message
|
|
139
|
+
|
|
140
|
+
width = min(console.size.width, 120) # Cap width for performance
|
|
141
|
+
height = min(console.size.height - 2, 30) # Leave room for message
|
|
142
|
+
|
|
143
|
+
# Create confetti system
|
|
144
|
+
system = ConfettiSystem(width, height)
|
|
145
|
+
|
|
146
|
+
fps = 30
|
|
147
|
+
frame_time = 1.0 / fps
|
|
148
|
+
|
|
149
|
+
# First burst at the beginning
|
|
150
|
+
system.spawn_burst(num_particles=60)
|
|
151
|
+
|
|
152
|
+
# Track when to spawn second burst
|
|
153
|
+
second_burst_frame = int(fps * 0.4) # Second burst after 0.4 seconds
|
|
154
|
+
|
|
155
|
+
with Live("", refresh_per_second=fps, console=console, transient=True) as live:
|
|
156
|
+
frame = 0
|
|
157
|
+
# Keep running until all particles have fallen off screen
|
|
158
|
+
while frame < seconds * fps or len(system.particles) > 0:
|
|
159
|
+
# Spawn second burst
|
|
160
|
+
if frame == second_burst_frame:
|
|
161
|
+
system.spawn_burst(num_particles=60)
|
|
162
|
+
|
|
163
|
+
system.update()
|
|
164
|
+
live.update(system.render_with_colors())
|
|
165
|
+
time.sleep(frame_time)
|
|
166
|
+
frame += 1
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def show_confetti_async(console: Console, seconds: float = 2.5) -> None:
|
|
170
|
+
"""Non-blocking confetti animation that runs in a background thread.
|
|
171
|
+
|
|
172
|
+
The animation will run independently while training starts immediately.
|
|
173
|
+
"""
|
|
174
|
+
import threading
|
|
175
|
+
|
|
176
|
+
def _run_confetti() -> None:
|
|
177
|
+
try:
|
|
178
|
+
show_confetti(console, seconds)
|
|
179
|
+
except Exception:
|
|
180
|
+
hud_console.info("Launching training...")
|
|
181
|
+
|
|
182
|
+
thread = threading.Thread(target=_run_confetti, daemon=True)
|
|
183
|
+
thread.start()
|
|
184
|
+
# Don't wait - let training start immediately while confetti plays
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
__all__ = ["show_confetti", "show_confetti_async"]
|
hud/cli/rl/config.py
CHANGED
|
@@ -21,22 +21,29 @@ console = Console()
|
|
|
21
21
|
def generate_config_interactive(
|
|
22
22
|
model_name: str,
|
|
23
23
|
presets: list[dict[str, Any]],
|
|
24
|
+
yes: bool = False,
|
|
24
25
|
) -> tuple[Config, float]:
|
|
25
26
|
"""Generate RL training configuration interactively."""
|
|
26
27
|
# Validate model is a VL model
|
|
27
28
|
validate_vl_model(model_name)
|
|
28
29
|
|
|
29
30
|
# Display preset options
|
|
30
|
-
|
|
31
|
+
if not yes:
|
|
32
|
+
display_preset_table(presets, 80.0) # Assuming A100 80GB
|
|
31
33
|
|
|
32
34
|
# Let user select preset
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
35
|
+
if yes:
|
|
36
|
+
# Use default preset (Balanced if available, otherwise first)
|
|
37
|
+
preset_choice = 1 if len(presets) > 1 else 0
|
|
38
|
+
selected_preset = presets[preset_choice]
|
|
39
|
+
hud_console.info(f"Auto-selecting preset: {selected_preset['name']} (--yes mode)")
|
|
40
|
+
else:
|
|
41
|
+
preset_choice = hud_console.select(
|
|
42
|
+
"Select a training configuration preset:",
|
|
43
|
+
choices=[{"name": p["name"], "value": i} for i, p in enumerate(presets)],
|
|
44
|
+
default=1 if len(presets) > 1 else 0, # Default to "Balanced" if available
|
|
45
|
+
)
|
|
46
|
+
selected_preset = presets[preset_choice] # type: ignore
|
|
40
47
|
|
|
41
48
|
# Use preset values directly
|
|
42
49
|
max_steps_per_episode = selected_preset["max_steps_per_episode"]
|
hud/cli/rl/local_runner.py
CHANGED
|
@@ -30,6 +30,7 @@ def run_local_training(
|
|
|
30
30
|
model: str | None,
|
|
31
31
|
config_file: Path | None,
|
|
32
32
|
output_dir: str,
|
|
33
|
+
yes: bool,
|
|
33
34
|
restart: bool,
|
|
34
35
|
verbose: bool,
|
|
35
36
|
no_ddp: bool,
|
|
@@ -63,8 +64,11 @@ def run_local_training(
|
|
|
63
64
|
try:
|
|
64
65
|
import typer
|
|
65
66
|
|
|
66
|
-
if not
|
|
67
|
-
|
|
67
|
+
if not yes:
|
|
68
|
+
if not typer.confirm("\nDo you want to continue anyway?", default=False):
|
|
69
|
+
raise typer.Exit(1)
|
|
70
|
+
else:
|
|
71
|
+
hud_console.warning("Auto-continuing despite Python 3.13+ (--yes mode)")
|
|
68
72
|
except Exception as e:
|
|
69
73
|
hud_console.warning(f"Failed to confirm: {e}")
|
|
70
74
|
return
|
|
@@ -113,7 +117,13 @@ def run_local_training(
|
|
|
113
117
|
try:
|
|
114
118
|
import typer
|
|
115
119
|
|
|
116
|
-
|
|
120
|
+
if yes:
|
|
121
|
+
continue_training = True
|
|
122
|
+
hud_console.info("Auto-continuing with healthy GPUs only (--yes mode)")
|
|
123
|
+
else:
|
|
124
|
+
continue_training = typer.confirm(
|
|
125
|
+
"\nContinue with healthy GPUs only?", default=True
|
|
126
|
+
)
|
|
117
127
|
except Exception:
|
|
118
128
|
continue_training = True
|
|
119
129
|
|
|
@@ -200,21 +210,25 @@ def run_local_training(
|
|
|
200
210
|
|
|
201
211
|
# Step 3: Model selection (if not provided)
|
|
202
212
|
if model is None and not config_file:
|
|
203
|
-
|
|
204
|
-
"
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
213
|
+
if yes:
|
|
214
|
+
model = "Qwen/Qwen2.5-VL-3B-Instruct" # Default model in yes mode
|
|
215
|
+
hud_console.info(f"Auto-selecting model: {model} (--yes mode)")
|
|
216
|
+
else:
|
|
217
|
+
model = hud_console.select(
|
|
218
|
+
"Select a model for RL training:",
|
|
219
|
+
choices=[
|
|
220
|
+
{
|
|
221
|
+
"name": "Qwen 2.5 VL 3B (Recommended - Vision-Language)",
|
|
222
|
+
"value": "Qwen/Qwen2.5-VL-3B-Instruct",
|
|
223
|
+
},
|
|
224
|
+
{"name": "Custom model", "value": "custom"},
|
|
225
|
+
],
|
|
226
|
+
default=0,
|
|
227
|
+
)
|
|
214
228
|
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
229
|
+
if model == "custom":
|
|
230
|
+
console.print("Enter the model name (HuggingFace ID):")
|
|
231
|
+
model = input().strip()
|
|
218
232
|
|
|
219
233
|
# Validate model is a VL model (whether provided via CLI or selected)
|
|
220
234
|
if model:
|
|
@@ -277,6 +291,7 @@ def run_local_training(
|
|
|
277
291
|
config, estimated_memory = generate_config_interactive(
|
|
278
292
|
model_name=model,
|
|
279
293
|
presets=presets,
|
|
294
|
+
yes=yes,
|
|
280
295
|
)
|
|
281
296
|
|
|
282
297
|
# Step 5: Save temporary config and display summary
|
|
@@ -288,8 +303,8 @@ def run_local_training(
|
|
|
288
303
|
# Display configuration summary
|
|
289
304
|
display_config_summary(config, len(tasks), gpu_info, estimated_memory)
|
|
290
305
|
|
|
291
|
-
# Step 6: Ask for confirmation (skip if config was provided)
|
|
292
|
-
if not config_file:
|
|
306
|
+
# Step 6: Ask for confirmation (skip if config was provided or in yes mode)
|
|
307
|
+
if not config_file and not yes:
|
|
293
308
|
console.print("\n[bold yellow]Options:[/bold yellow]")
|
|
294
309
|
console.print(" • Type [green]'start'[/green] to begin training")
|
|
295
310
|
console.print(" • Type [cyan]'edit'[/cyan] to open config in your editor")
|
|
@@ -346,7 +361,12 @@ def run_local_training(
|
|
|
346
361
|
try:
|
|
347
362
|
import typer
|
|
348
363
|
|
|
349
|
-
if
|
|
364
|
+
if yes:
|
|
365
|
+
# Always save in yes mode
|
|
366
|
+
config_path = Path("rl_config.json")
|
|
367
|
+
save_config(config, config_path)
|
|
368
|
+
hud_console.info("Auto-saved configuration (--yes mode)")
|
|
369
|
+
elif typer.confirm("Save this configuration for later?", default=True):
|
|
350
370
|
config_path = Path("rl_config.json")
|
|
351
371
|
save_config(config, config_path)
|
|
352
372
|
except Exception as e:
|
|
@@ -367,6 +387,10 @@ def run_local_training(
|
|
|
367
387
|
console.print(
|
|
368
388
|
"[red]Invalid choice. Type 'start', 'edit', or 'cancel':[/red] ", end=""
|
|
369
389
|
)
|
|
390
|
+
elif yes:
|
|
391
|
+
# In yes mode, auto-start training
|
|
392
|
+
hud_console.info("Auto-starting training (--yes mode)")
|
|
393
|
+
config = load_config(temp_config_path)
|
|
370
394
|
else:
|
|
371
395
|
console.print("\n[dim]Using provided configuration file...[/dim]")
|
|
372
396
|
config = load_config(temp_config_path)
|