agentexec 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentexec/__init__.py +73 -0
- agentexec/activity/__init__.py +50 -0
- agentexec/activity/models.py +294 -0
- agentexec/activity/schemas.py +70 -0
- agentexec/activity/tracker.py +267 -0
- agentexec/config.py +72 -0
- agentexec/core/__init__.py +0 -0
- agentexec/core/models.py +23 -0
- agentexec/core/queue.py +109 -0
- agentexec/core/redis_client.py +40 -0
- agentexec/core/task.py +132 -0
- agentexec/core/worker.py +304 -0
- agentexec/runners/__init__.py +13 -0
- agentexec/runners/base.py +135 -0
- agentexec/runners/openai.py +237 -0
- agentexec-0.1.0.dist-info/METADATA +370 -0
- agentexec-0.1.0.dist-info/RECORD +18 -0
- agentexec-0.1.0.dist-info/WHEEL +4 -0
agentexec/core/worker.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import logging
|
|
5
|
+
import multiprocessing as mp
|
|
6
|
+
from collections.abc import Callable
|
|
7
|
+
from multiprocessing.synchronize import Event as EventClass
|
|
8
|
+
from typing import Any, ClassVar, Generic, TypeVar, cast
|
|
9
|
+
|
|
10
|
+
from sqlalchemy import Engine
|
|
11
|
+
from sqlalchemy.orm import Session, sessionmaker
|
|
12
|
+
|
|
13
|
+
from agentexec.config import CONF
|
|
14
|
+
from agentexec.core.task import Task, TaskHandler
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
__all__ = [
|
|
20
|
+
"Worker",
|
|
21
|
+
"WorkerPool",
|
|
22
|
+
"get_worker_session",
|
|
23
|
+
]
|
|
24
|
+
|
|
25
|
+
T = TypeVar("T")
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class classproperty(Generic[T]):
|
|
29
|
+
"""Decorator for class-level properties.
|
|
30
|
+
|
|
31
|
+
Generic decorator that preserves the return type of the decorated method.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(self, func: Callable[[Any], T]) -> None:
|
|
35
|
+
self.func = func
|
|
36
|
+
|
|
37
|
+
def __get__(self, obj: Any, owner: type) -> T:
|
|
38
|
+
return self.func(owner)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _build_session(engine: Engine) -> Session:
|
|
42
|
+
"""Helper to build a new SQLAlchemy session from engine."""
|
|
43
|
+
SessionLocal = sessionmaker(bind=engine, autocommit=False, autoflush=False)
|
|
44
|
+
return cast(Session, SessionLocal())
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def get_worker_session() -> Session:
|
|
48
|
+
"""Get the current worker's database session.
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
Session from the current worker
|
|
52
|
+
|
|
53
|
+
Raises:
|
|
54
|
+
RuntimeError: If called outside of a worker process
|
|
55
|
+
"""
|
|
56
|
+
return Worker.session
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class Worker:
|
|
60
|
+
"""Individual worker process with isolated state.
|
|
61
|
+
|
|
62
|
+
Each worker maintains its own database session, event loop, and processes
|
|
63
|
+
tasks from the queue independently. Workers run in separate processes for
|
|
64
|
+
true parallelism.
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
# Class-level reference to current worker in this process
|
|
68
|
+
current: ClassVar[Worker | None] = None
|
|
69
|
+
|
|
70
|
+
# Instance attributes
|
|
71
|
+
_worker_id: int
|
|
72
|
+
_handlers: dict[str, TaskHandler]
|
|
73
|
+
_shutdown_event: EventClass
|
|
74
|
+
_session: Session
|
|
75
|
+
_loop: asyncio.AbstractEventLoop
|
|
76
|
+
|
|
77
|
+
def __init__(
|
|
78
|
+
self,
|
|
79
|
+
worker_id: int,
|
|
80
|
+
engine: Engine,
|
|
81
|
+
handlers: dict[str, TaskHandler],
|
|
82
|
+
shutdown_event: EventClass,
|
|
83
|
+
):
|
|
84
|
+
"""Initialize worker with isolated state.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
worker_id: Unique identifier for this worker
|
|
88
|
+
engine: SQLAlchemy engine to create session from
|
|
89
|
+
handlers: Task handler registry (shared read-only)
|
|
90
|
+
shutdown_event: Multiprocessing event for coordinated shutdown
|
|
91
|
+
"""
|
|
92
|
+
self._worker_id = worker_id
|
|
93
|
+
self._handlers = handlers
|
|
94
|
+
self._shutdown_event = shutdown_event
|
|
95
|
+
|
|
96
|
+
# Create worker-owned resources
|
|
97
|
+
self._session = _build_session(engine)
|
|
98
|
+
self._loop = asyncio.new_event_loop()
|
|
99
|
+
asyncio.set_event_loop(self._loop)
|
|
100
|
+
|
|
101
|
+
@classmethod
|
|
102
|
+
def run_in_process(
|
|
103
|
+
cls,
|
|
104
|
+
worker_id: int,
|
|
105
|
+
engine: Engine,
|
|
106
|
+
handlers: dict[str, TaskHandler],
|
|
107
|
+
shutdown_event: EventClass,
|
|
108
|
+
) -> None:
|
|
109
|
+
"""Entry point for running a worker in a new process.
|
|
110
|
+
|
|
111
|
+
Creates a Worker instance and runs it. This is called by multiprocessing
|
|
112
|
+
to start a worker in a new process.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
worker_id: Unique identifier for this worker
|
|
116
|
+
engine: SQLAlchemy engine to create session from
|
|
117
|
+
handlers: Task handler registry
|
|
118
|
+
shutdown_event: Multiprocessing event for coordinated shutdown
|
|
119
|
+
"""
|
|
120
|
+
instance = cls(worker_id, engine, handlers, shutdown_event)
|
|
121
|
+
Worker.current = instance
|
|
122
|
+
instance.run()
|
|
123
|
+
|
|
124
|
+
@classproperty
|
|
125
|
+
def session(cls) -> Session:
|
|
126
|
+
"""Get the current worker's database session.
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
Session from the current worker
|
|
130
|
+
|
|
131
|
+
Raises:
|
|
132
|
+
RuntimeError: If called outside of a worker process
|
|
133
|
+
"""
|
|
134
|
+
if cls.current is None:
|
|
135
|
+
raise RuntimeError("Worker.session called outside of worker context")
|
|
136
|
+
|
|
137
|
+
return cls.current._session
|
|
138
|
+
|
|
139
|
+
def run(self) -> None:
|
|
140
|
+
"""Main worker loop - polls queue and processes tasks."""
|
|
141
|
+
from agentexec.core.queue import dequeue
|
|
142
|
+
|
|
143
|
+
logger.info(f"Worker {self._worker_id} starting")
|
|
144
|
+
|
|
145
|
+
try:
|
|
146
|
+
while not self._shutdown_event.is_set():
|
|
147
|
+
# Dequeue task (blocks with timeout to check shutdown event)
|
|
148
|
+
if (task := dequeue()) is not None:
|
|
149
|
+
self._process_task(task)
|
|
150
|
+
except Exception as e:
|
|
151
|
+
logger.exception(f"Worker {self._worker_id} fatal error: {e}")
|
|
152
|
+
raise
|
|
153
|
+
finally:
|
|
154
|
+
self._cleanup()
|
|
155
|
+
|
|
156
|
+
def _process_task(self, task: Task) -> None:
|
|
157
|
+
"""Process a single task from the queue.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
task: Task to process
|
|
161
|
+
"""
|
|
162
|
+
logger.info(f"Worker {self._worker_id} processing task: {task.task_name}")
|
|
163
|
+
|
|
164
|
+
try:
|
|
165
|
+
handler = self._handlers[task.task_name]
|
|
166
|
+
task.started()
|
|
167
|
+
|
|
168
|
+
if asyncio.iscoroutinefunction(handler):
|
|
169
|
+
self._loop.run_until_complete(handler(**task.handler_kwargs))
|
|
170
|
+
else:
|
|
171
|
+
handler(**task.handler_kwargs)
|
|
172
|
+
|
|
173
|
+
task.completed()
|
|
174
|
+
logger.info(f"Worker {self._worker_id} completed task: {task.task_name}")
|
|
175
|
+
|
|
176
|
+
except KeyError:
|
|
177
|
+
logger.error(f"No handler registered for task: {task.task_name}")
|
|
178
|
+
raise RuntimeError(f"No handler for task: {task.task_name}")
|
|
179
|
+
except Exception as e:
|
|
180
|
+
task.errored(e)
|
|
181
|
+
logger.exception(f"Worker {self._worker_id} error executing task {task.task_name}: {e}")
|
|
182
|
+
|
|
183
|
+
def _cleanup(self) -> None:
|
|
184
|
+
"""Clean up worker resources on shutdown."""
|
|
185
|
+
self._session.close()
|
|
186
|
+
logger.debug(f"Worker {self._worker_id} closed database session")
|
|
187
|
+
|
|
188
|
+
self._loop.close()
|
|
189
|
+
logger.info(f"Worker {self._worker_id} shutting down")
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
class WorkerPool:
|
|
193
|
+
"""Manages a pool of worker processes for background task execution.
|
|
194
|
+
|
|
195
|
+
The WorkerPool coordinates multiple worker processes, handles task handler
|
|
196
|
+
registration, and manages graceful shutdown. Each worker runs in a separate
|
|
197
|
+
process with its own isolated state (session, event loop).
|
|
198
|
+
|
|
199
|
+
This allows multiple pools to coexist if needed, each managing their own
|
|
200
|
+
set of workers and handlers.
|
|
201
|
+
|
|
202
|
+
Example:
|
|
203
|
+
from sqlalchemy import create_engine
|
|
204
|
+
|
|
205
|
+
engine = create_engine("sqlite:///agents.db")
|
|
206
|
+
pool = WorkerPool(engine=engine)
|
|
207
|
+
|
|
208
|
+
@pool.task("research_company")
|
|
209
|
+
async def research_company(payload: dict, agent_id: str):
|
|
210
|
+
company_name = payload["company_name"]
|
|
211
|
+
# Task implementation
|
|
212
|
+
pass
|
|
213
|
+
|
|
214
|
+
pool.start()
|
|
215
|
+
"""
|
|
216
|
+
|
|
217
|
+
_engine: Engine
|
|
218
|
+
_handlers: dict[str, TaskHandler]
|
|
219
|
+
_processes: list[mp.Process]
|
|
220
|
+
_shutdown_event: EventClass
|
|
221
|
+
|
|
222
|
+
def __init__(self, engine: Engine) -> None:
|
|
223
|
+
"""Initialize the worker pool.
|
|
224
|
+
|
|
225
|
+
Args:
|
|
226
|
+
engine: SQLAlchemy engine for activity tracking.
|
|
227
|
+
Each worker will create its own session from this engine.
|
|
228
|
+
|
|
229
|
+
Configuration is loaded from agentexec.CONF:
|
|
230
|
+
- num_workers: Number of worker processes to spawn
|
|
231
|
+
- queue_name: Name of the Redis list to use as task queue
|
|
232
|
+
- redis_url: Redis connection URL (via get_redis())
|
|
233
|
+
"""
|
|
234
|
+
self._engine = engine
|
|
235
|
+
self._handlers: dict[str, TaskHandler] = {}
|
|
236
|
+
self._processes: list[mp.Process] = []
|
|
237
|
+
self._shutdown_event: EventClass = mp.Event()
|
|
238
|
+
|
|
239
|
+
def task(self, name: str) -> Callable[[TaskHandler], TaskHandler]:
|
|
240
|
+
"""Decorator to register a task handler.
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
name: Task type name that will be used when enqueueing.
|
|
244
|
+
|
|
245
|
+
Returns:
|
|
246
|
+
Decorator function that registers the handler.
|
|
247
|
+
|
|
248
|
+
Example:
|
|
249
|
+
@pool.task("research_company")
|
|
250
|
+
async def research_company(payload: dict, agent_id: str):
|
|
251
|
+
company_name = payload["company_name"]
|
|
252
|
+
# Implementation
|
|
253
|
+
pass
|
|
254
|
+
"""
|
|
255
|
+
|
|
256
|
+
def decorator(func: TaskHandler) -> TaskHandler:
|
|
257
|
+
self._handlers[name] = func
|
|
258
|
+
logger.info(f"Registered task handler: {name}")
|
|
259
|
+
return func
|
|
260
|
+
|
|
261
|
+
return decorator
|
|
262
|
+
|
|
263
|
+
def start(self) -> None:
|
|
264
|
+
"""Start the worker processes.
|
|
265
|
+
|
|
266
|
+
Spawns N worker processes (configured via num_workers) that will poll
|
|
267
|
+
the Redis queue and execute registered task handlers.
|
|
268
|
+
|
|
269
|
+
Each worker runs independently with its own session and event loop.
|
|
270
|
+
"""
|
|
271
|
+
logger.info(f"Starting {CONF.num_workers} worker processes")
|
|
272
|
+
|
|
273
|
+
for worker_id in range(CONF.num_workers):
|
|
274
|
+
process = mp.Process(
|
|
275
|
+
target=Worker.run_in_process,
|
|
276
|
+
args=(worker_id, self._engine, self._handlers, self._shutdown_event),
|
|
277
|
+
daemon=False,
|
|
278
|
+
)
|
|
279
|
+
process.start()
|
|
280
|
+
self._processes.append(process)
|
|
281
|
+
logger.info(f"Started worker process {worker_id} (PID: {process.pid})")
|
|
282
|
+
|
|
283
|
+
def shutdown(self, timeout: int | None = None) -> None:
|
|
284
|
+
"""Gracefully shutdown all worker processes.
|
|
285
|
+
|
|
286
|
+
Args:
|
|
287
|
+
timeout: Maximum seconds to wait for workers to finish current task.
|
|
288
|
+
Defaults to CONF.graceful_shutdown_timeout.
|
|
289
|
+
"""
|
|
290
|
+
if timeout is None:
|
|
291
|
+
timeout = CONF.graceful_shutdown_timeout
|
|
292
|
+
|
|
293
|
+
logger.info("Initiating graceful shutdown of worker pool")
|
|
294
|
+
self._shutdown_event.set()
|
|
295
|
+
|
|
296
|
+
for process in self._processes:
|
|
297
|
+
process.join(timeout=timeout)
|
|
298
|
+
if process.is_alive():
|
|
299
|
+
logger.warning(f"Worker process {process.pid} did not stop gracefully, terminating")
|
|
300
|
+
process.terminate()
|
|
301
|
+
process.join(timeout=5)
|
|
302
|
+
|
|
303
|
+
self._processes.clear()
|
|
304
|
+
logger.info("Worker pool shutdown complete")
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
"""Agent runners with activity tracking and lifecycle management."""
|
|
2
|
+
|
|
3
|
+
from agentexec.runners.base import BaseAgentRunner
|
|
4
|
+
|
|
5
|
+
__all__ = ["BaseAgentRunner"]
|
|
6
|
+
|
|
7
|
+
# OpenAI runner is only available if agents package is installed
|
|
8
|
+
try:
|
|
9
|
+
from agentexec.runners.openai import OpenAIRunner
|
|
10
|
+
|
|
11
|
+
__all__.append("OpenAIRunner")
|
|
12
|
+
except ImportError:
|
|
13
|
+
pass
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import Any
|
|
3
|
+
from abc import ABC
|
|
4
|
+
import logging
|
|
5
|
+
import uuid
|
|
6
|
+
|
|
7
|
+
from agentexec import activity
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class BaseAgentRunner(ABC):
|
|
13
|
+
"""Abstract base class for agent runners with activity tracking.
|
|
14
|
+
|
|
15
|
+
This base class provides:
|
|
16
|
+
- Automatic activity tracking (QUEUED -> RUNNING -> COMPLETE/ERROR)
|
|
17
|
+
- Status update tool for agent self-reporting
|
|
18
|
+
- Common lifecycle management
|
|
19
|
+
- Error handling and recovery patterns
|
|
20
|
+
|
|
21
|
+
Subclasses must implement:
|
|
22
|
+
- _execute_agent(): The actual agent execution logic
|
|
23
|
+
- max_turns_exceptions: Tuple of exception classes for max turns errors
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
agent_id: uuid.UUID
|
|
27
|
+
max_turns_recovery: bool
|
|
28
|
+
recovery_turns: int
|
|
29
|
+
|
|
30
|
+
prompts: _RunnerPrompts
|
|
31
|
+
tools: _RunnerTools
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
agent_id: uuid.UUID,
|
|
36
|
+
*,
|
|
37
|
+
max_turns_recovery: bool = True,
|
|
38
|
+
recovery_turns: int = 5,
|
|
39
|
+
wrap_up_prompt: str | None = None,
|
|
40
|
+
report_status_prompt: str | None = None,
|
|
41
|
+
) -> None:
|
|
42
|
+
"""Initialize the runner.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
max_turns_recovery: Enable automatic recovery when max turns exceeded.
|
|
46
|
+
wrap_up_prompt: Prompt to use for recovery run.
|
|
47
|
+
recovery_turns: Number of turns allowed for recovery.
|
|
48
|
+
status_prompt: Instruction snippet about using the status tool.
|
|
49
|
+
"""
|
|
50
|
+
self.agent_id = agent_id
|
|
51
|
+
self.max_turns_recovery = max_turns_recovery
|
|
52
|
+
self.recovery_turns = recovery_turns
|
|
53
|
+
|
|
54
|
+
# Tools namespace for accessing runner-provided tools
|
|
55
|
+
self.prompts = _RunnerPrompts(
|
|
56
|
+
report_status=report_status_prompt,
|
|
57
|
+
wrap_up=wrap_up_prompt,
|
|
58
|
+
)
|
|
59
|
+
self.tools = _RunnerTools(self.agent_id)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class _RunnerPrompts:
|
|
63
|
+
"""Namespace for runner-provided prompts.
|
|
64
|
+
|
|
65
|
+
Accessed via runner.prompts.*
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
report_status: str = (
|
|
69
|
+
"Using report_activity tool:\n"
|
|
70
|
+
" - Always report your current activity before you start a new step using the report_activity tool. \n"
|
|
71
|
+
" - Include a brief message about the task and context you are operating in (10 words or less). \n"
|
|
72
|
+
" - Don't use internal data or underlying system info and instead focus on what the user would care about. \n"
|
|
73
|
+
" - This informs the user of your progress as they have no visibility into your internal operations. \n"
|
|
74
|
+
" - You can call multiple tools in parallel per step so don't waste an entire step on this.\n"
|
|
75
|
+
" - Call it at the top of your list of the next round of tool uses; we should be careful to minimize turns used.\n"
|
|
76
|
+
)
|
|
77
|
+
wrap_up: str = "Please summarize your findings."
|
|
78
|
+
|
|
79
|
+
def __init__(
|
|
80
|
+
self,
|
|
81
|
+
*,
|
|
82
|
+
wrap_up: str | None = None,
|
|
83
|
+
report_status: str | None = None,
|
|
84
|
+
) -> None:
|
|
85
|
+
if wrap_up is not None:
|
|
86
|
+
self.wrap_up = wrap_up
|
|
87
|
+
if report_status is not None:
|
|
88
|
+
self.report_status = report_status
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class _RunnerTools:
|
|
92
|
+
"""Namespace for runner-provided tools.
|
|
93
|
+
|
|
94
|
+
Accessed via runner.tools.*
|
|
95
|
+
"""
|
|
96
|
+
|
|
97
|
+
_agent_id: uuid.UUID
|
|
98
|
+
|
|
99
|
+
def __init__(self, agent_id: uuid.UUID) -> None:
|
|
100
|
+
self._agent_id = agent_id
|
|
101
|
+
|
|
102
|
+
@property
|
|
103
|
+
def report_status(self) -> Any:
|
|
104
|
+
"""Get the status update tool.
|
|
105
|
+
|
|
106
|
+
This tool allows agents to report their progress back to the activity tracker.
|
|
107
|
+
The tool is bound to the current agent_id during runner.run().
|
|
108
|
+
|
|
109
|
+
Subclasses should override this to wrap with framework-specific decorators.
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
Plain function for status updates.
|
|
113
|
+
"""
|
|
114
|
+
agent_id = self._agent_id
|
|
115
|
+
|
|
116
|
+
def report_activity(message: str, percentage: int) -> str:
|
|
117
|
+
"""Report progress and status updates.
|
|
118
|
+
|
|
119
|
+
Use this tool to report your progress as you work through the task.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
message: A brief description of what you're currently doing
|
|
123
|
+
percentage: Your estimated completion percentage (0-100)
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
Confirmation message
|
|
127
|
+
"""
|
|
128
|
+
activity.update(
|
|
129
|
+
agent_id=agent_id,
|
|
130
|
+
message=message,
|
|
131
|
+
completion_percentage=percentage,
|
|
132
|
+
)
|
|
133
|
+
return "Status updated"
|
|
134
|
+
|
|
135
|
+
return report_activity
|
|
@@ -0,0 +1,237 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import uuid
|
|
3
|
+
from typing import Any, Callable
|
|
4
|
+
|
|
5
|
+
from agents import Agent, MaxTurnsExceeded, Runner, function_tool
|
|
6
|
+
from agents.items import TResponseInputItem
|
|
7
|
+
from agents.result import RunResult, RunResultStreaming
|
|
8
|
+
|
|
9
|
+
from agentexec.runners.base import BaseAgentRunner, _RunnerTools
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _extract_input(e: MaxTurnsExceeded) -> list[TResponseInputItem]:
|
|
16
|
+
"""
|
|
17
|
+
Extract the full conversation input history from a `MaxTurnsExceeded` exception.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
e: The MaxTurnsExceeded exception instance
|
|
21
|
+
Returns:
|
|
22
|
+
List of TResponseInputItem representing the full conversation history
|
|
23
|
+
"""
|
|
24
|
+
if not e.run_data:
|
|
25
|
+
logger.error("No run data available in MaxTurnsExceeded exception")
|
|
26
|
+
raise
|
|
27
|
+
|
|
28
|
+
# Reconstruct the full conversation history
|
|
29
|
+
final_input: list[TResponseInputItem] = (
|
|
30
|
+
list(e.run_data.input)
|
|
31
|
+
if isinstance(e.run_data.input, list)
|
|
32
|
+
else [{"role": "user", "content": e.run_data.input}]
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
# Add all the conversation items that were generated
|
|
36
|
+
final_input.extend([item.to_input_item() for item in e.run_data.new_items])
|
|
37
|
+
|
|
38
|
+
return final_input
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class _OpenAIRunnerTools(_RunnerTools):
|
|
42
|
+
"""OpenAI-specific tools wrapper that decorates with @function_tool."""
|
|
43
|
+
|
|
44
|
+
@property
|
|
45
|
+
def report_status(self) -> Any:
|
|
46
|
+
"""Get the status update tool wrapped with @function_tool decorator."""
|
|
47
|
+
return function_tool(super().report_status)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class OpenAIRunner(BaseAgentRunner):
|
|
51
|
+
"""Runner for OpenAI Agents SDK with automatic activity tracking.
|
|
52
|
+
|
|
53
|
+
This runner wraps the OpenAI Agents SDK and provides:
|
|
54
|
+
- Automatic agent_id generation
|
|
55
|
+
- Activity lifecycle management (QUEUED -> RUNNING -> COMPLETE/ERROR)
|
|
56
|
+
- Max turns recovery with configurable wrap-up prompts
|
|
57
|
+
- Status update tool with agent_id pre-baked
|
|
58
|
+
|
|
59
|
+
Example:
|
|
60
|
+
runner = agentexec.OpenAIRunner(
|
|
61
|
+
max_turns_recovery=True,
|
|
62
|
+
wrap_up_prompt="Please summarize your findings.",
|
|
63
|
+
status_prompt="Use update_status(message, percentage) to report progress.",
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
agent = Agent(
|
|
67
|
+
name="Research Agent",
|
|
68
|
+
instructions=f"Research companies. {runner.status_prompt}",
|
|
69
|
+
tools=[runner.tools.report_status],
|
|
70
|
+
model="gpt-4o",
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
result = await runner.run(
|
|
74
|
+
session=session,
|
|
75
|
+
agent=agent,
|
|
76
|
+
input="Research Acme Corp",
|
|
77
|
+
agent_id=agent_id, # Optional
|
|
78
|
+
max_turns=15,
|
|
79
|
+
)
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
def __init__(
|
|
83
|
+
self,
|
|
84
|
+
agent_id: uuid.UUID,
|
|
85
|
+
*,
|
|
86
|
+
max_turns_recovery: bool = False,
|
|
87
|
+
wrap_up_prompt: str | None = None,
|
|
88
|
+
recovery_turns: int = 5,
|
|
89
|
+
report_status_prompt: str | None = None,
|
|
90
|
+
) -> None:
|
|
91
|
+
"""Initialize the OpenAI runner.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
agent_id: UUID for tracking this agent's activity.
|
|
95
|
+
max_turns_recovery: Enable automatic recovery when max turns exceeded.
|
|
96
|
+
wrap_up_prompt: Prompt to use for recovery run.
|
|
97
|
+
recovery_turns: Number of turns allowed for recovery.
|
|
98
|
+
report_status_prompt: Instruction snippet about using the status tool.
|
|
99
|
+
"""
|
|
100
|
+
super().__init__(
|
|
101
|
+
agent_id,
|
|
102
|
+
max_turns_recovery=max_turns_recovery,
|
|
103
|
+
recovery_turns=recovery_turns,
|
|
104
|
+
wrap_up_prompt=wrap_up_prompt,
|
|
105
|
+
report_status_prompt=report_status_prompt,
|
|
106
|
+
)
|
|
107
|
+
# Override with OpenAI-specific tools
|
|
108
|
+
self.tools = _OpenAIRunnerTools(self.agent_id)
|
|
109
|
+
|
|
110
|
+
async def run(
|
|
111
|
+
self,
|
|
112
|
+
agent: Agent[Any],
|
|
113
|
+
input: str | list[TResponseInputItem],
|
|
114
|
+
max_turns: int = 10,
|
|
115
|
+
context: Any | None = None,
|
|
116
|
+
) -> RunResult:
|
|
117
|
+
"""Run the agent with automatic activity tracking.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
session: SQLAlchemy session for database operations.
|
|
121
|
+
agent: Agent instance.
|
|
122
|
+
input: User input/prompt for the agent.
|
|
123
|
+
agent_id: Optional agent ID (generated if not provided).
|
|
124
|
+
max_turns: Maximum number of agent iterations.
|
|
125
|
+
agent_type: Optional agent type for activity tracking.
|
|
126
|
+
context: Optional context for the agent run.
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
Result from the agent execution.
|
|
130
|
+
"""
|
|
131
|
+
# TODO match method signature of Runner.run
|
|
132
|
+
try:
|
|
133
|
+
result = await Runner.run(
|
|
134
|
+
agent,
|
|
135
|
+
input,
|
|
136
|
+
max_turns=max_turns,
|
|
137
|
+
context=context,
|
|
138
|
+
)
|
|
139
|
+
except MaxTurnsExceeded as e:
|
|
140
|
+
if not self.max_turns_recovery:
|
|
141
|
+
raise
|
|
142
|
+
|
|
143
|
+
logger.info("Max turns exceeded, attempting recovery")
|
|
144
|
+
final_input = _extract_input(e)
|
|
145
|
+
final_input.append(
|
|
146
|
+
{
|
|
147
|
+
"role": "user",
|
|
148
|
+
"content": self.prompts.wrap_up,
|
|
149
|
+
}
|
|
150
|
+
)
|
|
151
|
+
result = await Runner.run(
|
|
152
|
+
agent,
|
|
153
|
+
final_input,
|
|
154
|
+
max_turns=self.recovery_turns,
|
|
155
|
+
context=context,
|
|
156
|
+
)
|
|
157
|
+
except Exception:
|
|
158
|
+
raise
|
|
159
|
+
|
|
160
|
+
return result
|
|
161
|
+
|
|
162
|
+
async def run_streamed(
|
|
163
|
+
self,
|
|
164
|
+
agent: Agent[Any],
|
|
165
|
+
input: str | list[TResponseInputItem],
|
|
166
|
+
max_turns: int = 10,
|
|
167
|
+
context: Any | None = None,
|
|
168
|
+
forwarder: Callable | None = None,
|
|
169
|
+
) -> RunResultStreaming:
|
|
170
|
+
"""Run the agent in streaming mode with automatic activity tracking.
|
|
171
|
+
|
|
172
|
+
The returned streaming result can be used just like the underlying framework's
|
|
173
|
+
streaming result. Activity tracking happens automatically.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
session: SQLAlchemy session for database operations.
|
|
177
|
+
agent: Agent instance.
|
|
178
|
+
input: User input/prompt for the agent.
|
|
179
|
+
agent_id: Optional agent ID (generated if not provided).
|
|
180
|
+
max_turns: Maximum number of agent iterations.
|
|
181
|
+
agent_type: Optional agent type for activity tracking.
|
|
182
|
+
context: Optional context for the agent run.
|
|
183
|
+
|
|
184
|
+
Returns:
|
|
185
|
+
Streaming result from the agent execution.
|
|
186
|
+
|
|
187
|
+
Example:
|
|
188
|
+
result = await runner.run_streamed(session, agent, "Research XYZ", agent_id="123")
|
|
189
|
+
async for event in result.stream_events():
|
|
190
|
+
print(event)
|
|
191
|
+
"""
|
|
192
|
+
# TODO match method signature of Runner.run_streamed
|
|
193
|
+
# TODO I want to defer the `await` to the caller side
|
|
194
|
+
# TODO forwarder is just a placeholder but we need to come up with a solution for that functionality
|
|
195
|
+
try:
|
|
196
|
+
result = Runner.run_streamed(
|
|
197
|
+
agent,
|
|
198
|
+
input,
|
|
199
|
+
max_turns=max_turns,
|
|
200
|
+
context=context,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
async for event in result.stream_events():
|
|
204
|
+
if forwarder:
|
|
205
|
+
await forwarder(event)
|
|
206
|
+
# yield event
|
|
207
|
+
|
|
208
|
+
return result
|
|
209
|
+
except MaxTurnsExceeded as e:
|
|
210
|
+
if not self.max_turns_recovery:
|
|
211
|
+
raise
|
|
212
|
+
|
|
213
|
+
logger.info("Max turns exceeded, attempting recovery")
|
|
214
|
+
final_input = _extract_input(e)
|
|
215
|
+
final_input.append(
|
|
216
|
+
{
|
|
217
|
+
"role": "user",
|
|
218
|
+
"content": self.prompts.wrap_up,
|
|
219
|
+
}
|
|
220
|
+
)
|
|
221
|
+
result = Runner.run_streamed(
|
|
222
|
+
agent,
|
|
223
|
+
final_input,
|
|
224
|
+
max_turns=self.recovery_turns,
|
|
225
|
+
context=context,
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
async for event in result.stream_events():
|
|
229
|
+
if forwarder:
|
|
230
|
+
await forwarder(event)
|
|
231
|
+
# yield event
|
|
232
|
+
|
|
233
|
+
return result
|
|
234
|
+
except Exception:
|
|
235
|
+
raise
|
|
236
|
+
|
|
237
|
+
return result
|