econagents 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
econagents/__init__.py ADDED
@@ -0,0 +1,31 @@
1
+ """
2
+ econagents: A Python library for setting up and running economic experiments with LLMs or human subjects.
3
+ """
4
+
5
+ from econagents.core.agent_role import AgentRole
6
+ from econagents.core.game_runner import GameRunner, HybridGameRunnerConfig, TurnBasedGameRunnerConfig
7
+ from econagents.core.manager import AgentManager
8
+ from econagents.core.manager.phase import PhaseManager, HybridPhaseManager, TurnBasedPhaseManager
9
+ from econagents.core.state.fields import EventField
10
+ from econagents.core.state.game import GameState, MetaInformation, PrivateInformation, PublicInformation
11
+ from econagents.llm.openai import ChatOpenAI
12
+
13
+ # Don't manually change, let poetry-dynamic-versioning handle it.
14
+ __version__ = "0.0.1"
15
+
16
+ __all__: list[str] = [
17
+ "AgentRole",
18
+ "AgentManager",
19
+ "ChatOpenAI",
20
+ "PhaseManager",
21
+ "TurnBasedPhaseManager",
22
+ "HybridPhaseManager",
23
+ "GameState",
24
+ "MetaInformation",
25
+ "PrivateInformation",
26
+ "PublicInformation",
27
+ "GameRunner",
28
+ "TurnBasedGameRunnerConfig",
29
+ "HybridGameRunnerConfig",
30
+ "EventField",
31
+ ]
@@ -0,0 +1,5 @@
1
+ class Foo:
2
+ def __init__(self): ...
3
+ def __call__(self): ...
4
+
5
+ def divide(x: float, y: float) -> float: ...
@@ -0,0 +1,7 @@
1
+ from econagents.core.agent_role import AgentRole
2
+ from econagents.core.events import Message
3
+ from econagents.core.logging_mixin import LoggerMixin
4
+ from econagents.core.manager import AgentManager
5
+ from econagents.core.manager.phase import TurnBasedPhaseManager, HybridPhaseManager
6
+
7
+ __all__ = ["AgentRole", "AgentManager", "TurnBasedPhaseManager", "HybridPhaseManager", "Message", "LoggerMixin"]
@@ -0,0 +1,360 @@
1
+ import json
2
+ import logging
3
+ import re
4
+ from abc import ABC
5
+ from pathlib import Path
6
+ from typing import Any, Callable, ClassVar, Dict, Generic, Literal, Optional, Pattern, Protocol, TypeVar
7
+
8
+ from jinja2.sandbox import SandboxedEnvironment
9
+
10
+ from econagents.core.logging_mixin import LoggerMixin
11
+ from econagents.core.state.game import GameStateProtocol
12
+ from econagents.llm.openai import ChatOpenAI
13
+
14
+ StateT_contra = TypeVar("StateT_contra", bound=GameStateProtocol, contravariant=True)
15
+
16
+
17
+ class AgentProtocol(Protocol):
18
+ role: ClassVar[int]
19
+ name: ClassVar[str]
20
+ llm: ChatOpenAI
21
+ task_phases: ClassVar[list[int]]
22
+
23
+
24
+ SystemPromptHandler = Callable[[StateT_contra], str]
25
+ UserPromptHandler = Callable[[StateT_contra], str]
26
+ ResponseParser = Callable[[str, StateT_contra], dict]
27
+ PhaseHandler = Callable[[int, StateT_contra], Any]
28
+
29
+
30
+ class AgentRole(ABC, Generic[StateT_contra], LoggerMixin):
31
+ """Base agent role class with common attributes and phase handling.
32
+
33
+ This class provides a flexible framework for handling different phases in a game or task workflow.
34
+ It uses template-based prompts and allows customization of behavior for specific phases.
35
+
36
+ Args:
37
+ logger (Optional[logging.Logger]): External logger to use, defaults to None
38
+ """
39
+
40
+ role: ClassVar[int]
41
+ """Unique identifier for this role"""
42
+ name: ClassVar[str]
43
+ """Human-readable name for this role"""
44
+ llm: ChatOpenAI
45
+ """Language model instance for generating responses"""
46
+ task_phases: ClassVar[list[int]] = [] # Empty list means no specific phases are required
47
+ """List of phases this agent should participate in (empty means all phases)"""
48
+ task_phases_excluded: ClassVar[list[int]] = [] # Empty list means no phases are excluded
49
+
50
+ # Regex patterns for method name extraction
51
+ _SYSTEM_PROMPT_PATTERN: ClassVar[Pattern] = re.compile(r"get_phase_(\d+)_system_prompt")
52
+ _USER_PROMPT_PATTERN: ClassVar[Pattern] = re.compile(r"get_phase_(\d+)_user_prompt")
53
+ _RESPONSE_PARSER_PATTERN: ClassVar[Pattern] = re.compile(r"parse_phase_(\d+)_llm_response")
54
+ _PHASE_HANDLER_PATTERN: ClassVar[Pattern] = re.compile(r"handle_phase_(\d+)$")
55
+
56
+ def __init__(self, logger: Optional[logging.Logger] = None):
57
+ if logger:
58
+ self.logger = logger
59
+
60
+ # Validate that only one of task_phases or task_phases_excluded is specified
61
+ if self.task_phases and self.task_phases_excluded:
62
+ raise ValueError(
63
+ f"Only one of task_phases or task_phases_excluded should be specified, not both. "
64
+ f"Got task_phases={self.task_phases} and task_phases_excluded={self.task_phases_excluded}"
65
+ )
66
+
67
+ # Handler registries
68
+ self._system_prompt_handlers: Dict[int, SystemPromptHandler] = {}
69
+ self._user_prompt_handlers: Dict[int, UserPromptHandler] = {}
70
+ self._response_parsers: Dict[int, ResponseParser] = {}
71
+ self._phase_handlers: Dict[int, PhaseHandler] = {}
72
+
73
+ # Auto-register phase-specific methods if they exist
74
+ self._register_phase_specific_methods()
75
+
76
+ def _resolve_prompt_file(
77
+ self, prompt_type: Literal["system", "user"], phase: int, role: str, prompts_path: Path
78
+ ) -> Optional[Path]:
79
+ """Resolve the prompt file path for the given parameters.
80
+
81
+ Args:
82
+ prompt_type (Literal["system", "user"]): Type of prompt (system, user)
83
+ phase (int): Game phase number
84
+ role (str): Agent role name
85
+ prompts_path (Path): Path to prompt templates directory
86
+
87
+ Returns:
88
+ Path to the prompt file if found, None otherwise
89
+
90
+ Raises:
91
+ FileNotFoundError: If no matching prompt template is found
92
+ """
93
+ # Try phase-specific prompt first
94
+ phase_file = prompts_path / f"{role.lower()}_{prompt_type}_phase_{phase}.jinja2"
95
+ if phase_file.exists():
96
+ return phase_file
97
+
98
+ # Fall back to general prompt
99
+ general_file = prompts_path / f"{role.lower()}_{prompt_type}.jinja2"
100
+ if general_file.exists():
101
+ return general_file
102
+
103
+ return None
104
+
105
+ def render_prompt(
106
+ self, context: dict, prompt_type: Literal["system", "user"], phase: int, prompts_path: Path
107
+ ) -> str:
108
+ """Render a prompt template with the given context.
109
+
110
+ Template resolution order:
111
+
112
+ 1. Agent-specific phase prompt (e.g., "agent_name_system_phase_1.jinja2")
113
+
114
+ 2. Agent-specific general prompt (e.g., "agent_name_system.jinja2")
115
+
116
+ 3. All-role phase prompt (e.g., "all_system_phase_1.jinja2")
117
+
118
+ 4. All-role general prompt (e.g., "all_system.jinja2")
119
+
120
+ Args:
121
+ context (dict): Template context variables
122
+ prompt_type (Literal["system", "user"]): Type of prompt (system, user)
123
+ phase (int): Game phase number
124
+ prompts_path (Path): Path to prompt templates directory
125
+
126
+ Returns:
127
+ str: Rendered prompt
128
+
129
+ Raises:
130
+ FileNotFoundError: If no matching prompt template is found
131
+ """
132
+ # Try role-specific prompt first, then fall back to 'all'
133
+ for role in [self.name, "all"]:
134
+ if prompt_file := self._resolve_prompt_file(prompt_type, phase, role, prompts_path):
135
+ with prompt_file.open() as f:
136
+ template = SandboxedEnvironment().from_string(f.read())
137
+ return template.render(**context)
138
+
139
+ raise FileNotFoundError(
140
+ f"No prompt template found for type={prompt_type}, phase={phase}, "
141
+ f"roles=[{self.name}, all] in {prompts_path}"
142
+ )
143
+
144
+ def _extract_phase_from_pattern(self, attr_name: str, pattern: Pattern) -> Optional[int]:
145
+ """Extract phase number from a method name using regex pattern.
146
+
147
+ Args:
148
+ attr_name (str): Method name
149
+ pattern (Pattern): Regex pattern with a capturing group for the phase number
150
+
151
+ Returns:
152
+ Optional[int]: Phase number if found and valid, None otherwise
153
+ """
154
+ if match := pattern.match(attr_name):
155
+ try:
156
+ return int(match.group(1))
157
+ except (ValueError, IndexError):
158
+ self.logger.warning(f"Failed to extract phase number from {attr_name}")
159
+ return None
160
+
161
+ def _register_phase_specific_methods(self) -> None:
162
+ """Automatically register phase-specific methods if they exist in the subclass.
163
+
164
+ This method scans the class for methods matching the naming patterns for
165
+ phase-specific handlers and registers them automatically.
166
+ """
167
+ for attr_name in dir(self):
168
+ # Skip special methods and non-callable attributes
169
+ if attr_name.startswith("__") or not callable(getattr(self, attr_name, None)):
170
+ continue
171
+
172
+ # Check for phase-specific system prompt handlers
173
+ if phase := self._extract_phase_from_pattern(attr_name, self._SYSTEM_PROMPT_PATTERN):
174
+ self.register_system_prompt_handler(phase, getattr(self, attr_name))
175
+
176
+ # Check for phase-specific user prompt handlers
177
+ elif phase := self._extract_phase_from_pattern(attr_name, self._USER_PROMPT_PATTERN):
178
+ self.register_user_prompt_handler(phase, getattr(self, attr_name))
179
+
180
+ # Check for phase-specific response parsers
181
+ elif phase := self._extract_phase_from_pattern(attr_name, self._RESPONSE_PARSER_PATTERN):
182
+ self.register_response_parser(phase, getattr(self, attr_name))
183
+
184
+ # Check for phase-specific handlers
185
+ elif phase := self._extract_phase_from_pattern(attr_name, self._PHASE_HANDLER_PATTERN):
186
+ self.register_phase_handler(phase, getattr(self, attr_name))
187
+
188
+ def register_system_prompt_handler(self, phase: int, handler: SystemPromptHandler) -> None:
189
+ """Register a custom system prompt handler for a specific phase.
190
+
191
+ Args:
192
+ phase (int): Game phase number
193
+ handler (SystemPromptHandler): Function that generates system prompts for this phase
194
+ """
195
+ self._system_prompt_handlers[phase] = handler
196
+ self.logger.debug(f"Registered system prompt handler for phase {phase}")
197
+
198
+ def register_user_prompt_handler(self, phase: int, handler: UserPromptHandler) -> None:
199
+ """Register a custom user prompt handler for a specific phase.
200
+
201
+ Args:
202
+ phase (int): Game phase number
203
+ handler (UserPromptHandler): Function that generates user prompts for this phase
204
+ """
205
+ self._user_prompt_handlers[phase] = handler
206
+ self.logger.debug(f"Registered user prompt handler for phase {phase}")
207
+
208
+ def register_response_parser(self, phase: int, parser: ResponseParser) -> None:
209
+ """Register a custom response parser for a specific phase.
210
+
211
+ Args:
212
+ phase (int): Game phase number
213
+ parser (ResponseParser): Function that parses LLM responses for this phase
214
+ """
215
+ self._response_parsers[phase] = parser
216
+ self.logger.debug(f"Registered response parser for phase {phase}")
217
+
218
+ def register_phase_handler(self, phase: int, handler: PhaseHandler) -> None:
219
+ """Register a custom phase handler for a specific phase.
220
+
221
+ Args:
222
+ phase (int): Game phase number
223
+ handler (PhaseHandler): Function that handles this phase
224
+ """
225
+ self._phase_handlers[phase] = handler
226
+ self.logger.debug(f"Registered phase handler for phase {phase}")
227
+
228
+ def get_phase_system_prompt(self, state: StateT_contra, prompts_path: Path) -> str:
229
+ """Get the system prompt for the current phase.
230
+
231
+ This method will use a phase-specific handler if registered,
232
+ otherwise it falls back to the default implementation using templates.
233
+
234
+ Args:
235
+ state (StateT_contra): Current game state
236
+ prompts_path (Path): Path to prompt templates directory
237
+
238
+ Returns:
239
+ str: System prompt string
240
+ """
241
+ phase = state.meta.phase
242
+ if phase in self._system_prompt_handlers:
243
+ return self._system_prompt_handlers[phase](state)
244
+ return self.render_prompt(
245
+ context=state.model_dump(), prompt_type="system", phase=phase, prompts_path=prompts_path
246
+ )
247
+
248
+ def get_phase_user_prompt(self, state: StateT_contra, prompts_path: Path) -> str:
249
+ """Get the user prompt for the current phase.
250
+
251
+ This method will use a phase-specific handler if registered,
252
+ otherwise it falls back to the default implementation using templates.
253
+
254
+ Args:
255
+ state (StateT_contra): Current game state
256
+ prompts_path (Path): Path to prompt templates directory
257
+
258
+ Returns:
259
+ str: User prompt string
260
+ """
261
+ phase = state.meta.phase
262
+ if phase in self._user_prompt_handlers:
263
+ return self._user_prompt_handlers[phase](state)
264
+ return self.render_prompt(
265
+ context=state.model_dump(), prompt_type="user", phase=phase, prompts_path=prompts_path
266
+ )
267
+
268
+ def parse_phase_llm_response(self, response: str, state: StateT_contra) -> dict:
269
+ """Parse the LLM response for the current phase.
270
+
271
+ This method will use a phase-specific parser if registered,
272
+ otherwise it falls back to the default implementation which attempts
273
+ to parse the response as JSON.
274
+
275
+ Args:
276
+ response (str): Raw LLM response string
277
+ state (StateT_contra): Current game state
278
+
279
+ Returns:
280
+ dict: Parsed response as a dictionary
281
+ """
282
+ phase = state.meta.phase
283
+ if phase in self._response_parsers:
284
+ return self._response_parsers[phase](response, state)
285
+
286
+ try:
287
+ return json.loads(response)
288
+ except json.JSONDecodeError as e:
289
+ self.logger.error(f"Failed to parse LLM response as JSON: {e}")
290
+ self.logger.debug(f"Raw response: {response}")
291
+ return {"error": "Failed to parse response", "raw_response": response}
292
+
293
+ async def handle_phase(self, phase: int, state: StateT_contra, prompts_path: Path) -> Optional[dict]:
294
+ """Handle the current phase of the task or game.
295
+
296
+ This method will use a phase-specific handler if registered,
297
+ otherwise it falls back to the default implementation using the LLM.
298
+
299
+ By default, the agent acts in all phases unless:
300
+ 1. task_phases is non-empty and the phase is not in task_phases, or
301
+ 2. phase is explicitly listed in task_phases_excluded
302
+
303
+ Args:
304
+ phase (int): Game phase number
305
+ state (StateT_contra): Current game state
306
+ prompts_path (Path): Path to prompt templates directory
307
+
308
+ Returns:
309
+ Optional[dict]: Phase result dictionary or None if phase is not handled
310
+ """
311
+ # Skip the phase if it's in the excluded list
312
+ if phase in self.task_phases_excluded:
313
+ self.logger.debug(f"Phase {phase} is in excluded phases {self.task_phases_excluded}, skipping")
314
+ return None
315
+
316
+ # Skip the phase if task_phases is non-empty and phase is not in it
317
+ if self.task_phases and phase not in self.task_phases:
318
+ self.logger.debug(f"Phase {phase} not in task phases {self.task_phases}, skipping")
319
+ return None
320
+
321
+ if phase in self._phase_handlers:
322
+ self.logger.debug(f"Using custom handler for phase {phase}")
323
+ return await self._phase_handlers[phase](phase, state)
324
+
325
+ self.logger.debug(f"Using default LLM handler for phase {phase}")
326
+ return await self.handle_phase_with_llm(phase, state, prompts_path=prompts_path)
327
+
328
+ async def handle_phase_with_llm(self, phase: int, state: StateT_contra, prompts_path: Path) -> Optional[dict]:
329
+ """Handle the phase using the LLM.
330
+
331
+ This is the default implementation that uses the LLM to handle the phase
332
+ by generating prompts, sending them to the LLM, and parsing the response.
333
+
334
+ Args:
335
+ phase (int): Game phase number
336
+ state (StateT_contra): Current game state
337
+ prompts_path (Path): Path to prompt templates directory
338
+
339
+ Returns:
340
+ Optional[dict]: Phase result dictionary or None if phase is not handled
341
+ """
342
+ system_prompt = self.get_phase_system_prompt(state, prompts_path=prompts_path)
343
+ self.logger.debug("\n+-----SYSTEM PROMPT----+\n" + f"{system_prompt}\n+------------------+")
344
+
345
+ user_prompt = self.get_phase_user_prompt(state, prompts_path=prompts_path)
346
+ self.logger.debug("\n+-----USER PROMPT----+\n" + f"{user_prompt}\n+------------------+")
347
+
348
+ messages = self.llm.build_messages(system_prompt, user_prompt)
349
+
350
+ try:
351
+ response = await self.llm.get_response(
352
+ messages=messages,
353
+ tracing_extra={
354
+ "state": state.model_dump(),
355
+ },
356
+ )
357
+ return self.parse_phase_llm_response(response, state)
358
+ except Exception as e:
359
+ self.logger.error(f"Error getting LLM response: {e}")
360
+ return {"error": str(e), "phase": phase}
@@ -0,0 +1,14 @@
1
+ from typing import Any
2
+
3
+ from pydantic import BaseModel
4
+
5
+
6
+ class Message(BaseModel):
7
+ """A message from the server to the agent."""
8
+
9
+ message_type: str
10
+ """Type of message"""
11
+ event_type: str
12
+ """Type of event"""
13
+ data: dict[str, Any]
14
+ """Data associated with the message"""