hud-python 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of hud-python might be problematic. Click here for more details.

hud/job.py CHANGED
@@ -1,17 +1,27 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import asyncio
3
4
  import datetime
4
5
  import functools
5
6
  import inspect
6
7
  import logging
7
- from collections.abc import Callable
8
- from typing import Any, TypeVar, cast
8
+ import sys
9
+ from collections.abc import Callable, Coroutine
10
+ from typing import TYPE_CHECKING, Any, TypeVar, cast
9
11
 
10
- from pydantic import BaseModel, TypeAdapter
12
+ from pydantic import BaseModel, PrivateAttr, TypeAdapter
11
13
 
14
+ from hud import gym
12
15
  from hud.server import make_request
13
16
  from hud.settings import settings
17
+ from hud.task import Task
18
+ from hud.taskset import TaskSet
14
19
  from hud.trajectory import Trajectory
20
+ from hud.utils.progress import StepProgressTracker
21
+
22
+ if TYPE_CHECKING:
23
+ from hud.adapters.common import Adapter
24
+ from hud.agent.base import Agent
15
25
 
16
26
  logger = logging.getLogger("hud.job")
17
27
 
@@ -25,7 +35,7 @@ class Job(BaseModel):
25
35
  """
26
36
  A job represents a collection of related trajectories.
27
37
  It holds metadata and provides methods to interact with job data.
28
- Instances should typically be obtained via `create_job` or `load_job`.
38
+ Instances should typically be obtained via `create_job`, `load_job`, or the new `run_job`.
29
39
  """
30
40
 
31
41
  id: str
@@ -34,23 +44,85 @@ class Job(BaseModel):
34
44
  created_at: datetime.datetime
35
45
  status: str
36
46
 
37
- async def load_trajectories(self, *, api_key: str | None = None) -> list[Trajectory]:
47
+ # Internal cache for trajectories
48
+ _trajectories: list[Trajectory] | None = PrivateAttr(default=None)
49
+ # Store execution errors for debugging
50
+ errors: list[dict[str, Any]] = []
51
+
52
+ async def load_trajectories(
53
+ self, *, api_key: str | None = None, force_reload: bool = False
54
+ ) -> list[Trajectory]:
38
55
  """
39
56
  Loads the trajectories associated with this job.
57
+ Uses cached results unless force_reload is True.
58
+
59
+ Args:
60
+ api_key: Optional API key.
61
+ force_reload: If True, fetches trajectories from the API even if cached.
40
62
 
41
63
  Returns:
42
64
  List[Trajectory]: The trajectories in the job
43
65
  """
66
+ if self._trajectories is not None and not force_reload:
67
+ logger.debug("Returning cached trajectories for Job %s", self.id)
68
+ return self._trajectories
69
+
70
+ logger.debug("Fetching trajectories for Job %s from API...", self.id)
44
71
  api_key = api_key or settings.api_key
45
72
 
46
- data = await make_request(
47
- method="GET",
48
- url=f"{settings.base_url}/v2/jobs/{self.id}/trajectories",
49
- api_key=api_key,
50
- )
73
+ try:
74
+ data = await make_request(
75
+ method="GET",
76
+ url=f"{settings.base_url}/v2/jobs/{self.id}/trajectories",
77
+ api_key=api_key,
78
+ )
79
+ self._trajectories = TypeAdapter(list[Trajectory]).validate_python(data)
80
+ logger.debug("Loaded %d trajectories for Job %s", len(self._trajectories), self.id)
81
+ return self._trajectories
82
+ except Exception as e:
83
+ logger.exception("Failed to load trajectories for Job %s: %s", self.id, e)
84
+ self._trajectories = None # Ensure cache is cleared on error
85
+ return [] # Return empty list on error
86
+
87
+ async def get_analytics(self, *, force_reload: bool = False) -> dict[str, Any]:
88
+ """
89
+ Calculates and returns analytics for the job based on its trajectories.
90
+
91
+ Args:
92
+ force_reload: If True, re-fetches trajectories before calculating.
93
+
94
+ Returns:
95
+ Dictionary containing analytics (e.g., task_count, avg_reward).
96
+ """
97
+ trajectories = await self.load_trajectories(force_reload=force_reload)
51
98
 
52
- return TypeAdapter(list[Trajectory]).validate_python(data)
99
+ task_count = len(trajectories)
100
+ if task_count == 0:
101
+ return {"task_count": 0, "avg_reward": None, "success_rate": None} # Or other default
102
+
103
+ total_reward = 0
104
+ successful_tasks = 0
105
+ valid_rewards = 0
53
106
 
107
+ for traj in trajectories:
108
+ # Example: Assume reward is numeric and success is reward >= 1.0
109
+ # Adjust based on actual trajectory data structure and evaluation logic
110
+ if isinstance(traj.reward, int | float):
111
+ total_reward += traj.reward
112
+ valid_rewards += 1
113
+ if traj.reward >= 1.0:
114
+ successful_tasks += 1
115
+ # Add more complex logic here if needed based on traj.evaluation_result or metadata
116
+
117
+ avg_reward = (total_reward / valid_rewards) if valid_rewards > 0 else None
118
+ success_rate = (successful_tasks / task_count) * 100 if task_count > 0 else None
119
+
120
+ return {
121
+ "task_count": task_count,
122
+ "avg_reward": avg_reward,
123
+ "success_rate": success_rate,
124
+ # Add other relevant stats here
125
+ }
54
126
 
55
127
  async def create_job(name: str, gym_id: str | None = None,
56
128
  evalset_id: str | None = None,
@@ -84,7 +156,9 @@ async def create_job(name: str, gym_id: str | None = None,
84
156
  # or at least the necessary fields (id, name, metadata, created_at, status)
85
157
  # If not, we might need to make a subsequent GET request
86
158
  job_data = data # Adjust if the API response structure is different
87
-
159
+
160
+ logger.info("[HUD] View job at https://app.hud.so/jobs/%s.", job_data["id"])
161
+
88
162
  return Job(
89
163
  id=job_data["id"],
90
164
  name=job_data["name"],
@@ -183,3 +257,337 @@ def get_active_job() -> Job | None:
183
257
  frame = frame.f_back
184
258
 
185
259
  return None
260
+
261
+ # --- Moved helper functions from runner.py ---
262
+
263
+ async def _execute_task(
264
+ agent_cls: type[Agent],
265
+ adapter_cls: type[Adapter] | None,
266
+ agent_kwargs: dict[str, Any] | None,
267
+ adapter_kwargs: dict[str, Any] | None,
268
+ task: Task,
269
+ job_name: str,
270
+ task_id: str,
271
+ max_steps_per_task: int,
272
+ job: Job,
273
+ tracker: StepProgressTracker | None = None,
274
+ # Use semaphores instead of rate limiter
275
+ env_creation_semaphore: asyncio.Semaphore | None = None,
276
+ agent_predict_semaphore: asyncio.Semaphore | None = None,
277
+ ) -> None:
278
+ """Helper function to instantiate/run/evaluate a single task, with concurrency limits via
279
+ semaphores."""
280
+ if tracker:
281
+ tracker.start_task(task_id)
282
+ env = None
283
+ agent_instance: Agent | None = None
284
+ status = "error"
285
+ error_msg = "Initialization failed"
286
+ try:
287
+ adapter_instance = None
288
+ if adapter_cls:
289
+ adapter_instance = adapter_cls(**(adapter_kwargs or {}))
290
+ agent_instance = agent_cls(adapter=adapter_instance, **(agent_kwargs or {}))
291
+ if agent_instance is None:
292
+ raise RuntimeError("Agent could not be instantiated")
293
+
294
+ # Environment creation with semaphore
295
+ if env_creation_semaphore:
296
+ async with env_creation_semaphore:
297
+ env = await gym.make(task, job=job)
298
+ else:
299
+ env = await gym.make(task, job=job)
300
+
301
+ obs_tuple = await env.reset()
302
+ if obs_tuple is None:
303
+ raise ValueError(f"env.reset() returned None for task {task_id}")
304
+ obs, _ = obs_tuple
305
+
306
+ step_error = None
307
+ for step in range(max_steps_per_task):
308
+ action, done = (None, False)
309
+ try:
310
+ # Agent prediction with semaphore
311
+ if agent_predict_semaphore:
312
+ async with agent_predict_semaphore:
313
+ action, done = await agent_instance.predict(obs)
314
+ else:
315
+ action, done = await agent_instance.predict(obs)
316
+
317
+ if tracker:
318
+ tracker.increment_step(task_id)
319
+
320
+ if action is None and not done:
321
+ done = True
322
+
323
+ step_result = await env.step(action)
324
+ if step_result is None:
325
+ terminated = True
326
+ else:
327
+ obs, _, terminated, _ = step_result
328
+ if terminated or done:
329
+ break
330
+
331
+ except Exception as agent_step_err:
332
+ logger.exception("[Job: %s/%s, Task: %s] Step %d Error: %s", job.name, job.id,
333
+ task_id, step + 1, agent_step_err)
334
+ step_error = f"Error at step {step + 1}: {agent_step_err}"
335
+ # Store step error in job
336
+ job.errors.append({
337
+ "task_id": task_id,
338
+ "type": "step_error",
339
+ "step": step + 1,
340
+ "error": str(agent_step_err),
341
+ "timestamp": datetime.datetime.now().isoformat()
342
+ })
343
+ break
344
+ else:
345
+ logger.warning("[Job: %s/%s, Task: %s] Max steps reached.", job.name, job.id, task_id)
346
+
347
+ # --- Evaluate Task ---
348
+ evaluation_result = None
349
+ if step_error:
350
+ status = "error"
351
+ error_msg = step_error
352
+ else:
353
+ try:
354
+ evaluation_result = await env.evaluate()
355
+ status = "completed"
356
+ error_msg = None
357
+ except Exception as eval_err:
358
+ logger.exception("[Job: %s/%s, Task: %s] Evaluation Error: %s", job.name,
359
+ job.id, task_id, eval_err)
360
+ status = "error"
361
+ error_msg = f"Evaluation failed: {eval_err}"
362
+ # Store evaluation error in job
363
+ job.errors.append({
364
+ "task_id": task_id,
365
+ "type": "evaluation_error",
366
+ "error": str(eval_err),
367
+ "timestamp": datetime.datetime.now().isoformat()
368
+ })
369
+
370
+ except Exception as e:
371
+ logger.exception("[Job: %s/%s, Task: %s] Setup/Run Error: %s", job.name, job.id, task_id, e)
372
+ status = "error"
373
+ error_msg = str(e)
374
+ # Store setup/initialization error in job
375
+ job.errors.append({
376
+ "task_id": task_id,
377
+ "type": "setup_error",
378
+ "error": str(e),
379
+ "timestamp": datetime.datetime.now().isoformat()
380
+ })
381
+
382
+ finally:
383
+ if tracker:
384
+ tracker.finish_task(task_id)
385
+ if env:
386
+ try:
387
+ await env.close()
388
+ except Exception as close_err:
389
+ logger.exception("[Job: %s/%s, Task: %s] Close Error: %s", job.name, job.id,
390
+ task_id, close_err)
391
+ # Store environment close error in job
392
+ job.errors.append({
393
+ "task_id": task_id,
394
+ "type": "env_close_error",
395
+ "error": str(close_err),
396
+ "timestamp": datetime.datetime.now().isoformat()
397
+ })
398
+
399
+ log_suffix = f" Error: {error_msg}" if status == "error" else f" Eval: {evaluation_result}"
400
+ logger.info("[Job: %s/%s, Task: %s] Finished local execution. Status: %s.%s", job.name,
401
+ job.id, task_id, status, log_suffix)
402
+
403
+ async def _progress_monitor(tracker: StepProgressTracker, interval: float = 1.0) -> None:
404
+ """Coroutine to periodically display progress using the tracker."""
405
+ try:
406
+ while not tracker.is_finished():
407
+ sys.stderr.write(f"\r{tracker.display()}")
408
+ sys.stderr.flush()
409
+ await asyncio.sleep(interval)
410
+ sys.stderr.write(f"\r{tracker.display()}\n")
411
+ sys.stderr.flush()
412
+ logger.debug("Progress monitor finished.")
413
+ except asyncio.CancelledError:
414
+ sys.stderr.write("\nProgress monitor cancelled.\n")
415
+ sys.stderr.flush()
416
+ logger.debug("Progress monitor cancelled.")
417
+ except Exception as e:
418
+ sys.stderr.write(f"\nProgress monitor error: {e}\n")
419
+ sys.stderr.flush()
420
+ logger.exception("Progress monitor error: %s", e)
421
+
422
+
423
+ # --- New run_job function ---
424
+
425
+ async def run_job(
426
+ agent_cls: type[Agent],
427
+ task_or_taskset: Task | TaskSet,
428
+ job_name: str,
429
+ adapter_cls: type[Adapter] | None = None,
430
+ agent_kwargs: dict[str, Any] | None = None,
431
+ adapter_kwargs: dict[str, Any] | None = None,
432
+ max_steps_per_task: int = 20,
433
+ run_parallel: bool = True,
434
+ job_metadata: dict[str, Any] | None = None,
435
+ show_progress: bool = True,
436
+ # Concurrency control with semaphores
437
+ max_concurrent_env_creations: int | None = 30, # Limits env.make calls
438
+ max_concurrent_agent_predictions: int | None = 30, # Limits agent.predict calls
439
+ max_concurrent_tasks: int | None = 30, # Limits overall task concurrency
440
+ ) -> Job:
441
+ """
442
+ Creates Job, executes tasks locally, linking them to the Job.
443
+ Instantiates agent/adapter per task. Shows step-based progress.
444
+
445
+ Controls concurrency in three ways:
446
+ 1. Limits concurrent environment creations
447
+ 2. Limits concurrent agent predictions
448
+ 3. Limits overall concurrent tasks (when run_parallel=True)
449
+
450
+ All concurrency controls use semaphores for reliability.
451
+ Tracks all errors that occur during execution in job.errors.
452
+
453
+ Args:
454
+ agent_cls: Agent class to instantiate.
455
+ task_or_taskset: Task or TaskSet to run.
456
+ job_name: Name for the Job.
457
+ adapter_cls: Optional Adapter class.
458
+ agent_kwargs: Optional kwargs for agent constructor.
459
+ adapter_kwargs: Optional kwargs for adapter constructor.
460
+ max_steps_per_task: Step limit per task.
461
+ run_parallel: Run TaskSet tasks concurrently if True (limited by max_concurrent_tasks).
462
+ job_metadata: Metadata for the created Job.
463
+ show_progress: Display the step-based progress tracker.
464
+ max_concurrent_env_creations: Max concurrent environment creation calls.
465
+ max_concurrent_agent_predictions: Max concurrent agent prediction calls.
466
+ max_concurrent_tasks: Max number of tasks to run actively at the same time.
467
+
468
+ Returns:
469
+ The created Job object with errors stored in job.errors.
470
+ """
471
+ tasks_to_run: list[Task] = []
472
+ created_job: Job | None = None
473
+
474
+ # --- Create Job ---
475
+ try:
476
+ logger.info("Creating job with name: '%s'", job_name)
477
+ created_job = await create_job(name=job_name, metadata=job_metadata)
478
+ logger.info("Created job with ID: %s", created_job.id)
479
+ except Exception as e:
480
+ logger.exception("Failed to create job '%s': %s", job_name, e)
481
+ raise
482
+
483
+ # --- Task Setup ---
484
+ is_taskset = isinstance(task_or_taskset, TaskSet)
485
+ if is_taskset:
486
+ tasks_to_run = task_or_taskset.tasks if task_or_taskset.tasks else []
487
+ elif isinstance(task_or_taskset, Task):
488
+ tasks_to_run = [task_or_taskset]
489
+ run_parallel = False
490
+ else:
491
+ raise TypeError("task_or_taskset must be either a Task or a TaskSet")
492
+
493
+ if not tasks_to_run:
494
+ logger.warning("Job '%s' (%s): No tasks found to run.", created_job.name, created_job.id)
495
+ return created_job
496
+
497
+ task_ids = [(str(task.id) if task.id else f"task_{i}") for i, task in enumerate(tasks_to_run)]
498
+ num_tasks = len(tasks_to_run)
499
+
500
+ # --- Create semaphores for concurrency control ---
501
+ env_creation_sema = None
502
+ if max_concurrent_env_creations and max_concurrent_env_creations > 0:
503
+ env_creation_sema = asyncio.Semaphore(max_concurrent_env_creations)
504
+ logger.info("Limiting concurrent environment creations to %d.",
505
+ max_concurrent_env_creations)
506
+
507
+ agent_predict_sema = None
508
+ if max_concurrent_agent_predictions and max_concurrent_agent_predictions > 0:
509
+ agent_predict_sema = asyncio.Semaphore(max_concurrent_agent_predictions)
510
+ logger.info("Limiting concurrent agent predictions to %d.",
511
+ max_concurrent_agent_predictions)
512
+
513
+ task_execution_sema = None
514
+ effective_concurrency = num_tasks # Default to running all if parallel
515
+ if run_parallel and max_concurrent_tasks and max_concurrent_tasks > 0:
516
+ effective_concurrency = min(num_tasks, max_concurrent_tasks)
517
+ task_execution_sema = asyncio.Semaphore(effective_concurrency)
518
+ logger.info("Limiting concurrent task executions to %d.", effective_concurrency)
519
+ elif not run_parallel:
520
+ effective_concurrency = 1 # Sequential means concurrency of 1
521
+
522
+ # --- Instantiate Tracker & Start Monitor ---
523
+ tracker = None
524
+ monitor_task = None
525
+ if show_progress and num_tasks > 0:
526
+ tracker = StepProgressTracker(total_tasks=num_tasks, max_steps_per_task=max_steps_per_task)
527
+ monitor_task = asyncio.create_task(_progress_monitor(tracker))
528
+
529
+ # --- Execute Tasks ---
530
+ job_desc_suffix = f" (Job ID: {created_job.id})"
531
+
532
+ async def task_wrapper(task_coro: Coroutine, semaphore: asyncio.Semaphore | None) -> None:
533
+ if semaphore:
534
+ async with semaphore:
535
+ await task_coro
536
+ else:
537
+ await task_coro
538
+
539
+ try:
540
+ if run_parallel and is_taskset:
541
+ logger.info("Job '%s'%s: Running %d tasks with concurrency %d.", created_job.name,
542
+ job_desc_suffix, num_tasks, effective_concurrency)
543
+
544
+ task_coroutines = [
545
+ _execute_task(
546
+ agent_cls=agent_cls, adapter_cls=adapter_cls, agent_kwargs=agent_kwargs,
547
+ adapter_kwargs=adapter_kwargs, task=task, job_name=created_job.name,
548
+ task_id=task_id,
549
+ max_steps_per_task=max_steps_per_task, job=created_job, tracker=tracker,
550
+ env_creation_semaphore=env_creation_sema,
551
+ agent_predict_semaphore=agent_predict_sema,
552
+ )
553
+ for task, task_id in zip(tasks_to_run, task_ids, strict=True)
554
+ ]
555
+
556
+ # Wrap coroutines with semaphore management if limiting concurrency
557
+ wrapped_tasks = [
558
+ task_wrapper(coro, task_execution_sema)
559
+ for i, coro in enumerate(task_coroutines)
560
+ ]
561
+
562
+ # Run all wrapped tasks
563
+ await asyncio.gather(*wrapped_tasks)
564
+
565
+ else:
566
+ # SEQUENTIAL (or single task)
567
+ logger.info("Job '%s'%s: Running %d tasks sequentially.", created_job.name,
568
+ job_desc_suffix, num_tasks)
569
+ for i, task in enumerate(tasks_to_run):
570
+ task_id = task_ids[i]
571
+ await _execute_task(
572
+ agent_cls=agent_cls, adapter_cls=adapter_cls, agent_kwargs=agent_kwargs,
573
+ adapter_kwargs=adapter_kwargs, task=task, job_name=created_job.name,
574
+ task_id=task_id,
575
+ max_steps_per_task=max_steps_per_task, job=created_job, tracker=tracker,
576
+ env_creation_semaphore=env_creation_sema,
577
+ agent_predict_semaphore=agent_predict_sema,
578
+ )
579
+
580
+ finally:
581
+ # Ensure monitor task is stopped and awaited cleanly
582
+ if monitor_task is not None and not monitor_task.done():
583
+ monitor_task.cancel()
584
+ try:
585
+ await monitor_task
586
+ except asyncio.CancelledError:
587
+ pass
588
+ except Exception as e:
589
+ logger.error("Error awaiting progress monitor task: %s", e)
590
+
591
+ logger.info("Job '%s'%s finished local execution phase for %d tasks.", created_job.name,
592
+ job_desc_suffix, num_tasks)
593
+ return created_job
hud/task.py CHANGED
@@ -5,8 +5,7 @@ from typing import TYPE_CHECKING, Any
5
5
  from pydantic import BaseModel
6
6
 
7
7
  from hud.types import CustomGym, Gym
8
- from hud.utils import HudStyleConfig
9
- from hud.utils.config import HudStyleConfigs
8
+ from hud.utils.common import HudStyleConfig, HudStyleConfigs
10
9
 
11
10
  if TYPE_CHECKING:
12
11
  from inspect_ai.dataset import Sample
@@ -35,7 +34,7 @@ class Task(BaseModel):
35
34
 
36
35
  The setup and evaluate configurations can be in several formats:
37
36
  - String (function name): "chrome.maximize"
38
- - String (function with args): "chrome.activate_tab 5"
37
+ - Tuple (function with args): ("chrome.activate_tab", 5)
39
38
  - Dict: {"function": "chrome.navigate", "args": ["https://example.com"]}
40
39
  - List of the above: ["chrome.maximize", {"function": "chrome.navigate", "args": ["https://example.com"]}]
41
40
 
@@ -68,15 +67,15 @@ class Task(BaseModel):
68
67
  @classmethod
69
68
  def from_inspect_sample(cls, sample: Sample) -> Task:
70
69
  """Create a Task from an Inspect dataset sample.
71
- The task's sandbox is a local ubuntu container using the standard controller.
72
- Files will be copied to the user directory
70
+ Automatically detects if a CustomGym (docker) or QA Gym is needed based on sample.sandbox.
71
+ Configures evaluation using 'response_includes' or 'match_all' based on sample.target.
73
72
 
74
73
  Args:
75
74
  sample: An Inspect dataset Sample object
76
75
 
77
76
  Returns:
78
77
  Task instance
79
-
78
+
80
79
  The Inspect Sample has these fields:
81
80
  - input (str | list[ChatMessage]): The input to be submitted to the model
82
81
  - choices (list[str] | None): Optional multiple choice answer list
@@ -87,10 +86,8 @@ class Task(BaseModel):
87
86
  - files (dict[str, str] | None): Optional files that go with the sample
88
87
  - setup (str | None): Optional setup script to run for sample
89
88
  """
90
- # Extract the input as prompt
91
89
  prompt = sample.input
92
- if isinstance(prompt, list): # Handle ChatMessage format
93
- # Convert chat message list to a string representation
90
+ if isinstance(prompt, list):
94
91
  prompt_parts = []
95
92
  for message in prompt:
96
93
  role = message.role
@@ -98,36 +95,50 @@ class Task(BaseModel):
98
95
  prompt_parts.append(f"{role.capitalize()}: {content}")
99
96
  prompt = "\n\n".join(prompt_parts)
100
97
 
101
- # Map sandbox from Inspect to our envspec
98
+ evaluate_config = None
99
+ if sample.target:
100
+ if isinstance(sample.target, str):
101
+ evaluate_config = ("response_includes", [sample.target])
102
+ elif isinstance(sample.target, list):
103
+ evaluate_config = ("match_all", sample.target)
104
+
105
+ task_gym: Gym | None = None
106
+ task_setup: HudStyleConfigs | None = None
107
+
102
108
  sandbox = sample.sandbox
103
109
  dockerfile = None
110
+ use_qa_gym = True
111
+
104
112
  if sandbox:
105
113
  if isinstance(sandbox, str):
106
- if sandbox != "docker":
107
- raise ValueError("docker is the only supported sandbox")
114
+ if sandbox == "docker":
115
+ dockerfile = UBUNTU_DOCKERFILE
116
+ use_qa_gym = False
108
117
  elif isinstance(sandbox, tuple) and len(sandbox) == 2:
109
118
  sandbox_type, sandbox_config = sandbox
110
- if sandbox_type != "docker":
111
- raise ValueError("docker is the only supported sandbox")
112
- dockerfile = sandbox_config
113
- else:
114
- raise ValueError("Invalid sandbox configuration")
115
-
116
- gym = CustomGym(
117
- dockerfile=dockerfile or UBUNTU_DOCKERFILE,
118
- location="local",
119
- )
119
+ if sandbox_type == "docker":
120
+ dockerfile = sandbox_config
121
+ use_qa_gym = False
122
+
123
+ if use_qa_gym:
124
+ task_gym = "qa"
125
+ task_setup = None
126
+ else:
127
+ task_gym = CustomGym(
128
+ dockerfile=dockerfile or UBUNTU_DOCKERFILE,
129
+ location="local",
130
+ )
131
+ task_setup = [x for x in convert_inspect_setup(sample.setup)] if sample.setup else None
132
+ # TODO: Handle sample.files for CustomGym case if needed
133
+
120
134
 
121
135
  return cls(
122
- id=str(sample.id) if sample.id else None,
136
+ id=None,
123
137
  prompt=prompt,
124
- setup=[x for x in convert_inspect_setup(sample.setup)] if sample.setup else [],
138
+ setup=task_setup,
125
139
  metadata=sample.metadata,
126
140
  choices=sample.choices,
127
- target=sample.target,
128
- gym=gym,
141
+ evaluate=evaluate_config,
142
+ gym=task_gym,
143
+ # files=sample.files, # TODO: Decide how/if to handle files
129
144
  )
130
-
131
- def convert_sdk01(self) -> None:
132
- self.setup = [HudStyleConfig(function="reset", args=[{"task_id": self.id}])]
133
- self.evaluate = [HudStyleConfig(function="evaluate", args=[])]
hud/taskset.py CHANGED
@@ -9,6 +9,8 @@ from hud.settings import settings
9
9
  from hud.task import Task
10
10
 
11
11
  if TYPE_CHECKING:
12
+ from collections.abc import Iterator
13
+
12
14
  from inspect_ai.dataset import Dataset
13
15
 
14
16
 
@@ -49,6 +51,12 @@ class TaskSet(BaseModel):
49
51
  """
50
52
  return len(self.tasks)
51
53
 
54
+ def __iter__(self) -> Iterator[Task]:
55
+ """
56
+ Returns an iterator over the tasks in the taskset.
57
+ """
58
+ return iter(self.tasks)
59
+
52
60
 
53
61
  async def load_taskset(taskset_id: str, api_key: str | None = None) -> TaskSet:
54
62
  """
hud/types.py CHANGED
@@ -44,9 +44,6 @@ class CustomGym(BaseModel):
44
44
  # Read the Dockerfile content
45
45
  self.dockerfile = dockerfile_path.read_text()
46
46
 
47
- # Strings are identifiers for gyms on the HUD server
48
- Gym = CustomGym | str
49
-
50
47
  class EnvironmentStatus(str, enum.Enum):
51
48
  """
52
49
  Status of the environment.
@@ -63,3 +60,8 @@ class EnvironmentStatus(str, enum.Enum):
63
60
  COMPLETED = "completed"
64
61
  ERROR = "error"
65
62
 
63
+ # Available HUD gyms
64
+ ServerGym = Literal["qa", "hud-browser", "hud-ubuntu", "OSWorld-Ubuntu"]
65
+
66
+ # Gyms can be either custom or server-side
67
+ Gym = CustomGym | ServerGym
hud/utils/common.py CHANGED
@@ -3,16 +3,46 @@ from __future__ import annotations
3
3
  import io
4
4
  import logging
5
5
  import tarfile
6
- from typing import TYPE_CHECKING, TypedDict
6
+ from typing import TYPE_CHECKING, Any, TypedDict
7
+
8
+ from pydantic import BaseModel
7
9
 
8
10
  from hud.server.requests import make_request
9
11
  from hud.settings import settings
10
12
 
11
13
  if TYPE_CHECKING:
14
+ from collections.abc import Iterator
12
15
  from pathlib import Path
13
16
 
14
17
  logger = logging.getLogger("hud.utils.common")
15
18
 
19
+ class HudStyleConfig(BaseModel):
20
+ function: str # Format: "x.y.z"
21
+ args: list[Any] # Must be json serializable
22
+
23
+ id: str | None = None # Optional id for remote execution
24
+
25
+ def __len__(self) -> int:
26
+ return len(self.args)
27
+
28
+ def __getitem__(self, index: int) -> Any:
29
+ return self.args[index]
30
+
31
+ def __iter__(self) -> Iterator[Any]:
32
+ return iter(self.args)
33
+
34
+ def __str__(self) -> str:
35
+ return f"{self.function}: {', '.join(str(arg) for arg in self.args)}"
36
+
37
+ # Type alias for the shorthand config, which just converts to function name and args
38
+ ShorthandConfig = tuple[str | dict[str, Any] | list[str] | list[dict[str, Any]], ...]
39
+
40
+ # Type alias for multiple config formats
41
+ HudStyleConfigs = (
42
+ ShorthandConfig | HudStyleConfig | list[HudStyleConfig] | list[ShorthandConfig]
43
+ | dict[str, Any] | str
44
+ )
45
+
16
46
  class ExecuteResult(TypedDict):
17
47
  """
18
48
  Result of an execute command.