hud-python 0.2.4__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of hud-python might be problematic. Click here for more details.

Files changed (50) hide show
  1. hud/__init__.py +22 -2
  2. hud/adapters/claude/adapter.py +9 -2
  3. hud/adapters/claude/tests/__init__.py +1 -0
  4. hud/adapters/claude/tests/test_adapter.py +519 -0
  5. hud/adapters/common/types.py +5 -1
  6. hud/adapters/operator/adapter.py +4 -0
  7. hud/adapters/operator/tests/__init__.py +1 -0
  8. hud/adapters/operator/tests/test_adapter.py +370 -0
  9. hud/agent/__init__.py +4 -0
  10. hud/agent/base.py +18 -2
  11. hud/agent/claude.py +20 -17
  12. hud/agent/claude_plays_pokemon.py +282 -0
  13. hud/agent/langchain.py +12 -7
  14. hud/agent/misc/__init__.py +3 -0
  15. hud/agent/misc/response_agent.py +80 -0
  16. hud/agent/operator.py +27 -19
  17. hud/agent/tests/__init__.py +1 -0
  18. hud/agent/tests/test_base.py +202 -0
  19. hud/env/docker_client.py +28 -18
  20. hud/env/environment.py +32 -16
  21. hud/env/local_docker_client.py +83 -42
  22. hud/env/remote_client.py +1 -3
  23. hud/env/remote_docker_client.py +72 -15
  24. hud/exceptions.py +12 -0
  25. hud/gym.py +71 -53
  26. hud/job.py +52 -7
  27. hud/settings.py +6 -0
  28. hud/task.py +45 -33
  29. hud/taskset.py +44 -4
  30. hud/telemetry/__init__.py +21 -0
  31. hud/telemetry/_trace.py +173 -0
  32. hud/telemetry/context.py +193 -0
  33. hud/telemetry/exporter.py +417 -0
  34. hud/telemetry/instrumentation/__init__.py +3 -0
  35. hud/telemetry/instrumentation/mcp.py +498 -0
  36. hud/telemetry/instrumentation/registry.py +59 -0
  37. hud/telemetry/mcp_models.py +331 -0
  38. hud/telemetry/tests/__init__.py +1 -0
  39. hud/telemetry/tests/test_context.py +203 -0
  40. hud/telemetry/tests/test_trace.py +270 -0
  41. hud/types.py +10 -26
  42. hud/utils/common.py +22 -2
  43. hud/utils/misc.py +53 -0
  44. hud/utils/tests/test_version.py +1 -1
  45. hud/version.py +7 -0
  46. {hud_python-0.2.4.dist-info → hud_python-0.2.5.dist-info}/METADATA +90 -22
  47. hud_python-0.2.5.dist-info/RECORD +84 -0
  48. hud_python-0.2.4.dist-info/RECORD +0 -62
  49. {hud_python-0.2.4.dist-info → hud_python-0.2.5.dist-info}/WHEEL +0 -0
  50. {hud_python-0.2.4.dist-info → hud_python-0.2.5.dist-info}/licenses/LICENSE +0 -0
@@ -2,7 +2,9 @@ from __future__ import annotations
2
2
 
3
3
  import logging
4
4
  from base64 import b64decode, b64encode
5
- from typing import Any
5
+ from typing import TYPE_CHECKING, Any
6
+
7
+ import httpx
6
8
 
7
9
  from hud.env.docker_client import DockerClient
8
10
  from hud.exceptions import HudResponseError
@@ -10,11 +12,27 @@ from hud.server import make_request
10
12
  from hud.settings import settings
11
13
  from hud.types import EnvironmentStatus
12
14
  from hud.utils import ExecuteResult
13
- from hud.utils.common import get_gym_id
15
+ from hud.utils.common import directory_to_zip_bytes, get_gym_id
16
+
17
+ if TYPE_CHECKING:
18
+ from pathlib import Path
14
19
 
15
20
  logger = logging.getLogger("hud.env.remote_env_client")
16
21
 
17
22
 
23
+ async def upload_bytes_to_presigned_url(presigned_url: str, data_bytes: bytes) -> None:
24
+ try:
25
+ async with httpx.AsyncClient() as client:
26
+ response = await client.put(presigned_url, content=data_bytes)
27
+ response.raise_for_status()
28
+ except httpx.HTTPStatusError as e:
29
+ logger.exception("Failed to upload to presigned URL")
30
+ raise HudResponseError(message=f"Failed to upload to presigned URL: {e}") from e
31
+ except httpx.RequestError as e:
32
+ logger.exception("Network error uploading to presigned URL")
33
+ raise HudResponseError(message=f"Network error uploading to presigned URL: {e}") from e
34
+
35
+
18
36
  class RemoteDockerClient(DockerClient):
19
37
  """
20
38
  Remote environment client implementation.
@@ -22,21 +40,64 @@ class RemoteDockerClient(DockerClient):
22
40
  Uses the HUD API to manage a remote environment.
23
41
  """
24
42
 
43
+ @classmethod
44
+ async def build_image(cls, build_context: Path) -> tuple[str, dict[str, Any]]:
45
+ """
46
+ Build an image from a build context.
47
+ """
48
+ # create the presigned url by making a POST request to /v2/builds
49
+ logger.info("Creating build")
50
+ response = await make_request(
51
+ method="POST",
52
+ url=f"{settings.base_url}/v2/builds",
53
+ api_key=settings.api_key,
54
+ )
55
+ logger.info("Build created")
56
+ presigned_url = response["presigned_url"]
57
+
58
+ # List files in the build context
59
+ files = list(build_context.glob("**/*"))
60
+ logger.info("Found %d files in build context %s", len(files), build_context)
61
+
62
+ if len(files) == 0:
63
+ raise HudResponseError(message="Build context is empty")
64
+
65
+ # zip the build context
66
+ logger.info("Zipping build context")
67
+ zip_bytes = directory_to_zip_bytes(build_context)
68
+ logger.info("Created zip archive of size %d kb", len(zip_bytes) // 1024)
69
+ # upload the zip bytes to the presigned url
70
+ logger.info("Uploading build context")
71
+ await upload_bytes_to_presigned_url(presigned_url, zip_bytes)
72
+ logger.info("Build context uploaded")
73
+
74
+ # start the build and return uri and logs
75
+ logger.info("Starting build")
76
+ response = await make_request(
77
+ method="POST",
78
+ url=f"{settings.base_url}/v2/builds/{response['id']}/start",
79
+ api_key=settings.api_key,
80
+ )
81
+ logger.info("Build completed")
82
+
83
+ return response["uri"], {"logs": response["logs"]}
84
+
25
85
  @classmethod
26
86
  async def create(
27
87
  cls,
28
- dockerfile: str,
88
+ image_uri: str,
29
89
  *,
30
90
  job_id: str | None = None,
31
91
  task_id: str | None = None,
32
92
  metadata: dict[str, Any] | None = None,
33
- ) -> tuple[RemoteDockerClient, dict[str, Any]]:
93
+ ) -> RemoteDockerClient:
34
94
  """
35
- Creates a remote environment client from a dockerfile or gym_id.
95
+ Creates a remote environment client from an image.
36
96
 
37
97
  Args:
38
- dockerfile: The dockerfile content to build the environment
39
- gym_id: The gym_id of the environment to create
98
+ image_uri: The image uri to create the environment from
99
+ job_id: The job_id of the environment to create
100
+ task_id: The task_id of the environment to create
40
101
  metadata: Metadata to associate with the environment
41
102
 
42
103
  Returns:
@@ -52,13 +113,14 @@ class RemoteDockerClient(DockerClient):
52
113
 
53
114
  logger.info("Creating remote environment")
54
115
 
55
- true_gym_id = await get_gym_id("docker")
116
+ true_gym_id = await get_gym_id("local-docker")
117
+ # true_gym_id = await get_gym_id("docker")
56
118
 
57
119
  # augment metadata with dockerfile
58
120
  if "environment_config" not in metadata:
59
121
  metadata["environment_config"] = {}
60
122
 
61
- metadata["environment_config"]["dockerfile"] = dockerfile
123
+ metadata["environment_config"]["image_uri"] = image_uri
62
124
 
63
125
  # Create a new environment via the HUD API
64
126
  response = await make_request(
@@ -85,12 +147,7 @@ class RemoteDockerClient(DockerClient):
85
147
  response_json=response,
86
148
  )
87
149
 
88
- # Create the controller instance
89
- controller = cls(env_id)
90
-
91
- build_metadata = response.get("metadata", {})
92
-
93
- return controller, build_metadata
150
+ return cls(env_id)
94
151
 
95
152
  def __init__(self, env_id: str) -> None:
96
153
  """
hud/exceptions.py CHANGED
@@ -165,3 +165,15 @@ class HudNetworkError(HudException):
165
165
  This exception is raised when there are issues with the network
166
166
  connection, DNS resolution, or other network-related problems.
167
167
  """
168
+
169
+
170
+ class GymMakeException(HudException):
171
+ """Raised when environment creation or setup fails, includes context data."""
172
+
173
+ def __init__(self, message: str, data: dict[str, Any]) -> None:
174
+ super().__init__(message)
175
+ self.data = data
176
+
177
+ def __str__(self) -> str:
178
+ base = super().__str__()
179
+ return f"{base} | Data: {self.data}"
hud/gym.py CHANGED
@@ -8,6 +8,8 @@ from hud.env.environment import Environment
8
8
  from hud.env.local_docker_client import LocalDockerClient
9
9
  from hud.env.remote_client import RemoteClient
10
10
  from hud.env.remote_docker_client import RemoteDockerClient
11
+ from hud.exceptions import GymMakeException
12
+ from hud.telemetry.context import get_current_task_run_id
11
13
  from hud.types import CustomGym, Gym
12
14
  from hud.utils.common import get_gym_id
13
15
 
@@ -34,17 +36,19 @@ async def make(
34
36
  job_id: ID of job to associate with this environment (deprecated, use job instead)
35
37
  metadata: Additional metadata for the environment
36
38
  """
37
- if metadata is None:
38
- metadata = {}
39
+ task = None
40
+ if isinstance(env_src, str | CustomGym):
41
+ gym = env_src
42
+ else:
43
+ gym = env_src.gym
44
+ task = env_src
39
45
 
40
- # Handle job parameter
41
46
  effective_job_id = None
42
47
  if job is not None:
43
48
  effective_job_id = job.id
44
49
  elif job_id is not None:
45
50
  effective_job_id = job_id
46
51
  else:
47
- # Try to get an active job from the decorator context
48
52
  try:
49
53
  import hud.job
50
54
 
@@ -52,59 +56,73 @@ async def make(
52
56
  if active_job:
53
57
  effective_job_id = active_job.id
54
58
  except ImportError:
55
- pass # Module not available, skip
56
-
57
- gym = None
58
- task = None
59
- if isinstance(env_src, str | CustomGym):
60
- gym = env_src
61
- else:
62
- gym = env_src.gym
63
- task = env_src
59
+ pass
60
+
61
+ build_data = {}
62
+ try:
63
+ metadata_copy = {} if metadata is None else metadata.copy()
64
+
65
+ current_task_run_id = get_current_task_run_id()
66
+ if current_task_run_id:
67
+ metadata_copy["task_run_id"] = current_task_run_id
68
+ logger.debug(
69
+ "Passing task_run_id %s from hud.telemetry context to environment metadata.",
70
+ current_task_run_id,
71
+ )
64
72
 
65
- if isinstance(gym, CustomGym):
66
- # Create the environment (depending on location)
67
- if gym.dockerfile is None:
68
- raise ValueError("Dockerfile is required for custom environments")
69
- if gym.location == "local":
70
- logger.info("Creating local environment")
71
- client, build_data = await LocalDockerClient.create(gym.dockerfile)
72
- elif gym.location == "remote":
73
- logger.info("Creating remote environment")
74
- client, build_data = await RemoteDockerClient.create(
75
- dockerfile=gym.dockerfile,
73
+ if isinstance(gym, CustomGym):
74
+ if isinstance(gym.image_or_build_context, str):
75
+ uri = gym.image_or_build_context
76
+ elif isinstance(gym.image_or_build_context, Path):
77
+ if gym.location == "local":
78
+ uri, build_data = await LocalDockerClient.build_image(
79
+ gym.image_or_build_context
80
+ )
81
+ elif gym.location == "remote":
82
+ uri, build_data = await RemoteDockerClient.build_image(
83
+ gym.image_or_build_context
84
+ )
85
+ else:
86
+ raise ValueError(f"Invalid environment location: {gym.location}")
87
+ else:
88
+ raise ValueError(f"Invalid image or build context: {gym.image_or_build_context}")
89
+
90
+ if gym.location == "local":
91
+ logger.info("Creating local environment")
92
+ client = await LocalDockerClient.create(uri)
93
+ elif gym.location == "remote":
94
+ logger.info("Creating remote environment")
95
+ client = await RemoteDockerClient.create(
96
+ image_uri=uri,
97
+ job_id=effective_job_id,
98
+ task_id=task.id if task else None,
99
+ metadata=metadata_copy,
100
+ )
101
+ else:
102
+ raise ValueError(f"Invalid environment location: {gym.location}")
103
+
104
+ if isinstance(gym.image_or_build_context, Path):
105
+ logger.info("Setting source path %s", gym.image_or_build_context)
106
+ client.set_source_path(gym.image_or_build_context)
107
+ elif isinstance(gym, str):
108
+ logger.info("Creating private environment")
109
+ true_gym_id = await get_gym_id(gym)
110
+ client, build_data = await RemoteClient.create(
111
+ gym_id=true_gym_id,
76
112
  job_id=effective_job_id,
77
113
  task_id=task.id if task else None,
78
- metadata=metadata,
114
+ metadata=metadata_copy,
79
115
  )
80
116
  else:
81
- raise ValueError(f"Invalid environment location: {gym.location}")
82
-
83
- # Set up the environment with a source path
84
- if gym.controller_source_dir:
85
- logger.info("Setting source path")
86
- client.set_source_path(Path(gym.controller_source_dir))
87
- elif isinstance(gym, str):
88
- logger.info("Creating private environment")
89
- # Note: the gym_name_or_id is a unique identifier, but it is not a true
90
- # gym_id for the purposes of building the environment
91
- # we therefore fetch the gym_id from the HUD API here
92
- true_gym_id = await get_gym_id(gym)
93
-
94
- # Create the environment
95
- client, build_data = await RemoteClient.create(
96
- gym_id=true_gym_id,
97
- job_id=effective_job_id,
98
- task_id=task.id if task else None,
99
- metadata=metadata,
100
- )
101
- else:
102
- raise ValueError(f"Invalid gym source: {gym}")
117
+ raise ValueError(f"Invalid gym source: {gym}")
103
118
 
104
- # Create the environment itself
105
- environment = Environment(client=client, metadata=metadata, task=task, build_data=build_data)
106
-
107
- if task:
108
- await environment._setup()
119
+ environment = Environment(
120
+ client=client, metadata=metadata_copy, task=task, build_data=build_data
121
+ )
109
122
 
110
- return environment
123
+ if task:
124
+ await environment._setup()
125
+ return environment
126
+ except Exception as e:
127
+ build_data["exception"] = str(e)
128
+ raise GymMakeException("Failed to create environment", build_data) from e
hud/job.py CHANGED
@@ -12,11 +12,13 @@ from typing import TYPE_CHECKING, Any, TypeVar, cast
12
12
  from pydantic import BaseModel, PrivateAttr, TypeAdapter
13
13
 
14
14
  import hud.server
15
- from hud import gym
15
+ from hud import Response, gym
16
+ from hud.agent import ResponseAgent
16
17
  from hud.settings import settings
17
18
  from hud.task import Task
18
19
  from hud.taskset import TaskSet
19
20
  from hud.trajectory import Trajectory
21
+ from hud.utils.common import Observation
20
22
  from hud.utils.progress import StepProgressTracker
21
23
 
22
24
  if TYPE_CHECKING:
@@ -162,7 +164,7 @@ async def create_job(
162
164
  # If not, we might need to make a subsequent GET request
163
165
  job_data = data # Adjust if the API response structure is different
164
166
 
165
- logger.info("[HUD] View job at https://app.hud.so/jobs/%s.", job_data["id"])
167
+ logger.info("View job at https://app.hud.so/jobs/%s.", job_data["id"])
166
168
 
167
169
  return Job(
168
170
  id=job_data["id"],
@@ -259,6 +261,27 @@ def get_active_job() -> Job | None:
259
261
  return None
260
262
 
261
263
 
264
+ async def _maybe_resample_action(
265
+ obs: Observation, action: Any, response_agent: ResponseAgent
266
+ ) -> tuple[Observation, bool]:
267
+ if isinstance(action, Response):
268
+ action = action.model_dump()
269
+ if isinstance(action, dict) and action.get("type") == "response":
270
+ response_text = action.get("text", "")
271
+ if response_agent and response_text:
272
+ try:
273
+ decision = await response_agent.determine_response(response_text)
274
+ if decision == "CONTINUE":
275
+ logger.info("ResponseAgent indicated CONTINUE. Retrying...")
276
+ obs = Observation(text="Please continue.")
277
+ return obs, False
278
+ elif decision == "CONTINUE":
279
+ logger.warning("Max continue retries reached. Stopping despite CONTINUE.")
280
+ except Exception as e:
281
+ logger.warning("Error using ResponseAgent: %s", e)
282
+ return obs, True
283
+
284
+
262
285
  async def _execute_task(
263
286
  agent_cls: type[Agent],
264
287
  adapter_cls: type[Adapter] | None,
@@ -270,6 +293,7 @@ async def _execute_task(
270
293
  max_steps_per_task: int,
271
294
  job: Job,
272
295
  tracker: StepProgressTracker | None = None,
296
+ auto_reply_question: bool = False,
273
297
  # Use semaphores instead of rate limiter
274
298
  env_creation_semaphore: asyncio.Semaphore | None = None,
275
299
  agent_predict_semaphore: asyncio.Semaphore | None = None,
@@ -283,10 +307,15 @@ async def _execute_task(
283
307
  status = "error"
284
308
  error_msg = "Initialization failed"
285
309
  try:
310
+ response_agent = ResponseAgent() if auto_reply_question else None
311
+
286
312
  adapter_instance = None
287
313
  if adapter_cls:
288
314
  adapter_instance = adapter_cls(**(adapter_kwargs or {}))
289
- agent_instance = agent_cls(adapter=adapter_instance, **(agent_kwargs or {}))
315
+ agent_instance = agent_cls(
316
+ adapter=adapter_instance,
317
+ **(agent_kwargs or {}),
318
+ )
290
319
  if agent_instance is None:
291
320
  raise RuntimeError("Agent could not be instantiated")
292
321
 
@@ -303,6 +332,7 @@ async def _execute_task(
303
332
  obs, _ = obs_tuple
304
333
 
305
334
  step_error = None
335
+
306
336
  for step in range(max_steps_per_task):
307
337
  action, done = (None, False)
308
338
  try:
@@ -319,6 +349,11 @@ async def _execute_task(
319
349
  if action is None and not done:
320
350
  done = True
321
351
 
352
+ if done and response_agent:
353
+ obs, finish = await _maybe_resample_action(obs, action[-1], response_agent)
354
+ if not finish:
355
+ continue
356
+
322
357
  step_result = await env.step(action)
323
358
  if step_result is None:
324
359
  terminated = True
@@ -347,7 +382,7 @@ async def _execute_task(
347
382
  "timestamp": datetime.datetime.now().isoformat(),
348
383
  }
349
384
  )
350
- break
385
+ continue
351
386
  else:
352
387
  logger.warning("[Job: %s/%s, Task: %s] Max steps reached.", job.name, job.id, task_id)
353
388
 
@@ -361,6 +396,7 @@ async def _execute_task(
361
396
  evaluation_result = await env.evaluate()
362
397
  status = "completed"
363
398
  error_msg = None
399
+ # logger.info("Evaluation result: %s", evaluation_result)
364
400
  except Exception as eval_err:
365
401
  logger.exception(
366
402
  "[Job: %s/%s, Task: %s] Evaluation Error: %s",
@@ -453,6 +489,7 @@ async def run_job(
453
489
  agent_cls: type[Agent],
454
490
  task_or_taskset: Task | TaskSet,
455
491
  job_name: str,
492
+ auto_reply_question: bool = False,
456
493
  adapter_cls: type[Adapter] | None = None,
457
494
  agent_kwargs: dict[str, Any] | None = None,
458
495
  adapter_kwargs: dict[str, Any] | None = None,
@@ -461,8 +498,8 @@ async def run_job(
461
498
  job_metadata: dict[str, Any] | None = None,
462
499
  show_progress: bool = True,
463
500
  # Concurrency control with semaphores
464
- max_concurrent_env_creations: int | None = 30, # Limits env.make calls
465
- max_concurrent_agent_predictions: int | None = 30, # Limits agent.predict calls
501
+ max_concurrent_env_creations: int | None = 30, # Limits gym.make calls
502
+ max_concurrent_agent_predictions: int | None = None, # No limit on LLM calls
466
503
  max_concurrent_tasks: int | None = 30, # Limits overall task concurrency
467
504
  ) -> Job:
468
505
  """
@@ -495,12 +532,16 @@ async def run_job(
495
532
  Returns:
496
533
  The created Job object with errors stored in job.errors.
497
534
  """
535
+ hud_logger = logging.getLogger("hud")
536
+ hud_logger.setLevel(logging.CRITICAL)
537
+
498
538
  tasks_to_run: list[Task] = []
499
539
  created_job: Job | None = None
500
540
 
501
541
  evalset_id = None
502
542
  if isinstance(task_or_taskset, TaskSet):
503
543
  evalset_id = task_or_taskset.id
544
+ await task_or_taskset.fit(agent_cls)
504
545
 
505
546
  gym_id = None
506
547
  if isinstance(task_or_taskset, Task):
@@ -519,7 +560,7 @@ async def run_job(
519
560
  evalset_id=evalset_id,
520
561
  gym_id=gym_id,
521
562
  )
522
- logger.info("Created job with ID: %s", created_job.id)
563
+ # logger.info("Created job with ID: %s", created_job.id)
523
564
  except Exception as e:
524
565
  logger.exception("Failed to create job '%s': %s", job_name, e)
525
566
  raise
@@ -555,6 +596,8 @@ async def run_job(
555
596
  logger.info(
556
597
  "Limiting concurrent agent predictions to %d.", max_concurrent_agent_predictions
557
598
  )
599
+ else:
600
+ logger.info("No limit on concurrent agent predictions.")
558
601
 
559
602
  task_execution_sema = None
560
603
  effective_concurrency = num_tasks # Default to running all if parallel
@@ -606,6 +649,7 @@ async def run_job(
606
649
  tracker=tracker,
607
650
  env_creation_semaphore=env_creation_sema,
608
651
  agent_predict_semaphore=agent_predict_sema,
652
+ auto_reply_question=auto_reply_question,
609
653
  )
610
654
  for task, task_id in zip(tasks_to_run, task_ids, strict=True)
611
655
  ]
@@ -641,6 +685,7 @@ async def run_job(
641
685
  tracker=tracker,
642
686
  env_creation_semaphore=env_creation_sema,
643
687
  agent_predict_semaphore=agent_predict_sema,
688
+ auto_reply_question=auto_reply_question,
644
689
  )
645
690
 
646
691
  finally:
hud/settings.py CHANGED
@@ -38,6 +38,12 @@ class Settings(BaseSettings):
38
38
  validation_alias="OPENAI_API_KEY",
39
39
  )
40
40
 
41
+ telemetry_enabled: bool = Field(
42
+ default=True,
43
+ description="Enable telemetry for the HUD SDK",
44
+ validation_alias="TELEMETRY_ENABLED",
45
+ )
46
+
41
47
 
42
48
  # Create a singleton instance
43
49
  settings = Settings()
hud/task.py CHANGED
@@ -1,7 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import tempfile
4
+ from pathlib import Path
3
5
  from typing import TYPE_CHECKING, Any
4
6
 
7
+ from inspect_ai.util._sandbox import SandboxEnvironmentSpec
5
8
  from pydantic import BaseModel
6
9
 
7
10
  from hud.types import CustomGym, Gym
@@ -10,11 +13,7 @@ from hud.utils.common import FunctionConfig, FunctionConfigs
10
13
  if TYPE_CHECKING:
11
14
  from inspect_ai.dataset import Sample
12
15
 
13
- # Environment specifications:
14
- # These represent the environment as a whole, including both the controller
15
- # and the environment type (eg, what os, which services are running)
16
-
17
- UBUNTU_DOCKERFILE = "ubuntu:latest"
16
+ from hud.agent import Agent
18
17
 
19
18
 
20
19
  def convert_inspect_setup(setup: str) -> list[FunctionConfig]:
@@ -57,6 +56,12 @@ class Task(BaseModel):
57
56
  gym: Gym | None = None
58
57
  config: dict[str, Any] | None = None
59
58
 
59
+ description: str | None = None
60
+
61
+ @classmethod
62
+ def from_dict(cls, data: dict[str, Any]) -> Task:
63
+ return cls(**data)
64
+
60
65
  @classmethod
61
66
  def from_inspect_sample(cls, sample: Sample) -> Task:
62
67
  """Create a Task from an Inspect dataset sample.
@@ -91,38 +96,37 @@ class Task(BaseModel):
91
96
  evaluate_config = None
92
97
  if sample.target:
93
98
  if isinstance(sample.target, str):
94
- evaluate_config = ("response_includes", [sample.target])
99
+ evaluate_config = FunctionConfig(function="response_includes", args=[sample.target])
95
100
  elif isinstance(sample.target, list):
96
- evaluate_config = ("match_all", sample.target)
101
+ evaluate_config = FunctionConfig(function="match_all", args=sample.target)
97
102
 
98
- task_gym: Gym | None = None
99
- task_setup: FunctionConfigs | None = None
103
+ task_setup: FunctionConfigs | None = (
104
+ convert_inspect_setup(sample.setup) if sample.setup else None
105
+ )
100
106
 
101
107
  sandbox = sample.sandbox
102
- dockerfile = None
103
- use_qa_gym = True
104
-
105
- if sandbox:
106
- if isinstance(sandbox, str):
107
- if sandbox == "docker":
108
- dockerfile = UBUNTU_DOCKERFILE
109
- use_qa_gym = False
110
- elif isinstance(sandbox, tuple) and len(sandbox) == 2:
111
- sandbox_type, sandbox_config = sandbox
112
- if sandbox_type == "docker":
113
- dockerfile = sandbox_config
114
- use_qa_gym = False
115
-
116
- if use_qa_gym:
117
- task_gym = "qa"
118
- task_setup = None
119
- else:
120
- task_gym = CustomGym(
121
- dockerfile=dockerfile or UBUNTU_DOCKERFILE,
122
- location="local",
123
- )
124
- task_setup = [x for x in convert_inspect_setup(sample.setup)] if sample.setup else None
125
- # TODO: Handle sample.files for CustomGym case if needed
108
+
109
+ match sandbox:
110
+ case "docker":
111
+ task_gym = CustomGym(
112
+ image_or_build_context="ubuntu:latest",
113
+ location="local",
114
+ )
115
+ case SandboxEnvironmentSpec(type="docker", config=str()):
116
+ # create temp dir and put dockerfile there, then use that path
117
+ temp_dir = tempfile.mkdtemp()
118
+ temp_dir_path = Path(temp_dir)
119
+ dockerfile_path = temp_dir_path / "Dockerfile"
120
+ dockerfile_path.write_text(sandbox.config)
121
+ task_gym = CustomGym(
122
+ image_or_build_context=temp_dir_path,
123
+ location="local",
124
+ )
125
+ case None:
126
+ task_gym = "qa"
127
+ task_setup = None
128
+ case _:
129
+ raise ValueError(f"Unsupported sandbox type: {sandbox}")
126
130
 
127
131
  return cls(
128
132
  id=None,
@@ -132,3 +136,11 @@ class Task(BaseModel):
132
136
  gym=task_gym,
133
137
  # files=sample.files, # TODO: Decide how/if to handle files
134
138
  )
139
+
140
+ async def fit(self, agent: Agent | type[Agent]) -> None:
141
+ if isinstance(agent, type):
142
+ agent = agent()
143
+
144
+ if self.gym is None:
145
+ return
146
+ self.gym = agent.transfer_gyms.get(self.gym, self.gym)