hud-python 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of hud-python might be problematic. Click here for more details.
- hud/__init__.py +3 -2
- hud/adapters/__init__.py +2 -1
- hud/adapters/claude/adapter.py +15 -2
- hud/adapters/common/types.py +7 -3
- hud/adapters/operator/adapter.py +10 -6
- hud/agent/__init__.py +2 -1
- hud/agent/claude.py +22 -2
- hud/agent/langchain.py +198 -0
- hud/agent/operator.py +35 -17
- hud/env/docker_client.py +1 -1
- hud/env/environment.py +182 -9
- hud/env/local_docker_client.py +3 -1
- hud/env/remote_client.py +4 -0
- hud/gym.py +3 -3
- hud/job.py +420 -12
- hud/task.py +41 -30
- hud/taskset.py +8 -0
- hud/types.py +5 -3
- hud/utils/common.py +31 -1
- hud/utils/config.py +2 -93
- hud/utils/progress.py +136 -0
- {hud_python-0.2.0.dist-info → hud_python-0.2.2.dist-info}/METADATA +52 -39
- hud_python-0.2.2.dist-info/RECORD +46 -0
- hud_python-0.2.0.dist-info/RECORD +0 -44
- {hud_python-0.2.0.dist-info → hud_python-0.2.2.dist-info}/WHEEL +0 -0
- {hud_python-0.2.0.dist-info → hud_python-0.2.2.dist-info}/licenses/LICENSE +0 -0
hud/job.py
CHANGED
|
@@ -1,17 +1,27 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import asyncio
|
|
3
4
|
import datetime
|
|
4
5
|
import functools
|
|
5
6
|
import inspect
|
|
6
7
|
import logging
|
|
7
|
-
|
|
8
|
-
from
|
|
8
|
+
import sys
|
|
9
|
+
from collections.abc import Callable, Coroutine
|
|
10
|
+
from typing import TYPE_CHECKING, Any, TypeVar, cast
|
|
9
11
|
|
|
10
|
-
from pydantic import BaseModel, TypeAdapter
|
|
12
|
+
from pydantic import BaseModel, PrivateAttr, TypeAdapter
|
|
11
13
|
|
|
14
|
+
from hud import gym
|
|
12
15
|
from hud.server import make_request
|
|
13
16
|
from hud.settings import settings
|
|
17
|
+
from hud.task import Task
|
|
18
|
+
from hud.taskset import TaskSet
|
|
14
19
|
from hud.trajectory import Trajectory
|
|
20
|
+
from hud.utils.progress import StepProgressTracker
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from hud.adapters.common import Adapter
|
|
24
|
+
from hud.agent.base import Agent
|
|
15
25
|
|
|
16
26
|
logger = logging.getLogger("hud.job")
|
|
17
27
|
|
|
@@ -25,7 +35,7 @@ class Job(BaseModel):
|
|
|
25
35
|
"""
|
|
26
36
|
A job represents a collection of related trajectories.
|
|
27
37
|
It holds metadata and provides methods to interact with job data.
|
|
28
|
-
Instances should typically be obtained via `create_job` or `
|
|
38
|
+
Instances should typically be obtained via `create_job`, `load_job`, or the new `run_job`.
|
|
29
39
|
"""
|
|
30
40
|
|
|
31
41
|
id: str
|
|
@@ -34,23 +44,85 @@ class Job(BaseModel):
|
|
|
34
44
|
created_at: datetime.datetime
|
|
35
45
|
status: str
|
|
36
46
|
|
|
37
|
-
|
|
47
|
+
# Internal cache for trajectories
|
|
48
|
+
_trajectories: list[Trajectory] | None = PrivateAttr(default=None)
|
|
49
|
+
# Store execution errors for debugging
|
|
50
|
+
errors: list[dict[str, Any]] = []
|
|
51
|
+
|
|
52
|
+
async def load_trajectories(
|
|
53
|
+
self, *, api_key: str | None = None, force_reload: bool = False
|
|
54
|
+
) -> list[Trajectory]:
|
|
38
55
|
"""
|
|
39
56
|
Loads the trajectories associated with this job.
|
|
57
|
+
Uses cached results unless force_reload is True.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
api_key: Optional API key.
|
|
61
|
+
force_reload: If True, fetches trajectories from the API even if cached.
|
|
40
62
|
|
|
41
63
|
Returns:
|
|
42
64
|
List[Trajectory]: The trajectories in the job
|
|
43
65
|
"""
|
|
66
|
+
if self._trajectories is not None and not force_reload:
|
|
67
|
+
logger.debug("Returning cached trajectories for Job %s", self.id)
|
|
68
|
+
return self._trajectories
|
|
69
|
+
|
|
70
|
+
logger.debug("Fetching trajectories for Job %s from API...", self.id)
|
|
44
71
|
api_key = api_key or settings.api_key
|
|
45
72
|
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
73
|
+
try:
|
|
74
|
+
data = await make_request(
|
|
75
|
+
method="GET",
|
|
76
|
+
url=f"{settings.base_url}/v2/jobs/{self.id}/trajectories",
|
|
77
|
+
api_key=api_key,
|
|
78
|
+
)
|
|
79
|
+
self._trajectories = TypeAdapter(list[Trajectory]).validate_python(data)
|
|
80
|
+
logger.debug("Loaded %d trajectories for Job %s", len(self._trajectories), self.id)
|
|
81
|
+
return self._trajectories
|
|
82
|
+
except Exception as e:
|
|
83
|
+
logger.exception("Failed to load trajectories for Job %s: %s", self.id, e)
|
|
84
|
+
self._trajectories = None # Ensure cache is cleared on error
|
|
85
|
+
return [] # Return empty list on error
|
|
86
|
+
|
|
87
|
+
async def get_analytics(self, *, force_reload: bool = False) -> dict[str, Any]:
|
|
88
|
+
"""
|
|
89
|
+
Calculates and returns analytics for the job based on its trajectories.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
force_reload: If True, re-fetches trajectories before calculating.
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
Dictionary containing analytics (e.g., task_count, avg_reward).
|
|
96
|
+
"""
|
|
97
|
+
trajectories = await self.load_trajectories(force_reload=force_reload)
|
|
51
98
|
|
|
52
|
-
|
|
99
|
+
task_count = len(trajectories)
|
|
100
|
+
if task_count == 0:
|
|
101
|
+
return {"task_count": 0, "avg_reward": None, "success_rate": None} # Or other default
|
|
102
|
+
|
|
103
|
+
total_reward = 0
|
|
104
|
+
successful_tasks = 0
|
|
105
|
+
valid_rewards = 0
|
|
53
106
|
|
|
107
|
+
for traj in trajectories:
|
|
108
|
+
# Example: Assume reward is numeric and success is reward >= 1.0
|
|
109
|
+
# Adjust based on actual trajectory data structure and evaluation logic
|
|
110
|
+
if isinstance(traj.reward, int | float):
|
|
111
|
+
total_reward += traj.reward
|
|
112
|
+
valid_rewards += 1
|
|
113
|
+
if traj.reward >= 1.0:
|
|
114
|
+
successful_tasks += 1
|
|
115
|
+
# Add more complex logic here if needed based on traj.evaluation_result or metadata
|
|
116
|
+
|
|
117
|
+
avg_reward = (total_reward / valid_rewards) if valid_rewards > 0 else None
|
|
118
|
+
success_rate = (successful_tasks / task_count) * 100 if task_count > 0 else None
|
|
119
|
+
|
|
120
|
+
return {
|
|
121
|
+
"task_count": task_count,
|
|
122
|
+
"avg_reward": avg_reward,
|
|
123
|
+
"success_rate": success_rate,
|
|
124
|
+
# Add other relevant stats here
|
|
125
|
+
}
|
|
54
126
|
|
|
55
127
|
async def create_job(name: str, gym_id: str | None = None,
|
|
56
128
|
evalset_id: str | None = None,
|
|
@@ -84,7 +156,9 @@ async def create_job(name: str, gym_id: str | None = None,
|
|
|
84
156
|
# or at least the necessary fields (id, name, metadata, created_at, status)
|
|
85
157
|
# If not, we might need to make a subsequent GET request
|
|
86
158
|
job_data = data # Adjust if the API response structure is different
|
|
87
|
-
|
|
159
|
+
|
|
160
|
+
logger.info("[HUD] View job at https://app.hud.so/jobs/%s.", job_data["id"])
|
|
161
|
+
|
|
88
162
|
return Job(
|
|
89
163
|
id=job_data["id"],
|
|
90
164
|
name=job_data["name"],
|
|
@@ -183,3 +257,337 @@ def get_active_job() -> Job | None:
|
|
|
183
257
|
frame = frame.f_back
|
|
184
258
|
|
|
185
259
|
return None
|
|
260
|
+
|
|
261
|
+
# --- Moved helper functions from runner.py ---
|
|
262
|
+
|
|
263
|
+
async def _execute_task(
|
|
264
|
+
agent_cls: type[Agent],
|
|
265
|
+
adapter_cls: type[Adapter] | None,
|
|
266
|
+
agent_kwargs: dict[str, Any] | None,
|
|
267
|
+
adapter_kwargs: dict[str, Any] | None,
|
|
268
|
+
task: Task,
|
|
269
|
+
job_name: str,
|
|
270
|
+
task_id: str,
|
|
271
|
+
max_steps_per_task: int,
|
|
272
|
+
job: Job,
|
|
273
|
+
tracker: StepProgressTracker | None = None,
|
|
274
|
+
# Use semaphores instead of rate limiter
|
|
275
|
+
env_creation_semaphore: asyncio.Semaphore | None = None,
|
|
276
|
+
agent_predict_semaphore: asyncio.Semaphore | None = None,
|
|
277
|
+
) -> None:
|
|
278
|
+
"""Helper function to instantiate/run/evaluate a single task, with concurrency limits via
|
|
279
|
+
semaphores."""
|
|
280
|
+
if tracker:
|
|
281
|
+
tracker.start_task(task_id)
|
|
282
|
+
env = None
|
|
283
|
+
agent_instance: Agent | None = None
|
|
284
|
+
status = "error"
|
|
285
|
+
error_msg = "Initialization failed"
|
|
286
|
+
try:
|
|
287
|
+
adapter_instance = None
|
|
288
|
+
if adapter_cls:
|
|
289
|
+
adapter_instance = adapter_cls(**(adapter_kwargs or {}))
|
|
290
|
+
agent_instance = agent_cls(adapter=adapter_instance, **(agent_kwargs or {}))
|
|
291
|
+
if agent_instance is None:
|
|
292
|
+
raise RuntimeError("Agent could not be instantiated")
|
|
293
|
+
|
|
294
|
+
# Environment creation with semaphore
|
|
295
|
+
if env_creation_semaphore:
|
|
296
|
+
async with env_creation_semaphore:
|
|
297
|
+
env = await gym.make(task, job=job)
|
|
298
|
+
else:
|
|
299
|
+
env = await gym.make(task, job=job)
|
|
300
|
+
|
|
301
|
+
obs_tuple = await env.reset()
|
|
302
|
+
if obs_tuple is None:
|
|
303
|
+
raise ValueError(f"env.reset() returned None for task {task_id}")
|
|
304
|
+
obs, _ = obs_tuple
|
|
305
|
+
|
|
306
|
+
step_error = None
|
|
307
|
+
for step in range(max_steps_per_task):
|
|
308
|
+
action, done = (None, False)
|
|
309
|
+
try:
|
|
310
|
+
# Agent prediction with semaphore
|
|
311
|
+
if agent_predict_semaphore:
|
|
312
|
+
async with agent_predict_semaphore:
|
|
313
|
+
action, done = await agent_instance.predict(obs)
|
|
314
|
+
else:
|
|
315
|
+
action, done = await agent_instance.predict(obs)
|
|
316
|
+
|
|
317
|
+
if tracker:
|
|
318
|
+
tracker.increment_step(task_id)
|
|
319
|
+
|
|
320
|
+
if action is None and not done:
|
|
321
|
+
done = True
|
|
322
|
+
|
|
323
|
+
step_result = await env.step(action)
|
|
324
|
+
if step_result is None:
|
|
325
|
+
terminated = True
|
|
326
|
+
else:
|
|
327
|
+
obs, _, terminated, _ = step_result
|
|
328
|
+
if terminated or done:
|
|
329
|
+
break
|
|
330
|
+
|
|
331
|
+
except Exception as agent_step_err:
|
|
332
|
+
logger.exception("[Job: %s/%s, Task: %s] Step %d Error: %s", job.name, job.id,
|
|
333
|
+
task_id, step + 1, agent_step_err)
|
|
334
|
+
step_error = f"Error at step {step + 1}: {agent_step_err}"
|
|
335
|
+
# Store step error in job
|
|
336
|
+
job.errors.append({
|
|
337
|
+
"task_id": task_id,
|
|
338
|
+
"type": "step_error",
|
|
339
|
+
"step": step + 1,
|
|
340
|
+
"error": str(agent_step_err),
|
|
341
|
+
"timestamp": datetime.datetime.now().isoformat()
|
|
342
|
+
})
|
|
343
|
+
break
|
|
344
|
+
else:
|
|
345
|
+
logger.warning("[Job: %s/%s, Task: %s] Max steps reached.", job.name, job.id, task_id)
|
|
346
|
+
|
|
347
|
+
# --- Evaluate Task ---
|
|
348
|
+
evaluation_result = None
|
|
349
|
+
if step_error:
|
|
350
|
+
status = "error"
|
|
351
|
+
error_msg = step_error
|
|
352
|
+
else:
|
|
353
|
+
try:
|
|
354
|
+
evaluation_result = await env.evaluate()
|
|
355
|
+
status = "completed"
|
|
356
|
+
error_msg = None
|
|
357
|
+
except Exception as eval_err:
|
|
358
|
+
logger.exception("[Job: %s/%s, Task: %s] Evaluation Error: %s", job.name,
|
|
359
|
+
job.id, task_id, eval_err)
|
|
360
|
+
status = "error"
|
|
361
|
+
error_msg = f"Evaluation failed: {eval_err}"
|
|
362
|
+
# Store evaluation error in job
|
|
363
|
+
job.errors.append({
|
|
364
|
+
"task_id": task_id,
|
|
365
|
+
"type": "evaluation_error",
|
|
366
|
+
"error": str(eval_err),
|
|
367
|
+
"timestamp": datetime.datetime.now().isoformat()
|
|
368
|
+
})
|
|
369
|
+
|
|
370
|
+
except Exception as e:
|
|
371
|
+
logger.exception("[Job: %s/%s, Task: %s] Setup/Run Error: %s", job.name, job.id, task_id, e)
|
|
372
|
+
status = "error"
|
|
373
|
+
error_msg = str(e)
|
|
374
|
+
# Store setup/initialization error in job
|
|
375
|
+
job.errors.append({
|
|
376
|
+
"task_id": task_id,
|
|
377
|
+
"type": "setup_error",
|
|
378
|
+
"error": str(e),
|
|
379
|
+
"timestamp": datetime.datetime.now().isoformat()
|
|
380
|
+
})
|
|
381
|
+
|
|
382
|
+
finally:
|
|
383
|
+
if tracker:
|
|
384
|
+
tracker.finish_task(task_id)
|
|
385
|
+
if env:
|
|
386
|
+
try:
|
|
387
|
+
await env.close()
|
|
388
|
+
except Exception as close_err:
|
|
389
|
+
logger.exception("[Job: %s/%s, Task: %s] Close Error: %s", job.name, job.id,
|
|
390
|
+
task_id, close_err)
|
|
391
|
+
# Store environment close error in job
|
|
392
|
+
job.errors.append({
|
|
393
|
+
"task_id": task_id,
|
|
394
|
+
"type": "env_close_error",
|
|
395
|
+
"error": str(close_err),
|
|
396
|
+
"timestamp": datetime.datetime.now().isoformat()
|
|
397
|
+
})
|
|
398
|
+
|
|
399
|
+
log_suffix = f" Error: {error_msg}" if status == "error" else f" Eval: {evaluation_result}"
|
|
400
|
+
logger.info("[Job: %s/%s, Task: %s] Finished local execution. Status: %s.%s", job.name,
|
|
401
|
+
job.id, task_id, status, log_suffix)
|
|
402
|
+
|
|
403
|
+
async def _progress_monitor(tracker: StepProgressTracker, interval: float = 1.0) -> None:
|
|
404
|
+
"""Coroutine to periodically display progress using the tracker."""
|
|
405
|
+
try:
|
|
406
|
+
while not tracker.is_finished():
|
|
407
|
+
sys.stderr.write(f"\r{tracker.display()}")
|
|
408
|
+
sys.stderr.flush()
|
|
409
|
+
await asyncio.sleep(interval)
|
|
410
|
+
sys.stderr.write(f"\r{tracker.display()}\n")
|
|
411
|
+
sys.stderr.flush()
|
|
412
|
+
logger.debug("Progress monitor finished.")
|
|
413
|
+
except asyncio.CancelledError:
|
|
414
|
+
sys.stderr.write("\nProgress monitor cancelled.\n")
|
|
415
|
+
sys.stderr.flush()
|
|
416
|
+
logger.debug("Progress monitor cancelled.")
|
|
417
|
+
except Exception as e:
|
|
418
|
+
sys.stderr.write(f"\nProgress monitor error: {e}\n")
|
|
419
|
+
sys.stderr.flush()
|
|
420
|
+
logger.exception("Progress monitor error: %s", e)
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
# --- New run_job function ---
|
|
424
|
+
|
|
425
|
+
async def run_job(
|
|
426
|
+
agent_cls: type[Agent],
|
|
427
|
+
task_or_taskset: Task | TaskSet,
|
|
428
|
+
job_name: str,
|
|
429
|
+
adapter_cls: type[Adapter] | None = None,
|
|
430
|
+
agent_kwargs: dict[str, Any] | None = None,
|
|
431
|
+
adapter_kwargs: dict[str, Any] | None = None,
|
|
432
|
+
max_steps_per_task: int = 20,
|
|
433
|
+
run_parallel: bool = True,
|
|
434
|
+
job_metadata: dict[str, Any] | None = None,
|
|
435
|
+
show_progress: bool = True,
|
|
436
|
+
# Concurrency control with semaphores
|
|
437
|
+
max_concurrent_env_creations: int | None = 30, # Limits env.make calls
|
|
438
|
+
max_concurrent_agent_predictions: int | None = 30, # Limits agent.predict calls
|
|
439
|
+
max_concurrent_tasks: int | None = 30, # Limits overall task concurrency
|
|
440
|
+
) -> Job:
|
|
441
|
+
"""
|
|
442
|
+
Creates Job, executes tasks locally, linking them to the Job.
|
|
443
|
+
Instantiates agent/adapter per task. Shows step-based progress.
|
|
444
|
+
|
|
445
|
+
Controls concurrency in three ways:
|
|
446
|
+
1. Limits concurrent environment creations
|
|
447
|
+
2. Limits concurrent agent predictions
|
|
448
|
+
3. Limits overall concurrent tasks (when run_parallel=True)
|
|
449
|
+
|
|
450
|
+
All concurrency controls use semaphores for reliability.
|
|
451
|
+
Tracks all errors that occur during execution in job.errors.
|
|
452
|
+
|
|
453
|
+
Args:
|
|
454
|
+
agent_cls: Agent class to instantiate.
|
|
455
|
+
task_or_taskset: Task or TaskSet to run.
|
|
456
|
+
job_name: Name for the Job.
|
|
457
|
+
adapter_cls: Optional Adapter class.
|
|
458
|
+
agent_kwargs: Optional kwargs for agent constructor.
|
|
459
|
+
adapter_kwargs: Optional kwargs for adapter constructor.
|
|
460
|
+
max_steps_per_task: Step limit per task.
|
|
461
|
+
run_parallel: Run TaskSet tasks concurrently if True (limited by max_concurrent_tasks).
|
|
462
|
+
job_metadata: Metadata for the created Job.
|
|
463
|
+
show_progress: Display the step-based progress tracker.
|
|
464
|
+
max_concurrent_env_creations: Max concurrent environment creation calls.
|
|
465
|
+
max_concurrent_agent_predictions: Max concurrent agent prediction calls.
|
|
466
|
+
max_concurrent_tasks: Max number of tasks to run actively at the same time.
|
|
467
|
+
|
|
468
|
+
Returns:
|
|
469
|
+
The created Job object with errors stored in job.errors.
|
|
470
|
+
"""
|
|
471
|
+
tasks_to_run: list[Task] = []
|
|
472
|
+
created_job: Job | None = None
|
|
473
|
+
|
|
474
|
+
# --- Create Job ---
|
|
475
|
+
try:
|
|
476
|
+
logger.info("Creating job with name: '%s'", job_name)
|
|
477
|
+
created_job = await create_job(name=job_name, metadata=job_metadata)
|
|
478
|
+
logger.info("Created job with ID: %s", created_job.id)
|
|
479
|
+
except Exception as e:
|
|
480
|
+
logger.exception("Failed to create job '%s': %s", job_name, e)
|
|
481
|
+
raise
|
|
482
|
+
|
|
483
|
+
# --- Task Setup ---
|
|
484
|
+
is_taskset = isinstance(task_or_taskset, TaskSet)
|
|
485
|
+
if is_taskset:
|
|
486
|
+
tasks_to_run = task_or_taskset.tasks if task_or_taskset.tasks else []
|
|
487
|
+
elif isinstance(task_or_taskset, Task):
|
|
488
|
+
tasks_to_run = [task_or_taskset]
|
|
489
|
+
run_parallel = False
|
|
490
|
+
else:
|
|
491
|
+
raise TypeError("task_or_taskset must be either a Task or a TaskSet")
|
|
492
|
+
|
|
493
|
+
if not tasks_to_run:
|
|
494
|
+
logger.warning("Job '%s' (%s): No tasks found to run.", created_job.name, created_job.id)
|
|
495
|
+
return created_job
|
|
496
|
+
|
|
497
|
+
task_ids = [(str(task.id) if task.id else f"task_{i}") for i, task in enumerate(tasks_to_run)]
|
|
498
|
+
num_tasks = len(tasks_to_run)
|
|
499
|
+
|
|
500
|
+
# --- Create semaphores for concurrency control ---
|
|
501
|
+
env_creation_sema = None
|
|
502
|
+
if max_concurrent_env_creations and max_concurrent_env_creations > 0:
|
|
503
|
+
env_creation_sema = asyncio.Semaphore(max_concurrent_env_creations)
|
|
504
|
+
logger.info("Limiting concurrent environment creations to %d.",
|
|
505
|
+
max_concurrent_env_creations)
|
|
506
|
+
|
|
507
|
+
agent_predict_sema = None
|
|
508
|
+
if max_concurrent_agent_predictions and max_concurrent_agent_predictions > 0:
|
|
509
|
+
agent_predict_sema = asyncio.Semaphore(max_concurrent_agent_predictions)
|
|
510
|
+
logger.info("Limiting concurrent agent predictions to %d.",
|
|
511
|
+
max_concurrent_agent_predictions)
|
|
512
|
+
|
|
513
|
+
task_execution_sema = None
|
|
514
|
+
effective_concurrency = num_tasks # Default to running all if parallel
|
|
515
|
+
if run_parallel and max_concurrent_tasks and max_concurrent_tasks > 0:
|
|
516
|
+
effective_concurrency = min(num_tasks, max_concurrent_tasks)
|
|
517
|
+
task_execution_sema = asyncio.Semaphore(effective_concurrency)
|
|
518
|
+
logger.info("Limiting concurrent task executions to %d.", effective_concurrency)
|
|
519
|
+
elif not run_parallel:
|
|
520
|
+
effective_concurrency = 1 # Sequential means concurrency of 1
|
|
521
|
+
|
|
522
|
+
# --- Instantiate Tracker & Start Monitor ---
|
|
523
|
+
tracker = None
|
|
524
|
+
monitor_task = None
|
|
525
|
+
if show_progress and num_tasks > 0:
|
|
526
|
+
tracker = StepProgressTracker(total_tasks=num_tasks, max_steps_per_task=max_steps_per_task)
|
|
527
|
+
monitor_task = asyncio.create_task(_progress_monitor(tracker))
|
|
528
|
+
|
|
529
|
+
# --- Execute Tasks ---
|
|
530
|
+
job_desc_suffix = f" (Job ID: {created_job.id})"
|
|
531
|
+
|
|
532
|
+
async def task_wrapper(task_coro: Coroutine, semaphore: asyncio.Semaphore | None) -> None:
|
|
533
|
+
if semaphore:
|
|
534
|
+
async with semaphore:
|
|
535
|
+
await task_coro
|
|
536
|
+
else:
|
|
537
|
+
await task_coro
|
|
538
|
+
|
|
539
|
+
try:
|
|
540
|
+
if run_parallel and is_taskset:
|
|
541
|
+
logger.info("Job '%s'%s: Running %d tasks with concurrency %d.", created_job.name,
|
|
542
|
+
job_desc_suffix, num_tasks, effective_concurrency)
|
|
543
|
+
|
|
544
|
+
task_coroutines = [
|
|
545
|
+
_execute_task(
|
|
546
|
+
agent_cls=agent_cls, adapter_cls=adapter_cls, agent_kwargs=agent_kwargs,
|
|
547
|
+
adapter_kwargs=adapter_kwargs, task=task, job_name=created_job.name,
|
|
548
|
+
task_id=task_id,
|
|
549
|
+
max_steps_per_task=max_steps_per_task, job=created_job, tracker=tracker,
|
|
550
|
+
env_creation_semaphore=env_creation_sema,
|
|
551
|
+
agent_predict_semaphore=agent_predict_sema,
|
|
552
|
+
)
|
|
553
|
+
for task, task_id in zip(tasks_to_run, task_ids, strict=True)
|
|
554
|
+
]
|
|
555
|
+
|
|
556
|
+
# Wrap coroutines with semaphore management if limiting concurrency
|
|
557
|
+
wrapped_tasks = [
|
|
558
|
+
task_wrapper(coro, task_execution_sema)
|
|
559
|
+
for i, coro in enumerate(task_coroutines)
|
|
560
|
+
]
|
|
561
|
+
|
|
562
|
+
# Run all wrapped tasks
|
|
563
|
+
await asyncio.gather(*wrapped_tasks)
|
|
564
|
+
|
|
565
|
+
else:
|
|
566
|
+
# SEQUENTIAL (or single task)
|
|
567
|
+
logger.info("Job '%s'%s: Running %d tasks sequentially.", created_job.name,
|
|
568
|
+
job_desc_suffix, num_tasks)
|
|
569
|
+
for i, task in enumerate(tasks_to_run):
|
|
570
|
+
task_id = task_ids[i]
|
|
571
|
+
await _execute_task(
|
|
572
|
+
agent_cls=agent_cls, adapter_cls=adapter_cls, agent_kwargs=agent_kwargs,
|
|
573
|
+
adapter_kwargs=adapter_kwargs, task=task, job_name=created_job.name,
|
|
574
|
+
task_id=task_id,
|
|
575
|
+
max_steps_per_task=max_steps_per_task, job=created_job, tracker=tracker,
|
|
576
|
+
env_creation_semaphore=env_creation_sema,
|
|
577
|
+
agent_predict_semaphore=agent_predict_sema,
|
|
578
|
+
)
|
|
579
|
+
|
|
580
|
+
finally:
|
|
581
|
+
# Ensure monitor task is stopped and awaited cleanly
|
|
582
|
+
if monitor_task is not None and not monitor_task.done():
|
|
583
|
+
monitor_task.cancel()
|
|
584
|
+
try:
|
|
585
|
+
await monitor_task
|
|
586
|
+
except asyncio.CancelledError:
|
|
587
|
+
pass
|
|
588
|
+
except Exception as e:
|
|
589
|
+
logger.error("Error awaiting progress monitor task: %s", e)
|
|
590
|
+
|
|
591
|
+
logger.info("Job '%s'%s finished local execution phase for %d tasks.", created_job.name,
|
|
592
|
+
job_desc_suffix, num_tasks)
|
|
593
|
+
return created_job
|
hud/task.py
CHANGED
|
@@ -5,8 +5,7 @@ from typing import TYPE_CHECKING, Any
|
|
|
5
5
|
from pydantic import BaseModel
|
|
6
6
|
|
|
7
7
|
from hud.types import CustomGym, Gym
|
|
8
|
-
from hud.utils import HudStyleConfig
|
|
9
|
-
from hud.utils.config import HudStyleConfigs
|
|
8
|
+
from hud.utils.common import HudStyleConfig, HudStyleConfigs
|
|
10
9
|
|
|
11
10
|
if TYPE_CHECKING:
|
|
12
11
|
from inspect_ai.dataset import Sample
|
|
@@ -35,7 +34,7 @@ class Task(BaseModel):
|
|
|
35
34
|
|
|
36
35
|
The setup and evaluate configurations can be in several formats:
|
|
37
36
|
- String (function name): "chrome.maximize"
|
|
38
|
-
-
|
|
37
|
+
- Tuple (function with args): ("chrome.activate_tab", 5)
|
|
39
38
|
- Dict: {"function": "chrome.navigate", "args": ["https://example.com"]}
|
|
40
39
|
- List of the above: ["chrome.maximize", {"function": "chrome.navigate", "args": ["https://example.com"]}]
|
|
41
40
|
|
|
@@ -68,15 +67,15 @@ class Task(BaseModel):
|
|
|
68
67
|
@classmethod
|
|
69
68
|
def from_inspect_sample(cls, sample: Sample) -> Task:
|
|
70
69
|
"""Create a Task from an Inspect dataset sample.
|
|
71
|
-
|
|
72
|
-
|
|
70
|
+
Automatically detects if a CustomGym (docker) or QA Gym is needed based on sample.sandbox.
|
|
71
|
+
Configures evaluation using 'response_includes' or 'match_all' based on sample.target.
|
|
73
72
|
|
|
74
73
|
Args:
|
|
75
74
|
sample: An Inspect dataset Sample object
|
|
76
75
|
|
|
77
76
|
Returns:
|
|
78
77
|
Task instance
|
|
79
|
-
|
|
78
|
+
|
|
80
79
|
The Inspect Sample has these fields:
|
|
81
80
|
- input (str | list[ChatMessage]): The input to be submitted to the model
|
|
82
81
|
- choices (list[str] | None): Optional multiple choice answer list
|
|
@@ -87,10 +86,8 @@ class Task(BaseModel):
|
|
|
87
86
|
- files (dict[str, str] | None): Optional files that go with the sample
|
|
88
87
|
- setup (str | None): Optional setup script to run for sample
|
|
89
88
|
"""
|
|
90
|
-
# Extract the input as prompt
|
|
91
89
|
prompt = sample.input
|
|
92
|
-
if isinstance(prompt, list):
|
|
93
|
-
# Convert chat message list to a string representation
|
|
90
|
+
if isinstance(prompt, list):
|
|
94
91
|
prompt_parts = []
|
|
95
92
|
for message in prompt:
|
|
96
93
|
role = message.role
|
|
@@ -98,36 +95,50 @@ class Task(BaseModel):
|
|
|
98
95
|
prompt_parts.append(f"{role.capitalize()}: {content}")
|
|
99
96
|
prompt = "\n\n".join(prompt_parts)
|
|
100
97
|
|
|
101
|
-
|
|
98
|
+
evaluate_config = None
|
|
99
|
+
if sample.target:
|
|
100
|
+
if isinstance(sample.target, str):
|
|
101
|
+
evaluate_config = ("response_includes", [sample.target])
|
|
102
|
+
elif isinstance(sample.target, list):
|
|
103
|
+
evaluate_config = ("match_all", sample.target)
|
|
104
|
+
|
|
105
|
+
task_gym: Gym | None = None
|
|
106
|
+
task_setup: HudStyleConfigs | None = None
|
|
107
|
+
|
|
102
108
|
sandbox = sample.sandbox
|
|
103
109
|
dockerfile = None
|
|
110
|
+
use_qa_gym = True
|
|
111
|
+
|
|
104
112
|
if sandbox:
|
|
105
113
|
if isinstance(sandbox, str):
|
|
106
|
-
if sandbox
|
|
107
|
-
|
|
114
|
+
if sandbox == "docker":
|
|
115
|
+
dockerfile = UBUNTU_DOCKERFILE
|
|
116
|
+
use_qa_gym = False
|
|
108
117
|
elif isinstance(sandbox, tuple) and len(sandbox) == 2:
|
|
109
118
|
sandbox_type, sandbox_config = sandbox
|
|
110
|
-
if sandbox_type
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
119
|
+
if sandbox_type == "docker":
|
|
120
|
+
dockerfile = sandbox_config
|
|
121
|
+
use_qa_gym = False
|
|
122
|
+
|
|
123
|
+
if use_qa_gym:
|
|
124
|
+
task_gym = "qa"
|
|
125
|
+
task_setup = None
|
|
126
|
+
else:
|
|
127
|
+
task_gym = CustomGym(
|
|
128
|
+
dockerfile=dockerfile or UBUNTU_DOCKERFILE,
|
|
129
|
+
location="local",
|
|
130
|
+
)
|
|
131
|
+
task_setup = [x for x in convert_inspect_setup(sample.setup)] if sample.setup else None
|
|
132
|
+
# TODO: Handle sample.files for CustomGym case if needed
|
|
133
|
+
|
|
120
134
|
|
|
121
135
|
return cls(
|
|
122
|
-
id=
|
|
136
|
+
id=None,
|
|
123
137
|
prompt=prompt,
|
|
124
|
-
setup=
|
|
138
|
+
setup=task_setup,
|
|
125
139
|
metadata=sample.metadata,
|
|
126
140
|
choices=sample.choices,
|
|
127
|
-
|
|
128
|
-
gym=
|
|
141
|
+
evaluate=evaluate_config,
|
|
142
|
+
gym=task_gym,
|
|
143
|
+
# files=sample.files, # TODO: Decide how/if to handle files
|
|
129
144
|
)
|
|
130
|
-
|
|
131
|
-
def convert_sdk01(self) -> None:
|
|
132
|
-
self.setup = [HudStyleConfig(function="reset", args=[{"task_id": self.id}])]
|
|
133
|
-
self.evaluate = [HudStyleConfig(function="evaluate", args=[])]
|
hud/taskset.py
CHANGED
|
@@ -9,6 +9,8 @@ from hud.settings import settings
|
|
|
9
9
|
from hud.task import Task
|
|
10
10
|
|
|
11
11
|
if TYPE_CHECKING:
|
|
12
|
+
from collections.abc import Iterator
|
|
13
|
+
|
|
12
14
|
from inspect_ai.dataset import Dataset
|
|
13
15
|
|
|
14
16
|
|
|
@@ -49,6 +51,12 @@ class TaskSet(BaseModel):
|
|
|
49
51
|
"""
|
|
50
52
|
return len(self.tasks)
|
|
51
53
|
|
|
54
|
+
def __iter__(self) -> Iterator[Task]:
|
|
55
|
+
"""
|
|
56
|
+
Returns an iterator over the tasks in the taskset.
|
|
57
|
+
"""
|
|
58
|
+
return iter(self.tasks)
|
|
59
|
+
|
|
52
60
|
|
|
53
61
|
async def load_taskset(taskset_id: str, api_key: str | None = None) -> TaskSet:
|
|
54
62
|
"""
|
hud/types.py
CHANGED
|
@@ -44,9 +44,6 @@ class CustomGym(BaseModel):
|
|
|
44
44
|
# Read the Dockerfile content
|
|
45
45
|
self.dockerfile = dockerfile_path.read_text()
|
|
46
46
|
|
|
47
|
-
# Strings are identifiers for gyms on the HUD server
|
|
48
|
-
Gym = CustomGym | str
|
|
49
|
-
|
|
50
47
|
class EnvironmentStatus(str, enum.Enum):
|
|
51
48
|
"""
|
|
52
49
|
Status of the environment.
|
|
@@ -63,3 +60,8 @@ class EnvironmentStatus(str, enum.Enum):
|
|
|
63
60
|
COMPLETED = "completed"
|
|
64
61
|
ERROR = "error"
|
|
65
62
|
|
|
63
|
+
# Available HUD gyms
|
|
64
|
+
ServerGym = Literal["qa", "hud-browser", "hud-ubuntu", "OSWorld-Ubuntu"]
|
|
65
|
+
|
|
66
|
+
# Gyms can be either custom or server-side
|
|
67
|
+
Gym = CustomGym | ServerGym
|
hud/utils/common.py
CHANGED
|
@@ -3,16 +3,46 @@ from __future__ import annotations
|
|
|
3
3
|
import io
|
|
4
4
|
import logging
|
|
5
5
|
import tarfile
|
|
6
|
-
from typing import TYPE_CHECKING, TypedDict
|
|
6
|
+
from typing import TYPE_CHECKING, Any, TypedDict
|
|
7
|
+
|
|
8
|
+
from pydantic import BaseModel
|
|
7
9
|
|
|
8
10
|
from hud.server.requests import make_request
|
|
9
11
|
from hud.settings import settings
|
|
10
12
|
|
|
11
13
|
if TYPE_CHECKING:
|
|
14
|
+
from collections.abc import Iterator
|
|
12
15
|
from pathlib import Path
|
|
13
16
|
|
|
14
17
|
logger = logging.getLogger("hud.utils.common")
|
|
15
18
|
|
|
19
|
+
class HudStyleConfig(BaseModel):
|
|
20
|
+
function: str # Format: "x.y.z"
|
|
21
|
+
args: list[Any] # Must be json serializable
|
|
22
|
+
|
|
23
|
+
id: str | None = None # Optional id for remote execution
|
|
24
|
+
|
|
25
|
+
def __len__(self) -> int:
|
|
26
|
+
return len(self.args)
|
|
27
|
+
|
|
28
|
+
def __getitem__(self, index: int) -> Any:
|
|
29
|
+
return self.args[index]
|
|
30
|
+
|
|
31
|
+
def __iter__(self) -> Iterator[Any]:
|
|
32
|
+
return iter(self.args)
|
|
33
|
+
|
|
34
|
+
def __str__(self) -> str:
|
|
35
|
+
return f"{self.function}: {', '.join(str(arg) for arg in self.args)}"
|
|
36
|
+
|
|
37
|
+
# Type alias for the shorthand config, which just converts to function name and args
|
|
38
|
+
ShorthandConfig = tuple[str | dict[str, Any] | list[str] | list[dict[str, Any]], ...]
|
|
39
|
+
|
|
40
|
+
# Type alias for multiple config formats
|
|
41
|
+
HudStyleConfigs = (
|
|
42
|
+
ShorthandConfig | HudStyleConfig | list[HudStyleConfig] | list[ShorthandConfig]
|
|
43
|
+
| dict[str, Any] | str
|
|
44
|
+
)
|
|
45
|
+
|
|
16
46
|
class ExecuteResult(TypedDict):
|
|
17
47
|
"""
|
|
18
48
|
Result of an execute command.
|