hud-python 0.3.5__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of hud-python might be problematic. Click here for more details.

Files changed (192) hide show
  1. hud/__init__.py +22 -89
  2. hud/agents/__init__.py +15 -0
  3. hud/agents/art.py +101 -0
  4. hud/agents/base.py +599 -0
  5. hud/{mcp → agents}/claude.py +373 -321
  6. hud/{mcp → agents}/langchain.py +250 -250
  7. hud/agents/misc/__init__.py +7 -0
  8. hud/{agent → agents}/misc/response_agent.py +80 -80
  9. hud/{mcp → agents}/openai.py +352 -334
  10. hud/agents/openai_chat_generic.py +154 -0
  11. hud/{mcp → agents}/tests/__init__.py +1 -1
  12. hud/agents/tests/test_base.py +742 -0
  13. hud/agents/tests/test_claude.py +324 -0
  14. hud/{mcp → agents}/tests/test_client.py +363 -324
  15. hud/{mcp → agents}/tests/test_openai.py +237 -238
  16. hud/cli/__init__.py +617 -0
  17. hud/cli/__main__.py +8 -0
  18. hud/cli/analyze.py +371 -0
  19. hud/cli/analyze_metadata.py +230 -0
  20. hud/cli/build.py +427 -0
  21. hud/cli/clone.py +185 -0
  22. hud/cli/cursor.py +92 -0
  23. hud/cli/debug.py +392 -0
  24. hud/cli/docker_utils.py +83 -0
  25. hud/cli/init.py +281 -0
  26. hud/cli/interactive.py +353 -0
  27. hud/cli/mcp_server.py +756 -0
  28. hud/cli/pull.py +336 -0
  29. hud/cli/push.py +370 -0
  30. hud/cli/remote_runner.py +311 -0
  31. hud/cli/runner.py +160 -0
  32. hud/cli/tests/__init__.py +3 -0
  33. hud/cli/tests/test_analyze.py +284 -0
  34. hud/cli/tests/test_cli_init.py +265 -0
  35. hud/cli/tests/test_cli_main.py +27 -0
  36. hud/cli/tests/test_clone.py +142 -0
  37. hud/cli/tests/test_cursor.py +253 -0
  38. hud/cli/tests/test_debug.py +453 -0
  39. hud/cli/tests/test_mcp_server.py +139 -0
  40. hud/cli/tests/test_utils.py +388 -0
  41. hud/cli/utils.py +263 -0
  42. hud/clients/README.md +143 -0
  43. hud/clients/__init__.py +16 -0
  44. hud/clients/base.py +379 -0
  45. hud/clients/fastmcp.py +222 -0
  46. hud/clients/mcp_use.py +278 -0
  47. hud/clients/tests/__init__.py +1 -0
  48. hud/clients/tests/test_client_integration.py +111 -0
  49. hud/clients/tests/test_fastmcp.py +342 -0
  50. hud/clients/tests/test_protocol.py +188 -0
  51. hud/clients/utils/__init__.py +1 -0
  52. hud/clients/utils/retry_transport.py +160 -0
  53. hud/datasets.py +322 -192
  54. hud/misc/__init__.py +1 -0
  55. hud/{agent → misc}/claude_plays_pokemon.py +292 -283
  56. hud/otel/__init__.py +35 -0
  57. hud/otel/collector.py +142 -0
  58. hud/otel/config.py +164 -0
  59. hud/otel/context.py +536 -0
  60. hud/otel/exporters.py +366 -0
  61. hud/otel/instrumentation.py +97 -0
  62. hud/otel/processors.py +118 -0
  63. hud/otel/tests/__init__.py +1 -0
  64. hud/otel/tests/test_processors.py +197 -0
  65. hud/server/__init__.py +5 -5
  66. hud/server/context.py +114 -0
  67. hud/server/helper/__init__.py +5 -0
  68. hud/server/low_level.py +132 -0
  69. hud/server/server.py +166 -0
  70. hud/server/tests/__init__.py +3 -0
  71. hud/settings.py +73 -79
  72. hud/shared/__init__.py +5 -0
  73. hud/{exceptions.py → shared/exceptions.py} +180 -180
  74. hud/{server → shared}/requests.py +264 -264
  75. hud/shared/tests/test_exceptions.py +157 -0
  76. hud/{server → shared}/tests/test_requests.py +275 -275
  77. hud/telemetry/__init__.py +25 -30
  78. hud/telemetry/instrument.py +379 -0
  79. hud/telemetry/job.py +309 -141
  80. hud/telemetry/replay.py +74 -0
  81. hud/telemetry/trace.py +83 -0
  82. hud/tools/__init__.py +33 -34
  83. hud/tools/base.py +365 -65
  84. hud/tools/bash.py +161 -137
  85. hud/tools/computer/__init__.py +15 -13
  86. hud/tools/computer/anthropic.py +437 -420
  87. hud/tools/computer/hud.py +376 -334
  88. hud/tools/computer/openai.py +295 -292
  89. hud/tools/computer/settings.py +82 -0
  90. hud/tools/edit.py +314 -290
  91. hud/tools/executors/__init__.py +30 -30
  92. hud/tools/executors/base.py +539 -532
  93. hud/tools/executors/pyautogui.py +621 -619
  94. hud/tools/executors/tests/__init__.py +1 -1
  95. hud/tools/executors/tests/test_base_executor.py +338 -338
  96. hud/tools/executors/tests/test_pyautogui_executor.py +165 -165
  97. hud/tools/executors/xdo.py +511 -503
  98. hud/tools/{playwright_tool.py → playwright.py} +412 -379
  99. hud/tools/tests/__init__.py +3 -3
  100. hud/tools/tests/test_base.py +282 -0
  101. hud/tools/tests/test_bash.py +158 -152
  102. hud/tools/tests/test_bash_extended.py +197 -0
  103. hud/tools/tests/test_computer.py +425 -52
  104. hud/tools/tests/test_computer_actions.py +34 -34
  105. hud/tools/tests/test_edit.py +259 -240
  106. hud/tools/tests/test_init.py +27 -27
  107. hud/tools/tests/test_playwright_tool.py +183 -183
  108. hud/tools/tests/test_tools.py +145 -157
  109. hud/tools/tests/test_utils.py +156 -156
  110. hud/tools/types.py +72 -0
  111. hud/tools/utils.py +50 -50
  112. hud/types.py +136 -89
  113. hud/utils/__init__.py +10 -16
  114. hud/utils/async_utils.py +65 -0
  115. hud/utils/design.py +168 -0
  116. hud/utils/mcp.py +55 -0
  117. hud/utils/progress.py +149 -149
  118. hud/utils/telemetry.py +66 -66
  119. hud/utils/tests/test_async_utils.py +173 -0
  120. hud/utils/tests/test_init.py +17 -21
  121. hud/utils/tests/test_progress.py +261 -225
  122. hud/utils/tests/test_telemetry.py +82 -37
  123. hud/utils/tests/test_version.py +8 -8
  124. hud/version.py +7 -7
  125. hud_python-0.4.1.dist-info/METADATA +476 -0
  126. hud_python-0.4.1.dist-info/RECORD +132 -0
  127. hud_python-0.4.1.dist-info/entry_points.txt +3 -0
  128. {hud_python-0.3.5.dist-info → hud_python-0.4.1.dist-info}/licenses/LICENSE +21 -21
  129. hud/adapters/__init__.py +0 -8
  130. hud/adapters/claude/__init__.py +0 -5
  131. hud/adapters/claude/adapter.py +0 -180
  132. hud/adapters/claude/tests/__init__.py +0 -1
  133. hud/adapters/claude/tests/test_adapter.py +0 -519
  134. hud/adapters/common/__init__.py +0 -6
  135. hud/adapters/common/adapter.py +0 -178
  136. hud/adapters/common/tests/test_adapter.py +0 -289
  137. hud/adapters/common/types.py +0 -446
  138. hud/adapters/operator/__init__.py +0 -5
  139. hud/adapters/operator/adapter.py +0 -108
  140. hud/adapters/operator/tests/__init__.py +0 -1
  141. hud/adapters/operator/tests/test_adapter.py +0 -370
  142. hud/agent/__init__.py +0 -19
  143. hud/agent/base.py +0 -126
  144. hud/agent/claude.py +0 -271
  145. hud/agent/langchain.py +0 -215
  146. hud/agent/misc/__init__.py +0 -3
  147. hud/agent/operator.py +0 -268
  148. hud/agent/tests/__init__.py +0 -1
  149. hud/agent/tests/test_base.py +0 -202
  150. hud/env/__init__.py +0 -11
  151. hud/env/client.py +0 -35
  152. hud/env/docker_client.py +0 -349
  153. hud/env/environment.py +0 -446
  154. hud/env/local_docker_client.py +0 -358
  155. hud/env/remote_client.py +0 -212
  156. hud/env/remote_docker_client.py +0 -292
  157. hud/gym.py +0 -130
  158. hud/job.py +0 -773
  159. hud/mcp/__init__.py +0 -17
  160. hud/mcp/base.py +0 -631
  161. hud/mcp/client.py +0 -312
  162. hud/mcp/tests/test_base.py +0 -512
  163. hud/mcp/tests/test_claude.py +0 -294
  164. hud/task.py +0 -149
  165. hud/taskset.py +0 -237
  166. hud/telemetry/_trace.py +0 -347
  167. hud/telemetry/context.py +0 -230
  168. hud/telemetry/exporter.py +0 -575
  169. hud/telemetry/instrumentation/__init__.py +0 -3
  170. hud/telemetry/instrumentation/mcp.py +0 -259
  171. hud/telemetry/instrumentation/registry.py +0 -59
  172. hud/telemetry/mcp_models.py +0 -270
  173. hud/telemetry/tests/__init__.py +0 -1
  174. hud/telemetry/tests/test_context.py +0 -210
  175. hud/telemetry/tests/test_trace.py +0 -312
  176. hud/tools/helper/README.md +0 -56
  177. hud/tools/helper/__init__.py +0 -9
  178. hud/tools/helper/mcp_server.py +0 -78
  179. hud/tools/helper/server_initialization.py +0 -115
  180. hud/tools/helper/utils.py +0 -58
  181. hud/trajectory.py +0 -94
  182. hud/utils/agent.py +0 -37
  183. hud/utils/common.py +0 -256
  184. hud/utils/config.py +0 -120
  185. hud/utils/deprecation.py +0 -115
  186. hud/utils/misc.py +0 -53
  187. hud/utils/tests/test_common.py +0 -277
  188. hud/utils/tests/test_config.py +0 -129
  189. hud_python-0.3.5.dist-info/METADATA +0 -284
  190. hud_python-0.3.5.dist-info/RECORD +0 -120
  191. /hud/{adapters/common → shared}/tests/__init__.py +0 -0
  192. {hud_python-0.3.5.dist-info → hud_python-0.4.1.dist-info}/WHEEL +0 -0
hud/job.py DELETED
@@ -1,773 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import asyncio
4
- import functools
5
- import inspect
6
- import logging
7
- import sys
8
- from collections.abc import Callable, Coroutine
9
- from datetime import datetime
10
- from typing import TYPE_CHECKING, Any, TypeVar, cast
11
-
12
- from pydantic import BaseModel, PrivateAttr, TypeAdapter
13
-
14
- import hud.server
15
- from hud import Response, gym
16
- from hud.agent import ResponseAgent
17
- from hud.settings import settings
18
- from hud.task import Task
19
- from hud.taskset import TaskSet
20
- from hud.trajectory import Trajectory
21
- from hud.utils.progress import StepProgressTracker
22
-
23
- if TYPE_CHECKING:
24
- from hud.adapters.common import Adapter
25
- from hud.agent.base import Agent
26
- from hud.utils.common import Observation
27
-
28
- logger = logging.getLogger("hud.job")
29
-
30
- # Type variable for the decorator
31
- T = TypeVar("T", bound=Callable)
32
-
33
- # Global registry to store active jobs created by decorators
34
- _ACTIVE_JOBS = {}
35
-
36
-
37
- class Job(BaseModel):
38
- """
39
- A job represents a collection of related trajectories.
40
- It holds metadata and provides methods to interact with job data.
41
- Instances should typically be obtained via `create_job`, `load_job`, or the new `run_job`.
42
- """
43
-
44
- id: str
45
- name: str
46
- metadata: dict[str, Any] | None = None
47
- created_at: datetime
48
- status: str
49
-
50
- # Internal cache for trajectories
51
- _trajectories: list[Trajectory] | None = PrivateAttr(default=None)
52
- # Store execution errors for debugging
53
- errors: list[dict[str, Any]] = []
54
-
55
- async def load_trajectories(
56
- self, *, api_key: str | None = None, force_reload: bool = False
57
- ) -> list[Trajectory]:
58
- """
59
- Loads the trajectories associated with this job.
60
- Uses cached results unless force_reload is True.
61
-
62
- Args:
63
- api_key: Optional API key.
64
- force_reload: If True, fetches trajectories from the API even if cached.
65
-
66
- Returns:
67
- List[Trajectory]: The trajectories in the job
68
- """
69
- if self._trajectories is not None and not force_reload:
70
- logger.debug("Returning cached trajectories for Job %s", self.id)
71
- return self._trajectories
72
-
73
- logger.debug("Fetching trajectories for Job %s from API...", self.id)
74
- api_key = api_key or settings.api_key
75
-
76
- try:
77
- data = await hud.server.make_request(
78
- method="GET",
79
- url=f"{settings.base_url}/v2/jobs/{self.id}/trajectories",
80
- api_key=api_key,
81
- )
82
- self._trajectories = TypeAdapter(list[Trajectory]).validate_python(data)
83
- logger.debug("Loaded %d trajectories for Job %s", len(self._trajectories), self.id)
84
- return self._trajectories
85
- except Exception as e:
86
- logger.exception("Failed to load trajectories for Job %s: %s", self.id, e)
87
- self._trajectories = None # Ensure cache is cleared on error
88
- return [] # Return empty list on error
89
-
90
- async def get_analytics(self, *, force_reload: bool = False) -> dict[str, Any]:
91
- """
92
- Calculates and returns analytics for the job based on its trajectories.
93
-
94
- Args:
95
- force_reload: If True, re-fetches trajectories before calculating.
96
-
97
- Returns:
98
- Dictionary containing analytics (e.g., task_count, avg_reward).
99
- """
100
- trajectories = await self.load_trajectories(force_reload=force_reload)
101
-
102
- task_count = len(trajectories)
103
- if task_count == 0:
104
- return {"task_count": 0, "avg_reward": None, "success_rate": None} # Or other default
105
-
106
- total_reward = 0
107
- successful_tasks = 0
108
- valid_rewards = 0
109
-
110
- for traj in trajectories:
111
- # Example: Assume reward is numeric and success is reward >= 1.0
112
- # Adjust based on actual trajectory data structure and evaluation logic
113
- if isinstance(traj.reward, int | float):
114
- total_reward += traj.reward
115
- valid_rewards += 1
116
- if traj.reward >= 1.0:
117
- successful_tasks += 1
118
- # Add more complex logic here if needed based on traj.evaluation_result or metadata
119
-
120
- avg_reward = (total_reward / valid_rewards) if valid_rewards > 0 else None
121
- success_rate = (successful_tasks / task_count) * 100 if task_count > 0 else None
122
-
123
- return {
124
- "task_count": task_count,
125
- "avg_reward": avg_reward,
126
- "success_rate": success_rate,
127
- # Add other relevant stats here
128
- }
129
-
130
-
131
- async def create_job(
132
- name: str,
133
- gym_id: str | None = None,
134
- evalset_id: str | None = None,
135
- metadata: dict[str, Any] | None = None,
136
- ) -> Job:
137
- """
138
- Creates a new job.
139
-
140
- Args:
141
- name: The name of the job
142
- metadata: Metadata for the job
143
-
144
- Returns:
145
- Job: The created job instance
146
- """
147
- api_key = settings.api_key
148
- metadata = metadata or {}
149
-
150
- data = await hud.server.make_request(
151
- method="POST",
152
- url=f"{settings.base_url}/v2/jobs",
153
- json={
154
- "name": name,
155
- "metadata": metadata,
156
- "gym_id": gym_id,
157
- "evalset_id": evalset_id,
158
- },
159
- api_key=api_key,
160
- )
161
-
162
- # Assume the backend API returns the full job data upon creation
163
- # or at least the necessary fields (id, name, metadata, created_at, status)
164
- # If not, we might need to make a subsequent GET request
165
- job_data = data # Adjust if the API response structure is different
166
-
167
- created_at = datetime.fromisoformat(job_data["created_at"].replace("Z", "+00:00"))
168
-
169
- logger.info("View job at https://app.hud.so/jobs/%s.", job_data["id"])
170
-
171
- return Job(
172
- id=job_data["id"],
173
- name=job_data["name"],
174
- metadata=job_data.get("metadata", {}), # Ensure metadata is dict
175
- created_at=created_at, # Parse datetime
176
- status=job_data["status"],
177
- )
178
-
179
-
180
- async def load_job(job_id: str, api_key: str | None = None) -> Job:
181
- """
182
- Retrieves a job by its ID.
183
-
184
- Args:
185
- job_id: The ID of the job to retrieve
186
-
187
- Returns:
188
- Job: The retrieved job instance
189
- """
190
- api_key = api_key or settings.api_key
191
-
192
- data = await hud.server.make_request(
193
- method="GET",
194
- url=f"{settings.base_url}/v2/jobs/{job_id}",
195
- api_key=api_key,
196
- )
197
-
198
- if not data:
199
- raise ValueError(f"Job {job_id} not found")
200
-
201
- # Validate and create the Job instance from the fetched data
202
- return Job.model_validate(data)
203
-
204
-
205
- def job(name: str, metadata: dict[str, Any] | None = None) -> Callable[[T], T]:
206
- """
207
- Decorator to automatically create and associate a job with all environments
208
- created within the decorated function.
209
-
210
- Args:
211
- name: The name of the job
212
- metadata: Additional metadata for the job
213
-
214
- Returns:
215
- A decorator function that creates a job and associates it with environments
216
- """
217
-
218
- def decorator(func: T) -> T:
219
- @functools.wraps(func)
220
- async def wrapper(*args: Any, **kwargs: Any) -> Any:
221
- # Create a job for this function call using the new function
222
- job = await create_job(name=name, metadata=metadata)
223
-
224
- # Store in global registry with a unique key based on function and call
225
- call_id = f"{func.__module__}.{func.__qualname__}_{id(wrapper)}"
226
- _ACTIVE_JOBS[call_id] = job
227
-
228
- try:
229
- # Add the function's frame to the stack for lookup
230
- frame = inspect.currentframe()
231
- if frame:
232
- frame.f_locals["_job_call_id"] = call_id
233
-
234
- # Run the decorated function
235
- result = await func(*args, **kwargs)
236
- return result
237
- finally:
238
- # Clean up
239
- _ACTIVE_JOBS.pop(call_id, None)
240
-
241
- return cast("T", wrapper)
242
-
243
- return decorator
244
-
245
-
246
- def get_active_job() -> Job | None:
247
- """
248
- Get the currently active job from the call stack, if any.
249
- Used internally by gym.make to automatically associate environments with jobs.
250
-
251
- Returns:
252
- The active job or None if no job is active
253
- """
254
- # Walk up the stack to find any frame with _job_call_id
255
- frame = inspect.currentframe()
256
- while frame:
257
- if "_job_call_id" in frame.f_locals:
258
- call_id = frame.f_locals["_job_call_id"]
259
- if call_id in _ACTIVE_JOBS:
260
- return _ACTIVE_JOBS[call_id]
261
- frame = frame.f_back
262
-
263
- return None
264
-
265
-
266
- async def _maybe_resample_action(
267
- obs: Observation, action: Any, response_agent: ResponseAgent
268
- ) -> tuple[Observation, bool]:
269
- if isinstance(action, Response):
270
- action = action.model_dump()
271
- if isinstance(action, dict) and action.get("type") == "response":
272
- response_text = action.get("text", "")
273
- if response_agent and response_text:
274
- try:
275
- decision = await response_agent.determine_response(response_text)
276
- if decision == "CONTINUE":
277
- logger.info("ResponseAgent indicated CONTINUE. for message: %s", response_text)
278
- obs.text = "Yes, please continue."
279
- return obs, False
280
- else:
281
- logger.warning("ResponseAgent indicated STOP for message: %s", response_text)
282
- except Exception as e:
283
- logger.warning("Error using ResponseAgent: %s", e)
284
- return obs, True
285
-
286
-
287
- async def _execute_task(
288
- agent_cls: type[Agent],
289
- adapter_cls: type[Adapter] | None,
290
- agent_kwargs: dict[str, Any] | None,
291
- adapter_kwargs: dict[str, Any] | None,
292
- task: Task,
293
- job_name: str,
294
- task_id: str,
295
- max_steps_per_task: int,
296
- job: Job,
297
- tracker: StepProgressTracker | None = None,
298
- auto_reply_question: bool = False,
299
- # Use semaphores instead of rate limiter
300
- env_creation_semaphore: asyncio.Semaphore | None = None,
301
- agent_predict_semaphore: asyncio.Semaphore | None = None,
302
- ) -> None:
303
- """Helper function to instantiate/run/evaluate a single task, with concurrency limits via
304
- semaphores."""
305
- if tracker:
306
- tracker.start_task(task_id)
307
- env = None
308
- agent_instance: Agent | None = None
309
- status = "error"
310
- error_msg = "Initialization failed"
311
- try:
312
- response_agent = ResponseAgent() if auto_reply_question else None
313
-
314
- adapter_instance = None
315
- if adapter_cls:
316
- adapter_instance = adapter_cls(**(adapter_kwargs or {}))
317
- agent_instance = agent_cls(
318
- adapter=adapter_instance,
319
- **(agent_kwargs or {}),
320
- )
321
- if agent_instance is None:
322
- raise RuntimeError("Agent could not be instantiated")
323
-
324
- agent_name = agent_instance.name
325
- logger.info("Using agent: %s", agent_name)
326
- if task.metadata is None or not isinstance(task.metadata, dict):
327
- task.metadata = {}
328
- task.metadata["agent_name"] = agent_name
329
-
330
- # Environment creation with semaphore
331
- if env_creation_semaphore:
332
- async with env_creation_semaphore:
333
- env = await gym.make(task, job=job)
334
- else:
335
- env = await gym.make(task, job=job)
336
-
337
- if not env:
338
- raise ValueError(f"Environment creation failed for task {task_id}")
339
-
340
- obs_tuple = await env.reset()
341
- if obs_tuple is None:
342
- raise ValueError(f"env.reset() returned None for task {task_id}")
343
- obs, _ = obs_tuple
344
-
345
- step_error = None
346
-
347
- resampled_actions = 0
348
-
349
- for step in range(max_steps_per_task):
350
- action, done = (None, False)
351
- try:
352
- # Agent prediction with semaphore
353
- try:
354
- if agent_predict_semaphore:
355
- async with agent_predict_semaphore:
356
- action, done = await agent_instance.predict(obs)
357
- else:
358
- action, done = await agent_instance.predict(obs)
359
- except Exception as e:
360
- # if agent prediction fails, pass back the error to the agent
361
- logger.exception("[TR: %s] Agent prediction failed: %s", task_id, e)
362
- resampled_actions += 1
363
- if resampled_actions > 5:
364
- logger.warning(
365
- "[TR: %s] Resampled action %d times. Stopping.",
366
- task_id,
367
- resampled_actions,
368
- )
369
- break
370
- continue
371
-
372
- if tracker:
373
- tracker.increment_step(task_id)
374
-
375
- finish = False
376
- if done and response_agent and action and len(action) > 0:
377
- obs, finish = await _maybe_resample_action(obs, action[-1], response_agent)
378
- resampled_actions += 1
379
- if resampled_actions > 5:
380
- logger.warning(
381
- "[TR: %s] Resampled action %d times. Stopping.",
382
- task_id,
383
- resampled_actions,
384
- )
385
- break
386
- if not finish:
387
- continue
388
-
389
- step_result = await env.step(action)
390
- if step_result is None:
391
- terminated = True
392
- else:
393
- obs, _, terminated, _ = step_result
394
- if terminated or done or finish:
395
- break
396
-
397
- except Exception as agent_step_err:
398
- logger.exception(
399
- "[TR: %s] Step %d Error: %s",
400
- task_id,
401
- step + 1,
402
- agent_step_err,
403
- )
404
- step_error = f"Error at step {step + 1}: {agent_step_err}"
405
- # Store step error in job
406
- job.errors.append(
407
- {
408
- "task_id": task_id,
409
- "type": "step_error",
410
- "step": step + 1,
411
- "error": str(agent_step_err),
412
- "timestamp": datetime.now().isoformat(),
413
- }
414
- )
415
- continue
416
- else:
417
- logger.warning("[TR: %s] Max steps reached.", task_id)
418
-
419
- # --- Evaluate Task ---
420
- evaluation_result = None
421
- if step_error:
422
- status = "error"
423
- error_msg = step_error
424
- else:
425
- try:
426
- evaluation_result = await env.evaluate()
427
- status = "completed"
428
- error_msg = None
429
- # logger.info("Evaluation result: %s", evaluation_result)
430
- except Exception as eval_err:
431
- logger.exception(
432
- "[TR: %s] Evaluation Error: %s",
433
- task_id,
434
- eval_err,
435
- )
436
- status = "error"
437
- error_msg = f"Evaluation failed: {eval_err}"
438
- # Store evaluation error in job
439
- job.errors.append(
440
- {
441
- "task_id": task_id,
442
- "type": "evaluation_error",
443
- "error": str(eval_err),
444
- "timestamp": datetime.now().isoformat(),
445
- }
446
- )
447
-
448
- except Exception as e:
449
- logger.exception("[TR: %s] Setup/Run Error: %s", task_id, e)
450
- status = "error"
451
- error_msg = str(e)
452
- # Store setup/initialization error in job
453
- job.errors.append(
454
- {
455
- "task_id": task_id,
456
- "type": "setup_error",
457
- "error": str(e),
458
- "timestamp": datetime.now().isoformat(),
459
- }
460
- )
461
-
462
- finally:
463
- if tracker:
464
- tracker.finish_task(task_id)
465
- if env:
466
- try:
467
- await env.close()
468
- except Exception as close_err:
469
- logger.exception("[TR: %s] Close Error: %s", task_id, close_err)
470
- # Store environment close error in job
471
- job.errors.append(
472
- {
473
- "task_id": task_id,
474
- "type": "env_close_error",
475
- "error": str(close_err),
476
- "timestamp": datetime.now().isoformat(),
477
- }
478
- )
479
-
480
- log_suffix = f" Error: {error_msg}" if status == "error" else f" Eval: {evaluation_result}"
481
- logger.info(
482
- "[TR: %s] Finished local execution. Status: %s.%s",
483
- task_id,
484
- status,
485
- log_suffix,
486
- )
487
-
488
-
489
- async def _progress_monitor(tracker: StepProgressTracker, interval: float = 1.0) -> None:
490
- """Coroutine to periodically display progress using the tracker."""
491
- try:
492
- while not tracker.is_finished():
493
- sys.stderr.write(f"\r{tracker.display()}")
494
- sys.stderr.flush()
495
- await asyncio.sleep(interval)
496
- sys.stderr.write(f"\r{tracker.display()}\n")
497
- sys.stderr.flush()
498
- logger.debug("Progress monitor finished.")
499
- except asyncio.CancelledError:
500
- sys.stderr.write("\nProgress monitor cancelled.\n")
501
- sys.stderr.flush()
502
- logger.debug("Progress monitor cancelled.")
503
- except Exception as e:
504
- sys.stderr.write(f"\nProgress monitor error: {e}\n")
505
- sys.stderr.flush()
506
- logger.exception("Progress monitor error: %s", e)
507
-
508
-
509
- # --- New run_job function ---
510
-
511
-
512
- async def run_job(
513
- agent_cls: type[Agent],
514
- task_or_taskset: Task | TaskSet,
515
- job_name: str,
516
- auto_reply_question: bool = False,
517
- adapter_cls: type[Adapter] | None = None,
518
- agent_kwargs: dict[str, Any] | None = None,
519
- adapter_kwargs: dict[str, Any] | None = None,
520
- max_steps_per_task: int = 20,
521
- run_parallel: bool = True,
522
- job_metadata: dict[str, Any] | None = None,
523
- show_progress: bool = True,
524
- verbose: bool = False,
525
- # Concurrency control with semaphores
526
- max_concurrent_env_creations: int | None = 30, # Limits gym.make calls
527
- max_concurrent_agent_predictions: int | None = None, # No limit on LLM calls
528
- max_concurrent_tasks: int | None = 30, # Limits overall task concurrency
529
- ) -> Job:
530
- """
531
- Creates Job, executes tasks locally, linking them to the Job.
532
- Instantiates agent/adapter per task. Shows step-based progress.
533
-
534
- Controls concurrency in three ways:
535
- 1. Limits concurrent environment creations
536
- 2. Limits concurrent agent predictions
537
- 3. Limits overall concurrent tasks (when run_parallel=True)
538
-
539
- All concurrency controls use semaphores for reliability.
540
- Tracks all errors that occur during execution in job.errors.
541
-
542
- Args:
543
- agent_cls: Agent class to instantiate.
544
- task_or_taskset: Task or TaskSet to run.
545
- job_name: Name for the Job.
546
- adapter_cls: Optional Adapter class.
547
- agent_kwargs: Optional kwargs for agent constructor.
548
- adapter_kwargs: Optional kwargs for adapter constructor.
549
- max_steps_per_task: Step limit per task.
550
- run_parallel: Run TaskSet tasks concurrently if True (limited by max_concurrent_tasks).
551
- job_metadata: Metadata for the created Job.
552
- show_progress: Display the step-based progress tracker.
553
- max_concurrent_env_creations: Max concurrent environment creation calls.
554
- max_concurrent_agent_predictions: Max concurrent agent prediction calls.
555
- max_concurrent_tasks: Max number of tasks to run actively at the same time.
556
-
557
- Returns:
558
- The created Job object with errors stored in job.errors.
559
- """
560
-
561
- tasks_to_run: list[Task] = []
562
- created_job: Job | None = None
563
-
564
- # Get hud logger
565
- if not verbose:
566
- logger = logging.getLogger("hud")
567
- logger.setLevel(logging.CRITICAL)
568
- logger = logging.getLogger("hud.job")
569
-
570
- evalset_id = None
571
- if isinstance(task_or_taskset, TaskSet):
572
- evalset_id = task_or_taskset.id
573
- task_or_taskset.fit(agent_cls)
574
-
575
- gym_id = None
576
- if isinstance(task_or_taskset, Task):
577
- gym_id = task_or_taskset.gym if isinstance(task_or_taskset.gym, str) else None
578
- elif isinstance(task_or_taskset, TaskSet):
579
- gym_id = (
580
- task_or_taskset.tasks[0].gym if isinstance(task_or_taskset.tasks[0].gym, str) else None
581
- )
582
-
583
- # --- Create Job ---
584
- try:
585
- logger.info("Creating job with name: '%s'", job_name)
586
- created_job = await create_job(
587
- name=job_name,
588
- metadata=job_metadata,
589
- evalset_id=evalset_id,
590
- gym_id=gym_id,
591
- )
592
- # logger.info("Created job with ID: %s", created_job.id)
593
- except Exception as e:
594
- logger.exception("Failed to create job '%s': %s", job_name, e)
595
- raise
596
-
597
- # --- Task Setup ---
598
- is_taskset = isinstance(task_or_taskset, TaskSet)
599
- if is_taskset:
600
- tasks_to_run = task_or_taskset.tasks if task_or_taskset.tasks else []
601
- elif isinstance(task_or_taskset, Task):
602
- tasks_to_run = [task_or_taskset]
603
- run_parallel = False
604
- else:
605
- raise TypeError("task_or_taskset must be either a Task or a TaskSet")
606
-
607
- if not tasks_to_run:
608
- logger.warning("Job '%s' (%s): No tasks found to run.", created_job.name, created_job.id)
609
- return created_job
610
-
611
- task_ids = [(str(task.id) if task.id else f"task_{i}") for i, task in enumerate(tasks_to_run)]
612
- num_tasks = len(tasks_to_run)
613
-
614
- # --- Create semaphores for concurrency control ---
615
- env_creation_sema = None
616
- if max_concurrent_env_creations and max_concurrent_env_creations > 0:
617
- env_creation_sema = asyncio.Semaphore(max_concurrent_env_creations)
618
- logger.info(
619
- "Limiting concurrent environment creations to %d.", max_concurrent_env_creations
620
- )
621
-
622
- agent_predict_sema = None
623
- if max_concurrent_agent_predictions and max_concurrent_agent_predictions > 0:
624
- agent_predict_sema = asyncio.Semaphore(max_concurrent_agent_predictions)
625
- logger.info(
626
- "Limiting concurrent agent predictions to %d.", max_concurrent_agent_predictions
627
- )
628
- else:
629
- logger.info("No limit on concurrent agent predictions.")
630
-
631
- task_execution_sema = None
632
- effective_concurrency = num_tasks # Default to running all if parallel
633
- if run_parallel and max_concurrent_tasks and max_concurrent_tasks > 0:
634
- effective_concurrency = min(num_tasks, max_concurrent_tasks)
635
- task_execution_sema = asyncio.Semaphore(effective_concurrency)
636
- logger.info("Limiting concurrent task executions to %d.", effective_concurrency)
637
- elif not run_parallel:
638
- effective_concurrency = 1 # Sequential means concurrency of 1
639
-
640
- # --- Instantiate Tracker & Start Monitor ---
641
- tracker = None
642
- monitor_task = None
643
- if show_progress and num_tasks > 0:
644
- tracker = StepProgressTracker(total_tasks=num_tasks, max_steps_per_task=max_steps_per_task)
645
- monitor_task = asyncio.create_task(_progress_monitor(tracker))
646
-
647
- # --- Execute Tasks ---
648
- job_desc_suffix = f" (Job ID: {created_job.id})"
649
-
650
- async def task_wrapper(task_coro: Coroutine, semaphore: asyncio.Semaphore | None) -> None:
651
- if semaphore:
652
- async with semaphore:
653
- await task_coro
654
- else:
655
- await task_coro
656
-
657
- try:
658
- if run_parallel and is_taskset:
659
- logger.info(
660
- "Job '%s'%s: Running %d tasks with concurrency %d.",
661
- created_job.name,
662
- job_desc_suffix,
663
- num_tasks,
664
- effective_concurrency,
665
- )
666
-
667
- task_coroutines = [
668
- _execute_task(
669
- agent_cls=agent_cls,
670
- adapter_cls=adapter_cls,
671
- agent_kwargs=agent_kwargs,
672
- adapter_kwargs=adapter_kwargs,
673
- task=task,
674
- job_name=created_job.name,
675
- task_id=task_id,
676
- max_steps_per_task=max_steps_per_task,
677
- job=created_job,
678
- tracker=tracker,
679
- env_creation_semaphore=env_creation_sema,
680
- agent_predict_semaphore=agent_predict_sema,
681
- auto_reply_question=auto_reply_question,
682
- )
683
- for task, task_id in zip(tasks_to_run, task_ids, strict=True)
684
- ]
685
-
686
- # Wrap coroutines with semaphore management if limiting concurrency
687
- wrapped_tasks = [
688
- task_wrapper(coro, task_execution_sema) for i, coro in enumerate(task_coroutines)
689
- ]
690
-
691
- # Run all wrapped tasks
692
- await asyncio.gather(*wrapped_tasks)
693
-
694
- else:
695
- # SEQUENTIAL (or single task)
696
- logger.info(
697
- "Job '%s'%s: Running %d tasks sequentially.",
698
- created_job.name,
699
- job_desc_suffix,
700
- num_tasks,
701
- )
702
- for i, task in enumerate(tasks_to_run):
703
- task_id = task_ids[i]
704
- await _execute_task(
705
- agent_cls=agent_cls,
706
- adapter_cls=adapter_cls,
707
- agent_kwargs=agent_kwargs,
708
- adapter_kwargs=adapter_kwargs,
709
- task=task,
710
- job_name=created_job.name,
711
- task_id=task_id,
712
- max_steps_per_task=max_steps_per_task,
713
- job=created_job,
714
- tracker=tracker,
715
- env_creation_semaphore=env_creation_sema,
716
- agent_predict_semaphore=agent_predict_sema,
717
- auto_reply_question=auto_reply_question,
718
- )
719
-
720
- finally:
721
- # Ensure monitor task is stopped and awaited cleanly
722
- if monitor_task is not None and not monitor_task.done():
723
- monitor_task.cancel()
724
- try:
725
- await monitor_task
726
- except asyncio.CancelledError:
727
- pass
728
- except Exception as e:
729
- logger.error("Error awaiting progress monitor task: %s", e)
730
-
731
- logger.info(
732
- "Job '%s'%s finished local execution phase for %d tasks.",
733
- created_job.name,
734
- job_desc_suffix,
735
- num_tasks,
736
- )
737
- return created_job
738
-
739
-
740
- """
741
- c7f85f7d-3730-4c9a-85a3-a1dc436c3bd2
742
-
743
-
744
- de12c3cc-9d9c-4e90-82cc-1d71d30ede54
745
- 59104743-0a63-4569-a8b5-1eda1a1b55ac
746
- ff759429-056c-4cde-8851-11e26729ff03
747
-
748
-
749
- 7b98ea22-e243-4eeb-a6db-79f4a76da2b3
750
-
751
- 7aad3f7b-d74f-470d-826d-d817f95fdd67
752
-
753
- e356ede6-074a-49ef-9fcd-69e5bcfbdec9
754
-
755
- 26cd1192-3991-4d1b-b599-b2bed1bcb606
756
-
757
- 31ece277-970f-4763-b0c8-bf19a56f56c7
758
-
759
-
760
- f9b722a0-5f33-466b-bce0-8ece101f2bc6
761
- 33d1af33-8952-4945-b901-229bcfd88354
762
-
763
- 6c3d6557-e745-44ab-bc10-300180a81c79
764
- 6c3d6557-e745-44ab-bc10-300180a81c79
765
- 502e02b5-9939-4e57-91af-4fcbcb90a979
766
-
767
- 7aad3f7b-d74f-470d-826d-d817f95fdd67
768
-
769
-
770
- 31ece277-970f-4763-b0c8-bf19a56f56c7
771
-
772
-
773
- e356ede6-074a-49ef-9fcd-69e5bcfbdec9"""