zeroshot-agentic-workflows 0.1.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zeroshot_agentic_workflows-0.1.5/PKG-INFO +14 -0
- zeroshot_agentic_workflows-0.1.5/README.md +3 -0
- zeroshot_agentic_workflows-0.1.5/pyproject.toml +19 -0
- zeroshot_agentic_workflows-0.1.5/src/zeroshot_agentic_workflows/__init__.py +67 -0
- zeroshot_agentic_workflows-0.1.5/src/zeroshot_agentic_workflows/agent_service.py +216 -0
- zeroshot_agentic_workflows-0.1.5/src/zeroshot_agentic_workflows/decorators.py +302 -0
- zeroshot_agentic_workflows-0.1.5/src/zeroshot_agentic_workflows/factory.py +46 -0
- zeroshot_agentic_workflows-0.1.5/src/zeroshot_agentic_workflows/param_mapper.py +56 -0
- zeroshot_agentic_workflows-0.1.5/src/zeroshot_agentic_workflows/prompt_utils.py +87 -0
- zeroshot_agentic_workflows-0.1.5/src/zeroshot_agentic_workflows/py.typed +1 -0
- zeroshot_agentic_workflows-0.1.5/src/zeroshot_agentic_workflows/service_ollama.py +89 -0
- zeroshot_agentic_workflows-0.1.5/src/zeroshot_agentic_workflows/service_openai.py +81 -0
- zeroshot_agentic_workflows-0.1.5/src/zeroshot_agentic_workflows/session.py +232 -0
- zeroshot_agentic_workflows-0.1.5/src/zeroshot_agentic_workflows/session_factory.py +22 -0
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: zeroshot-agentic-workflows
|
|
3
|
+
Version: 0.1.5
|
|
4
|
+
Summary: Framework-agnostic agent workflow building blocks for Zeroshot Python packages.
|
|
5
|
+
License-Expression: MIT
|
|
6
|
+
Requires-Dist: zeroshot-commons==0.1.5
|
|
7
|
+
Requires-Dist: openai-agents>=0.0.7
|
|
8
|
+
Requires-Dist: pyyaml>=6.0
|
|
9
|
+
Requires-Python: >=3.12
|
|
10
|
+
Description-Content-Type: text/markdown
|
|
11
|
+
|
|
12
|
+
# zeroshot-agentic-workflows
|
|
13
|
+
|
|
14
|
+
Framework-agnostic agent workflow building blocks for Zeroshot Python packages.
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "zeroshot-agentic-workflows"
|
|
3
|
+
version = "0.1.5"
|
|
4
|
+
description = "Framework-agnostic agent workflow building blocks for Zeroshot Python packages."
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
requires-python = ">=3.12"
|
|
7
|
+
license = "MIT"
|
|
8
|
+
dependencies = [
|
|
9
|
+
"zeroshot-commons==0.1.5",
|
|
10
|
+
"openai-agents>=0.0.7",
|
|
11
|
+
"pyyaml>=6.0",
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
[tool.uv.sources]
|
|
15
|
+
zeroshot-commons = { workspace = true }
|
|
16
|
+
|
|
17
|
+
[build-system]
|
|
18
|
+
requires = ["uv_build>=0.11.8,<0.12"]
|
|
19
|
+
build-backend = "uv_build"
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
"""Agent workflow building blocks for Zeroshot Python packages."""
|
|
2
|
+
|
|
3
|
+
from .agent_service import (
|
|
4
|
+
AgentConfig,
|
|
5
|
+
AgentRunConfig,
|
|
6
|
+
AgentRunResult,
|
|
7
|
+
AgentType,
|
|
8
|
+
AiAgentConfig,
|
|
9
|
+
AiAgentProvider,
|
|
10
|
+
AiAgentService,
|
|
11
|
+
AiAgentServiceLocal,
|
|
12
|
+
ConsensusRunResult,
|
|
13
|
+
ConsensusStrategy,
|
|
14
|
+
)
|
|
15
|
+
from .decorators import agent, agentic_workflow, consensus_agent
|
|
16
|
+
from .factory import AiAgentFactory
|
|
17
|
+
from .param_mapper import AgentParameterMapper
|
|
18
|
+
from .prompt_utils import (
|
|
19
|
+
ParsedPrompt,
|
|
20
|
+
PromptFrontmatter,
|
|
21
|
+
generate_tools_reference,
|
|
22
|
+
parse_prompt_frontmatter,
|
|
23
|
+
)
|
|
24
|
+
from .session import (
|
|
25
|
+
CONVERSATION_SESSION_REPOSITORY,
|
|
26
|
+
ConversationItemModel,
|
|
27
|
+
ConversationMessage,
|
|
28
|
+
ConversationSessionModel,
|
|
29
|
+
ConversationSessionRepository,
|
|
30
|
+
InMemoryConversationSessionRepository,
|
|
31
|
+
RepositorySession,
|
|
32
|
+
SessionItem,
|
|
33
|
+
SessionNotFoundError,
|
|
34
|
+
)
|
|
35
|
+
from .session_factory import AiSessionFactory
|
|
36
|
+
|
|
37
|
+
__all__ = [
|
|
38
|
+
"CONVERSATION_SESSION_REPOSITORY",
|
|
39
|
+
"AgentConfig",
|
|
40
|
+
"AgentParameterMapper",
|
|
41
|
+
"AgentRunConfig",
|
|
42
|
+
"AgentRunResult",
|
|
43
|
+
"AgentType",
|
|
44
|
+
"AiAgentConfig",
|
|
45
|
+
"AiAgentFactory",
|
|
46
|
+
"AiAgentProvider",
|
|
47
|
+
"AiAgentService",
|
|
48
|
+
"AiAgentServiceLocal",
|
|
49
|
+
"AiSessionFactory",
|
|
50
|
+
"ConsensusRunResult",
|
|
51
|
+
"ConsensusStrategy",
|
|
52
|
+
"ConversationItemModel",
|
|
53
|
+
"ConversationMessage",
|
|
54
|
+
"ConversationSessionModel",
|
|
55
|
+
"ConversationSessionRepository",
|
|
56
|
+
"InMemoryConversationSessionRepository",
|
|
57
|
+
"ParsedPrompt",
|
|
58
|
+
"PromptFrontmatter",
|
|
59
|
+
"RepositorySession",
|
|
60
|
+
"SessionItem",
|
|
61
|
+
"SessionNotFoundError",
|
|
62
|
+
"agent",
|
|
63
|
+
"agentic_workflow",
|
|
64
|
+
"consensus_agent",
|
|
65
|
+
"generate_tools_reference",
|
|
66
|
+
"parse_prompt_frontmatter",
|
|
67
|
+
]
|
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from enum import StrEnum
|
|
6
|
+
from typing import Any, ClassVar, Protocol, TypeVar
|
|
7
|
+
|
|
8
|
+
T = TypeVar("T")
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass(frozen=True, slots=True)
|
|
12
|
+
class AgentRunConfig:
|
|
13
|
+
input: str
|
|
14
|
+
context: dict[str, Any] | None = None
|
|
15
|
+
session: Any | None = None
|
|
16
|
+
max_turns: int | None = None
|
|
17
|
+
branch: str | None = None
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass(frozen=True, slots=True)
|
|
21
|
+
class AgentRunResult[T]:
|
|
22
|
+
output: T
|
|
23
|
+
success: bool
|
|
24
|
+
error: str | None = None
|
|
25
|
+
raw_result: Any | None = None
|
|
26
|
+
working_dir: str | None = None
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class ConsensusStrategy(StrEnum):
|
|
30
|
+
MAJORITY = "majority"
|
|
31
|
+
UNANIMOUS = "unanimous"
|
|
32
|
+
JUDGE = "judge"
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass(frozen=True, slots=True)
|
|
36
|
+
class ConsensusRunResult(AgentRunResult[T]):
|
|
37
|
+
runs: list[AgentRunResult[T]] = field(default_factory=list)
|
|
38
|
+
agreement: float = 0.0
|
|
39
|
+
total_runs: int = 0
|
|
40
|
+
successful_runs: int = 0
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclass(frozen=True, slots=True)
|
|
44
|
+
class AgentConfig[T]:
|
|
45
|
+
name: str
|
|
46
|
+
instructions: str
|
|
47
|
+
model: str | None = None
|
|
48
|
+
tools: list[Any] = field(default_factory=list)
|
|
49
|
+
output_schema: Any | None = None
|
|
50
|
+
model_settings: dict[str, Any] | None = None
|
|
51
|
+
input_guardrails: list[Any] | None = None
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@dataclass(frozen=True, slots=True)
|
|
55
|
+
class AgentType[T]:
|
|
56
|
+
config: AgentConfig[T]
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class AiAgentService(Protocol):
|
|
60
|
+
def create_agent(self, config: AgentConfig[T]) -> AgentType[T]: ...
|
|
61
|
+
|
|
62
|
+
async def run_agent(
|
|
63
|
+
self,
|
|
64
|
+
agent: AgentType[T],
|
|
65
|
+
config: AgentRunConfig,
|
|
66
|
+
) -> AgentRunResult[T]: ...
|
|
67
|
+
|
|
68
|
+
async def create_and_run(
|
|
69
|
+
self,
|
|
70
|
+
agent_config: AgentConfig[T],
|
|
71
|
+
run_config: AgentRunConfig,
|
|
72
|
+
) -> AgentRunResult[T]: ...
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class AiAgentProvider(StrEnum):
|
|
76
|
+
OPENAI = "openai"
|
|
77
|
+
OLLAMA = "ollama"
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@dataclass(frozen=True, slots=True)
|
|
81
|
+
class AiAgentConfig:
|
|
82
|
+
local: bool
|
|
83
|
+
provider: AiAgentProvider = AiAgentProvider.OPENAI
|
|
84
|
+
openai_api_token: str | None = None
|
|
85
|
+
ollama_base_url: str = "http://localhost:11434"
|
|
86
|
+
default_model: str | None = None
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class AiAgentServiceLocal:
|
|
90
|
+
_instance: ClassVar[AiAgentServiceLocal | None] = None
|
|
91
|
+
_responses_by_agent_name: ClassVar[dict[str, list[Any]]] = {}
|
|
92
|
+
_last_response_by_agent_name: ClassVar[dict[str, Any]] = {}
|
|
93
|
+
_errors_by_agent_name: ClassVar[dict[str, str]] = {}
|
|
94
|
+
_mock_working_dir: ClassVar[str | None] = None
|
|
95
|
+
|
|
96
|
+
@classmethod
|
|
97
|
+
def get_instance(cls) -> AiAgentServiceLocal:
|
|
98
|
+
if cls._instance is None:
|
|
99
|
+
cls._instance = cls()
|
|
100
|
+
return cls._instance
|
|
101
|
+
|
|
102
|
+
@classmethod
|
|
103
|
+
def set_response(cls, agent_name: str, output: Any) -> None:
|
|
104
|
+
cls._responses_by_agent_name.setdefault(agent_name, []).append(output)
|
|
105
|
+
|
|
106
|
+
@classmethod
|
|
107
|
+
def set_responses(cls, agent_name: str, outputs: list[Any]) -> None:
|
|
108
|
+
cls._responses_by_agent_name.setdefault(agent_name, []).extend(outputs)
|
|
109
|
+
|
|
110
|
+
@classmethod
|
|
111
|
+
def set_mock_working_dir(cls, directory: str | None) -> None:
|
|
112
|
+
cls._mock_working_dir = directory
|
|
113
|
+
|
|
114
|
+
@classmethod
|
|
115
|
+
def set_error(cls, agent_name: str, error_message: str) -> None:
|
|
116
|
+
cls._errors_by_agent_name[agent_name] = error_message
|
|
117
|
+
|
|
118
|
+
@classmethod
|
|
119
|
+
def clear_responses(cls) -> None:
|
|
120
|
+
cls._responses_by_agent_name.clear()
|
|
121
|
+
cls._last_response_by_agent_name.clear()
|
|
122
|
+
|
|
123
|
+
@classmethod
|
|
124
|
+
def clear_errors(cls) -> None:
|
|
125
|
+
cls._errors_by_agent_name.clear()
|
|
126
|
+
|
|
127
|
+
@classmethod
|
|
128
|
+
def clear_all_overrides(cls) -> None:
|
|
129
|
+
cls._responses_by_agent_name.clear()
|
|
130
|
+
cls._last_response_by_agent_name.clear()
|
|
131
|
+
cls._errors_by_agent_name.clear()
|
|
132
|
+
cls._mock_working_dir = None
|
|
133
|
+
|
|
134
|
+
def create_agent(self, config: AgentConfig[T]) -> AgentType[T]:
|
|
135
|
+
return AgentType(config=config)
|
|
136
|
+
|
|
137
|
+
async def run_agent(
|
|
138
|
+
self,
|
|
139
|
+
agent: AgentType[T],
|
|
140
|
+
config: AgentRunConfig,
|
|
141
|
+
) -> AgentRunResult[T]:
|
|
142
|
+
return await self._execute_agent(agent.config, config)
|
|
143
|
+
|
|
144
|
+
async def create_and_run(
|
|
145
|
+
self,
|
|
146
|
+
agent_config: AgentConfig[T],
|
|
147
|
+
run_config: AgentRunConfig,
|
|
148
|
+
) -> AgentRunResult[T]:
|
|
149
|
+
return await self._execute_agent(agent_config, run_config)
|
|
150
|
+
|
|
151
|
+
async def _execute_agent(
|
|
152
|
+
self,
|
|
153
|
+
agent_config: AgentConfig[T],
|
|
154
|
+
run_config: AgentRunConfig,
|
|
155
|
+
) -> AgentRunResult[T]:
|
|
156
|
+
result = await self._get_agent_result(agent_config)
|
|
157
|
+
|
|
158
|
+
if run_config.session is not None and result.success:
|
|
159
|
+
await run_config.session.add_items(
|
|
160
|
+
[
|
|
161
|
+
{"role": "user", "content": run_config.input},
|
|
162
|
+
{
|
|
163
|
+
"role": "assistant",
|
|
164
|
+
"status": "completed",
|
|
165
|
+
"content": [
|
|
166
|
+
{
|
|
167
|
+
"type": "output_text",
|
|
168
|
+
"text": (
|
|
169
|
+
result.output
|
|
170
|
+
if isinstance(result.output, str)
|
|
171
|
+
else json.dumps(result.output)
|
|
172
|
+
),
|
|
173
|
+
}
|
|
174
|
+
],
|
|
175
|
+
},
|
|
176
|
+
]
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
return result
|
|
180
|
+
|
|
181
|
+
async def _get_agent_result(self, agent_config: AgentConfig[T]) -> AgentRunResult[T]:
|
|
182
|
+
error = self._errors_by_agent_name.get(agent_config.name)
|
|
183
|
+
if error is not None:
|
|
184
|
+
return AgentRunResult(
|
|
185
|
+
success=False,
|
|
186
|
+
error=error,
|
|
187
|
+
output=None, # type: ignore[arg-type]
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
responses = self._responses_by_agent_name.get(agent_config.name)
|
|
191
|
+
if responses:
|
|
192
|
+
response = responses.pop(0)
|
|
193
|
+
self._last_response_by_agent_name[agent_config.name] = response
|
|
194
|
+
return AgentRunResult(
|
|
195
|
+
success=True,
|
|
196
|
+
output=response,
|
|
197
|
+
working_dir=self._mock_working_dir,
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
if agent_config.name in self._last_response_by_agent_name:
|
|
201
|
+
return AgentRunResult(
|
|
202
|
+
success=True,
|
|
203
|
+
output=self._last_response_by_agent_name[agent_config.name],
|
|
204
|
+
working_dir=self._mock_working_dir,
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
return AgentRunResult(
|
|
208
|
+
success=True,
|
|
209
|
+
output=self._generate_default_response(agent_config),
|
|
210
|
+
working_dir=self._mock_working_dir,
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
def _generate_default_response(self, config: AgentConfig[T]) -> T:
|
|
214
|
+
if config.output_schema is not None:
|
|
215
|
+
return {} # type: ignore[return-value]
|
|
216
|
+
return f"Mock response for {config.name}" # type: ignore[return-value]
|
|
@@ -0,0 +1,302 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import functools
|
|
5
|
+
import json
|
|
6
|
+
import logging
|
|
7
|
+
import time
|
|
8
|
+
from collections import Counter
|
|
9
|
+
from collections.abc import Callable
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
from .agent_service import (
|
|
14
|
+
AgentConfig,
|
|
15
|
+
AgentRunConfig,
|
|
16
|
+
AgentRunResult,
|
|
17
|
+
AiAgentService,
|
|
18
|
+
ConsensusRunResult,
|
|
19
|
+
ConsensusStrategy,
|
|
20
|
+
)
|
|
21
|
+
from .param_mapper import AgentParameterMapper
|
|
22
|
+
from .prompt_utils import generate_tools_reference, parse_prompt_frontmatter
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def agentic_workflow(
|
|
28
|
+
*,
|
|
29
|
+
prompts_directory: str,
|
|
30
|
+
tool_registry: dict[str, str] | None = None,
|
|
31
|
+
) -> Any:
|
|
32
|
+
"""Class decorator that registers prompt directory and tool registry."""
|
|
33
|
+
|
|
34
|
+
def decorator(cls: type) -> type:
|
|
35
|
+
cls._agentic_workflow_options = { # type: ignore[attr-defined]
|
|
36
|
+
"prompts_directory": prompts_directory,
|
|
37
|
+
"tool_registry": tool_registry,
|
|
38
|
+
}
|
|
39
|
+
return cls
|
|
40
|
+
|
|
41
|
+
return decorator
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def agent(
|
|
45
|
+
*,
|
|
46
|
+
name: str | None = None,
|
|
47
|
+
tools: list[Any] | Callable[..., list[Any]] | None = None,
|
|
48
|
+
model: str | None = None,
|
|
49
|
+
model_settings: dict[str, Any] | None = None,
|
|
50
|
+
output_schema: Any | None = None,
|
|
51
|
+
max_turns: int | None = None,
|
|
52
|
+
branch_param: str | None = None,
|
|
53
|
+
) -> Any:
|
|
54
|
+
"""Method decorator that turns a method into an agent invocation.
|
|
55
|
+
|
|
56
|
+
The decorated method's body is replaced. Its parameters are mapped
|
|
57
|
+
to the agent input JSON, and the prompt is loaded from a markdown file
|
|
58
|
+
in the class's ``prompts_directory``.
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
def decorator(fn: Any) -> Any:
|
|
62
|
+
mapper = AgentParameterMapper.from_function(fn)
|
|
63
|
+
|
|
64
|
+
@functools.wraps(fn)
|
|
65
|
+
async def wrapper(self: Any, *args: Any, **kwargs: Any) -> AgentRunResult[Any]:
|
|
66
|
+
ai_service: AiAgentService = self._ai_agent_service
|
|
67
|
+
options = getattr(self.__class__, "_agentic_workflow_options", {})
|
|
68
|
+
prompts_dir = options.get("prompts_directory", "")
|
|
69
|
+
|
|
70
|
+
# Load prompt from markdown file
|
|
71
|
+
prompt_path = Path(prompts_dir) / f"{fn.__name__}.md"
|
|
72
|
+
if prompt_path.exists():
|
|
73
|
+
raw_prompt = prompt_path.read_text()
|
|
74
|
+
else:
|
|
75
|
+
raw_prompt = ""
|
|
76
|
+
|
|
77
|
+
parsed = parse_prompt_frontmatter(raw_prompt)
|
|
78
|
+
instructions = parsed.content
|
|
79
|
+
|
|
80
|
+
# Resolve tools
|
|
81
|
+
resolved_tools = tools
|
|
82
|
+
if callable(resolved_tools):
|
|
83
|
+
resolved_tools = resolved_tools(self)
|
|
84
|
+
resolved_tools = resolved_tools or []
|
|
85
|
+
|
|
86
|
+
# Generate tools reference and prepend
|
|
87
|
+
tools_ref = generate_tools_reference(resolved_tools)
|
|
88
|
+
if tools_ref:
|
|
89
|
+
instructions = tools_ref + "\n" + instructions
|
|
90
|
+
|
|
91
|
+
# Map parameters
|
|
92
|
+
all_args = args
|
|
93
|
+
mapped = mapper.map_arguments(all_args)
|
|
94
|
+
session = mapper.find_session(all_args)
|
|
95
|
+
|
|
96
|
+
# Determine agent name
|
|
97
|
+
agent_name = name or f"{self.__class__.__name__}:{fn.__name__}"
|
|
98
|
+
|
|
99
|
+
# Get branch if specified
|
|
100
|
+
branch = None
|
|
101
|
+
if branch_param:
|
|
102
|
+
branch = mapper.get_param_value(branch_param, all_args)
|
|
103
|
+
|
|
104
|
+
# Determine max_turns
|
|
105
|
+
effective_max_turns = max_turns
|
|
106
|
+
if effective_max_turns is None:
|
|
107
|
+
effective_max_turns = 8 if resolved_tools else 1
|
|
108
|
+
|
|
109
|
+
config = AgentConfig(
|
|
110
|
+
name=agent_name,
|
|
111
|
+
instructions=instructions,
|
|
112
|
+
model=model,
|
|
113
|
+
tools=resolved_tools,
|
|
114
|
+
output_schema=output_schema,
|
|
115
|
+
model_settings=model_settings,
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
run_config = AgentRunConfig(
|
|
119
|
+
input=mapped.input,
|
|
120
|
+
context=mapped.context,
|
|
121
|
+
session=session,
|
|
122
|
+
max_turns=effective_max_turns,
|
|
123
|
+
branch=branch,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
start = time.monotonic()
|
|
127
|
+
result = await ai_service.create_and_run(config, run_config)
|
|
128
|
+
elapsed = time.monotonic() - start
|
|
129
|
+
logger.debug(
|
|
130
|
+
"Agent %s completed in %.2fs (success=%s)",
|
|
131
|
+
agent_name,
|
|
132
|
+
elapsed,
|
|
133
|
+
result.success,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
return result
|
|
137
|
+
|
|
138
|
+
return wrapper
|
|
139
|
+
|
|
140
|
+
return decorator
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def consensus_agent(
|
|
144
|
+
*,
|
|
145
|
+
name: str | None = None,
|
|
146
|
+
tools: list[Any] | Callable[..., list[Any]] | None = None,
|
|
147
|
+
model: str | None = None,
|
|
148
|
+
model_settings: dict[str, Any] | None = None,
|
|
149
|
+
output_schema: Any | None = None,
|
|
150
|
+
max_turns: int | None = None,
|
|
151
|
+
runs: int,
|
|
152
|
+
consensus_strategy: ConsensusStrategy,
|
|
153
|
+
judge: Callable[..., Any] | None = None,
|
|
154
|
+
temperature_spread: tuple[float, float] | None = None,
|
|
155
|
+
) -> Any:
|
|
156
|
+
"""Method decorator for consensus-based multi-run agent invocation."""
|
|
157
|
+
|
|
158
|
+
if runs % 2 == 0:
|
|
159
|
+
raise ValueError("runs must be odd")
|
|
160
|
+
if consensus_strategy == ConsensusStrategy.JUDGE and judge is None:
|
|
161
|
+
raise ValueError("judge function required for JUDGE strategy")
|
|
162
|
+
|
|
163
|
+
def decorator(fn: Any) -> Any:
|
|
164
|
+
mapper = AgentParameterMapper.from_function(fn)
|
|
165
|
+
|
|
166
|
+
@functools.wraps(fn)
|
|
167
|
+
async def wrapper(self: Any, *args: Any, **kwargs: Any) -> ConsensusRunResult[Any]:
|
|
168
|
+
ai_service: AiAgentService = self._ai_agent_service
|
|
169
|
+
options = getattr(self.__class__, "_agentic_workflow_options", {})
|
|
170
|
+
prompts_dir = options.get("prompts_directory", "")
|
|
171
|
+
|
|
172
|
+
prompt_path = Path(prompts_dir) / f"{fn.__name__}.md"
|
|
173
|
+
raw_prompt = prompt_path.read_text() if prompt_path.exists() else ""
|
|
174
|
+
parsed = parse_prompt_frontmatter(raw_prompt)
|
|
175
|
+
instructions = parsed.content
|
|
176
|
+
|
|
177
|
+
resolved_tools = tools
|
|
178
|
+
if callable(resolved_tools):
|
|
179
|
+
resolved_tools = resolved_tools(self)
|
|
180
|
+
resolved_tools = resolved_tools or []
|
|
181
|
+
|
|
182
|
+
tools_ref = generate_tools_reference(resolved_tools)
|
|
183
|
+
if tools_ref:
|
|
184
|
+
instructions = tools_ref + "\n" + instructions
|
|
185
|
+
|
|
186
|
+
mapped = mapper.map_arguments(args)
|
|
187
|
+
session = mapper.find_session(args)
|
|
188
|
+
agent_name = name or f"{self.__class__.__name__}:{fn.__name__}"
|
|
189
|
+
effective_max_turns = max_turns or (8 if resolved_tools else 1)
|
|
190
|
+
|
|
191
|
+
# Build configs for each run
|
|
192
|
+
async def single_run(run_index: int) -> AgentRunResult[Any]:
|
|
193
|
+
ms = dict(model_settings or {})
|
|
194
|
+
if temperature_spread:
|
|
195
|
+
lo, hi = temperature_spread
|
|
196
|
+
t = lo + (hi - lo) * run_index / max(runs - 1, 1)
|
|
197
|
+
ms["temperature"] = t
|
|
198
|
+
|
|
199
|
+
config = AgentConfig(
|
|
200
|
+
name=agent_name,
|
|
201
|
+
instructions=instructions,
|
|
202
|
+
model=model,
|
|
203
|
+
tools=resolved_tools,
|
|
204
|
+
output_schema=output_schema,
|
|
205
|
+
model_settings=ms or None,
|
|
206
|
+
)
|
|
207
|
+
run_config = AgentRunConfig(
|
|
208
|
+
input=mapped.input,
|
|
209
|
+
context=mapped.context,
|
|
210
|
+
session=session,
|
|
211
|
+
max_turns=effective_max_turns,
|
|
212
|
+
)
|
|
213
|
+
return await ai_service.create_and_run(config, run_config)
|
|
214
|
+
|
|
215
|
+
all_results = await asyncio.gather(*(single_run(i) for i in range(runs)))
|
|
216
|
+
all_results_list = list(all_results)
|
|
217
|
+
|
|
218
|
+
return await _resolve_consensus(
|
|
219
|
+
all_results_list,
|
|
220
|
+
consensus_strategy,
|
|
221
|
+
runs,
|
|
222
|
+
judge_fn=judge,
|
|
223
|
+
instance=self,
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
return wrapper
|
|
227
|
+
|
|
228
|
+
return decorator
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
async def _resolve_consensus(
|
|
232
|
+
all_results: list[AgentRunResult[Any]],
|
|
233
|
+
strategy: ConsensusStrategy,
|
|
234
|
+
total_runs: int,
|
|
235
|
+
judge_fn: Callable[..., Any] | None = None,
|
|
236
|
+
instance: Any = None,
|
|
237
|
+
) -> ConsensusRunResult[Any]:
|
|
238
|
+
successful = [r for r in all_results if r.success]
|
|
239
|
+
|
|
240
|
+
if not successful:
|
|
241
|
+
return ConsensusRunResult(
|
|
242
|
+
output=None, # type: ignore[arg-type]
|
|
243
|
+
success=False,
|
|
244
|
+
error="All runs failed",
|
|
245
|
+
runs=all_results,
|
|
246
|
+
agreement=0.0,
|
|
247
|
+
total_runs=total_runs,
|
|
248
|
+
successful_runs=0,
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
if strategy == ConsensusStrategy.MAJORITY:
|
|
252
|
+
serialized = [json.dumps(r.output, sort_keys=True, default=str) for r in successful]
|
|
253
|
+
counts = Counter(serialized)
|
|
254
|
+
winner_key, winner_count = counts.most_common(1)[0]
|
|
255
|
+
winner_result = next(
|
|
256
|
+
r for r, s in zip(successful, serialized, strict=False) if s == winner_key
|
|
257
|
+
)
|
|
258
|
+
return ConsensusRunResult(
|
|
259
|
+
output=winner_result.output,
|
|
260
|
+
success=True,
|
|
261
|
+
raw_result=winner_result.raw_result,
|
|
262
|
+
runs=all_results,
|
|
263
|
+
agreement=winner_count / len(successful),
|
|
264
|
+
total_runs=total_runs,
|
|
265
|
+
successful_runs=len(successful),
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
if strategy == ConsensusStrategy.UNANIMOUS:
|
|
269
|
+
serialized = [json.dumps(r.output, sort_keys=True, default=str) for r in successful]
|
|
270
|
+
if len(set(serialized)) != 1:
|
|
271
|
+
return ConsensusRunResult(
|
|
272
|
+
output=None, # type: ignore[arg-type]
|
|
273
|
+
success=False,
|
|
274
|
+
error="Unanimous consensus not reached",
|
|
275
|
+
runs=all_results,
|
|
276
|
+
agreement=0.0,
|
|
277
|
+
total_runs=total_runs,
|
|
278
|
+
successful_runs=len(successful),
|
|
279
|
+
)
|
|
280
|
+
return ConsensusRunResult(
|
|
281
|
+
output=successful[0].output,
|
|
282
|
+
success=True,
|
|
283
|
+
raw_result=successful[0].raw_result,
|
|
284
|
+
runs=all_results,
|
|
285
|
+
agreement=1.0,
|
|
286
|
+
total_runs=total_runs,
|
|
287
|
+
successful_runs=len(successful),
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
if strategy == ConsensusStrategy.JUDGE:
|
|
291
|
+
assert judge_fn is not None
|
|
292
|
+
judge_result = await judge_fn(instance, successful)
|
|
293
|
+
return ConsensusRunResult(
|
|
294
|
+
output=judge_result.output,
|
|
295
|
+
success=judge_result.success,
|
|
296
|
+
runs=all_results,
|
|
297
|
+
agreement=0.0,
|
|
298
|
+
total_runs=total_runs,
|
|
299
|
+
successful_runs=len(successful),
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
raise ValueError(f"Unknown consensus strategy: {strategy}")
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from .agent_service import AiAgentConfig, AiAgentProvider, AiAgentService, AiAgentServiceLocal
|
|
4
|
+
from .service_ollama import AiAgentServiceOllama
|
|
5
|
+
from .service_openai import AiAgentServiceOpenai
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class AiAgentFactory:
|
|
9
|
+
"""Creates the appropriate AiAgentService based on configuration."""
|
|
10
|
+
|
|
11
|
+
def __init__(self, config: AiAgentConfig) -> None:
|
|
12
|
+
self._config = config
|
|
13
|
+
|
|
14
|
+
def make_agent_service(self) -> AiAgentService:
|
|
15
|
+
if self._config.local:
|
|
16
|
+
return AiAgentServiceLocal.get_instance()
|
|
17
|
+
|
|
18
|
+
if self._config.provider == AiAgentProvider.OLLAMA:
|
|
19
|
+
return AiAgentServiceOllama(
|
|
20
|
+
base_url=self._config.ollama_base_url,
|
|
21
|
+
default_model=self._config.default_model or "qwen2.5:14b",
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
if self._config.provider == AiAgentProvider.OPENAI:
|
|
25
|
+
if not self._config.openai_api_token:
|
|
26
|
+
raise ValueError("openai_api_token is required for the OpenAI provider")
|
|
27
|
+
return AiAgentServiceOpenai(
|
|
28
|
+
api_key=self._config.openai_api_token,
|
|
29
|
+
default_model=self._config.default_model or "gpt-4o",
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
raise ValueError(f"Unknown provider: {self._config.provider}")
|
|
33
|
+
|
|
34
|
+
@staticmethod
|
|
35
|
+
def make_ollama_service(
|
|
36
|
+
base_url: str = "http://localhost:11434",
|
|
37
|
+
default_model: str = "qwen2.5:14b",
|
|
38
|
+
) -> AiAgentServiceOllama:
|
|
39
|
+
return AiAgentServiceOllama(base_url=base_url, default_model=default_model)
|
|
40
|
+
|
|
41
|
+
@staticmethod
|
|
42
|
+
def make_openai_service(
|
|
43
|
+
api_key: str,
|
|
44
|
+
default_model: str = "gpt-4o",
|
|
45
|
+
) -> AiAgentServiceOpenai:
|
|
46
|
+
return AiAgentServiceOpenai(api_key=api_key, default_model=default_model)
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import inspect
|
|
4
|
+
import json
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from .session import RepositorySession
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class MappedArguments:
|
|
13
|
+
input: str
|
|
14
|
+
context: dict[str, Any] | None = None
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class AgentParameterMapper:
|
|
18
|
+
"""Maps decorated method parameters to agent input JSON."""
|
|
19
|
+
|
|
20
|
+
def __init__(self, param_names: list[str]) -> None:
|
|
21
|
+
self._param_names = param_names
|
|
22
|
+
|
|
23
|
+
@classmethod
|
|
24
|
+
def from_function(cls, func: Any) -> AgentParameterMapper:
|
|
25
|
+
sig = inspect.signature(func)
|
|
26
|
+
names = [name for name, _param in sig.parameters.items() if name != "self"]
|
|
27
|
+
return cls(names)
|
|
28
|
+
|
|
29
|
+
def map_arguments(self, args: tuple[Any, ...]) -> MappedArguments:
|
|
30
|
+
input_obj: dict[str, Any] = {}
|
|
31
|
+
context: dict[str, Any] | None = None
|
|
32
|
+
|
|
33
|
+
for name, value in zip(self._param_names, args, strict=False):
|
|
34
|
+
if isinstance(value, RepositorySession):
|
|
35
|
+
continue
|
|
36
|
+
if name == "context":
|
|
37
|
+
context = value
|
|
38
|
+
continue
|
|
39
|
+
input_obj[name] = value
|
|
40
|
+
|
|
41
|
+
return MappedArguments(
|
|
42
|
+
input=json.dumps(input_obj, indent=2, default=str),
|
|
43
|
+
context=context,
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
def find_session(self, args: tuple[Any, ...]) -> RepositorySession | None:
|
|
47
|
+
for value in args:
|
|
48
|
+
if isinstance(value, RepositorySession):
|
|
49
|
+
return value
|
|
50
|
+
return None
|
|
51
|
+
|
|
52
|
+
def get_param_value(self, name: str, args: tuple[Any, ...]) -> Any:
|
|
53
|
+
for pname, value in zip(self._param_names, args, strict=False):
|
|
54
|
+
if pname == name:
|
|
55
|
+
return value
|
|
56
|
+
return None
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import yaml
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass
|
|
11
|
+
class PromptFrontmatter:
|
|
12
|
+
tools: list[str] = field(default_factory=list)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class ParsedPrompt:
|
|
17
|
+
frontmatter: PromptFrontmatter
|
|
18
|
+
content: str
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
_FRONTMATTER_RE = re.compile(r"^---\s*\n([\s\S]*?)\n---\s*\n([\s\S]*)$")
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def parse_prompt_frontmatter(markdown: str) -> ParsedPrompt:
|
|
25
|
+
match = _FRONTMATTER_RE.match(markdown)
|
|
26
|
+
if not match:
|
|
27
|
+
return ParsedPrompt(frontmatter=PromptFrontmatter(), content=markdown)
|
|
28
|
+
|
|
29
|
+
raw_frontmatter = match.group(1)
|
|
30
|
+
content = match.group(2)
|
|
31
|
+
|
|
32
|
+
parsed = yaml.safe_load(raw_frontmatter) or {}
|
|
33
|
+
tools = parsed.get("tools", [])
|
|
34
|
+
|
|
35
|
+
return ParsedPrompt(
|
|
36
|
+
frontmatter=PromptFrontmatter(tools=tools or []),
|
|
37
|
+
content=content,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def generate_tools_reference(tools: list[Any]) -> str:
|
|
42
|
+
if not tools:
|
|
43
|
+
return ""
|
|
44
|
+
|
|
45
|
+
lines = ["## Available Tools\n"]
|
|
46
|
+
for tool in tools:
|
|
47
|
+
name = getattr(tool, "name", str(tool))
|
|
48
|
+
description = getattr(tool, "description", "")
|
|
49
|
+
lines.append(f"### {name}")
|
|
50
|
+
if description:
|
|
51
|
+
lines.append(description)
|
|
52
|
+
lines.append("")
|
|
53
|
+
|
|
54
|
+
return "\n".join(lines)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def map_tool_keys(
|
|
58
|
+
tool_keys: list[str],
|
|
59
|
+
registry: dict[str, str],
|
|
60
|
+
) -> list[str]:
|
|
61
|
+
mapped: list[str] = []
|
|
62
|
+
for key in tool_keys:
|
|
63
|
+
if key not in registry:
|
|
64
|
+
available = ", ".join(sorted(registry.keys()))
|
|
65
|
+
raise ValueError(f"Tool key '{key}' not found in registry. Available: {available}")
|
|
66
|
+
mapped.append(registry[key])
|
|
67
|
+
return mapped
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def validate_tools_match(
|
|
71
|
+
declared_tool_names: list[str],
|
|
72
|
+
actual_tools: list[Any],
|
|
73
|
+
) -> None:
|
|
74
|
+
actual_names = {getattr(t, "name", str(t)) for t in actual_tools}
|
|
75
|
+
declared_set = set(declared_tool_names)
|
|
76
|
+
|
|
77
|
+
missing = declared_set - actual_names
|
|
78
|
+
extra = actual_names - declared_set
|
|
79
|
+
|
|
80
|
+
errors: list[str] = []
|
|
81
|
+
if missing:
|
|
82
|
+
errors.append(f"Declared in frontmatter but not provided: {missing}")
|
|
83
|
+
if extra:
|
|
84
|
+
errors.append(f"Provided but not declared in frontmatter: {extra}")
|
|
85
|
+
|
|
86
|
+
if errors:
|
|
87
|
+
raise ValueError("; ".join(errors))
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from agents import Agent, Runner
|
|
7
|
+
from agents.model_settings import ModelSettings
|
|
8
|
+
from agents.models.openai_chatcompletions import OpenAIChatCompletionsModel
|
|
9
|
+
from openai import AsyncOpenAI
|
|
10
|
+
|
|
11
|
+
from .agent_service import AgentConfig, AgentRunConfig, AgentRunResult, AgentType, T
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class AiAgentServiceOllama:
|
|
17
|
+
"""Ollama implementation via OpenAI-compatible API."""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
base_url: str = "http://localhost:11434",
|
|
22
|
+
default_model: str = "qwen2.5:14b",
|
|
23
|
+
) -> None:
|
|
24
|
+
self._base_url = base_url
|
|
25
|
+
self._default_model = default_model
|
|
26
|
+
self._client = AsyncOpenAI(
|
|
27
|
+
base_url=f"{base_url}/v1",
|
|
28
|
+
api_key="ollama",
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
def create_agent(self, config: AgentConfig[T]) -> AgentType[T]:
|
|
32
|
+
return AgentType(config=config)
|
|
33
|
+
|
|
34
|
+
async def run_agent(
|
|
35
|
+
self,
|
|
36
|
+
agent: AgentType[T],
|
|
37
|
+
config: AgentRunConfig,
|
|
38
|
+
) -> AgentRunResult[T]:
|
|
39
|
+
return await self.create_and_run(agent.config, config)
|
|
40
|
+
|
|
41
|
+
async def create_and_run(
|
|
42
|
+
self,
|
|
43
|
+
agent_config: AgentConfig[T],
|
|
44
|
+
run_config: AgentRunConfig,
|
|
45
|
+
) -> AgentRunResult[T]:
|
|
46
|
+
try:
|
|
47
|
+
model_name = agent_config.model or self._default_model
|
|
48
|
+
model = OpenAIChatCompletionsModel(
|
|
49
|
+
model=model_name,
|
|
50
|
+
openai_client=self._client,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
ms = agent_config.model_settings
|
|
54
|
+
if ms is not None and not isinstance(ms, ModelSettings):
|
|
55
|
+
ms = ModelSettings(**ms)
|
|
56
|
+
|
|
57
|
+
sdk_agent = Agent(
|
|
58
|
+
name=agent_config.name,
|
|
59
|
+
instructions=agent_config.instructions,
|
|
60
|
+
model=model,
|
|
61
|
+
tools=agent_config.tools or [],
|
|
62
|
+
output_type=agent_config.output_schema,
|
|
63
|
+
model_settings=ms or ModelSettings(),
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
run_kwargs: dict[str, Any] = {"input": run_config.input}
|
|
67
|
+
if run_config.context is not None:
|
|
68
|
+
run_kwargs["context"] = run_config.context
|
|
69
|
+
if run_config.max_turns is not None:
|
|
70
|
+
run_kwargs["max_turns"] = run_config.max_turns
|
|
71
|
+
|
|
72
|
+
result = await Runner.run(sdk_agent, **run_kwargs)
|
|
73
|
+
|
|
74
|
+
return AgentRunResult(
|
|
75
|
+
output=result.final_output,
|
|
76
|
+
success=True,
|
|
77
|
+
raw_result=result,
|
|
78
|
+
)
|
|
79
|
+
except Exception as exc:
|
|
80
|
+
error_msg = str(exc)
|
|
81
|
+
if "max_turns" in error_msg.lower():
|
|
82
|
+
logger.warning("Agent %s reached max turns", agent_config.name)
|
|
83
|
+
else:
|
|
84
|
+
logger.exception("Agent %s failed", agent_config.name)
|
|
85
|
+
return AgentRunResult(
|
|
86
|
+
output=None, # type: ignore[arg-type]
|
|
87
|
+
success=False,
|
|
88
|
+
error=error_msg,
|
|
89
|
+
)
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from agents import Agent, Runner
|
|
7
|
+
from agents.model_settings import ModelSettings
|
|
8
|
+
from agents.run import RunConfig
|
|
9
|
+
|
|
10
|
+
from .agent_service import AgentConfig, AgentRunConfig, AgentRunResult, AgentType, T
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class AiAgentServiceOpenai:
|
|
16
|
+
"""OpenAI Agents SDK implementation."""
|
|
17
|
+
|
|
18
|
+
def __init__(self, api_key: str, default_model: str = "gpt-4o") -> None:
|
|
19
|
+
self._api_key = api_key
|
|
20
|
+
self._default_model = default_model
|
|
21
|
+
|
|
22
|
+
def create_agent(self, config: AgentConfig[T]) -> AgentType[T]:
|
|
23
|
+
return AgentType(config=config)
|
|
24
|
+
|
|
25
|
+
async def run_agent(
|
|
26
|
+
self,
|
|
27
|
+
agent: AgentType[T],
|
|
28
|
+
config: AgentRunConfig,
|
|
29
|
+
) -> AgentRunResult[T]:
|
|
30
|
+
return await self.create_and_run(agent.config, config)
|
|
31
|
+
|
|
32
|
+
async def create_and_run(
|
|
33
|
+
self,
|
|
34
|
+
agent_config: AgentConfig[T],
|
|
35
|
+
run_config: AgentRunConfig,
|
|
36
|
+
) -> AgentRunResult[T]:
|
|
37
|
+
try:
|
|
38
|
+
model = agent_config.model or self._default_model
|
|
39
|
+
ms = agent_config.model_settings
|
|
40
|
+
if ms is not None and not isinstance(ms, ModelSettings):
|
|
41
|
+
ms = ModelSettings(**ms)
|
|
42
|
+
|
|
43
|
+
sdk_agent = Agent(
|
|
44
|
+
name=agent_config.name,
|
|
45
|
+
instructions=agent_config.instructions,
|
|
46
|
+
model=model,
|
|
47
|
+
tools=agent_config.tools or [],
|
|
48
|
+
output_type=agent_config.output_schema,
|
|
49
|
+
model_settings=ms or ModelSettings(),
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
run_kwargs: dict[str, Any] = {"input": run_config.input}
|
|
53
|
+
if run_config.context is not None:
|
|
54
|
+
run_kwargs["context"] = run_config.context
|
|
55
|
+
if run_config.max_turns is not None:
|
|
56
|
+
run_kwargs["max_turns"] = run_config.max_turns
|
|
57
|
+
|
|
58
|
+
run_cfg = RunConfig(model=model)
|
|
59
|
+
|
|
60
|
+
result = await Runner.run(
|
|
61
|
+
sdk_agent,
|
|
62
|
+
run_config=run_cfg,
|
|
63
|
+
**run_kwargs,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
return AgentRunResult(
|
|
67
|
+
output=result.final_output,
|
|
68
|
+
success=True,
|
|
69
|
+
raw_result=result,
|
|
70
|
+
)
|
|
71
|
+
except Exception as exc:
|
|
72
|
+
error_msg = str(exc)
|
|
73
|
+
if "max_turns" in error_msg.lower():
|
|
74
|
+
logger.warning("Agent %s reached max turns", agent_config.name)
|
|
75
|
+
else:
|
|
76
|
+
logger.exception("Agent %s failed", agent_config.name)
|
|
77
|
+
return AgentRunResult(
|
|
78
|
+
output=None, # type: ignore[arg-type]
|
|
79
|
+
success=False,
|
|
80
|
+
error=error_msg,
|
|
81
|
+
)
|
|
@@ -0,0 +1,232 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from collections.abc import Mapping, Sequence
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from time import time
|
|
7
|
+
from typing import Any, Protocol, TypedDict, cast
|
|
8
|
+
from uuid import uuid4
|
|
9
|
+
|
|
10
|
+
CONVERSATION_SESSION_REPOSITORY = "CONVERSATION_SESSION_REPOSITORY"
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _now_ms() -> int:
|
|
14
|
+
return int(time() * 1000)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass(slots=True)
|
|
18
|
+
class ConversationSessionModel:
|
|
19
|
+
session_id: str
|
|
20
|
+
client_id: str
|
|
21
|
+
created_at: int
|
|
22
|
+
updated_at: int
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass(slots=True)
|
|
26
|
+
class ConversationItemModel:
|
|
27
|
+
item_id: str
|
|
28
|
+
session_id: str
|
|
29
|
+
sequence_number: int
|
|
30
|
+
role: str
|
|
31
|
+
content: str
|
|
32
|
+
metadata: str | None
|
|
33
|
+
created_at: int
|
|
34
|
+
deleted_at: int | None
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class ConversationMessage(TypedDict):
|
|
38
|
+
role: str
|
|
39
|
+
content: str
|
|
40
|
+
metadata: dict[str, Any] | None
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class InputTextPart(TypedDict):
|
|
44
|
+
type: str
|
|
45
|
+
text: str
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class SessionItem(TypedDict):
|
|
49
|
+
role: str
|
|
50
|
+
content: str | list[InputTextPart]
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class ConversationSessionRepository(Protocol):
|
|
54
|
+
async def create_session(self, client_id: str) -> ConversationSessionModel: ...
|
|
55
|
+
|
|
56
|
+
async def get_session(self, session_id: str) -> ConversationSessionModel: ...
|
|
57
|
+
|
|
58
|
+
async def get_conversation_items(
|
|
59
|
+
self,
|
|
60
|
+
session_id: str,
|
|
61
|
+
limit: int | None = None,
|
|
62
|
+
) -> list[ConversationItemModel]: ...
|
|
63
|
+
|
|
64
|
+
async def add_conversation_items(
|
|
65
|
+
self,
|
|
66
|
+
session_id: str,
|
|
67
|
+
items: Sequence[ConversationMessage],
|
|
68
|
+
) -> None: ...
|
|
69
|
+
|
|
70
|
+
async def clear_conversation(self, session_id: str) -> None: ...
|
|
71
|
+
|
|
72
|
+
async def pop_last_item(self, session_id: str) -> ConversationItemModel | None: ...
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class SessionNotFoundError(LookupError):
|
|
76
|
+
pass
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class InMemoryConversationSessionRepository:
|
|
80
|
+
def __init__(self) -> None:
|
|
81
|
+
self._sessions: dict[str, ConversationSessionModel] = {}
|
|
82
|
+
self._items: dict[str, list[ConversationItemModel]] = {}
|
|
83
|
+
|
|
84
|
+
async def create_session(self, client_id: str) -> ConversationSessionModel:
|
|
85
|
+
session_id = str(uuid4())
|
|
86
|
+
now = _now_ms()
|
|
87
|
+
session = ConversationSessionModel(
|
|
88
|
+
session_id=session_id,
|
|
89
|
+
client_id=client_id,
|
|
90
|
+
created_at=now,
|
|
91
|
+
updated_at=now,
|
|
92
|
+
)
|
|
93
|
+
self._sessions[session_id] = session
|
|
94
|
+
self._items[session_id] = []
|
|
95
|
+
return session
|
|
96
|
+
|
|
97
|
+
async def get_session(self, session_id: str) -> ConversationSessionModel:
|
|
98
|
+
session = self._sessions.get(session_id)
|
|
99
|
+
if session is None:
|
|
100
|
+
raise SessionNotFoundError(f"Session not found: {session_id}")
|
|
101
|
+
return session
|
|
102
|
+
|
|
103
|
+
async def get_conversation_items(
|
|
104
|
+
self,
|
|
105
|
+
session_id: str,
|
|
106
|
+
limit: int | None = None,
|
|
107
|
+
) -> list[ConversationItemModel]:
|
|
108
|
+
session_items = self._items.get(session_id, [])
|
|
109
|
+
active_items = [item for item in session_items if item.deleted_at is None]
|
|
110
|
+
if limit:
|
|
111
|
+
return active_items[-limit:]
|
|
112
|
+
return active_items
|
|
113
|
+
|
|
114
|
+
async def add_conversation_items(
|
|
115
|
+
self,
|
|
116
|
+
session_id: str,
|
|
117
|
+
items: Sequence[ConversationMessage],
|
|
118
|
+
) -> None:
|
|
119
|
+
session = await self.get_session(session_id)
|
|
120
|
+
session_items = self._items.get(session_id, [])
|
|
121
|
+
now = _now_ms()
|
|
122
|
+
sequence_number = max((item.sequence_number for item in session_items), default=-1) + 1
|
|
123
|
+
|
|
124
|
+
for item in items:
|
|
125
|
+
session_items.append(
|
|
126
|
+
ConversationItemModel(
|
|
127
|
+
item_id=str(uuid4()),
|
|
128
|
+
session_id=session_id,
|
|
129
|
+
sequence_number=sequence_number,
|
|
130
|
+
role=item["role"],
|
|
131
|
+
content=item["content"],
|
|
132
|
+
metadata=(
|
|
133
|
+
json.dumps(item["metadata"]) if item.get("metadata") is not None else None
|
|
134
|
+
),
|
|
135
|
+
created_at=now,
|
|
136
|
+
deleted_at=None,
|
|
137
|
+
)
|
|
138
|
+
)
|
|
139
|
+
sequence_number += 1
|
|
140
|
+
|
|
141
|
+
self._items[session_id] = session_items
|
|
142
|
+
session.updated_at = now
|
|
143
|
+
|
|
144
|
+
async def clear_conversation(self, session_id: str) -> None:
|
|
145
|
+
session = await self.get_session(session_id)
|
|
146
|
+
session_items = self._items.get(session_id, [])
|
|
147
|
+
now = _now_ms()
|
|
148
|
+
|
|
149
|
+
for item in session_items:
|
|
150
|
+
if item.deleted_at is None:
|
|
151
|
+
item.deleted_at = now
|
|
152
|
+
|
|
153
|
+
session.updated_at = now
|
|
154
|
+
|
|
155
|
+
async def pop_last_item(self, session_id: str) -> ConversationItemModel | None:
|
|
156
|
+
session = await self.get_session(session_id)
|
|
157
|
+
session_items = self._items.get(session_id, [])
|
|
158
|
+
active_items = [item for item in session_items if item.deleted_at is None]
|
|
159
|
+
if not active_items:
|
|
160
|
+
return None
|
|
161
|
+
|
|
162
|
+
last_item = active_items[-1]
|
|
163
|
+
now = _now_ms()
|
|
164
|
+
last_item.deleted_at = now
|
|
165
|
+
session.updated_at = now
|
|
166
|
+
return last_item
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
class RepositorySession:
|
|
170
|
+
def __init__(self, session_id: str, repository: ConversationSessionRepository) -> None:
|
|
171
|
+
self.session_id = session_id
|
|
172
|
+
self.repository = repository
|
|
173
|
+
self._cached_items: list[SessionItem] | None = None
|
|
174
|
+
|
|
175
|
+
def _to_content_parts(self, role: str, content: str) -> list[InputTextPart]:
|
|
176
|
+
if role == "user":
|
|
177
|
+
return [{"type": "input_text", "text": content}]
|
|
178
|
+
return [{"type": "output_text", "text": content}]
|
|
179
|
+
|
|
180
|
+
async def get_items(self, limit: int | None = None) -> list[SessionItem]:
|
|
181
|
+
if limit is None and self._cached_items is not None:
|
|
182
|
+
return self._cached_items
|
|
183
|
+
|
|
184
|
+
db_items = await self.repository.get_conversation_items(self.session_id, limit)
|
|
185
|
+
items: list[SessionItem] = [
|
|
186
|
+
SessionItem(
|
|
187
|
+
role=item.role,
|
|
188
|
+
content=self._to_content_parts(item.role, item.content),
|
|
189
|
+
)
|
|
190
|
+
for item in db_items
|
|
191
|
+
]
|
|
192
|
+
|
|
193
|
+
if limit is None:
|
|
194
|
+
self._cached_items = items
|
|
195
|
+
return items
|
|
196
|
+
|
|
197
|
+
async def add_items(self, items: Sequence[Mapping[str, Any]]) -> None:
|
|
198
|
+
messages: list[ConversationMessage] = []
|
|
199
|
+
for item in items:
|
|
200
|
+
role = item.get("role")
|
|
201
|
+
if role not in {"user", "assistant", "system"} or "content" not in item:
|
|
202
|
+
continue
|
|
203
|
+
content = item["content"]
|
|
204
|
+
if isinstance(content, str):
|
|
205
|
+
serialized_content = content
|
|
206
|
+
else:
|
|
207
|
+
serialized_content = json.dumps(content)
|
|
208
|
+
messages.append(
|
|
209
|
+
ConversationMessage(
|
|
210
|
+
role=cast(str, role),
|
|
211
|
+
content=serialized_content,
|
|
212
|
+
metadata=None,
|
|
213
|
+
)
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
if messages:
|
|
217
|
+
await self.repository.add_conversation_items(self.session_id, messages)
|
|
218
|
+
self._cached_items = None
|
|
219
|
+
|
|
220
|
+
async def clear_session(self) -> None:
|
|
221
|
+
await self.repository.clear_conversation(self.session_id)
|
|
222
|
+
self._cached_items = None
|
|
223
|
+
|
|
224
|
+
async def pop_item(self) -> SessionItem | None:
|
|
225
|
+
popped_item = await self.repository.pop_last_item(self.session_id)
|
|
226
|
+
self._cached_items = None
|
|
227
|
+
if popped_item is None:
|
|
228
|
+
return None
|
|
229
|
+
return SessionItem(
|
|
230
|
+
role=popped_item.role,
|
|
231
|
+
content=self._to_content_parts(popped_item.role, popped_item.content),
|
|
232
|
+
)
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from .session import ConversationSessionRepository, RepositorySession
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class AiSessionFactory:
|
|
7
|
+
"""Creates or retrieves RepositorySession instances."""
|
|
8
|
+
|
|
9
|
+
def __init__(self, repository: ConversationSessionRepository) -> None:
|
|
10
|
+
self._repository = repository
|
|
11
|
+
|
|
12
|
+
async def get_or_create_session(
|
|
13
|
+
self,
|
|
14
|
+
client_id: str,
|
|
15
|
+
session_id: str | None = None,
|
|
16
|
+
) -> RepositorySession:
|
|
17
|
+
if session_id is not None:
|
|
18
|
+
await self._repository.get_session(session_id)
|
|
19
|
+
return RepositorySession(session_id, self._repository)
|
|
20
|
+
|
|
21
|
+
session = await self._repository.create_session(client_id)
|
|
22
|
+
return RepositorySession(session.session_id, self._repository)
|