agentex-sdk 0.4.11__py3-none-any.whl → 0.4.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentex/_constants.py +3 -3
- agentex/_version.py +1 -1
- agentex/lib/adk/_modules/acp.py +43 -5
- agentex/lib/adk/providers/_modules/openai.py +15 -0
- agentex/lib/cli/handlers/deploy_handlers.py +4 -1
- agentex/lib/cli/templates/temporal/environments.yaml.j2 +1 -1
- agentex/lib/core/services/adk/acp/acp.py +85 -20
- agentex/lib/core/services/adk/providers/openai.py +149 -25
- agentex/lib/core/temporal/activities/adk/acp/acp_activities.py +20 -0
- agentex/lib/core/temporal/activities/adk/providers/openai_activities.py +265 -149
- agentex/lib/core/temporal/workers/worker.py +23 -2
- agentex/lib/sdk/fastacp/base/base_acp_server.py +22 -2
- agentex/lib/sdk/fastacp/base/constants.py +24 -0
- agentex/lib/types/acp.py +20 -0
- agentex/resources/agents.py +3 -0
- agentex/resources/tasks.py +4 -4
- agentex/types/agent.py +7 -1
- agentex/types/task.py +2 -0
- {agentex_sdk-0.4.11.dist-info → agentex_sdk-0.4.13.dist-info}/METADATA +1 -1
- {agentex_sdk-0.4.11.dist-info → agentex_sdk-0.4.13.dist-info}/RECORD +23 -22
- {agentex_sdk-0.4.11.dist-info → agentex_sdk-0.4.13.dist-info}/WHEEL +0 -0
- {agentex_sdk-0.4.11.dist-info → agentex_sdk-0.4.13.dist-info}/entry_points.txt +0 -0
- {agentex_sdk-0.4.11.dist-info → agentex_sdk-0.4.13.dist-info}/licenses/LICENSE +0 -0
@@ -1,34 +1,42 @@
|
|
1
1
|
# Standard library imports
|
2
2
|
import base64
|
3
|
-
from collections.abc import Callable
|
4
|
-
from contextlib import AsyncExitStack, asynccontextmanager
|
5
3
|
from enum import Enum
|
6
4
|
from typing import Any, Literal, Optional
|
7
|
-
|
8
|
-
from
|
5
|
+
from contextlib import AsyncExitStack, asynccontextmanager
|
6
|
+
from collections.abc import Callable
|
9
7
|
|
10
8
|
import cloudpickle
|
11
|
-
from
|
9
|
+
from mcp import StdioServerParameters
|
10
|
+
from agents import RunResult, RunContextWrapper, RunResultStreaming
|
11
|
+
from pydantic import Field, PrivateAttr
|
12
|
+
from agents.mcp import MCPServerStdio, MCPServerStdioParams
|
13
|
+
from temporalio import activity
|
14
|
+
from agents.tool import (
|
15
|
+
ComputerTool as OAIComputerTool,
|
16
|
+
FunctionTool as OAIFunctionTool,
|
17
|
+
WebSearchTool as OAIWebSearchTool,
|
18
|
+
FileSearchTool as OAIFileSearchTool,
|
19
|
+
LocalShellTool as OAILocalShellTool,
|
20
|
+
CodeInterpreterTool as OAICodeInterpreterTool,
|
21
|
+
ImageGenerationTool as OAIImageGenerationTool,
|
22
|
+
)
|
12
23
|
from agents.guardrail import InputGuardrail, OutputGuardrail
|
13
24
|
from agents.exceptions import InputGuardrailTripwireTriggered, OutputGuardrailTripwireTriggered
|
14
|
-
from agents.mcp import MCPServerStdio, MCPServerStdioParams
|
15
25
|
from agents.model_settings import ModelSettings as OAIModelSettings
|
16
|
-
from agents.tool import FunctionTool as OAIFunctionTool
|
17
|
-
from mcp import StdioServerParameters
|
18
|
-
from openai.types.responses.response_includable import ResponseIncludable
|
19
26
|
from openai.types.shared.reasoning import Reasoning
|
20
|
-
from
|
21
|
-
|
27
|
+
from openai.types.responses.response_includable import ResponseIncludable
|
28
|
+
|
29
|
+
from agentex.lib.utils import logging
|
30
|
+
|
31
|
+
# Third-party imports
|
32
|
+
from agentex.lib.types.tracing import BaseModelWithTraceParams
|
22
33
|
|
23
34
|
# Local imports
|
24
35
|
from agentex.lib.types.agent_results import (
|
25
36
|
SerializableRunResult,
|
26
37
|
SerializableRunResultStreaming,
|
27
38
|
)
|
28
|
-
|
29
|
-
# Third-party imports
|
30
|
-
from agentex.lib.types.tracing import BaseModelWithTraceParams
|
31
|
-
from agentex.lib.utils import logging
|
39
|
+
from agentex.lib.core.services.adk.providers.openai import OpenAIService
|
32
40
|
|
33
41
|
logger = logging.make_logger(__name__)
|
34
42
|
|
@@ -42,6 +50,147 @@ class OpenAIActivityName(str, Enum):
|
|
42
50
|
RUN_AGENT_STREAMED_AUTO_SEND = "run_agent_streamed_auto_send"
|
43
51
|
|
44
52
|
|
53
|
+
class WebSearchTool(BaseModelWithTraceParams):
|
54
|
+
"""Temporal-compatible wrapper for WebSearchTool."""
|
55
|
+
|
56
|
+
user_location: Optional[dict[str, Any]] = None # UserLocation object
|
57
|
+
search_context_size: Optional[Literal["low", "medium", "high"]] = "medium"
|
58
|
+
|
59
|
+
def to_oai_function_tool(self) -> OAIWebSearchTool:
|
60
|
+
kwargs = {}
|
61
|
+
if self.user_location is not None:
|
62
|
+
kwargs["user_location"] = self.user_location
|
63
|
+
if self.search_context_size is not None:
|
64
|
+
kwargs["search_context_size"] = self.search_context_size
|
65
|
+
return OAIWebSearchTool(**kwargs)
|
66
|
+
|
67
|
+
|
68
|
+
class FileSearchTool(BaseModelWithTraceParams):
|
69
|
+
"""Temporal-compatible wrapper for FileSearchTool."""
|
70
|
+
|
71
|
+
vector_store_ids: list[str]
|
72
|
+
max_num_results: Optional[int] = None
|
73
|
+
include_search_results: bool = False
|
74
|
+
ranking_options: Optional[dict[str, Any]] = None
|
75
|
+
filters: Optional[dict[str, Any]] = None
|
76
|
+
|
77
|
+
def to_oai_function_tool(self):
|
78
|
+
return OAIFileSearchTool(
|
79
|
+
vector_store_ids=self.vector_store_ids,
|
80
|
+
max_num_results=self.max_num_results,
|
81
|
+
include_search_results=self.include_search_results,
|
82
|
+
ranking_options=self.ranking_options,
|
83
|
+
filters=self.filters,
|
84
|
+
)
|
85
|
+
|
86
|
+
|
87
|
+
class ComputerTool(BaseModelWithTraceParams):
|
88
|
+
"""Temporal-compatible wrapper for ComputerTool."""
|
89
|
+
|
90
|
+
# We need to serialize the computer object and safety check function
|
91
|
+
computer_serialized: str = Field(default="", description="Serialized computer object")
|
92
|
+
on_safety_check_serialized: str = Field(default="", description="Serialized safety check function")
|
93
|
+
|
94
|
+
_computer: Any = PrivateAttr()
|
95
|
+
_on_safety_check: Optional[Callable] = PrivateAttr()
|
96
|
+
|
97
|
+
def __init__(
|
98
|
+
self,
|
99
|
+
*,
|
100
|
+
computer: Any = None,
|
101
|
+
on_safety_check: Optional[Callable] = None,
|
102
|
+
**data,
|
103
|
+
):
|
104
|
+
super().__init__(**data)
|
105
|
+
if computer is not None:
|
106
|
+
self.computer_serialized = self._serialize_callable(computer)
|
107
|
+
self._computer = computer
|
108
|
+
elif self.computer_serialized:
|
109
|
+
self._computer = self._deserialize_callable(self.computer_serialized)
|
110
|
+
|
111
|
+
if on_safety_check is not None:
|
112
|
+
self.on_safety_check_serialized = self._serialize_callable(on_safety_check)
|
113
|
+
self._on_safety_check = on_safety_check
|
114
|
+
elif self.on_safety_check_serialized:
|
115
|
+
self._on_safety_check = self._deserialize_callable(self.on_safety_check_serialized)
|
116
|
+
|
117
|
+
@classmethod
|
118
|
+
def _deserialize_callable(cls, serialized: str) -> Any:
|
119
|
+
encoded = serialized.encode()
|
120
|
+
serialized_bytes = base64.b64decode(encoded)
|
121
|
+
return cloudpickle.loads(serialized_bytes)
|
122
|
+
|
123
|
+
@classmethod
|
124
|
+
def _serialize_callable(cls, func: Any) -> str:
|
125
|
+
serialized_bytes = cloudpickle.dumps(func)
|
126
|
+
encoded = base64.b64encode(serialized_bytes)
|
127
|
+
return encoded.decode()
|
128
|
+
|
129
|
+
def to_oai_function_tool(self):
|
130
|
+
return OAIComputerTool(
|
131
|
+
computer=self._computer,
|
132
|
+
on_safety_check=self._on_safety_check,
|
133
|
+
)
|
134
|
+
|
135
|
+
|
136
|
+
class CodeInterpreterTool(BaseModelWithTraceParams):
|
137
|
+
"""Temporal-compatible wrapper for CodeInterpreterTool."""
|
138
|
+
|
139
|
+
tool_config: dict[str, Any] = Field(
|
140
|
+
default_factory=lambda: {"type": "code_interpreter"}, description="Tool configuration dict"
|
141
|
+
)
|
142
|
+
|
143
|
+
def to_oai_function_tool(self):
|
144
|
+
return OAICodeInterpreterTool(tool_config=self.tool_config)
|
145
|
+
|
146
|
+
|
147
|
+
class ImageGenerationTool(BaseModelWithTraceParams):
|
148
|
+
"""Temporal-compatible wrapper for ImageGenerationTool."""
|
149
|
+
|
150
|
+
tool_config: dict[str, Any] = Field(
|
151
|
+
default_factory=lambda: {"type": "image_generation"}, description="Tool configuration dict"
|
152
|
+
)
|
153
|
+
|
154
|
+
def to_oai_function_tool(self):
|
155
|
+
return OAIImageGenerationTool(tool_config=self.tool_config)
|
156
|
+
|
157
|
+
|
158
|
+
class LocalShellTool(BaseModelWithTraceParams):
|
159
|
+
"""Temporal-compatible wrapper for LocalShellTool."""
|
160
|
+
|
161
|
+
executor_serialized: str = Field(default="", description="Serialized LocalShellExecutor object")
|
162
|
+
|
163
|
+
_executor: Any = PrivateAttr()
|
164
|
+
|
165
|
+
def __init__(
|
166
|
+
self,
|
167
|
+
*,
|
168
|
+
executor: Any = None,
|
169
|
+
**data,
|
170
|
+
):
|
171
|
+
super().__init__(**data)
|
172
|
+
if executor is not None:
|
173
|
+
self.executor_serialized = self._serialize_callable(executor)
|
174
|
+
self._executor = executor
|
175
|
+
elif self.executor_serialized:
|
176
|
+
self._executor = self._deserialize_callable(self.executor_serialized)
|
177
|
+
|
178
|
+
@classmethod
|
179
|
+
def _deserialize_callable(cls, serialized: str) -> Any:
|
180
|
+
encoded = serialized.encode()
|
181
|
+
serialized_bytes = base64.b64decode(encoded)
|
182
|
+
return cloudpickle.loads(serialized_bytes)
|
183
|
+
|
184
|
+
@classmethod
|
185
|
+
def _serialize_callable(cls, func: Any) -> str:
|
186
|
+
serialized_bytes = cloudpickle.dumps(func)
|
187
|
+
encoded = base64.b64encode(serialized_bytes)
|
188
|
+
return encoded.decode()
|
189
|
+
|
190
|
+
def to_oai_function_tool(self):
|
191
|
+
return OAILocalShellTool(executor=self._executor)
|
192
|
+
|
193
|
+
|
45
194
|
class FunctionTool(BaseModelWithTraceParams):
|
46
195
|
name: str
|
47
196
|
description: str
|
@@ -79,22 +228,16 @@ class FunctionTool(BaseModelWithTraceParams):
|
|
79
228
|
super().__init__(**data)
|
80
229
|
if not on_invoke_tool:
|
81
230
|
if not self.on_invoke_tool_serialized:
|
82
|
-
raise ValueError(
|
83
|
-
"One of `on_invoke_tool` or `on_invoke_tool_serialized` should be set"
|
84
|
-
)
|
231
|
+
raise ValueError("One of `on_invoke_tool` or `on_invoke_tool_serialized` should be set")
|
85
232
|
else:
|
86
|
-
on_invoke_tool = self._deserialize_callable(
|
87
|
-
self.on_invoke_tool_serialized
|
88
|
-
)
|
233
|
+
on_invoke_tool = self._deserialize_callable(self.on_invoke_tool_serialized)
|
89
234
|
else:
|
90
235
|
self.on_invoke_tool_serialized = self._serialize_callable(on_invoke_tool)
|
91
236
|
|
92
237
|
self._on_invoke_tool = on_invoke_tool
|
93
238
|
|
94
239
|
@classmethod
|
95
|
-
def _deserialize_callable(
|
96
|
-
cls, serialized: str
|
97
|
-
) -> Callable[[RunContextWrapper, str], Any]:
|
240
|
+
def _deserialize_callable(cls, serialized: str) -> Callable[[RunContextWrapper, str], Any]:
|
98
241
|
encoded = serialized.encode()
|
99
242
|
serialized_bytes = base64.b64decode(encoded)
|
100
243
|
return cloudpickle.loads(serialized_bytes)
|
@@ -108,11 +251,9 @@ class FunctionTool(BaseModelWithTraceParams):
|
|
108
251
|
@property
|
109
252
|
def on_invoke_tool(self) -> Callable[[RunContextWrapper, str], Any]:
|
110
253
|
if self._on_invoke_tool is None and self.on_invoke_tool_serialized:
|
111
|
-
self._on_invoke_tool = self._deserialize_callable(
|
112
|
-
self.on_invoke_tool_serialized
|
113
|
-
)
|
254
|
+
self._on_invoke_tool = self._deserialize_callable(self.on_invoke_tool_serialized)
|
114
255
|
return self._on_invoke_tool
|
115
|
-
|
256
|
+
|
116
257
|
@on_invoke_tool.setter
|
117
258
|
def on_invoke_tool(self, value: Callable[[RunContextWrapper, str], Any]):
|
118
259
|
self.on_invoke_tool_serialized = self._serialize_callable(value)
|
@@ -135,8 +276,9 @@ class FunctionTool(BaseModelWithTraceParams):
|
|
135
276
|
|
136
277
|
|
137
278
|
class TemporalInputGuardrail(BaseModelWithTraceParams):
|
138
|
-
"""Temporal-compatible wrapper for InputGuardrail with function
|
279
|
+
"""Temporal-compatible wrapper for InputGuardrail with function
|
139
280
|
serialization."""
|
281
|
+
|
140
282
|
name: str
|
141
283
|
_guardrail_function: Callable = PrivateAttr()
|
142
284
|
guardrail_function_serialized: str = Field(
|
@@ -157,19 +299,12 @@ class TemporalInputGuardrail(BaseModelWithTraceParams):
|
|
157
299
|
super().__init__(**data)
|
158
300
|
if not guardrail_function:
|
159
301
|
if not self.guardrail_function_serialized:
|
160
|
-
raise ValueError(
|
161
|
-
"One of `guardrail_function` or "
|
162
|
-
"`guardrail_function_serialized` should be set"
|
163
|
-
)
|
302
|
+
raise ValueError("One of `guardrail_function` or `guardrail_function_serialized` should be set")
|
164
303
|
else:
|
165
|
-
guardrail_function = self._deserialize_callable(
|
166
|
-
self.guardrail_function_serialized
|
167
|
-
)
|
304
|
+
guardrail_function = self._deserialize_callable(self.guardrail_function_serialized)
|
168
305
|
else:
|
169
|
-
self.guardrail_function_serialized = self._serialize_callable(
|
170
|
-
|
171
|
-
)
|
172
|
-
|
306
|
+
self.guardrail_function_serialized = self._serialize_callable(guardrail_function)
|
307
|
+
|
173
308
|
self._guardrail_function = guardrail_function
|
174
309
|
|
175
310
|
@classmethod
|
@@ -186,13 +321,10 @@ class TemporalInputGuardrail(BaseModelWithTraceParams):
|
|
186
321
|
|
187
322
|
@property
|
188
323
|
def guardrail_function(self) -> Callable:
|
189
|
-
if
|
190
|
-
|
191
|
-
self._guardrail_function = self._deserialize_callable(
|
192
|
-
self.guardrail_function_serialized
|
193
|
-
)
|
324
|
+
if self._guardrail_function is None and self.guardrail_function_serialized:
|
325
|
+
self._guardrail_function = self._deserialize_callable(self.guardrail_function_serialized)
|
194
326
|
return self._guardrail_function
|
195
|
-
|
327
|
+
|
196
328
|
@guardrail_function.setter
|
197
329
|
def guardrail_function(self, value: Callable):
|
198
330
|
self.guardrail_function_serialized = self._serialize_callable(value)
|
@@ -200,15 +332,13 @@ class TemporalInputGuardrail(BaseModelWithTraceParams):
|
|
200
332
|
|
201
333
|
def to_oai_input_guardrail(self) -> InputGuardrail:
|
202
334
|
"""Convert to OpenAI InputGuardrail."""
|
203
|
-
return InputGuardrail(
|
204
|
-
guardrail_function=self.guardrail_function,
|
205
|
-
name=self.name
|
206
|
-
)
|
335
|
+
return InputGuardrail(guardrail_function=self.guardrail_function, name=self.name)
|
207
336
|
|
208
337
|
|
209
338
|
class TemporalOutputGuardrail(BaseModelWithTraceParams):
|
210
|
-
"""Temporal-compatible wrapper for OutputGuardrail with function
|
339
|
+
"""Temporal-compatible wrapper for OutputGuardrail with function
|
211
340
|
serialization."""
|
341
|
+
|
212
342
|
name: str
|
213
343
|
_guardrail_function: Callable = PrivateAttr()
|
214
344
|
guardrail_function_serialized: str = Field(
|
@@ -229,19 +359,12 @@ class TemporalOutputGuardrail(BaseModelWithTraceParams):
|
|
229
359
|
super().__init__(**data)
|
230
360
|
if not guardrail_function:
|
231
361
|
if not self.guardrail_function_serialized:
|
232
|
-
raise ValueError(
|
233
|
-
"One of `guardrail_function` or "
|
234
|
-
"`guardrail_function_serialized` should be set"
|
235
|
-
)
|
362
|
+
raise ValueError("One of `guardrail_function` or `guardrail_function_serialized` should be set")
|
236
363
|
else:
|
237
|
-
guardrail_function = self._deserialize_callable(
|
238
|
-
self.guardrail_function_serialized
|
239
|
-
)
|
364
|
+
guardrail_function = self._deserialize_callable(self.guardrail_function_serialized)
|
240
365
|
else:
|
241
|
-
self.guardrail_function_serialized = self._serialize_callable(
|
242
|
-
|
243
|
-
)
|
244
|
-
|
366
|
+
self.guardrail_function_serialized = self._serialize_callable(guardrail_function)
|
367
|
+
|
245
368
|
self._guardrail_function = guardrail_function
|
246
369
|
|
247
370
|
@classmethod
|
@@ -258,13 +381,10 @@ class TemporalOutputGuardrail(BaseModelWithTraceParams):
|
|
258
381
|
|
259
382
|
@property
|
260
383
|
def guardrail_function(self) -> Callable:
|
261
|
-
if
|
262
|
-
|
263
|
-
self._guardrail_function = self._deserialize_callable(
|
264
|
-
self.guardrail_function_serialized
|
265
|
-
)
|
384
|
+
if self._guardrail_function is None and self.guardrail_function_serialized:
|
385
|
+
self._guardrail_function = self._deserialize_callable(self.guardrail_function_serialized)
|
266
386
|
return self._guardrail_function
|
267
|
-
|
387
|
+
|
268
388
|
@guardrail_function.setter
|
269
389
|
def guardrail_function(self, value: Callable):
|
270
390
|
self.guardrail_function_serialized = self._serialize_callable(value)
|
@@ -272,10 +392,7 @@ class TemporalOutputGuardrail(BaseModelWithTraceParams):
|
|
272
392
|
|
273
393
|
def to_oai_output_guardrail(self) -> OutputGuardrail:
|
274
394
|
"""Convert to OpenAI OutputGuardrail."""
|
275
|
-
return OutputGuardrail(
|
276
|
-
guardrail_function=self.guardrail_function,
|
277
|
-
name=self.name
|
278
|
-
)
|
395
|
+
return OutputGuardrail(guardrail_function=self.guardrail_function, name=self.name)
|
279
396
|
|
280
397
|
|
281
398
|
class ModelSettings(BaseModelWithTraceParams):
|
@@ -297,9 +414,7 @@ class ModelSettings(BaseModelWithTraceParams):
|
|
297
414
|
extra_args: dict[str, Any] | None = None
|
298
415
|
|
299
416
|
def to_oai_model_settings(self) -> OAIModelSettings:
|
300
|
-
return OAIModelSettings(
|
301
|
-
**self.model_dump(exclude=["trace_id", "parent_span_id"])
|
302
|
-
)
|
417
|
+
return OAIModelSettings(**self.model_dump(exclude=["trace_id", "parent_span_id"]))
|
303
418
|
|
304
419
|
|
305
420
|
class RunAgentParams(BaseModelWithTraceParams):
|
@@ -313,12 +428,24 @@ class RunAgentParams(BaseModelWithTraceParams):
|
|
313
428
|
handoffs: list["RunAgentParams"] | None = None
|
314
429
|
model: str | None = None
|
315
430
|
model_settings: ModelSettings | None = None
|
316
|
-
tools:
|
431
|
+
tools: (
|
432
|
+
list[
|
433
|
+
FunctionTool
|
434
|
+
| WebSearchTool
|
435
|
+
| FileSearchTool
|
436
|
+
| ComputerTool
|
437
|
+
| CodeInterpreterTool
|
438
|
+
| ImageGenerationTool
|
439
|
+
| LocalShellTool
|
440
|
+
]
|
441
|
+
| None
|
442
|
+
) = None
|
317
443
|
output_type: Any = None
|
318
444
|
tool_use_behavior: Literal["run_llm_again", "stop_on_first_tool"] = "run_llm_again"
|
319
445
|
mcp_timeout_seconds: int | None = None
|
320
446
|
input_guardrails: list[TemporalInputGuardrail] | None = None
|
321
447
|
output_guardrails: list[TemporalOutputGuardrail] | None = None
|
448
|
+
max_turns: int | None = None
|
322
449
|
|
323
450
|
|
324
451
|
class RunAgentAutoSendParams(RunAgentParams):
|
@@ -365,11 +492,11 @@ class OpenAIActivities:
|
|
365
492
|
input_guardrails = None
|
366
493
|
if params.input_guardrails:
|
367
494
|
input_guardrails = [g.to_oai_input_guardrail() for g in params.input_guardrails]
|
368
|
-
|
495
|
+
|
369
496
|
output_guardrails = None
|
370
497
|
if params.output_guardrails:
|
371
498
|
output_guardrails = [g.to_oai_output_guardrail() for g in params.output_guardrails]
|
372
|
-
|
499
|
+
|
373
500
|
result = await self._openai_service.run_agent(
|
374
501
|
input_list=params.input_list,
|
375
502
|
mcp_server_params=params.mcp_server_params,
|
@@ -386,23 +513,23 @@ class OpenAIActivities:
|
|
386
513
|
tool_use_behavior=params.tool_use_behavior,
|
387
514
|
input_guardrails=input_guardrails,
|
388
515
|
output_guardrails=output_guardrails,
|
516
|
+
mcp_timeout_seconds=params.mcp_timeout_seconds,
|
517
|
+
max_turns=params.max_turns,
|
389
518
|
)
|
390
519
|
return self._to_serializable_run_result(result)
|
391
520
|
|
392
521
|
@activity.defn(name=OpenAIActivityName.RUN_AGENT_AUTO_SEND)
|
393
|
-
async def run_agent_auto_send(
|
394
|
-
self, params: RunAgentAutoSendParams
|
395
|
-
) -> SerializableRunResult:
|
522
|
+
async def run_agent_auto_send(self, params: RunAgentAutoSendParams) -> SerializableRunResult:
|
396
523
|
"""Run an agent with automatic TaskMessage creation."""
|
397
524
|
# Convert Temporal guardrails to OpenAI guardrails
|
398
525
|
input_guardrails = None
|
399
526
|
if params.input_guardrails:
|
400
527
|
input_guardrails = [g.to_oai_input_guardrail() for g in params.input_guardrails]
|
401
|
-
|
528
|
+
|
402
529
|
output_guardrails = None
|
403
530
|
if params.output_guardrails:
|
404
531
|
output_guardrails = [g.to_oai_output_guardrail() for g in params.output_guardrails]
|
405
|
-
|
532
|
+
|
406
533
|
try:
|
407
534
|
result = await self._openai_service.run_agent_auto_send(
|
408
535
|
task_id=params.task_id,
|
@@ -421,65 +548,60 @@ class OpenAIActivities:
|
|
421
548
|
tool_use_behavior=params.tool_use_behavior,
|
422
549
|
input_guardrails=input_guardrails,
|
423
550
|
output_guardrails=output_guardrails,
|
551
|
+
mcp_timeout_seconds=params.mcp_timeout_seconds,
|
552
|
+
max_turns=params.max_turns,
|
424
553
|
)
|
425
554
|
return self._to_serializable_run_result(result)
|
426
555
|
except InputGuardrailTripwireTriggered as e:
|
427
556
|
# Handle guardrail trigger gracefully
|
428
|
-
rejection_message =
|
429
|
-
|
557
|
+
rejection_message = (
|
558
|
+
"I'm sorry, but I cannot process this request due to a guardrail. Please try a different question."
|
559
|
+
)
|
560
|
+
|
430
561
|
# Try to extract rejection message from the guardrail result
|
431
|
-
if hasattr(e,
|
432
|
-
output_info = getattr(e.guardrail_result.output,
|
433
|
-
if isinstance(output_info, dict) and
|
434
|
-
rejection_message = output_info[
|
435
|
-
|
562
|
+
if hasattr(e, "guardrail_result") and hasattr(e.guardrail_result, "output"):
|
563
|
+
output_info = getattr(e.guardrail_result.output, "output_info", {})
|
564
|
+
if isinstance(output_info, dict) and "rejection_message" in output_info:
|
565
|
+
rejection_message = output_info["rejection_message"]
|
566
|
+
|
436
567
|
# Build the final input list with the rejection message
|
437
568
|
final_input_list = list(params.input_list or [])
|
438
|
-
final_input_list.append({
|
439
|
-
|
440
|
-
|
441
|
-
})
|
442
|
-
|
443
|
-
return SerializableRunResult(
|
444
|
-
final_output=rejection_message,
|
445
|
-
final_input_list=final_input_list
|
446
|
-
)
|
569
|
+
final_input_list.append({"role": "assistant", "content": rejection_message})
|
570
|
+
|
571
|
+
return SerializableRunResult(final_output=rejection_message, final_input_list=final_input_list)
|
447
572
|
except OutputGuardrailTripwireTriggered as e:
|
448
573
|
# Handle output guardrail trigger gracefully
|
449
|
-
rejection_message =
|
450
|
-
|
574
|
+
rejection_message = (
|
575
|
+
"I'm sorry, but I cannot provide this response due to a guardrail. Please try a different question."
|
576
|
+
)
|
577
|
+
|
451
578
|
# Try to extract rejection message from the guardrail result
|
452
|
-
if hasattr(e,
|
453
|
-
output_info = getattr(e.guardrail_result.output,
|
454
|
-
if isinstance(output_info, dict) and
|
455
|
-
rejection_message = output_info[
|
456
|
-
|
579
|
+
if hasattr(e, "guardrail_result") and hasattr(e.guardrail_result, "output"):
|
580
|
+
output_info = getattr(e.guardrail_result.output, "output_info", {})
|
581
|
+
if isinstance(output_info, dict) and "rejection_message" in output_info:
|
582
|
+
rejection_message = output_info["rejection_message"]
|
583
|
+
|
457
584
|
# Build the final input list with the rejection message
|
458
585
|
final_input_list = list(params.input_list or [])
|
459
|
-
final_input_list.append({
|
460
|
-
|
461
|
-
|
462
|
-
})
|
463
|
-
|
464
|
-
return SerializableRunResult(
|
465
|
-
final_output=rejection_message,
|
466
|
-
final_input_list=final_input_list
|
467
|
-
)
|
586
|
+
final_input_list.append({"role": "assistant", "content": rejection_message})
|
587
|
+
|
588
|
+
return SerializableRunResult(final_output=rejection_message, final_input_list=final_input_list)
|
468
589
|
|
469
590
|
@activity.defn(name=OpenAIActivityName.RUN_AGENT_STREAMED_AUTO_SEND)
|
470
591
|
async def run_agent_streamed_auto_send(
|
471
592
|
self, params: RunAgentStreamedAutoSendParams
|
472
593
|
) -> SerializableRunResultStreaming:
|
473
594
|
"""Run an agent with streaming and automatic TaskMessage creation."""
|
595
|
+
|
474
596
|
# Convert Temporal guardrails to OpenAI guardrails
|
475
597
|
input_guardrails = None
|
476
598
|
if params.input_guardrails:
|
477
599
|
input_guardrails = [g.to_oai_input_guardrail() for g in params.input_guardrails]
|
478
|
-
|
600
|
+
|
479
601
|
output_guardrails = None
|
480
602
|
if params.output_guardrails:
|
481
603
|
output_guardrails = [g.to_oai_output_guardrail() for g in params.output_guardrails]
|
482
|
-
|
604
|
+
|
483
605
|
try:
|
484
606
|
result = await self._openai_service.run_agent_streamed_auto_send(
|
485
607
|
task_id=params.task_id,
|
@@ -498,50 +620,44 @@ class OpenAIActivities:
|
|
498
620
|
tool_use_behavior=params.tool_use_behavior,
|
499
621
|
input_guardrails=input_guardrails,
|
500
622
|
output_guardrails=output_guardrails,
|
623
|
+
mcp_timeout_seconds=params.mcp_timeout_seconds,
|
624
|
+
max_turns=params.max_turns,
|
501
625
|
)
|
502
626
|
return self._to_serializable_run_result_streaming(result)
|
503
627
|
except InputGuardrailTripwireTriggered as e:
|
504
628
|
# Handle guardrail trigger gracefully
|
505
|
-
rejection_message =
|
506
|
-
|
629
|
+
rejection_message = (
|
630
|
+
"I'm sorry, but I cannot process this request due to a guardrail. Please try a different question."
|
631
|
+
)
|
632
|
+
|
507
633
|
# Try to extract rejection message from the guardrail result
|
508
|
-
if hasattr(e,
|
509
|
-
output_info = getattr(e.guardrail_result.output,
|
510
|
-
if isinstance(output_info, dict) and
|
511
|
-
rejection_message = output_info[
|
512
|
-
|
634
|
+
if hasattr(e, "guardrail_result") and hasattr(e.guardrail_result, "output"):
|
635
|
+
output_info = getattr(e.guardrail_result.output, "output_info", {})
|
636
|
+
if isinstance(output_info, dict) and "rejection_message" in output_info:
|
637
|
+
rejection_message = output_info["rejection_message"]
|
638
|
+
|
513
639
|
# Build the final input list with the rejection message
|
514
640
|
final_input_list = list(params.input_list or [])
|
515
|
-
final_input_list.append({
|
516
|
-
|
517
|
-
|
518
|
-
})
|
519
|
-
|
520
|
-
return SerializableRunResultStreaming(
|
521
|
-
final_output=rejection_message,
|
522
|
-
final_input_list=final_input_list
|
523
|
-
)
|
641
|
+
final_input_list.append({"role": "assistant", "content": rejection_message})
|
642
|
+
|
643
|
+
return SerializableRunResultStreaming(final_output=rejection_message, final_input_list=final_input_list)
|
524
644
|
except OutputGuardrailTripwireTriggered as e:
|
525
645
|
# Handle output guardrail trigger gracefully
|
526
|
-
rejection_message =
|
527
|
-
|
646
|
+
rejection_message = (
|
647
|
+
"I'm sorry, but I cannot provide this response due to a guardrail. Please try a different question."
|
648
|
+
)
|
649
|
+
|
528
650
|
# Try to extract rejection message from the guardrail result
|
529
|
-
if hasattr(e,
|
530
|
-
output_info = getattr(e.guardrail_result.output,
|
531
|
-
if isinstance(output_info, dict) and
|
532
|
-
rejection_message = output_info[
|
533
|
-
|
651
|
+
if hasattr(e, "guardrail_result") and hasattr(e.guardrail_result, "output"):
|
652
|
+
output_info = getattr(e.guardrail_result.output, "output_info", {})
|
653
|
+
if isinstance(output_info, dict) and "rejection_message" in output_info:
|
654
|
+
rejection_message = output_info["rejection_message"]
|
655
|
+
|
534
656
|
# Build the final input list with the rejection message
|
535
657
|
final_input_list = list(params.input_list or [])
|
536
|
-
final_input_list.append({
|
537
|
-
|
538
|
-
|
539
|
-
})
|
540
|
-
|
541
|
-
return SerializableRunResultStreaming(
|
542
|
-
final_output=rejection_message,
|
543
|
-
final_input_list=final_input_list
|
544
|
-
)
|
658
|
+
final_input_list.append({"role": "assistant", "content": rejection_message})
|
659
|
+
|
660
|
+
return SerializableRunResultStreaming(final_output=rejection_message, final_input_list=final_input_list)
|
545
661
|
|
546
662
|
@staticmethod
|
547
663
|
def _to_serializable_run_result(result: RunResult) -> SerializableRunResult:
|
@@ -4,7 +4,7 @@ import os
|
|
4
4
|
import uuid
|
5
5
|
from collections.abc import Callable
|
6
6
|
from concurrent.futures import ThreadPoolExecutor
|
7
|
-
from typing import Any
|
7
|
+
from typing import Any, overload
|
8
8
|
|
9
9
|
from aiohttp import web
|
10
10
|
from temporalio.client import Client
|
@@ -99,10 +99,28 @@ class AgentexWorker:
|
|
99
99
|
self.healthy = False
|
100
100
|
self.health_check_port = health_check_port
|
101
101
|
|
102
|
+
@overload
|
102
103
|
async def run(
|
103
104
|
self,
|
104
105
|
activities: list[Callable],
|
106
|
+
*,
|
105
107
|
workflow: type,
|
108
|
+
) -> None: ...
|
109
|
+
|
110
|
+
@overload
|
111
|
+
async def run(
|
112
|
+
self,
|
113
|
+
activities: list[Callable],
|
114
|
+
*,
|
115
|
+
workflows: list[type],
|
116
|
+
) -> None: ...
|
117
|
+
|
118
|
+
async def run(
|
119
|
+
self,
|
120
|
+
activities: list[Callable],
|
121
|
+
*,
|
122
|
+
workflow: type | None = None,
|
123
|
+
workflows: list[type] | None = None,
|
106
124
|
):
|
107
125
|
await self.start_health_check_server()
|
108
126
|
await self._register_agent()
|
@@ -115,11 +133,14 @@ class AgentexWorker:
|
|
115
133
|
if debug_enabled:
|
116
134
|
logger.info("🐛 [WORKER] Temporal debug mode enabled - deadlock detection disabled")
|
117
135
|
|
136
|
+
if workflow is None and workflows is None:
|
137
|
+
raise ValueError("Either workflow or workflows must be provided")
|
138
|
+
|
118
139
|
worker = Worker(
|
119
140
|
client=temporal_client,
|
120
141
|
task_queue=self.task_queue,
|
121
142
|
activity_executor=ThreadPoolExecutor(max_workers=self.max_workers),
|
122
|
-
workflows=[workflow],
|
143
|
+
workflows=[workflow] if workflows is None else workflows,
|
123
144
|
activities=activities,
|
124
145
|
workflow_runner=UnsandboxedWorkflowRunner(),
|
125
146
|
max_concurrent_activities=self.max_concurrent_activities,
|