contextforge-eval 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- context_forge/__init__.py +95 -0
- context_forge/core/__init__.py +55 -0
- context_forge/core/trace.py +369 -0
- context_forge/core/types.py +121 -0
- context_forge/evaluation.py +267 -0
- context_forge/exceptions.py +56 -0
- context_forge/graders/__init__.py +44 -0
- context_forge/graders/base.py +264 -0
- context_forge/graders/deterministic/__init__.py +11 -0
- context_forge/graders/deterministic/memory_corruption.py +130 -0
- context_forge/graders/hybrid.py +190 -0
- context_forge/graders/judges/__init__.py +11 -0
- context_forge/graders/judges/backends/__init__.py +9 -0
- context_forge/graders/judges/backends/ollama.py +173 -0
- context_forge/graders/judges/base.py +158 -0
- context_forge/graders/judges/memory_hygiene_judge.py +332 -0
- context_forge/graders/judges/models.py +113 -0
- context_forge/harness/__init__.py +43 -0
- context_forge/harness/user_simulator/__init__.py +70 -0
- context_forge/harness/user_simulator/adapters/__init__.py +13 -0
- context_forge/harness/user_simulator/adapters/base.py +67 -0
- context_forge/harness/user_simulator/adapters/crewai.py +100 -0
- context_forge/harness/user_simulator/adapters/langgraph.py +157 -0
- context_forge/harness/user_simulator/adapters/pydanticai.py +105 -0
- context_forge/harness/user_simulator/llm/__init__.py +5 -0
- context_forge/harness/user_simulator/llm/ollama.py +119 -0
- context_forge/harness/user_simulator/models.py +103 -0
- context_forge/harness/user_simulator/persona.py +154 -0
- context_forge/harness/user_simulator/runner.py +342 -0
- context_forge/harness/user_simulator/scenario.py +95 -0
- context_forge/harness/user_simulator/simulator.py +307 -0
- context_forge/instrumentation/__init__.py +23 -0
- context_forge/instrumentation/base.py +307 -0
- context_forge/instrumentation/instrumentors/__init__.py +17 -0
- context_forge/instrumentation/instrumentors/langchain.py +671 -0
- context_forge/instrumentation/instrumentors/langgraph.py +534 -0
- context_forge/instrumentation/tracer.py +588 -0
- context_forge/py.typed +0 -0
- contextforge_eval-0.1.0.dist-info/METADATA +420 -0
- contextforge_eval-0.1.0.dist-info/RECORD +43 -0
- contextforge_eval-0.1.0.dist-info/WHEEL +5 -0
- contextforge_eval-0.1.0.dist-info/licenses/LICENSE +201 -0
- contextforge_eval-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
"""CrewAI adapter for user simulation."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from langchain_core.messages import AIMessage, BaseMessage
|
|
7
|
+
|
|
8
|
+
from ..models import SimulationState
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class CrewAIAdapter:
|
|
12
|
+
"""Adapter for CrewAI multi-agent crews.
|
|
13
|
+
|
|
14
|
+
Wraps a CrewAI Crew and provides a conversational interface
|
|
15
|
+
for the simulation harness.
|
|
16
|
+
|
|
17
|
+
Note: CrewAI is task-oriented rather than conversational.
|
|
18
|
+
This adapter treats each user message as a task input.
|
|
19
|
+
|
|
20
|
+
Example usage:
|
|
21
|
+
from crewai import Agent, Crew, Task
|
|
22
|
+
|
|
23
|
+
agent = Agent(role="Assistant", goal="Help users", ...)
|
|
24
|
+
crew = Crew(agents=[agent], tasks=[...])
|
|
25
|
+
|
|
26
|
+
adapter = CrewAIAdapter(
|
|
27
|
+
crew=crew,
|
|
28
|
+
task_template="User request: {message}",
|
|
29
|
+
)
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
crew: Any,
|
|
35
|
+
task_template: str = "{message}",
|
|
36
|
+
agent_name: str = "crewai_crew",
|
|
37
|
+
context_window: int = 5,
|
|
38
|
+
):
|
|
39
|
+
"""Initialize CrewAI adapter.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
crew: CrewAI Crew instance
|
|
43
|
+
task_template: Template for converting messages to tasks
|
|
44
|
+
agent_name: Name for identification
|
|
45
|
+
context_window: Number of recent turns to include as context
|
|
46
|
+
"""
|
|
47
|
+
self._crew = crew
|
|
48
|
+
self._task_template = task_template
|
|
49
|
+
self._agent_name = agent_name
|
|
50
|
+
self._context_window = context_window
|
|
51
|
+
self._context: list[str] = []
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def framework(self) -> str:
|
|
55
|
+
return "crewai"
|
|
56
|
+
|
|
57
|
+
@property
|
|
58
|
+
def agent_name(self) -> str:
|
|
59
|
+
return self._agent_name
|
|
60
|
+
|
|
61
|
+
async def initialize(self, config: dict[str, Any] | None = None) -> None:
|
|
62
|
+
"""Reset context for new simulation."""
|
|
63
|
+
self._context = []
|
|
64
|
+
|
|
65
|
+
async def invoke(
|
|
66
|
+
self,
|
|
67
|
+
message: BaseMessage,
|
|
68
|
+
state: SimulationState,
|
|
69
|
+
) -> BaseMessage:
|
|
70
|
+
"""Invoke CrewAI with user message as task input."""
|
|
71
|
+
# Format message as task input
|
|
72
|
+
task_input = self._task_template.format(message=message.content)
|
|
73
|
+
|
|
74
|
+
# Build context from recent turns
|
|
75
|
+
context = "\n".join(self._context[-self._context_window:])
|
|
76
|
+
|
|
77
|
+
# Run crew
|
|
78
|
+
try:
|
|
79
|
+
result = await asyncio.to_thread(
|
|
80
|
+
self._crew.kickoff,
|
|
81
|
+
inputs={"task": task_input, "context": context, "message": message.content}
|
|
82
|
+
)
|
|
83
|
+
except Exception as e:
|
|
84
|
+
# Handle case where crew doesn't accept these inputs
|
|
85
|
+
result = await asyncio.to_thread(self._crew.kickoff)
|
|
86
|
+
|
|
87
|
+
# Store turn for context
|
|
88
|
+
self._context.append(f"User: {message.content}")
|
|
89
|
+
result_str = str(result) if result else ""
|
|
90
|
+
self._context.append(f"Agent: {result_str}")
|
|
91
|
+
|
|
92
|
+
return AIMessage(content=result_str)
|
|
93
|
+
|
|
94
|
+
async def cleanup(self) -> None:
|
|
95
|
+
"""Clean up CrewAI resources."""
|
|
96
|
+
pass
|
|
97
|
+
|
|
98
|
+
def get_state(self) -> dict[str, Any]:
|
|
99
|
+
"""Return current context state."""
|
|
100
|
+
return {"context_turns": len(self._context) // 2}
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
"""LangGraph adapter for user simulation."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from typing import Any, Callable, Optional
|
|
5
|
+
|
|
6
|
+
from langchain_core.messages import AIMessage, BaseMessage
|
|
7
|
+
|
|
8
|
+
from ..models import SimulationState
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class LangGraphAdapter:
|
|
12
|
+
"""Adapter for LangGraph StateGraph agents.
|
|
13
|
+
|
|
14
|
+
Wraps a compiled LangGraph and translates between the simulation
|
|
15
|
+
harness message format and LangGraph's state-based invocation.
|
|
16
|
+
|
|
17
|
+
Example usage:
|
|
18
|
+
from my_agent import build_my_graph, MyAgentState
|
|
19
|
+
from context_forge.instrumentation import LangChainInstrumentor
|
|
20
|
+
|
|
21
|
+
graph = build_my_graph()
|
|
22
|
+
|
|
23
|
+
# With instrumentation for trace capture
|
|
24
|
+
instrumentor = LangChainInstrumentor()
|
|
25
|
+
instrumentor.instrument()
|
|
26
|
+
|
|
27
|
+
adapter = LangGraphAdapter(
|
|
28
|
+
graph=graph,
|
|
29
|
+
state_class=MyAgentState,
|
|
30
|
+
input_key="message",
|
|
31
|
+
output_key="response",
|
|
32
|
+
callbacks=[instrumentor.get_callback_handler()],
|
|
33
|
+
)
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
graph: Any,
|
|
39
|
+
state_class: Optional[type] = None,
|
|
40
|
+
input_key: str = "message",
|
|
41
|
+
output_key: str = "response",
|
|
42
|
+
messages_key: str = "messages",
|
|
43
|
+
agent_name: str = "langgraph_agent",
|
|
44
|
+
initial_state: dict[str, Any] | None = None,
|
|
45
|
+
config: dict[str, Any] | None = None,
|
|
46
|
+
state_builder: Optional[Callable[[BaseMessage, SimulationState], dict[str, Any]]] = None,
|
|
47
|
+
callbacks: list[Any] | None = None,
|
|
48
|
+
):
|
|
49
|
+
"""Initialize the LangGraph adapter.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
graph: Compiled LangGraph StateGraph
|
|
53
|
+
state_class: TypedDict or Pydantic class for agent state (optional)
|
|
54
|
+
input_key: State key for user input message
|
|
55
|
+
output_key: State key for agent response
|
|
56
|
+
messages_key: State key for conversation history
|
|
57
|
+
agent_name: Name for identification
|
|
58
|
+
initial_state: Initial state values
|
|
59
|
+
config: LangGraph config (thread_id, etc.)
|
|
60
|
+
state_builder: Optional custom function to build input state
|
|
61
|
+
callbacks: List of callback handlers for instrumentation
|
|
62
|
+
"""
|
|
63
|
+
self._graph = graph
|
|
64
|
+
self._state_class = state_class
|
|
65
|
+
self._input_key = input_key
|
|
66
|
+
self._output_key = output_key
|
|
67
|
+
self._messages_key = messages_key
|
|
68
|
+
self._agent_name = agent_name
|
|
69
|
+
self._initial_state = initial_state or {}
|
|
70
|
+
self._config = config or {}
|
|
71
|
+
self._state_builder = state_builder
|
|
72
|
+
self._callbacks = callbacks or []
|
|
73
|
+
self._current_state: dict[str, Any] = {}
|
|
74
|
+
|
|
75
|
+
@property
|
|
76
|
+
def framework(self) -> str:
|
|
77
|
+
return "langgraph"
|
|
78
|
+
|
|
79
|
+
@property
|
|
80
|
+
def agent_name(self) -> str:
|
|
81
|
+
return self._agent_name
|
|
82
|
+
|
|
83
|
+
async def initialize(self, config: dict[str, Any] | None = None) -> None:
|
|
84
|
+
"""Reset state for a new simulation."""
|
|
85
|
+
self._current_state = dict(self._initial_state)
|
|
86
|
+
if config:
|
|
87
|
+
self._config.update(config)
|
|
88
|
+
|
|
89
|
+
async def invoke(
|
|
90
|
+
self,
|
|
91
|
+
message: BaseMessage,
|
|
92
|
+
state: SimulationState,
|
|
93
|
+
) -> BaseMessage:
|
|
94
|
+
"""Invoke the LangGraph agent with a user message."""
|
|
95
|
+
# Build input state
|
|
96
|
+
if self._state_builder:
|
|
97
|
+
input_state = self._state_builder(message, state)
|
|
98
|
+
else:
|
|
99
|
+
input_state = self._build_default_state(message, state)
|
|
100
|
+
|
|
101
|
+
# Invoke graph
|
|
102
|
+
result = await self._invoke_graph(input_state)
|
|
103
|
+
|
|
104
|
+
# Update internal state tracking
|
|
105
|
+
self._current_state = dict(result)
|
|
106
|
+
|
|
107
|
+
# Extract response
|
|
108
|
+
response_text = result.get(self._output_key, "")
|
|
109
|
+
if isinstance(response_text, BaseMessage):
|
|
110
|
+
return response_text
|
|
111
|
+
return AIMessage(content=str(response_text) if response_text else "")
|
|
112
|
+
|
|
113
|
+
def _build_default_state(
|
|
114
|
+
self,
|
|
115
|
+
message: BaseMessage,
|
|
116
|
+
state: SimulationState,
|
|
117
|
+
) -> dict[str, Any]:
|
|
118
|
+
"""Build default input state from message and simulation state."""
|
|
119
|
+
# Get messages from simulation state
|
|
120
|
+
messages = [t.message for t in state.turns]
|
|
121
|
+
|
|
122
|
+
input_state = {
|
|
123
|
+
self._input_key: message.content,
|
|
124
|
+
self._messages_key: messages,
|
|
125
|
+
**self._current_state,
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
# Carry over any fields from initial state that aren't set
|
|
129
|
+
for key, value in self._initial_state.items():
|
|
130
|
+
if key not in input_state:
|
|
131
|
+
input_state[key] = value
|
|
132
|
+
|
|
133
|
+
return input_state
|
|
134
|
+
|
|
135
|
+
async def _invoke_graph(self, input_state: dict) -> dict:
|
|
136
|
+
"""Invoke the graph, handling sync/async."""
|
|
137
|
+
# Build config with callbacks for instrumentation
|
|
138
|
+
invoke_config = dict(self._config)
|
|
139
|
+
if self._callbacks:
|
|
140
|
+
# Merge callbacks with existing config callbacks
|
|
141
|
+
existing_callbacks = invoke_config.get("callbacks", [])
|
|
142
|
+
invoke_config["callbacks"] = list(existing_callbacks) + list(self._callbacks)
|
|
143
|
+
|
|
144
|
+
if hasattr(self._graph, "ainvoke"):
|
|
145
|
+
return await self._graph.ainvoke(input_state, config=invoke_config)
|
|
146
|
+
else:
|
|
147
|
+
return await asyncio.to_thread(
|
|
148
|
+
self._graph.invoke, input_state, config=invoke_config
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
async def cleanup(self) -> None:
|
|
152
|
+
"""No cleanup needed for LangGraph."""
|
|
153
|
+
pass
|
|
154
|
+
|
|
155
|
+
def get_state(self) -> dict[str, Any]:
|
|
156
|
+
"""Return current agent state."""
|
|
157
|
+
return dict(self._current_state)
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
"""PydanticAI adapter for user simulation."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from typing import Any, Callable, Generic, Optional, TypeVar
|
|
5
|
+
|
|
6
|
+
from langchain_core.messages import AIMessage, BaseMessage
|
|
7
|
+
|
|
8
|
+
from ..models import SimulationState
|
|
9
|
+
|
|
10
|
+
T = TypeVar("T")
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class PydanticAIAdapter(Generic[T]):
|
|
14
|
+
"""Adapter for PydanticAI agents.
|
|
15
|
+
|
|
16
|
+
PydanticAI agents use typed dependencies and structured outputs.
|
|
17
|
+
This adapter manages the dependency injection and conversation state.
|
|
18
|
+
|
|
19
|
+
Example usage:
|
|
20
|
+
from pydantic_ai import Agent
|
|
21
|
+
|
|
22
|
+
agent = Agent(
|
|
23
|
+
model="ollama:llama3.1",
|
|
24
|
+
system_prompt="You are a helpful assistant.",
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
adapter = PydanticAIAdapter(
|
|
28
|
+
agent=agent,
|
|
29
|
+
deps_factory=lambda state: MyDeps(user_id=state.agent_state.get("user_id")),
|
|
30
|
+
)
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
agent: Any,
|
|
36
|
+
deps_factory: Optional[Callable[[SimulationState], T]] = None,
|
|
37
|
+
agent_name: str = "pydanticai_agent",
|
|
38
|
+
):
|
|
39
|
+
"""Initialize PydanticAI adapter.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
agent: PydanticAI Agent instance
|
|
43
|
+
deps_factory: Factory function to create dependencies from state
|
|
44
|
+
agent_name: Name for identification
|
|
45
|
+
"""
|
|
46
|
+
self._agent = agent
|
|
47
|
+
self._deps_factory = deps_factory
|
|
48
|
+
self._agent_name = agent_name
|
|
49
|
+
self._message_history: list[Any] = []
|
|
50
|
+
|
|
51
|
+
@property
|
|
52
|
+
def framework(self) -> str:
|
|
53
|
+
return "pydanticai"
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
def agent_name(self) -> str:
|
|
57
|
+
return self._agent_name
|
|
58
|
+
|
|
59
|
+
async def initialize(self, config: dict[str, Any] | None = None) -> None:
|
|
60
|
+
"""Reset for new simulation."""
|
|
61
|
+
self._message_history = []
|
|
62
|
+
|
|
63
|
+
async def invoke(
|
|
64
|
+
self,
|
|
65
|
+
message: BaseMessage,
|
|
66
|
+
state: SimulationState,
|
|
67
|
+
) -> BaseMessage:
|
|
68
|
+
"""Invoke PydanticAI agent."""
|
|
69
|
+
# Create dependencies if factory provided
|
|
70
|
+
deps = None
|
|
71
|
+
if self._deps_factory:
|
|
72
|
+
deps = self._deps_factory(state)
|
|
73
|
+
|
|
74
|
+
# Run agent
|
|
75
|
+
if deps is not None:
|
|
76
|
+
result = await self._agent.run(
|
|
77
|
+
message.content,
|
|
78
|
+
deps=deps,
|
|
79
|
+
message_history=self._message_history,
|
|
80
|
+
)
|
|
81
|
+
else:
|
|
82
|
+
result = await self._agent.run(
|
|
83
|
+
message.content,
|
|
84
|
+
message_history=self._message_history,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
# Update history
|
|
88
|
+
if hasattr(result, "all_messages"):
|
|
89
|
+
self._message_history = result.all_messages()
|
|
90
|
+
|
|
91
|
+
# Extract response
|
|
92
|
+
response_data = result.data if hasattr(result, "data") else str(result)
|
|
93
|
+
if isinstance(response_data, str):
|
|
94
|
+
return AIMessage(content=response_data)
|
|
95
|
+
else:
|
|
96
|
+
# Structured output - serialize to string
|
|
97
|
+
return AIMessage(content=json.dumps(response_data, default=str))
|
|
98
|
+
|
|
99
|
+
async def cleanup(self) -> None:
|
|
100
|
+
"""Clean up PydanticAI resources."""
|
|
101
|
+
pass
|
|
102
|
+
|
|
103
|
+
def get_state(self) -> dict[str, Any]:
|
|
104
|
+
"""Return current message history state."""
|
|
105
|
+
return {"message_history_length": len(self._message_history)}
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
"""Ollama client for user simulation LLM calls."""
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import httpx
|
|
6
|
+
from pydantic import BaseModel, Field
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class OllamaConfig(BaseModel):
|
|
10
|
+
"""Configuration for Ollama LLM client."""
|
|
11
|
+
|
|
12
|
+
base_url: str = "http://localhost:11434"
|
|
13
|
+
model: str = "llama3.2"
|
|
14
|
+
temperature: float = 0.7
|
|
15
|
+
max_tokens: int = 500
|
|
16
|
+
timeout: float = 60.0
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class OllamaClient:
|
|
20
|
+
"""Async client for Ollama API.
|
|
21
|
+
|
|
22
|
+
Used by LLMUserSimulator to generate simulated user responses.
|
|
23
|
+
|
|
24
|
+
Example usage:
|
|
25
|
+
async with OllamaClient() as client:
|
|
26
|
+
response = await client.generate(
|
|
27
|
+
prompt="What should the user say next?",
|
|
28
|
+
system="You are simulating a user named Sarah...",
|
|
29
|
+
)
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(self, config: Optional[OllamaConfig] = None):
|
|
33
|
+
"""Initialize the Ollama client.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
config: Configuration for Ollama connection
|
|
37
|
+
"""
|
|
38
|
+
self._config = config or OllamaConfig()
|
|
39
|
+
self._client: Optional[httpx.AsyncClient] = None
|
|
40
|
+
|
|
41
|
+
async def __aenter__(self) -> "OllamaClient":
|
|
42
|
+
"""Enter async context manager."""
|
|
43
|
+
self._client = httpx.AsyncClient(timeout=self._config.timeout)
|
|
44
|
+
return self
|
|
45
|
+
|
|
46
|
+
async def __aexit__(self, *args) -> None:
|
|
47
|
+
"""Exit async context manager."""
|
|
48
|
+
if self._client:
|
|
49
|
+
await self._client.aclose()
|
|
50
|
+
self._client = None
|
|
51
|
+
|
|
52
|
+
async def generate(
|
|
53
|
+
self,
|
|
54
|
+
prompt: str,
|
|
55
|
+
system: Optional[str] = None,
|
|
56
|
+
) -> str:
|
|
57
|
+
"""Generate a response from Ollama.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
prompt: User prompt to send
|
|
61
|
+
system: Optional system prompt
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
Generated text response
|
|
65
|
+
"""
|
|
66
|
+
if not self._client:
|
|
67
|
+
raise RuntimeError("Client not initialized. Use async context manager.")
|
|
68
|
+
|
|
69
|
+
messages = []
|
|
70
|
+
if system:
|
|
71
|
+
messages.append({"role": "system", "content": system})
|
|
72
|
+
messages.append({"role": "user", "content": prompt})
|
|
73
|
+
|
|
74
|
+
response = await self._client.post(
|
|
75
|
+
f"{self._config.base_url}/api/chat",
|
|
76
|
+
json={
|
|
77
|
+
"model": self._config.model,
|
|
78
|
+
"messages": messages,
|
|
79
|
+
"stream": False,
|
|
80
|
+
"options": {
|
|
81
|
+
"temperature": self._config.temperature,
|
|
82
|
+
"num_predict": self._config.max_tokens,
|
|
83
|
+
},
|
|
84
|
+
},
|
|
85
|
+
)
|
|
86
|
+
response.raise_for_status()
|
|
87
|
+
|
|
88
|
+
data = response.json()
|
|
89
|
+
return data["message"]["content"]
|
|
90
|
+
|
|
91
|
+
async def check_health(self) -> bool:
|
|
92
|
+
"""Check if Ollama is available.
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
True if Ollama is reachable and responding
|
|
96
|
+
"""
|
|
97
|
+
if not self._client:
|
|
98
|
+
raise RuntimeError("Client not initialized. Use async context manager.")
|
|
99
|
+
|
|
100
|
+
try:
|
|
101
|
+
response = await self._client.get(f"{self._config.base_url}/api/tags")
|
|
102
|
+
return response.status_code == 200
|
|
103
|
+
except httpx.RequestError:
|
|
104
|
+
return False
|
|
105
|
+
|
|
106
|
+
async def list_models(self) -> list[str]:
|
|
107
|
+
"""List available models.
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
List of model names available in Ollama
|
|
111
|
+
"""
|
|
112
|
+
if not self._client:
|
|
113
|
+
raise RuntimeError("Client not initialized. Use async context manager.")
|
|
114
|
+
|
|
115
|
+
response = await self._client.get(f"{self._config.base_url}/api/tags")
|
|
116
|
+
response.raise_for_status()
|
|
117
|
+
|
|
118
|
+
data = response.json()
|
|
119
|
+
return [model["name"] for model in data.get("models", [])]
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
"""Pydantic models for simulation state and results."""
|
|
2
|
+
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from enum import Enum
|
|
5
|
+
from typing import Any, Literal, Optional
|
|
6
|
+
|
|
7
|
+
from langchain_core.messages import BaseMessage
|
|
8
|
+
from pydantic import BaseModel, ConfigDict, Field
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ConversationRole(str, Enum):
|
|
12
|
+
"""Role in the conversation."""
|
|
13
|
+
|
|
14
|
+
USER = "user"
|
|
15
|
+
AGENT = "agent"
|
|
16
|
+
SYSTEM = "system"
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class SimulationTurn(BaseModel):
|
|
20
|
+
"""Single turn in the simulation conversation."""
|
|
21
|
+
|
|
22
|
+
turn_number: int
|
|
23
|
+
role: ConversationRole
|
|
24
|
+
message: BaseMessage
|
|
25
|
+
timestamp: datetime = Field(default_factory=datetime.now)
|
|
26
|
+
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
27
|
+
|
|
28
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class SimulationState(BaseModel):
|
|
32
|
+
"""Current state of a simulation run."""
|
|
33
|
+
|
|
34
|
+
simulation_id: str
|
|
35
|
+
scenario_id: str
|
|
36
|
+
persona_id: str
|
|
37
|
+
turns: list[SimulationTurn] = Field(default_factory=list)
|
|
38
|
+
current_turn: int = 0
|
|
39
|
+
max_turns: int = 20
|
|
40
|
+
started_at: datetime = Field(default_factory=datetime.now)
|
|
41
|
+
ended_at: Optional[datetime] = None
|
|
42
|
+
status: Literal["running", "completed", "failed", "terminated"] = "running"
|
|
43
|
+
termination_reason: Optional[str] = None
|
|
44
|
+
agent_state: dict[str, Any] = Field(default_factory=dict)
|
|
45
|
+
|
|
46
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
47
|
+
|
|
48
|
+
def get_messages(self) -> list[BaseMessage]:
|
|
49
|
+
"""Get all messages in conversation order."""
|
|
50
|
+
return [turn.message for turn in self.turns]
|
|
51
|
+
|
|
52
|
+
def get_last_agent_message(self) -> Optional[BaseMessage]:
|
|
53
|
+
"""Get the most recent agent message."""
|
|
54
|
+
for turn in reversed(self.turns):
|
|
55
|
+
if turn.role == ConversationRole.AGENT:
|
|
56
|
+
return turn.message
|
|
57
|
+
return None
|
|
58
|
+
|
|
59
|
+
def get_last_user_message(self) -> Optional[BaseMessage]:
|
|
60
|
+
"""Get the most recent user message."""
|
|
61
|
+
for turn in reversed(self.turns):
|
|
62
|
+
if turn.role == ConversationRole.USER:
|
|
63
|
+
return turn.message
|
|
64
|
+
return None
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class SimulationResult(BaseModel):
|
|
68
|
+
"""Result of a completed simulation."""
|
|
69
|
+
|
|
70
|
+
simulation_id: str
|
|
71
|
+
state: SimulationState
|
|
72
|
+
trace_path: Optional[str] = None
|
|
73
|
+
metrics: dict[str, Any] = Field(default_factory=dict)
|
|
74
|
+
success: bool = False
|
|
75
|
+
error: Optional[str] = None
|
|
76
|
+
|
|
77
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
78
|
+
|
|
79
|
+
def to_dict(self) -> dict[str, Any]:
|
|
80
|
+
"""Convert to dictionary for JSON serialization."""
|
|
81
|
+
return {
|
|
82
|
+
"simulation_id": self.simulation_id,
|
|
83
|
+
"scenario_id": self.state.scenario_id,
|
|
84
|
+
"persona_id": self.state.persona_id,
|
|
85
|
+
"total_turns": len(self.state.turns),
|
|
86
|
+
"status": self.state.status,
|
|
87
|
+
"termination_reason": self.state.termination_reason,
|
|
88
|
+
"started_at": self.state.started_at.isoformat(),
|
|
89
|
+
"ended_at": self.state.ended_at.isoformat() if self.state.ended_at else None,
|
|
90
|
+
"metrics": self.metrics,
|
|
91
|
+
"success": self.success,
|
|
92
|
+
"error": self.error,
|
|
93
|
+
"trace_path": self.trace_path,
|
|
94
|
+
"conversation": [
|
|
95
|
+
{
|
|
96
|
+
"turn": t.turn_number,
|
|
97
|
+
"role": t.role.value,
|
|
98
|
+
"content": t.message.content,
|
|
99
|
+
"timestamp": t.timestamp.isoformat(),
|
|
100
|
+
}
|
|
101
|
+
for t in self.state.turns
|
|
102
|
+
],
|
|
103
|
+
}
|