hud-python 0.4.36__py3-none-any.whl → 0.4.38__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of hud-python might be problematic. Click here for more details.

Files changed (44) hide show
  1. hud/agents/__init__.py +2 -0
  2. hud/agents/lite_llm.py +72 -0
  3. hud/agents/openai_chat_generic.py +21 -7
  4. hud/cli/__init__.py +19 -4
  5. hud/cli/build.py +17 -2
  6. hud/cli/dev.py +1 -1
  7. hud/cli/eval.py +93 -13
  8. hud/cli/flows/tasks.py +197 -65
  9. hud/cli/init.py +1 -1
  10. hud/cli/push.py +9 -0
  11. hud/cli/rl/__init__.py +14 -4
  12. hud/cli/rl/celebrate.py +187 -0
  13. hud/cli/rl/config.py +15 -8
  14. hud/cli/rl/local_runner.py +44 -20
  15. hud/cli/rl/remote_runner.py +164 -87
  16. hud/cli/rl/viewer.py +141 -0
  17. hud/cli/rl/wait_utils.py +89 -0
  18. hud/cli/utils/env_check.py +196 -0
  19. hud/cli/utils/source_hash.py +108 -0
  20. hud/clients/base.py +1 -1
  21. hud/clients/fastmcp.py +1 -1
  22. hud/otel/config.py +1 -1
  23. hud/otel/context.py +2 -2
  24. hud/rl/vllm_adapter.py +1 -1
  25. hud/server/server.py +84 -13
  26. hud/server/tests/test_add_tool.py +60 -0
  27. hud/server/tests/test_context.py +128 -0
  28. hud/server/tests/test_mcp_server_handlers.py +44 -0
  29. hud/server/tests/test_mcp_server_integration.py +405 -0
  30. hud/server/tests/test_mcp_server_more.py +247 -0
  31. hud/server/tests/test_run_wrapper.py +53 -0
  32. hud/server/tests/test_server_extra.py +166 -0
  33. hud/server/tests/test_sigterm_runner.py +78 -0
  34. hud/shared/hints.py +1 -1
  35. hud/telemetry/job.py +2 -2
  36. hud/types.py +9 -2
  37. hud/utils/tasks.py +32 -24
  38. hud/utils/tests/test_version.py +1 -1
  39. hud/version.py +1 -1
  40. {hud_python-0.4.36.dist-info → hud_python-0.4.38.dist-info}/METADATA +14 -12
  41. {hud_python-0.4.36.dist-info → hud_python-0.4.38.dist-info}/RECORD +44 -30
  42. {hud_python-0.4.36.dist-info → hud_python-0.4.38.dist-info}/WHEEL +0 -0
  43. {hud_python-0.4.36.dist-info → hud_python-0.4.38.dist-info}/entry_points.txt +0 -0
  44. {hud_python-0.4.36.dist-info → hud_python-0.4.38.dist-info}/licenses/LICENSE +0 -0
hud/cli/flows/tasks.py CHANGED
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import json
4
+ import logging
4
5
  import re
5
6
  from pathlib import Path
6
7
  from typing import TYPE_CHECKING, Any
@@ -8,10 +9,9 @@ from typing import TYPE_CHECKING, Any
8
9
  import typer
9
10
  import yaml
10
11
 
11
- from hud.cli.build import build_environment
12
12
  from hud.cli.push import push_environment
13
13
  from hud.cli.utils.docker import require_docker_running
14
- from hud.cli.utils.environment import is_environment_directory
14
+ from hud.cli.utils.env_check import ensure_built, find_environment_dir
15
15
  from hud.cli.utils.registry import extract_name_and_tag
16
16
  from hud.utils.hud_console import hud_console
17
17
  from hud.utils.tasks import load_tasks
@@ -20,6 +20,9 @@ if TYPE_CHECKING:
20
20
  from hud.types import Task
21
21
 
22
22
 
23
+ logger = logging.getLogger(__name__)
24
+
25
+
23
26
  def _is_remote_url(url: str) -> bool:
24
27
  """Match the remote url."""
25
28
  # See if a url is a remote url
@@ -53,62 +56,6 @@ def _validate_tasks(tasks: list[Task]) -> bool:
53
56
  return True
54
57
 
55
58
 
56
- def _find_environment_dir(tasks_path: Path) -> Path | None:
57
- """Find the environment directory related to a tasks file.
58
-
59
- Strategy:
60
- - Prefer a directory containing hud.lock.yaml
61
- - Fallback to a directory that looks like an environment (Dockerfile + pyproject.toml)
62
- - Search the tasks file directory, CWD, and a couple of parents
63
- """
64
- candidates: list[Path] = []
65
- cwd = Path.cwd()
66
- candidates.extend([tasks_path.parent, cwd])
67
-
68
- # Add parents (up to 2 levels for each)
69
- for base in list(candidates):
70
- p = base
71
- for _ in range(2):
72
- p = p.parent
73
- if p not in candidates:
74
- candidates.append(p)
75
-
76
- # Prefer those with hud.lock.yaml
77
- for d in candidates:
78
- if (d / "hud.lock.yaml").exists():
79
- return d
80
-
81
- # Otherwise, find a plausible environment dir
82
- for d in candidates:
83
- try:
84
- if is_environment_directory(d):
85
- return d
86
- except Exception as e:
87
- hud_console.debug(f"Skipping path {d}: {e}")
88
- continue
89
-
90
- return None
91
-
92
-
93
- def _ensure_built(env_dir: Path) -> dict[str, Any]:
94
- """Ensure the environment is built and a lock file exists; return lock data."""
95
- lock_path = env_dir / "hud.lock.yaml"
96
- if not lock_path.exists():
97
- hud_console.warning("No hud.lock.yaml found. The environment hasn't been built.")
98
- if not hud_console.confirm("Build the environment now (runs 'hud build')?", default=True):
99
- raise typer.Exit(1)
100
- # Check Docker availability before attempting a build
101
- require_docker_running()
102
- # Run build (non-interactive). If Docker isn't running, this will raise and stop the flow.
103
- # Force linux/amd64 platform to ensure compatibility during RL flows.
104
- build_environment(str(env_dir), platform="linux/amd64")
105
-
106
- # Load lock file
107
- with open(lock_path) as f:
108
- lock_data = yaml.safe_load(f) or {}
109
- return lock_data
110
-
111
-
112
59
  def _ensure_pushed(env_dir: Path, lock_data: dict[str, Any]) -> dict[str, Any]:
113
60
  """Ensure the environment is pushed to a registry; return updated lock data."""
114
61
  pushed = bool(lock_data.get("push"))
@@ -153,18 +100,106 @@ def _derive_remote_image(lock_data: dict[str, Any]) -> str:
153
100
  return f"{name}:{tag}"
154
101
 
155
102
 
103
+ def _extract_existing_images(tasks: list[Task]) -> set[str]:
104
+ """Extract all Mcp-Image references from tasks."""
105
+ images = set()
106
+
107
+ def _extract_from_obj(obj: Any) -> None:
108
+ if isinstance(obj, dict):
109
+ # Check for Mcp-Image in headers
110
+ if "headers" in obj and isinstance(obj["headers"], dict):
111
+ mcp_image = obj["headers"].get("Mcp-Image")
112
+ if mcp_image:
113
+ images.add(mcp_image)
114
+ # Recursively check nested objects
115
+ for v in obj.values():
116
+ _extract_from_obj(v)
117
+ elif isinstance(obj, list):
118
+ for item in obj:
119
+ _extract_from_obj(item)
120
+
121
+ for task in tasks:
122
+ if task.mcp_config:
123
+ _extract_from_obj(task.mcp_config)
124
+
125
+ return images
126
+
127
+
128
+ def _env_var_to_header_key(var_name: str) -> str:
129
+ """Convert ENV_VAR style to Env-Env-Var header style.
130
+
131
+ Example: OPENAI_API_KEY -> Env-Openai-Api-Key
132
+ """
133
+ parts = str(var_name).split("_")
134
+ return f"Env-{'-'.join(part.capitalize() for part in parts)}"
135
+
136
+
137
+ def _extract_api_key_vars(lock_data: dict[str, Any]) -> set[str]:
138
+ """Extract env var names from lock file's provided section (authoritative source).
139
+
140
+ We only use keys listed under environment.variables.provided, and exclude HUD_API_KEY
141
+ because Authorization already carries it.
142
+ """
143
+ provided_keys: set[str] = set()
144
+ if not isinstance(lock_data, dict):
145
+ return provided_keys
146
+ try:
147
+ env_section = (lock_data.get("environment") or {}).get("variables") or {}
148
+ provided = env_section.get("provided") or {}
149
+ for name in provided:
150
+ provided_keys.add(str(name))
151
+ except Exception as e:
152
+ logger.debug("Failed to parse provided env vars from lock data: %s", e)
153
+ provided_keys.discard("HUD_API_KEY")
154
+ return provided_keys
155
+
156
+
157
+ def _extract_dotenv_api_key_vars(env_dir: Path) -> set[str]:
158
+ """Parse .env for API-like variables to suggest as headers.
159
+
160
+ We intentionally include only keys that look like secrets to avoid noise:
161
+ any key containing one of: api, key, token, secret, password (case-insensitive).
162
+ """
163
+ dotenv_path = env_dir / ".env"
164
+ detected: set[str] = set()
165
+ if not dotenv_path.exists():
166
+ return detected
167
+ try:
168
+ for line in dotenv_path.read_text(encoding="utf-8").splitlines():
169
+ line = line.strip()
170
+ if not line or line.startswith("#"):
171
+ continue
172
+ if "=" not in line:
173
+ continue
174
+ name, _ = line.split("=", 1)
175
+ name = name.strip()
176
+ lowered = name.lower()
177
+ if any(s in lowered for s in ("api", "key", "token", "secret", "password")):
178
+ detected.add(name)
179
+ except Exception:
180
+ # Best-effort only
181
+ return detected
182
+ detected.discard("HUD_API_KEY")
183
+ return detected
184
+
185
+
156
186
  def convert_tasks_to_remote(tasks_file: str) -> str:
157
187
  """Convert a local tasks file to remote MCP tasks and return new filename.
158
188
 
159
189
  Steps:
160
190
  1) Find env dir; ensure built (hud.lock.yaml), otherwise build
161
191
  2) Ensure pushed to registry, otherwise push
162
- 3) Create remote_[tasks].json with mcp_config pointing to mcp.hud.so and Mcp-Image
163
- 4) Return the new tasks file path
192
+ 3) Check for outdated images in existing task configurations
193
+ 4) Create remote_[tasks].json with mcp_config pointing to mcp.hud.so and Mcp-Image
194
+ 5) Return the new tasks file path
164
195
  """
165
196
  tasks_path = Path(tasks_file).resolve()
166
197
 
167
- tasks = load_tasks(str(tasks_path))
198
+ # Load validated tasks for decision-making (may resolve env vars)
199
+ tasks: list[Task] = load_tasks(str(tasks_path)) # type: ignore[assignment]
200
+
201
+ # Load raw tasks to preserve placeholders when writing back to disk
202
+ raw_tasks: list[dict[str, Any]] = load_tasks(str(tasks_path), raw=True) # type: ignore[assignment]
168
203
 
169
204
  # Ensure HUD_API_KEY is available: prefer process env, else load from env_dir/.env
170
205
  from hud.settings import settings
@@ -174,24 +209,119 @@ def convert_tasks_to_remote(tasks_file: str) -> str:
174
209
  hud_console.info("Set it in your environment or run: hud set HUD_API_KEY=your-key-here")
175
210
  raise typer.Exit(1)
176
211
 
212
+ # Check if tasks already have remote URLs
213
+ already_remote = _validate_tasks(tasks)
214
+
215
+ # Extract existing images from tasks
216
+ existing_images = _extract_existing_images(tasks)
217
+
177
218
  # Load tasks (supports .json and .jsonl)
178
- if _validate_tasks(tasks):
219
+ if already_remote and not existing_images:
220
+ # Tasks are remote but have no image references - just return as-is
179
221
  return str(tasks_path)
180
222
 
181
223
  # Locate environment
182
- env_dir = _find_environment_dir(tasks_path)
224
+ env_dir = find_environment_dir(tasks_path)
183
225
  if not env_dir:
184
226
  hud_console.error("Could not locate an environment directory (Dockerfile + pyproject.toml)")
185
227
  hud_console.hint("Ensure you're in or near your environment folder before running 'hud rl'")
186
228
  raise typer.Exit(1)
187
229
 
188
230
  # Ensure built and pushed
189
- lock_data = _ensure_built(env_dir)
231
+ lock_data = ensure_built(env_dir, interactive=True)
190
232
  lock_data = _ensure_pushed(env_dir, lock_data)
191
233
 
192
234
  # Derive remote image name org/name:tag
193
235
  remote_image = _derive_remote_image(lock_data)
194
236
 
237
+ # Check if existing images are outdated
238
+ needs_update = False
239
+ should_update_image = False
240
+ if existing_images:
241
+ # Check if any existing image differs from the latest
242
+ for existing_img in existing_images:
243
+ if existing_img != remote_image:
244
+ hud_console.warning(f"Detected outdated image reference: {existing_img}")
245
+ hud_console.info(f"Latest pushed image: {remote_image}")
246
+ needs_update = True
247
+ break
248
+
249
+ if needs_update:
250
+ confirm_msg = "Update task configuration with the latest image?"
251
+ if hud_console.confirm(confirm_msg, default=True):
252
+ hud_console.info("Updating task configuration with latest image...")
253
+ should_update_image = True
254
+ else:
255
+ # If user doesn't want to update, just return the original file
256
+ if already_remote:
257
+ return str(tasks_path)
258
+ # Otherwise, continue with conversion but keep old images
259
+ remote_image = next(iter(existing_images)) # Use the first existing image
260
+
261
+ # If tasks are already remote and up-to-date (no update needed), return original file
262
+ if already_remote and not needs_update:
263
+ return str(tasks_path)
264
+
265
+ # If tasks are already remote and we just need to update the image
266
+ if already_remote and should_update_image:
267
+ # Update image references in-place on RAW tasks (preserve placeholders)
268
+ def _update_image_refs_raw(obj: Any) -> Any:
269
+ if isinstance(obj, dict):
270
+ new_obj = {}
271
+ for k, v in obj.items():
272
+ if k == "Mcp-Image" and isinstance(v, str) and v in existing_images:
273
+ new_obj[k] = remote_image
274
+ else:
275
+ new_obj[k] = _update_image_refs_raw(v)
276
+ return new_obj
277
+ elif isinstance(obj, list):
278
+ return [_update_image_refs_raw(item) for item in obj]
279
+ else:
280
+ return obj
281
+
282
+ updated_raw_tasks: list[dict[str, Any]] = []
283
+ for t in raw_tasks:
284
+ td = dict(t)
285
+ if "mcp_config" in td:
286
+ td["mcp_config"] = _update_image_refs_raw(td["mcp_config"])
287
+ updated_raw_tasks.append(td)
288
+
289
+ # Write updated file (preserve original format - check if it's .jsonl)
290
+ if tasks_path.suffix == ".jsonl":
291
+ with open(tasks_path, "w", encoding="utf-8") as f:
292
+ for task in updated_raw_tasks:
293
+ json.dump(task, f, ensure_ascii=False)
294
+ f.write("\n")
295
+ else:
296
+ with open(tasks_path, "w", encoding="utf-8") as f:
297
+ json.dump(updated_raw_tasks, f, ensure_ascii=False, indent=2)
298
+ f.write("\n")
299
+
300
+ hud_console.success(f"Updated {tasks_path.name} with latest image: {remote_image}")
301
+ return str(tasks_path)
302
+
303
+ # Extract additional API key headers from lock and suggest from .env
304
+ provided_keys = _extract_api_key_vars(lock_data)
305
+ dotenv_keys = _extract_dotenv_api_key_vars(env_dir)
306
+
307
+ # If .env contains API-like vars not in lock, offer to include them
308
+ missing = sorted(dotenv_keys - provided_keys)
309
+ if missing:
310
+ names_preview = ", ".join(missing)
311
+ prompt = (
312
+ f"Detected env vars in .env that look like API keys: {names_preview}.\n"
313
+ "Include them as remote headers (values will be ${VAR} placeholders)?"
314
+ )
315
+ if hud_console.confirm(prompt, default=True):
316
+ provided_keys.update(missing)
317
+
318
+ extra_api_key_headers: dict[str, str] = {}
319
+ for var_name in provided_keys:
320
+ if str(var_name).upper() == "HUD_API_KEY":
321
+ continue
322
+ header_key = _env_var_to_header_key(var_name)
323
+ extra_api_key_headers[header_key] = f"${{{var_name}}}"
324
+
195
325
  # Helper to strip extra fields from tool calls
196
326
  def _simplify_tool_call(tool: Any) -> Any:
197
327
  def _one(x: Any) -> dict[str, Any]:
@@ -229,6 +359,9 @@ def convert_tasks_to_remote(tasks_file: str) -> str:
229
359
  },
230
360
  }
231
361
 
362
+ # Merge additional API key headers
363
+ item["mcp_config"]["hud"]["headers"].update(extra_api_key_headers)
364
+
232
365
  # Optional fields, omit Nones
233
366
  if t.setup_tool is not None:
234
367
  item["setup_tool"] = _simplify_tool_call(t.setup_tool)
@@ -243,7 +376,6 @@ def convert_tasks_to_remote(tasks_file: str) -> str:
243
376
 
244
377
  tasks_payload.append(item)
245
378
 
246
- # Write new file: remote_<name>.json (always JSON array)
247
379
  remote_name = f"remote_{tasks_path.stem}.json"
248
380
  remote_path = tasks_path.parent / remote_name
249
381
  with open(remote_path, "w", encoding="utf-8") as f:
hud/cli/init.py CHANGED
@@ -21,7 +21,7 @@ GITHUB_BRANCH = "main"
21
21
 
22
22
  PRESET_MAP: dict[str, str | None] = {
23
23
  "blank": "blank",
24
- "deep-research": "remote_browser",
24
+ "deep-research": "deepresearch",
25
25
  "browser": "browser",
26
26
  }
27
27
 
hud/cli/push.py CHANGED
@@ -11,6 +11,7 @@ import requests
11
11
  import typer
12
12
  import yaml
13
13
 
14
+ from hud.cli.utils.env_check import ensure_built
14
15
  from hud.utils.hud_console import HUDConsole
15
16
 
16
17
 
@@ -131,6 +132,14 @@ def push_environment(
131
132
 
132
133
  # Find hud.lock.yaml in specified directory
133
134
  env_dir = Path(directory)
135
+
136
+ # Ensure environment is built and up-to-date (hash-based); interactive prompt
137
+ try:
138
+ ensure_built(env_dir, interactive=True)
139
+ except typer.Exit:
140
+ raise
141
+ except Exception as e:
142
+ HUDConsole().debug(f"Skipping pre-push build check: {e}")
134
143
  lock_path = env_dir / "hud.lock.yaml"
135
144
 
136
145
  if not lock_path.exists():
hud/cli/rl/__init__.py CHANGED
@@ -25,7 +25,7 @@ def rl_command(
25
25
  ),
26
26
  model: str | None = typer.Argument(
27
27
  None,
28
- help="Model to train (default: interactive selection)",
28
+ help="Model to train from https://hud.so/models (default: interactive selection)",
29
29
  ),
30
30
  config_file: Path | None = typer.Option( # noqa: B008
31
31
  None,
@@ -72,6 +72,12 @@ def rl_command(
72
72
  "--local",
73
73
  help="Run training locally instead of using remote API server",
74
74
  ),
75
+ yes: bool = typer.Option(
76
+ False,
77
+ "--yes",
78
+ "-y",
79
+ help="Auto-accept all prompts and use defaults (lazy mode)",
80
+ ),
75
81
  # Internal flag
76
82
  skip_vllm_startup: bool = typer.Option(
77
83
  False,
@@ -122,8 +128,7 @@ def rl_command(
122
128
  try:
123
129
  from hud.cli.flows.tasks import convert_tasks_to_remote
124
130
 
125
- console.print("\n[cyan]Preparing remote training tasks...[/cyan]")
126
- console.print("[cyan](build/push if needed)[/cyan]")
131
+ console.print("[cyan]Preparing remote training tasks...[/cyan]")
127
132
  tasks_file = convert_tasks_to_remote(tasks_file)
128
133
  except typer.Exit:
129
134
  raise
@@ -137,7 +142,11 @@ def rl_command(
137
142
  from .remote_runner import run_remote_training
138
143
 
139
144
  run_remote_training(
140
- tasks_file=tasks_file, model=model, config_file=config_file, output_dir=output_dir
145
+ tasks_file=tasks_file,
146
+ model=model,
147
+ config_file=config_file,
148
+ output_dir=output_dir,
149
+ yes=yes,
141
150
  )
142
151
  return
143
152
  except Exception as e:
@@ -152,6 +161,7 @@ def rl_command(
152
161
  model=model,
153
162
  config_file=config_file,
154
163
  output_dir=output_dir,
164
+ yes=yes,
155
165
  restart=restart,
156
166
  verbose=verbose,
157
167
  no_ddp=no_ddp,
@@ -0,0 +1,187 @@
1
+ # ruff: noqa: S311
2
+ from __future__ import annotations
3
+
4
+ import random
5
+ import time
6
+ from dataclasses import dataclass
7
+ from typing import TYPE_CHECKING, ClassVar
8
+
9
+ from rich.live import Live
10
+ from rich.text import Text
11
+
12
+ from hud.utils.hud_console import hud_console
13
+
14
+ if TYPE_CHECKING:
15
+ from rich.console import Console
16
+
17
+
18
+ @dataclass
19
+ class Particle:
20
+ """A confetti particle with physics."""
21
+
22
+ x: float
23
+ y: float
24
+ vx: float # velocity x
25
+ vy: float # velocity y
26
+ char: str
27
+ color: str
28
+
29
+ def update(self, gravity: float = 0.5, fps: float = 30.0) -> None:
30
+ """Update particle position and velocity."""
31
+ dt = 1.0 / fps
32
+ self.x += self.vx * dt
33
+ self.vy += gravity # Apply gravity
34
+ self.y += self.vy * dt
35
+
36
+
37
+ class ConfettiSystem:
38
+ """Minimal confetti system inspired by confetty."""
39
+
40
+ # Confetty-style colors
41
+ COLORS: ClassVar[list[str]] = ["#a864fd", "#29cdff", "#78ff44", "#ff718d", "#fdff6a"]
42
+ # Confetty-style characters
43
+ CHARS: ClassVar[list[str]] = ["█", "▓", "▒", "░", "▄", "▀"]
44
+
45
+ def __init__(self, width: int, height: int) -> None:
46
+ self.width = width
47
+ self.height = height
48
+ self.particles: list[Particle] = []
49
+
50
+ def spawn_burst(self, num_particles: int = 75) -> None:
51
+ """Spawn a burst of confetti particles from the top center."""
52
+ center_x = self.width / 2
53
+
54
+ for _ in range(num_particles):
55
+ # Start from top center with some horizontal spread
56
+ x = center_x + (self.width / 4) * (random.random() - 0.5)
57
+ y = 0
58
+
59
+ # Random velocities - horizontal spread and upward/slight downward initial velocity
60
+ vx = (random.random() - 0.5) * 100
61
+ vy = random.random() * 50 - 25 # Some go up first
62
+
63
+ particle = Particle(
64
+ x=x,
65
+ y=y,
66
+ vx=vx,
67
+ vy=vy,
68
+ char=random.choice(self.CHARS),
69
+ color=random.choice(self.COLORS),
70
+ )
71
+ self.particles.append(particle)
72
+
73
+ def update(self) -> None:
74
+ """Update all particles and remove off-screen ones."""
75
+ # Update physics
76
+ for particle in self.particles:
77
+ particle.update()
78
+
79
+ # Remove particles that are off-screen
80
+ self.particles = [p for p in self.particles if 0 <= p.x < self.width and p.y < self.height]
81
+
82
+ def render(self) -> str:
83
+ """Render the particle system to a string."""
84
+ # Create empty grid
85
+ grid = [[" " for _ in range(self.width)] for _ in range(self.height)]
86
+
87
+ # Place particles
88
+ for particle in self.particles:
89
+ x, y = int(particle.x), int(particle.y)
90
+ if 0 <= x < self.width and 0 <= y < self.height:
91
+ grid[y][x] = particle.char
92
+
93
+ # Convert to string
94
+ return "\n".join("".join(row) for row in grid)
95
+
96
+ def render_with_colors(self) -> Text:
97
+ """Render the particle system with colors for Rich."""
98
+ text = Text()
99
+
100
+ # Create empty grid with color info
101
+ grid: list[list[tuple[str, str] | None]] = [
102
+ [None for _ in range(self.width)] for _ in range(self.height)
103
+ ]
104
+
105
+ # Place particles with their colors
106
+ for particle in self.particles:
107
+ x, y = int(particle.x), int(particle.y)
108
+ if 0 <= x < self.width and 0 <= y < self.height:
109
+ grid[y][x] = (particle.char, particle.color)
110
+
111
+ # Build colored text
112
+ for row in grid:
113
+ for cell in row:
114
+ if cell:
115
+ char, color = cell
116
+ text.append(char, style=color)
117
+ else:
118
+ text.append(" ")
119
+ text.append("\n")
120
+
121
+ return text
122
+
123
+
124
+ def show_confetti(console: Console, seconds: float = 2.5) -> None:
125
+ """Display celebratory confetti animation inspired by confetty.
126
+
127
+ Shows "Starting training!" message first, then creates two bursts of
128
+ falling confetti particles that fall away completely.
129
+
130
+ Args:
131
+ console: Rich console instance
132
+ seconds: Duration to show confetti
133
+ """
134
+ # Show celebratory message first
135
+ console.print(
136
+ "[bold green]🎉 Starting training! See your model on https://hud.so/models[/bold green]"
137
+ )
138
+ time.sleep(0.3) # Brief pause to see the message
139
+
140
+ width = min(console.size.width, 120) # Cap width for performance
141
+ height = min(console.size.height - 2, 30) # Leave room for message
142
+
143
+ # Create confetti system
144
+ system = ConfettiSystem(width, height)
145
+
146
+ fps = 30
147
+ frame_time = 1.0 / fps
148
+
149
+ # First burst at the beginning
150
+ system.spawn_burst(num_particles=60)
151
+
152
+ # Track when to spawn second burst
153
+ second_burst_frame = int(fps * 0.4) # Second burst after 0.4 seconds
154
+
155
+ with Live("", refresh_per_second=fps, console=console, transient=True) as live:
156
+ frame = 0
157
+ # Keep running until all particles have fallen off screen
158
+ while frame < seconds * fps or len(system.particles) > 0:
159
+ # Spawn second burst
160
+ if frame == second_burst_frame:
161
+ system.spawn_burst(num_particles=60)
162
+
163
+ system.update()
164
+ live.update(system.render_with_colors())
165
+ time.sleep(frame_time)
166
+ frame += 1
167
+
168
+
169
+ def show_confetti_async(console: Console, seconds: float = 2.5) -> None:
170
+ """Non-blocking confetti animation that runs in a background thread.
171
+
172
+ The animation will run independently while training starts immediately.
173
+ """
174
+ import threading
175
+
176
+ def _run_confetti() -> None:
177
+ try:
178
+ show_confetti(console, seconds)
179
+ except Exception:
180
+ hud_console.info("Launching training...")
181
+
182
+ thread = threading.Thread(target=_run_confetti, daemon=True)
183
+ thread.start()
184
+ # Don't wait - let training start immediately while confetti plays
185
+
186
+
187
+ __all__ = ["show_confetti", "show_confetti_async"]
hud/cli/rl/config.py CHANGED
@@ -21,22 +21,29 @@ console = Console()
21
21
  def generate_config_interactive(
22
22
  model_name: str,
23
23
  presets: list[dict[str, Any]],
24
+ yes: bool = False,
24
25
  ) -> tuple[Config, float]:
25
26
  """Generate RL training configuration interactively."""
26
27
  # Validate model is a VL model
27
28
  validate_vl_model(model_name)
28
29
 
29
30
  # Display preset options
30
- display_preset_table(presets, 80.0) # Assuming A100 80GB
31
+ if not yes:
32
+ display_preset_table(presets, 80.0) # Assuming A100 80GB
31
33
 
32
34
  # Let user select preset
33
- preset_choice = hud_console.select(
34
- "Select a training configuration preset:",
35
- choices=[{"name": p["name"], "value": i} for i, p in enumerate(presets)],
36
- default=1 if len(presets) > 1 else 0, # Default to "Balanced" if available
37
- )
38
-
39
- selected_preset = presets[preset_choice] # type: ignore
35
+ if yes:
36
+ # Use default preset (Balanced if available, otherwise first)
37
+ preset_choice = 1 if len(presets) > 1 else 0
38
+ selected_preset = presets[preset_choice]
39
+ hud_console.info(f"Auto-selecting preset: {selected_preset['name']} (--yes mode)")
40
+ else:
41
+ preset_choice = hud_console.select(
42
+ "Select a training configuration preset:",
43
+ choices=[{"name": p["name"], "value": i} for i, p in enumerate(presets)],
44
+ default=1 if len(presets) > 1 else 0, # Default to "Balanced" if available
45
+ )
46
+ selected_preset = presets[preset_choice] # type: ignore
40
47
 
41
48
  # Use preset values directly
42
49
  max_steps_per_episode = selected_preset["max_steps_per_episode"]