hud-python 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of hud-python might be problematic. Click here for more details.

hud/env/remote_client.py CHANGED
@@ -2,7 +2,9 @@ from __future__ import annotations
2
2
 
3
3
  import logging
4
4
  from base64 import b64decode
5
- from typing import TYPE_CHECKING, Any
5
+ from typing import Any
6
+
7
+ from pydantic import BaseModel
6
8
 
7
9
  from hud.env.client import Client
8
10
  from hud.exceptions import HudResponseError
@@ -10,13 +12,18 @@ from hud.server import make_request
10
12
  from hud.settings import settings
11
13
  from hud.types import EnvironmentStatus
12
14
  from hud.utils import ExecuteResult
13
-
14
- if TYPE_CHECKING:
15
- from hud.utils.config import FunctionConfig
15
+ from hud.utils.config import FunctionConfig
16
16
 
17
17
  logger = logging.getLogger("hud.env.remote_env_client")
18
18
 
19
19
 
20
+ class SetupRequest(BaseModel):
21
+ task_id: str | None = None
22
+ setup: FunctionConfig | None = None
23
+ config: dict[str, Any] | None = None
24
+ metadata: dict[str, Any] | None = None
25
+
26
+
20
27
  class RemoteClient(Client):
21
28
  """
22
29
  Remote environment client implementation.
@@ -183,6 +190,17 @@ class RemoteClient(Client):
183
190
 
184
191
  return data["result"], b64decode(data["stdout"]), b64decode(data["stderr"])
185
192
 
193
+ async def setup(self, setup_request: SetupRequest) -> dict[str, Any]:
194
+ """
195
+ Setup the environment.
196
+ """
197
+ return await make_request(
198
+ method="POST",
199
+ url=f"{settings.base_url}/v1/environments/{self.env_id}/reset",
200
+ json=setup_request.model_dump(),
201
+ api_key=settings.api_key,
202
+ )
203
+
186
204
  async def close(self) -> None:
187
205
  """
188
206
  Close the remote environment by making a request to the server.
@@ -20,10 +20,14 @@ if TYPE_CHECKING:
20
20
  logger = logging.getLogger("hud.env.remote_env_client")
21
21
 
22
22
 
23
- async def upload_bytes_to_presigned_url(presigned_url: str, data_bytes: bytes) -> None:
23
+ async def upload_bytes_to_presigned_url(
24
+ presigned_url: str,
25
+ data_bytes: bytes,
26
+ timeout: float = 600,
27
+ ) -> None:
24
28
  try:
25
29
  async with httpx.AsyncClient() as client:
26
- response = await client.put(presigned_url, content=data_bytes)
30
+ response = await client.put(presigned_url, content=data_bytes, timeout=timeout)
27
31
  response.raise_for_status()
28
32
  except httpx.HTTPStatusError as e:
29
33
  logger.exception("Failed to upload to presigned URL")
@@ -113,8 +117,8 @@ class RemoteDockerClient(DockerClient):
113
117
 
114
118
  logger.info("Creating remote environment")
115
119
 
116
- true_gym_id = await get_gym_id("local-docker")
117
- # true_gym_id = await get_gym_id("docker")
120
+ # true_gym_id = await get_gym_id("local-docker")
121
+ true_gym_id = await get_gym_id("docker")
118
122
 
119
123
  # augment metadata with dockerfile
120
124
  if "environment_config" not in metadata:
hud/gym.py CHANGED
@@ -9,13 +9,13 @@ from hud.env.local_docker_client import LocalDockerClient
9
9
  from hud.env.remote_client import RemoteClient
10
10
  from hud.env.remote_docker_client import RemoteDockerClient
11
11
  from hud.exceptions import GymMakeException
12
+ from hud.task import Task
12
13
  from hud.telemetry.context import get_current_task_run_id
13
14
  from hud.types import CustomGym, Gym
14
15
  from hud.utils.common import get_gym_id
15
16
 
16
17
  if TYPE_CHECKING:
17
18
  from hud.job import Job
18
- from hud.task import Task
19
19
 
20
20
  logger = logging.getLogger("hud.gym")
21
21
 
@@ -39,9 +39,11 @@ async def make(
39
39
  task = None
40
40
  if isinstance(env_src, str | CustomGym):
41
41
  gym = env_src
42
- else:
42
+ elif isinstance(env_src, Task):
43
43
  gym = env_src.gym
44
44
  task = env_src
45
+ else:
46
+ raise GymMakeException(f"Invalid gym source: {env_src}", {})
45
47
 
46
48
  effective_job_id = None
47
49
  if job is not None:
@@ -89,9 +91,18 @@ async def make(
89
91
 
90
92
  if gym.location == "local":
91
93
  logger.info("Creating local environment")
92
- client = await LocalDockerClient.create(uri)
94
+ if gym.host_config:
95
+ logger.info("Using host config: %s", gym.host_config)
96
+ client = await LocalDockerClient.create(uri, gym.host_config)
97
+ else:
98
+ client = await LocalDockerClient.create(uri)
99
+
93
100
  elif gym.location == "remote":
94
101
  logger.info("Creating remote environment")
102
+
103
+ if gym.host_config:
104
+ raise ValueError("host_config is not supported for remote environments")
105
+
95
106
  client = await RemoteDockerClient.create(
96
107
  image_uri=uri,
97
108
  job_id=effective_job_id,
@@ -105,7 +116,7 @@ async def make(
105
116
  logger.info("Setting source path %s", gym.image_or_build_context)
106
117
  client.set_source_path(gym.image_or_build_context)
107
118
  elif isinstance(gym, str):
108
- logger.info("Creating private environment")
119
+ logger.debug("Creating private environment")
109
120
  true_gym_id = await get_gym_id(gym)
110
121
  client, build_data = await RemoteClient.create(
111
122
  gym_id=true_gym_id,
hud/job.py CHANGED
@@ -1,12 +1,12 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import asyncio
4
- import datetime
5
4
  import functools
6
5
  import inspect
7
6
  import logging
8
7
  import sys
9
8
  from collections.abc import Callable, Coroutine
9
+ from datetime import datetime
10
10
  from typing import TYPE_CHECKING, Any, TypeVar, cast
11
11
 
12
12
  from pydantic import BaseModel, PrivateAttr, TypeAdapter
@@ -18,12 +18,12 @@ from hud.settings import settings
18
18
  from hud.task import Task
19
19
  from hud.taskset import TaskSet
20
20
  from hud.trajectory import Trajectory
21
- from hud.utils.common import Observation
22
21
  from hud.utils.progress import StepProgressTracker
23
22
 
24
23
  if TYPE_CHECKING:
25
24
  from hud.adapters.common import Adapter
26
25
  from hud.agent.base import Agent
26
+ from hud.utils.common import Observation
27
27
 
28
28
  logger = logging.getLogger("hud.job")
29
29
 
@@ -44,7 +44,7 @@ class Job(BaseModel):
44
44
  id: str
45
45
  name: str
46
46
  metadata: dict[str, Any] | None = None
47
- created_at: datetime.datetime
47
+ created_at: datetime
48
48
  status: str
49
49
 
50
50
  # Internal cache for trajectories
@@ -164,13 +164,15 @@ async def create_job(
164
164
  # If not, we might need to make a subsequent GET request
165
165
  job_data = data # Adjust if the API response structure is different
166
166
 
167
+ created_at = datetime.fromisoformat(job_data["created_at"].replace("Z", "+00:00"))
168
+
167
169
  logger.info("View job at https://app.hud.so/jobs/%s.", job_data["id"])
168
170
 
169
171
  return Job(
170
172
  id=job_data["id"],
171
173
  name=job_data["name"],
172
174
  metadata=job_data.get("metadata", {}), # Ensure metadata is dict
173
- created_at=datetime.datetime.fromisoformat(job_data["created_at"]), # Parse datetime
175
+ created_at=created_at, # Parse datetime
174
176
  status=job_data["status"],
175
177
  )
176
178
 
@@ -273,7 +275,7 @@ async def _maybe_resample_action(
273
275
  decision = await response_agent.determine_response(response_text)
274
276
  if decision == "CONTINUE":
275
277
  logger.info("ResponseAgent indicated CONTINUE. Retrying...")
276
- obs = Observation(text="Please continue.")
278
+ obs.text = "Please continue."
277
279
  return obs, False
278
280
  elif decision == "CONTINUE":
279
281
  logger.warning("Max continue retries reached. Stopping despite CONTINUE.")
@@ -319,6 +321,12 @@ async def _execute_task(
319
321
  if agent_instance is None:
320
322
  raise RuntimeError("Agent could not be instantiated")
321
323
 
324
+ agent_name = agent_instance.name
325
+ logger.info("Using agent: %s", agent_name)
326
+ if task.metadata is None or not isinstance(task.metadata, dict):
327
+ task.metadata = {}
328
+ task.metadata["agent_name"] = agent_name
329
+
322
330
  # Environment creation with semaphore
323
331
  if env_creation_semaphore:
324
332
  async with env_creation_semaphore:
@@ -326,6 +334,9 @@ async def _execute_task(
326
334
  else:
327
335
  env = await gym.make(task, job=job)
328
336
 
337
+ if not env:
338
+ raise ValueError(f"Environment creation failed for task {task_id}")
339
+
329
340
  obs_tuple = await env.reset()
330
341
  if obs_tuple is None:
331
342
  raise ValueError(f"env.reset() returned None for task {task_id}")
@@ -333,24 +344,45 @@ async def _execute_task(
333
344
 
334
345
  step_error = None
335
346
 
347
+ resampled_actions = 0
348
+
336
349
  for step in range(max_steps_per_task):
337
350
  action, done = (None, False)
338
351
  try:
339
352
  # Agent prediction with semaphore
340
- if agent_predict_semaphore:
341
- async with agent_predict_semaphore:
353
+ try:
354
+ if agent_predict_semaphore:
355
+ async with agent_predict_semaphore:
356
+ action, done = await agent_instance.predict(obs)
357
+ else:
342
358
  action, done = await agent_instance.predict(obs)
343
- else:
344
- action, done = await agent_instance.predict(obs)
359
+ except Exception as e:
360
+ # if agent prediction fails, pass back the error to the agent
361
+ logger.exception("[TR: %s] Agent prediction failed: %s", task_id, e)
362
+ resampled_actions += 1
363
+ if resampled_actions > 5:
364
+ logger.warning(
365
+ "[TR: %s] Resampled action %d times. Stopping.",
366
+ task_id,
367
+ resampled_actions,
368
+ )
369
+ break
370
+ continue
345
371
 
346
372
  if tracker:
347
373
  tracker.increment_step(task_id)
348
374
 
349
- if action is None and not done:
350
- done = True
351
-
352
- if done and response_agent:
375
+ finish = False
376
+ if done and response_agent and action and len(action) > 0:
353
377
  obs, finish = await _maybe_resample_action(obs, action[-1], response_agent)
378
+ resampled_actions += 1
379
+ if resampled_actions > 5:
380
+ logger.warning(
381
+ "[TR: %s] Resampled action %d times. Stopping.",
382
+ task_id,
383
+ resampled_actions,
384
+ )
385
+ break
354
386
  if not finish:
355
387
  continue
356
388
 
@@ -359,14 +391,12 @@ async def _execute_task(
359
391
  terminated = True
360
392
  else:
361
393
  obs, _, terminated, _ = step_result
362
- if terminated or done:
394
+ if terminated or done or finish:
363
395
  break
364
396
 
365
397
  except Exception as agent_step_err:
366
398
  logger.exception(
367
- "[Job: %s/%s, Task: %s] Step %d Error: %s",
368
- job.name,
369
- job.id,
399
+ "[TR: %s] Step %d Error: %s",
370
400
  task_id,
371
401
  step + 1,
372
402
  agent_step_err,
@@ -379,12 +409,12 @@ async def _execute_task(
379
409
  "type": "step_error",
380
410
  "step": step + 1,
381
411
  "error": str(agent_step_err),
382
- "timestamp": datetime.datetime.now().isoformat(),
412
+ "timestamp": datetime.now().isoformat(),
383
413
  }
384
414
  )
385
415
  continue
386
416
  else:
387
- logger.warning("[Job: %s/%s, Task: %s] Max steps reached.", job.name, job.id, task_id)
417
+ logger.warning("[TR: %s] Max steps reached.", task_id)
388
418
 
389
419
  # --- Evaluate Task ---
390
420
  evaluation_result = None
@@ -399,9 +429,7 @@ async def _execute_task(
399
429
  # logger.info("Evaluation result: %s", evaluation_result)
400
430
  except Exception as eval_err:
401
431
  logger.exception(
402
- "[Job: %s/%s, Task: %s] Evaluation Error: %s",
403
- job.name,
404
- job.id,
432
+ "[TR: %s] Evaluation Error: %s",
405
433
  task_id,
406
434
  eval_err,
407
435
  )
@@ -413,12 +441,12 @@ async def _execute_task(
413
441
  "task_id": task_id,
414
442
  "type": "evaluation_error",
415
443
  "error": str(eval_err),
416
- "timestamp": datetime.datetime.now().isoformat(),
444
+ "timestamp": datetime.now().isoformat(),
417
445
  }
418
446
  )
419
447
 
420
448
  except Exception as e:
421
- logger.exception("[Job: %s/%s, Task: %s] Setup/Run Error: %s", job.name, job.id, task_id, e)
449
+ logger.exception("[TR: %s] Setup/Run Error: %s", task_id, e)
422
450
  status = "error"
423
451
  error_msg = str(e)
424
452
  # Store setup/initialization error in job
@@ -427,7 +455,7 @@ async def _execute_task(
427
455
  "task_id": task_id,
428
456
  "type": "setup_error",
429
457
  "error": str(e),
430
- "timestamp": datetime.datetime.now().isoformat(),
458
+ "timestamp": datetime.now().isoformat(),
431
459
  }
432
460
  )
433
461
 
@@ -438,24 +466,20 @@ async def _execute_task(
438
466
  try:
439
467
  await env.close()
440
468
  except Exception as close_err:
441
- logger.exception(
442
- "[Job: %s/%s, Task: %s] Close Error: %s", job.name, job.id, task_id, close_err
443
- )
469
+ logger.exception("[TR: %s] Close Error: %s", task_id, close_err)
444
470
  # Store environment close error in job
445
471
  job.errors.append(
446
472
  {
447
473
  "task_id": task_id,
448
474
  "type": "env_close_error",
449
475
  "error": str(close_err),
450
- "timestamp": datetime.datetime.now().isoformat(),
476
+ "timestamp": datetime.now().isoformat(),
451
477
  }
452
478
  )
453
479
 
454
480
  log_suffix = f" Error: {error_msg}" if status == "error" else f" Eval: {evaluation_result}"
455
481
  logger.info(
456
- "[Job: %s/%s, Task: %s] Finished local execution. Status: %s.%s",
457
- job.name,
458
- job.id,
482
+ "[TR: %s] Finished local execution. Status: %s.%s",
459
483
  task_id,
460
484
  status,
461
485
  log_suffix,
@@ -497,6 +521,7 @@ async def run_job(
497
521
  run_parallel: bool = True,
498
522
  job_metadata: dict[str, Any] | None = None,
499
523
  show_progress: bool = True,
524
+ verbose: bool = False,
500
525
  # Concurrency control with semaphores
501
526
  max_concurrent_env_creations: int | None = 30, # Limits gym.make calls
502
527
  max_concurrent_agent_predictions: int | None = None, # No limit on LLM calls
@@ -532,16 +557,20 @@ async def run_job(
532
557
  Returns:
533
558
  The created Job object with errors stored in job.errors.
534
559
  """
535
- hud_logger = logging.getLogger("hud")
536
- hud_logger.setLevel(logging.CRITICAL)
537
560
 
538
561
  tasks_to_run: list[Task] = []
539
562
  created_job: Job | None = None
540
563
 
564
+ # Get hud logger
565
+ if not verbose:
566
+ logger = logging.getLogger("hud")
567
+ logger.setLevel(logging.CRITICAL)
568
+ logger = logging.getLogger("hud.job")
569
+
541
570
  evalset_id = None
542
571
  if isinstance(task_or_taskset, TaskSet):
543
572
  evalset_id = task_or_taskset.id
544
- await task_or_taskset.fit(agent_cls)
573
+ task_or_taskset.fit(agent_cls)
545
574
 
546
575
  gym_id = None
547
576
  if isinstance(task_or_taskset, Task):
@@ -706,3 +735,39 @@ async def run_job(
706
735
  num_tasks,
707
736
  )
708
737
  return created_job
738
+
739
+
740
+ """
741
+ c7f85f7d-3730-4c9a-85a3-a1dc436c3bd2
742
+
743
+
744
+ de12c3cc-9d9c-4e90-82cc-1d71d30ede54
745
+ 59104743-0a63-4569-a8b5-1eda1a1b55ac
746
+ ff759429-056c-4cde-8851-11e26729ff03
747
+
748
+
749
+ 7b98ea22-e243-4eeb-a6db-79f4a76da2b3
750
+
751
+ 7aad3f7b-d74f-470d-826d-d817f95fdd67
752
+
753
+ e356ede6-074a-49ef-9fcd-69e5bcfbdec9
754
+
755
+ 26cd1192-3991-4d1b-b599-b2bed1bcb606
756
+
757
+ 31ece277-970f-4763-b0c8-bf19a56f56c7
758
+
759
+
760
+ f9b722a0-5f33-466b-bce0-8ece101f2bc6
761
+ 33d1af33-8952-4945-b901-229bcfd88354
762
+
763
+ 6c3d6557-e745-44ab-bc10-300180a81c79
764
+ 6c3d6557-e745-44ab-bc10-300180a81c79
765
+ 502e02b5-9939-4e57-91af-4fcbcb90a979
766
+
767
+ 7aad3f7b-d74f-470d-826d-d817f95fdd67
768
+
769
+
770
+ 31ece277-970f-4763-b0c8-bf19a56f56c7
771
+
772
+
773
+ e356ede6-074a-49ef-9fcd-69e5bcfbdec9"""
hud/server/requests.py CHANGED
@@ -6,6 +6,7 @@ from __future__ import annotations
6
6
 
7
7
  import asyncio
8
8
  import logging
9
+ import ssl
9
10
  import time
10
11
  from typing import Any
11
12
 
@@ -20,7 +21,7 @@ from hud.exceptions import (
20
21
 
21
22
  # Set up logger
22
23
  logger = logging.getLogger("hud.http")
23
- logger.setLevel(logging.DEBUG)
24
+ logger.setLevel(logging.INFO)
24
25
 
25
26
 
26
27
  # Long running requests can take up to 10 minutes.
@@ -37,7 +38,7 @@ async def _handle_retry(
37
38
  ) -> None:
38
39
  """Helper function to handle retry logic and logging."""
39
40
  retry_time = retry_delay * (2 ** (attempt - 1)) # Exponential backoff
40
- logger.warning(
41
+ logger.debug(
41
42
  "%s from %s, retrying in %.2f seconds (attempt %d/%d)",
42
43
  error_msg,
43
44
  url,
@@ -140,6 +141,12 @@ async def make_request(
140
141
  continue
141
142
  else:
142
143
  raise HudNetworkError(f"Network error: {e!s}") from None
144
+ except ssl.SSLError as e:
145
+ if attempt <= max_retries:
146
+ await _handle_retry(attempt, max_retries, retry_delay, url, f"SSL error: {e}")
147
+ continue
148
+ else:
149
+ raise HudNetworkError(f"SSL error: {e!s}") from None
143
150
  except Exception as e:
144
151
  raise HudRequestError(f"Unexpected error: {e!s}") from None
145
152
  raise HudRequestError(f"Request failed after {max_retries} retries with unknown error")
@@ -201,7 +208,7 @@ def make_request_sync(
201
208
  # Check if we got a retriable status code
202
209
  if response.status_code in retry_status_codes and attempt <= max_retries:
203
210
  retry_time = retry_delay * (2 ** (attempt - 1)) # Exponential backoff
204
- logger.warning(
211
+ logger.debug(
205
212
  "Received status %d from %s, retrying in %.2f seconds (attempt %d/%d)",
206
213
  response.status_code,
207
214
  url,
@@ -222,7 +229,7 @@ def make_request_sync(
222
229
  except httpx.RequestError as e:
223
230
  if attempt <= max_retries:
224
231
  retry_time = retry_delay * (2 ** (attempt - 1))
225
- logger.warning(
232
+ logger.debug(
226
233
  "Network error %s from %s, retrying in %.2f seconds (attempt %d/%d)",
227
234
  str(e),
228
235
  url,
@@ -234,6 +241,21 @@ def make_request_sync(
234
241
  continue
235
242
  else:
236
243
  raise HudNetworkError(f"Network error: {e!s}") from None
244
+ except ssl.SSLError as e:
245
+ if attempt <= max_retries:
246
+ retry_time = retry_delay * (2 ** (attempt - 1)) # Exponential backoff
247
+ logger.debug(
248
+ "SSL error %s from %s, retrying in %.2f seconds (attempt %d/%d)",
249
+ str(e),
250
+ url,
251
+ retry_time,
252
+ attempt,
253
+ max_retries,
254
+ )
255
+ time.sleep(retry_time)
256
+ continue
257
+ else:
258
+ raise HudNetworkError(f"SSL error: {e!s}") from None
237
259
  except Exception as e:
238
260
  raise HudRequestError(f"Unexpected error: {e!s}") from None
239
261
  raise HudRequestError(f"Request failed after {max_retries} retries with unknown error")
hud/settings.py CHANGED
@@ -15,7 +15,7 @@ class Settings(BaseSettings):
15
15
  model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="allow")
16
16
 
17
17
  base_url: str = Field(
18
- default="https://orcstaging.hud.so/hud-gym/api",
18
+ default="https://orchestration.hud.so/hud-gym/api",
19
19
  description="Base URL for the HUD API",
20
20
  validation_alias="base_url",
21
21
  )
@@ -44,6 +44,12 @@ class Settings(BaseSettings):
44
44
  validation_alias="TELEMETRY_ENABLED",
45
45
  )
46
46
 
47
+ fancy_logging: bool = Field(
48
+ default=True,
49
+ description="Enable fancy logging for the HUD SDK",
50
+ validation_alias="FANCY_LOGGING",
51
+ )
52
+
47
53
 
48
54
  # Create a singleton instance
49
55
  settings = Settings()
hud/task.py CHANGED
@@ -2,12 +2,12 @@ from __future__ import annotations
2
2
 
3
3
  import tempfile
4
4
  from pathlib import Path
5
- from typing import TYPE_CHECKING, Any
5
+ from typing import TYPE_CHECKING, Any, Literal, cast
6
6
 
7
7
  from inspect_ai.util._sandbox import SandboxEnvironmentSpec
8
- from pydantic import BaseModel
8
+ from pydantic import BaseModel, Field
9
9
 
10
- from hud.types import CustomGym, Gym
10
+ from hud.types import CustomGym, Gym, MetadataKeys, SensitiveData
11
11
  from hud.utils.common import FunctionConfig, FunctionConfigs
12
12
 
13
13
  if TYPE_CHECKING:
@@ -40,28 +40,78 @@ class Task(BaseModel):
40
40
  Attributes:
41
41
  id: The remote task ID (optional if local-only)
42
42
  prompt: The task prompt or instruction
43
+ system_prompt: The system prompt for the evalset (optional)
43
44
  setup: Environment setup configuration (optional)
44
45
  evaluate: Configuration for evaluating responses
45
46
  metadata: Additional task metadata
47
+ sensitive_data: Sensitive data such as API keys, passwords, etc.
46
48
  choices: Multiple choice answer list (for Inspect compatibility)
47
49
  target: Ideal target output (for Inspect compatibility)
48
50
  files: Files that go along with the task (for Inspect compatibility)
49
51
  gym: Environment specification
50
52
  """
51
53
 
52
- id: str | None = None
53
- prompt: str
54
+ id: str | None = None # Remote task ID (optional if local-only)
55
+
56
+ prompt: str # Task prompt or instruction
57
+ system_prompt: str | None = None # System prompt for the evalset (optional)
58
+
59
+ gym: Gym | None = None # Environment specification
60
+
61
+ # Setup and evaluate configurations for the environment (environment specific)
54
62
  setup: FunctionConfigs | None = None
55
63
  evaluate: FunctionConfigs | None = None
56
- gym: Gym | None = None
64
+
65
+ # Overflow configuration for environments that don't conform to the standard
57
66
  config: dict[str, Any] | None = None
58
67
 
68
+ # Sensitive data such as API keys, passwords, etc.
69
+ sensitive_data: SensitiveData = Field(default_factory=dict)
70
+
71
+ # Metadata for the task evaluation, information about the agent (see MetadataKeys)
72
+ metadata: dict[MetadataKeys, Any] = Field(default_factory=dict)
73
+
74
+ # Description of the task, for extra information about its purpose and context
59
75
  description: str | None = None
60
76
 
61
77
  @classmethod
62
78
  def from_dict(cls, data: dict[str, Any]) -> Task:
63
79
  return cls(**data)
64
80
 
81
+ @classmethod
82
+ def from_serialized(cls, data: dict[str, Any]) -> Task:
83
+ gym_data = data.get("gym")
84
+ parsed_gym: Gym | None = gym_data
85
+
86
+ parsed_setup = [(param, entry) for param, entry in data.get("setup", [])]
87
+ parsed_evaluate = [(param, entry) for param, entry in data.get("evaluate", [])]
88
+
89
+ # Convert dict gym data to CustomGym if needed
90
+ if (
91
+ isinstance(gym_data, dict)
92
+ and gym_data.get("type") == "public"
93
+ and gym_data.get("location") in ("local", "remote")
94
+ and gym_data.get("image_or_build_context") is not None
95
+ ):
96
+ parsed_gym = CustomGym(
97
+ type=cast("Literal['public']", gym_data["type"]),
98
+ location=cast("Literal['local', 'remote']", gym_data["location"]),
99
+ image_or_build_context=Path(gym_data["image_or_build_context"]),
100
+ )
101
+
102
+ return cls(
103
+ id=data.get("id"),
104
+ prompt=data.get("prompt", ""),
105
+ system_prompt=data.get("system_prompt"),
106
+ setup=parsed_setup,
107
+ evaluate=parsed_evaluate,
108
+ gym=parsed_gym,
109
+ config=data.get("config"),
110
+ description=data.get("description"),
111
+ sensitive_data=data.get("sensitive_data", {}),
112
+ metadata=data.get("metadata", {}),
113
+ )
114
+
65
115
  @classmethod
66
116
  def from_inspect_sample(cls, sample: Sample) -> Task:
67
117
  """Create a Task from an Inspect dataset sample.
@@ -144,3 +194,31 @@ class Task(BaseModel):
144
194
  if self.gym is None:
145
195
  return
146
196
  self.gym = agent.transfer_gyms.get(self.gym, self.gym)
197
+
198
+ def serialize(self) -> dict[str, Any]:
199
+ if isinstance(self.setup, list):
200
+ parsed_setup = [[param, entry] for param, entry in self.setup]
201
+ else:
202
+ parsed_setup = self.setup
203
+ if isinstance(self.evaluate, list):
204
+ parsed_evaluate = [[param, entry] for param, entry in self.evaluate]
205
+ else:
206
+ parsed_evaluate = self.evaluate
207
+
208
+ if isinstance(self.gym, CustomGym):
209
+ parsed_gym = self.gym.model_dump()
210
+ parsed_gym["image_or_build_context"] = str(parsed_gym["image_or_build_context"])
211
+ else: # is ServerGym
212
+ parsed_gym = self.gym
213
+
214
+ return {
215
+ "id": self.id,
216
+ "prompt": self.prompt,
217
+ "config": self.config,
218
+ "description": self.description,
219
+ "setup": parsed_setup,
220
+ "evaluate": parsed_evaluate,
221
+ "gym": parsed_gym,
222
+ "sensitive_data": self.sensitive_data,
223
+ "metadata": self.metadata,
224
+ }