hud-python 0.4.45__py3-none-any.whl → 0.5.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hud/__init__.py +27 -7
- hud/agents/__init__.py +70 -5
- hud/agents/base.py +238 -500
- hud/agents/claude.py +236 -247
- hud/agents/gateway.py +42 -0
- hud/agents/gemini.py +264 -0
- hud/agents/gemini_cua.py +324 -0
- hud/agents/grounded_openai.py +98 -100
- hud/agents/misc/integration_test_agent.py +51 -20
- hud/agents/misc/response_agent.py +48 -36
- hud/agents/openai.py +282 -296
- hud/agents/{openai_chat_generic.py → openai_chat.py} +63 -33
- hud/agents/operator.py +199 -0
- hud/agents/resolver.py +70 -0
- hud/agents/tests/conftest.py +133 -0
- hud/agents/tests/test_base.py +300 -622
- hud/agents/tests/test_base_runtime.py +233 -0
- hud/agents/tests/test_claude.py +381 -214
- hud/agents/tests/test_client.py +9 -10
- hud/agents/tests/test_gemini.py +369 -0
- hud/agents/tests/test_grounded_openai_agent.py +65 -50
- hud/agents/tests/test_openai.py +377 -140
- hud/agents/tests/test_operator.py +362 -0
- hud/agents/tests/test_resolver.py +192 -0
- hud/agents/tests/test_run_eval.py +179 -0
- hud/agents/types.py +148 -0
- hud/cli/__init__.py +493 -546
- hud/cli/analyze.py +43 -5
- hud/cli/build.py +699 -113
- hud/cli/debug.py +8 -5
- hud/cli/dev.py +889 -732
- hud/cli/eval.py +793 -667
- hud/cli/flows/dev.py +167 -0
- hud/cli/flows/init.py +191 -0
- hud/cli/flows/tasks.py +153 -56
- hud/cli/flows/templates.py +151 -0
- hud/cli/flows/tests/__init__.py +1 -0
- hud/cli/flows/tests/test_dev.py +126 -0
- hud/cli/init.py +60 -58
- hud/cli/pull.py +1 -1
- hud/cli/push.py +38 -13
- hud/cli/rft.py +311 -0
- hud/cli/rft_status.py +145 -0
- hud/cli/tests/test_analyze.py +5 -5
- hud/cli/tests/test_analyze_metadata.py +3 -2
- hud/cli/tests/test_analyze_module.py +120 -0
- hud/cli/tests/test_build.py +110 -8
- hud/cli/tests/test_build_failure.py +41 -0
- hud/cli/tests/test_build_module.py +50 -0
- hud/cli/tests/test_cli_init.py +6 -1
- hud/cli/tests/test_cli_more_wrappers.py +30 -0
- hud/cli/tests/test_cli_root.py +140 -0
- hud/cli/tests/test_convert.py +361 -0
- hud/cli/tests/test_debug.py +12 -10
- hud/cli/tests/test_dev.py +197 -0
- hud/cli/tests/test_eval.py +251 -0
- hud/cli/tests/test_eval_bedrock.py +51 -0
- hud/cli/tests/test_init.py +124 -0
- hud/cli/tests/test_main_module.py +11 -5
- hud/cli/tests/test_mcp_server.py +12 -100
- hud/cli/tests/test_push.py +1 -1
- hud/cli/tests/test_push_happy.py +74 -0
- hud/cli/tests/test_push_wrapper.py +23 -0
- hud/cli/tests/test_registry.py +1 -1
- hud/cli/tests/test_utils.py +1 -1
- hud/cli/{rl → utils}/celebrate.py +14 -12
- hud/cli/utils/config.py +18 -1
- hud/cli/utils/docker.py +130 -4
- hud/cli/utils/env_check.py +9 -9
- hud/cli/utils/git.py +136 -0
- hud/cli/utils/interactive.py +39 -5
- hud/cli/utils/metadata.py +70 -1
- hud/cli/utils/runner.py +1 -1
- hud/cli/utils/server.py +2 -2
- hud/cli/utils/source_hash.py +3 -3
- hud/cli/utils/tasks.py +4 -1
- hud/cli/utils/tests/__init__.py +0 -0
- hud/cli/utils/tests/test_config.py +58 -0
- hud/cli/utils/tests/test_docker.py +93 -0
- hud/cli/utils/tests/test_docker_hints.py +71 -0
- hud/cli/utils/tests/test_env_check.py +74 -0
- hud/cli/utils/tests/test_environment.py +42 -0
- hud/cli/utils/tests/test_git.py +142 -0
- hud/cli/utils/tests/test_interactive_module.py +60 -0
- hud/cli/utils/tests/test_local_runner.py +50 -0
- hud/cli/utils/tests/test_logging_utils.py +23 -0
- hud/cli/utils/tests/test_metadata.py +49 -0
- hud/cli/utils/tests/test_package_runner.py +35 -0
- hud/cli/utils/tests/test_registry_utils.py +49 -0
- hud/cli/utils/tests/test_remote_runner.py +25 -0
- hud/cli/utils/tests/test_runner_modules.py +52 -0
- hud/cli/utils/tests/test_source_hash.py +36 -0
- hud/cli/utils/tests/test_tasks.py +80 -0
- hud/cli/utils/version_check.py +258 -0
- hud/cli/{rl → utils}/viewer.py +2 -2
- hud/clients/README.md +12 -11
- hud/clients/__init__.py +4 -3
- hud/clients/base.py +166 -26
- hud/clients/environment.py +51 -0
- hud/clients/fastmcp.py +13 -6
- hud/clients/mcp_use.py +45 -15
- hud/clients/tests/test_analyze_scenarios.py +206 -0
- hud/clients/tests/test_protocol.py +9 -3
- hud/datasets/__init__.py +23 -20
- hud/datasets/loader.py +326 -0
- hud/datasets/runner.py +198 -105
- hud/datasets/tests/__init__.py +0 -0
- hud/datasets/tests/test_loader.py +221 -0
- hud/datasets/tests/test_utils.py +315 -0
- hud/datasets/utils.py +270 -90
- hud/environment/__init__.py +52 -0
- hud/environment/connection.py +258 -0
- hud/environment/connectors/__init__.py +33 -0
- hud/environment/connectors/base.py +68 -0
- hud/environment/connectors/local.py +177 -0
- hud/environment/connectors/mcp_config.py +137 -0
- hud/environment/connectors/openai.py +101 -0
- hud/environment/connectors/remote.py +172 -0
- hud/environment/environment.py +835 -0
- hud/environment/integrations/__init__.py +45 -0
- hud/environment/integrations/adk.py +67 -0
- hud/environment/integrations/anthropic.py +196 -0
- hud/environment/integrations/gemini.py +92 -0
- hud/environment/integrations/langchain.py +82 -0
- hud/environment/integrations/llamaindex.py +68 -0
- hud/environment/integrations/openai.py +238 -0
- hud/environment/mock.py +306 -0
- hud/environment/router.py +263 -0
- hud/environment/scenarios.py +620 -0
- hud/environment/tests/__init__.py +1 -0
- hud/environment/tests/test_connection.py +317 -0
- hud/environment/tests/test_connectors.py +205 -0
- hud/environment/tests/test_environment.py +593 -0
- hud/environment/tests/test_integrations.py +257 -0
- hud/environment/tests/test_local_connectors.py +242 -0
- hud/environment/tests/test_scenarios.py +1086 -0
- hud/environment/tests/test_tools.py +208 -0
- hud/environment/types.py +23 -0
- hud/environment/utils/__init__.py +35 -0
- hud/environment/utils/formats.py +215 -0
- hud/environment/utils/schema.py +171 -0
- hud/environment/utils/tool_wrappers.py +113 -0
- hud/eval/__init__.py +67 -0
- hud/eval/context.py +727 -0
- hud/eval/display.py +299 -0
- hud/eval/instrument.py +187 -0
- hud/eval/manager.py +533 -0
- hud/eval/parallel.py +268 -0
- hud/eval/task.py +372 -0
- hud/eval/tests/__init__.py +1 -0
- hud/eval/tests/test_context.py +178 -0
- hud/eval/tests/test_eval.py +210 -0
- hud/eval/tests/test_manager.py +152 -0
- hud/eval/tests/test_parallel.py +168 -0
- hud/eval/tests/test_task.py +291 -0
- hud/eval/types.py +65 -0
- hud/eval/utils.py +194 -0
- hud/patches/__init__.py +19 -0
- hud/patches/mcp_patches.py +308 -0
- hud/patches/warnings.py +54 -0
- hud/samples/browser.py +4 -4
- hud/server/__init__.py +2 -1
- hud/server/low_level.py +2 -1
- hud/server/router.py +164 -0
- hud/server/server.py +567 -80
- hud/server/tests/test_mcp_server_integration.py +11 -11
- hud/server/tests/test_mcp_server_more.py +1 -1
- hud/server/tests/test_server_extra.py +2 -0
- hud/settings.py +45 -3
- hud/shared/exceptions.py +36 -10
- hud/shared/hints.py +26 -1
- hud/shared/requests.py +15 -3
- hud/shared/tests/test_exceptions.py +40 -31
- hud/shared/tests/test_hints.py +167 -0
- hud/telemetry/__init__.py +20 -19
- hud/telemetry/exporter.py +201 -0
- hud/telemetry/instrument.py +165 -253
- hud/telemetry/tests/test_eval_telemetry.py +356 -0
- hud/telemetry/tests/test_exporter.py +258 -0
- hud/telemetry/tests/test_instrument.py +401 -0
- hud/tools/__init__.py +18 -2
- hud/tools/agent.py +223 -0
- hud/tools/apply_patch.py +639 -0
- hud/tools/base.py +54 -4
- hud/tools/bash.py +2 -2
- hud/tools/computer/__init__.py +36 -3
- hud/tools/computer/anthropic.py +2 -2
- hud/tools/computer/gemini.py +385 -0
- hud/tools/computer/hud.py +23 -6
- hud/tools/computer/openai.py +20 -21
- hud/tools/computer/qwen.py +434 -0
- hud/tools/computer/settings.py +37 -0
- hud/tools/edit.py +3 -7
- hud/tools/executors/base.py +4 -2
- hud/tools/executors/pyautogui.py +1 -1
- hud/tools/grounding/grounded_tool.py +13 -18
- hud/tools/grounding/grounder.py +10 -31
- hud/tools/grounding/tests/test_grounded_tool.py +26 -44
- hud/tools/jupyter.py +330 -0
- hud/tools/playwright.py +18 -3
- hud/tools/shell.py +308 -0
- hud/tools/tests/test_agent_tool.py +355 -0
- hud/tools/tests/test_apply_patch.py +718 -0
- hud/tools/tests/test_computer.py +4 -9
- hud/tools/tests/test_computer_actions.py +24 -2
- hud/tools/tests/test_jupyter_tool.py +181 -0
- hud/tools/tests/test_shell.py +596 -0
- hud/tools/tests/test_submit.py +85 -0
- hud/tools/tests/test_types.py +193 -0
- hud/tools/types.py +21 -1
- hud/types.py +194 -56
- hud/utils/__init__.py +2 -0
- hud/utils/env.py +67 -0
- hud/utils/hud_console.py +89 -18
- hud/utils/mcp.py +15 -58
- hud/utils/strict_schema.py +162 -0
- hud/utils/tests/test_init.py +1 -2
- hud/utils/tests/test_mcp.py +1 -28
- hud/utils/tests/test_pretty_errors.py +186 -0
- hud/utils/tests/test_tool_shorthand.py +154 -0
- hud/utils/tests/test_version.py +1 -1
- hud/utils/types.py +20 -0
- hud/version.py +1 -1
- hud_python-0.5.13.dist-info/METADATA +264 -0
- hud_python-0.5.13.dist-info/RECORD +305 -0
- {hud_python-0.4.45.dist-info → hud_python-0.5.13.dist-info}/WHEEL +1 -1
- hud/agents/langchain.py +0 -261
- hud/agents/lite_llm.py +0 -72
- hud/cli/rl/__init__.py +0 -180
- hud/cli/rl/config.py +0 -101
- hud/cli/rl/display.py +0 -133
- hud/cli/rl/gpu.py +0 -63
- hud/cli/rl/gpu_utils.py +0 -321
- hud/cli/rl/local_runner.py +0 -595
- hud/cli/rl/presets.py +0 -96
- hud/cli/rl/remote_runner.py +0 -463
- hud/cli/rl/rl_api.py +0 -150
- hud/cli/rl/vllm.py +0 -177
- hud/cli/rl/wait_utils.py +0 -89
- hud/datasets/parallel.py +0 -687
- hud/misc/__init__.py +0 -1
- hud/misc/claude_plays_pokemon.py +0 -292
- hud/otel/__init__.py +0 -35
- hud/otel/collector.py +0 -142
- hud/otel/config.py +0 -181
- hud/otel/context.py +0 -570
- hud/otel/exporters.py +0 -369
- hud/otel/instrumentation.py +0 -135
- hud/otel/processors.py +0 -121
- hud/otel/tests/__init__.py +0 -1
- hud/otel/tests/test_processors.py +0 -197
- hud/rl/README.md +0 -30
- hud/rl/__init__.py +0 -1
- hud/rl/actor.py +0 -176
- hud/rl/buffer.py +0 -405
- hud/rl/chat_template.jinja +0 -101
- hud/rl/config.py +0 -192
- hud/rl/distributed.py +0 -132
- hud/rl/learner.py +0 -637
- hud/rl/tests/__init__.py +0 -1
- hud/rl/tests/test_learner.py +0 -186
- hud/rl/train.py +0 -382
- hud/rl/types.py +0 -101
- hud/rl/utils/start_vllm_server.sh +0 -30
- hud/rl/utils.py +0 -524
- hud/rl/vllm_adapter.py +0 -143
- hud/telemetry/job.py +0 -352
- hud/telemetry/replay.py +0 -74
- hud/telemetry/tests/test_replay.py +0 -40
- hud/telemetry/tests/test_trace.py +0 -63
- hud/telemetry/trace.py +0 -158
- hud/utils/agent_factories.py +0 -86
- hud/utils/async_utils.py +0 -65
- hud/utils/group_eval.py +0 -223
- hud/utils/progress.py +0 -149
- hud/utils/tasks.py +0 -127
- hud/utils/tests/test_async_utils.py +0 -173
- hud/utils/tests/test_progress.py +0 -261
- hud_python-0.4.45.dist-info/METADATA +0 -552
- hud_python-0.4.45.dist-info/RECORD +0 -228
- {hud_python-0.4.45.dist-info → hud_python-0.5.13.dist-info}/entry_points.txt +0 -0
- {hud_python-0.4.45.dist-info → hud_python-0.5.13.dist-info}/licenses/LICENSE +0 -0
hud/cli/rl/vllm.py
DELETED
|
@@ -1,177 +0,0 @@
|
|
|
1
|
-
"""vLLM server management utilities."""
|
|
2
|
-
|
|
3
|
-
from __future__ import annotations
|
|
4
|
-
|
|
5
|
-
import asyncio
|
|
6
|
-
import logging
|
|
7
|
-
import os
|
|
8
|
-
import subprocess
|
|
9
|
-
import time
|
|
10
|
-
from pathlib import Path
|
|
11
|
-
|
|
12
|
-
import httpx
|
|
13
|
-
from rich.console import Console
|
|
14
|
-
|
|
15
|
-
from hud.utils.hud_console import HUDConsole
|
|
16
|
-
|
|
17
|
-
logger = logging.getLogger(__name__)
|
|
18
|
-
hud_console = HUDConsole(logger)
|
|
19
|
-
|
|
20
|
-
console = Console()
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
def get_vllm_args(model_name: str, chat_template_path: Path | None = None) -> list[str]:
|
|
24
|
-
"""Get common vLLM server arguments for both local and remote deployments."""
|
|
25
|
-
args = [
|
|
26
|
-
"serve",
|
|
27
|
-
model_name,
|
|
28
|
-
"--api-key",
|
|
29
|
-
"token-abc123",
|
|
30
|
-
"--host",
|
|
31
|
-
"0.0.0.0", # noqa: S104
|
|
32
|
-
"--port",
|
|
33
|
-
"8000",
|
|
34
|
-
"--tensor-parallel-size",
|
|
35
|
-
"1",
|
|
36
|
-
"--trust-remote-code",
|
|
37
|
-
"--max-model-len",
|
|
38
|
-
"16384",
|
|
39
|
-
"--enable-lora",
|
|
40
|
-
"--max-lora-rank",
|
|
41
|
-
"64",
|
|
42
|
-
"--max-cpu-loras",
|
|
43
|
-
"4",
|
|
44
|
-
"--enable-auto-tool-choice",
|
|
45
|
-
"--tool-call-parser",
|
|
46
|
-
"hermes",
|
|
47
|
-
"--disable-log-requests",
|
|
48
|
-
"--dtype",
|
|
49
|
-
"auto",
|
|
50
|
-
]
|
|
51
|
-
|
|
52
|
-
# Add chat template if provided
|
|
53
|
-
if chat_template_path and chat_template_path.exists():
|
|
54
|
-
args.extend(["--chat-template", str(chat_template_path.absolute())])
|
|
55
|
-
|
|
56
|
-
return args
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
def check_vllm_server() -> bool:
|
|
60
|
-
"""Check if vLLM server is running."""
|
|
61
|
-
try:
|
|
62
|
-
response = httpx.get("http://localhost:8000/health", timeout=2.0)
|
|
63
|
-
return response.status_code == 200
|
|
64
|
-
except Exception:
|
|
65
|
-
return False
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
def kill_vllm_server() -> None:
|
|
69
|
-
"""Kill any running vLLM server processes."""
|
|
70
|
-
try:
|
|
71
|
-
# Check for PID file first
|
|
72
|
-
pid_file = Path("/tmp/vllm_server.pid") # noqa: S108
|
|
73
|
-
if pid_file.exists():
|
|
74
|
-
try:
|
|
75
|
-
pid = int(pid_file.read_text().strip())
|
|
76
|
-
subprocess.run(["kill", "-TERM", str(pid)], check=False) # noqa: S603, S607
|
|
77
|
-
time.sleep(2)
|
|
78
|
-
# Force kill if still running
|
|
79
|
-
subprocess.run(["kill", "-9", str(pid)], check=False) # noqa: S603, S607
|
|
80
|
-
pid_file.unlink()
|
|
81
|
-
except Exception as e:
|
|
82
|
-
hud_console.error(f"Failed to kill vLLM server: {e}")
|
|
83
|
-
|
|
84
|
-
# Also try to kill by process name
|
|
85
|
-
subprocess.run(["pkill", "-f", "vllm serve"], check=False) # noqa: S607
|
|
86
|
-
subprocess.run(["pkill", "-f", "vllm.entrypoints.openai.api_server"], check=False) # noqa: S607
|
|
87
|
-
time.sleep(2)
|
|
88
|
-
|
|
89
|
-
# Check for any process using port 8000
|
|
90
|
-
result = subprocess.run(["lsof", "-ti:8000"], capture_output=True, text=True, check=False) # noqa: S607
|
|
91
|
-
|
|
92
|
-
if result.stdout.strip():
|
|
93
|
-
for pid in result.stdout.strip().split("\n"):
|
|
94
|
-
try:
|
|
95
|
-
subprocess.run(["kill", "-9", pid], check=False) # noqa: S603, S607
|
|
96
|
-
except Exception as e:
|
|
97
|
-
hud_console.error(f"Failed to kill vLLM server: {e}")
|
|
98
|
-
|
|
99
|
-
console.print("[yellow]Killed existing vLLM server processes[/yellow]")
|
|
100
|
-
except Exception as e:
|
|
101
|
-
hud_console.error(f"Error killing vLLM server: {e}")
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
def start_vllm_server(model_name: str, gpu_index: int = 1, restart: bool = False) -> None:
|
|
105
|
-
"""Start vLLM server in the background with dynamic GPU selection."""
|
|
106
|
-
if restart:
|
|
107
|
-
kill_vllm_server()
|
|
108
|
-
time.sleep(3)
|
|
109
|
-
|
|
110
|
-
# Check if already running
|
|
111
|
-
if check_vllm_server():
|
|
112
|
-
console.print("[green]vLLM server is already running[/green]")
|
|
113
|
-
return
|
|
114
|
-
|
|
115
|
-
console.print(f"[cyan]Starting vLLM server with {model_name} on GPU {gpu_index}...[/cyan]")
|
|
116
|
-
|
|
117
|
-
# Set up environment variables
|
|
118
|
-
env = os.environ.copy()
|
|
119
|
-
env.update(
|
|
120
|
-
{
|
|
121
|
-
"CUDA_VISIBLE_DEVICES": str(gpu_index),
|
|
122
|
-
"VLLM_ALLOW_RUNTIME_LORA_UPDATING": "True",
|
|
123
|
-
"TOKENIZERS_PARALLELISM": "false",
|
|
124
|
-
"VLLM_LOGGING_LEVEL": "INFO", # Changed from DEBUG to reduce noise
|
|
125
|
-
"CUDA_LAUNCH_BLOCKING": "1", # Better error messages
|
|
126
|
-
}
|
|
127
|
-
)
|
|
128
|
-
|
|
129
|
-
# Get the path to chat template
|
|
130
|
-
chat_template_path = Path(__file__).parent.parent.parent / "rl" / "chat_template.jinja"
|
|
131
|
-
|
|
132
|
-
# Build the vLLM command
|
|
133
|
-
vllm_args = get_vllm_args(model_name, chat_template_path)
|
|
134
|
-
cmd = ["uv", "run", "vllm", *vllm_args]
|
|
135
|
-
|
|
136
|
-
# Start the server in the background
|
|
137
|
-
with open("/tmp/vllm_server.log", "w") as log_file: # noqa: S108,
|
|
138
|
-
process = subprocess.Popen( # noqa: S603
|
|
139
|
-
cmd,
|
|
140
|
-
env=env,
|
|
141
|
-
stdout=log_file,
|
|
142
|
-
stderr=subprocess.STDOUT,
|
|
143
|
-
preexec_fn=os.setpgrp, # type: ignore
|
|
144
|
-
cwd=Path.cwd(), # Use current working directory
|
|
145
|
-
)
|
|
146
|
-
|
|
147
|
-
console.print("[yellow]vLLM server starting in background...[/yellow]")
|
|
148
|
-
console.print(f"[yellow]Process ID: {process.pid}[/yellow]")
|
|
149
|
-
console.print("[yellow]Check logs at: /tmp/vllm_server.log[/yellow]")
|
|
150
|
-
|
|
151
|
-
# Save PID for later management
|
|
152
|
-
pid_file = Path("/tmp/vllm_server.pid") # noqa: S108
|
|
153
|
-
pid_file.write_text(str(process.pid))
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
async def wait_for_vllm_server(timeout: int = 360) -> bool: # noqa: ASYNC109
|
|
157
|
-
"""Wait for vLLM server to be ready."""
|
|
158
|
-
start_time = time.time()
|
|
159
|
-
console.print("[yellow]Waiting for vLLM server to be ready (up to 6 minutes)...[/yellow]")
|
|
160
|
-
|
|
161
|
-
async with httpx.AsyncClient() as client:
|
|
162
|
-
while time.time() - start_time < timeout:
|
|
163
|
-
try:
|
|
164
|
-
response = await client.get("http://localhost:8000/health", timeout=2.0)
|
|
165
|
-
if response.status_code == 200:
|
|
166
|
-
console.print("[green]✅ vLLM server is ready![/green]")
|
|
167
|
-
return True
|
|
168
|
-
except Exception as e:
|
|
169
|
-
hud_console.error(f"Failed to connect to vLLM server: {e}")
|
|
170
|
-
|
|
171
|
-
await asyncio.sleep(2)
|
|
172
|
-
elapsed = int(time.time() - start_time)
|
|
173
|
-
console.print(f"[yellow]Waiting... ({elapsed}s / {timeout}s)[/yellow]", end="\r")
|
|
174
|
-
|
|
175
|
-
console.print("\n[red]❌ vLLM server failed to start within timeout[/red]")
|
|
176
|
-
console.print("[yellow]Check /tmp/vllm_server.log for details[/yellow]")
|
|
177
|
-
return False
|
hud/cli/rl/wait_utils.py
DELETED
|
@@ -1,89 +0,0 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
|
-
import contextlib
|
|
4
|
-
import os
|
|
5
|
-
import select
|
|
6
|
-
import sys
|
|
7
|
-
import threading
|
|
8
|
-
import time as _time
|
|
9
|
-
from typing import TYPE_CHECKING
|
|
10
|
-
|
|
11
|
-
from watchfiles import watch
|
|
12
|
-
|
|
13
|
-
if TYPE_CHECKING:
|
|
14
|
-
from pathlib import Path
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
def wait_for_enter_cancel_or_change(file_path: Path) -> tuple[bool, bool, bool]:
|
|
18
|
-
"""Block until Enter (start), 'q' (cancel), or file change.
|
|
19
|
-
|
|
20
|
-
Returns (start_training, cancelled, changed).
|
|
21
|
-
- start_training: True if Enter (or any non-'q' line on POSIX) was received
|
|
22
|
-
- cancelled: True if 'q' was received or Ctrl-C
|
|
23
|
-
- changed: True if the file changed on disk
|
|
24
|
-
"""
|
|
25
|
-
start_training = False
|
|
26
|
-
cancelled = False
|
|
27
|
-
changed = False
|
|
28
|
-
|
|
29
|
-
stop_evt: threading.Event = threading.Event()
|
|
30
|
-
changed_evt: threading.Event = threading.Event()
|
|
31
|
-
|
|
32
|
-
def _watcher() -> None:
|
|
33
|
-
with contextlib.suppress(Exception):
|
|
34
|
-
for _ in watch(file_path, stop_event=stop_evt, debounce=200):
|
|
35
|
-
changed_evt.set()
|
|
36
|
-
break
|
|
37
|
-
|
|
38
|
-
t = threading.Thread(target=_watcher, daemon=True)
|
|
39
|
-
t.start()
|
|
40
|
-
|
|
41
|
-
try:
|
|
42
|
-
if os.name == "nt":
|
|
43
|
-
import msvcrt # type: ignore[attr-defined]
|
|
44
|
-
|
|
45
|
-
while True:
|
|
46
|
-
if changed_evt.is_set():
|
|
47
|
-
changed = True
|
|
48
|
-
break
|
|
49
|
-
|
|
50
|
-
if msvcrt.kbhit():
|
|
51
|
-
ch = msvcrt.getwch()
|
|
52
|
-
if ch in ("\r", "\n"):
|
|
53
|
-
start_training = True
|
|
54
|
-
break
|
|
55
|
-
if ch.lower() == "q":
|
|
56
|
-
cancelled = True
|
|
57
|
-
break
|
|
58
|
-
_time.sleep(0.15)
|
|
59
|
-
else:
|
|
60
|
-
while True:
|
|
61
|
-
if changed_evt.is_set():
|
|
62
|
-
changed = True
|
|
63
|
-
break
|
|
64
|
-
|
|
65
|
-
rlist, _, _ = select.select([sys.stdin], [], [], 0.25)
|
|
66
|
-
if rlist:
|
|
67
|
-
line = sys.stdin.readline()
|
|
68
|
-
if line is None:
|
|
69
|
-
continue
|
|
70
|
-
stripped = line.strip().lower()
|
|
71
|
-
if stripped == "q":
|
|
72
|
-
cancelled = True
|
|
73
|
-
break
|
|
74
|
-
# Any other (including empty) => start
|
|
75
|
-
start_training = True
|
|
76
|
-
break
|
|
77
|
-
_time.sleep(0.05)
|
|
78
|
-
|
|
79
|
-
except KeyboardInterrupt:
|
|
80
|
-
cancelled = True
|
|
81
|
-
finally:
|
|
82
|
-
stop_evt.set()
|
|
83
|
-
with contextlib.suppress(Exception):
|
|
84
|
-
t.join(timeout=1)
|
|
85
|
-
|
|
86
|
-
return start_training, cancelled, changed
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
__all__ = ["wait_for_enter_cancel_or_change"]
|