massgen 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of massgen might be problematic. Click here for more details.
- massgen/__init__.py +94 -0
- massgen/agent_config.py +507 -0
- massgen/backend/CLAUDE_API_RESEARCH.md +266 -0
- massgen/backend/Function calling openai responses.md +1161 -0
- massgen/backend/GEMINI_API_DOCUMENTATION.md +410 -0
- massgen/backend/OPENAI_RESPONSES_API_FORMAT.md +65 -0
- massgen/backend/__init__.py +25 -0
- massgen/backend/base.py +180 -0
- massgen/backend/chat_completions.py +228 -0
- massgen/backend/claude.py +661 -0
- massgen/backend/gemini.py +652 -0
- massgen/backend/grok.py +187 -0
- massgen/backend/response.py +397 -0
- massgen/chat_agent.py +440 -0
- massgen/cli.py +686 -0
- massgen/configs/README.md +293 -0
- massgen/configs/creative_team.yaml +53 -0
- massgen/configs/gemini_4o_claude.yaml +31 -0
- massgen/configs/news_analysis.yaml +51 -0
- massgen/configs/research_team.yaml +51 -0
- massgen/configs/single_agent.yaml +18 -0
- massgen/configs/single_flash2.5.yaml +44 -0
- massgen/configs/technical_analysis.yaml +51 -0
- massgen/configs/three_agents_default.yaml +31 -0
- massgen/configs/travel_planning.yaml +51 -0
- massgen/configs/two_agents.yaml +39 -0
- massgen/frontend/__init__.py +20 -0
- massgen/frontend/coordination_ui.py +945 -0
- massgen/frontend/displays/__init__.py +24 -0
- massgen/frontend/displays/base_display.py +83 -0
- massgen/frontend/displays/rich_terminal_display.py +3497 -0
- massgen/frontend/displays/simple_display.py +93 -0
- massgen/frontend/displays/terminal_display.py +381 -0
- massgen/frontend/logging/__init__.py +9 -0
- massgen/frontend/logging/realtime_logger.py +197 -0
- massgen/message_templates.py +431 -0
- massgen/orchestrator.py +1222 -0
- massgen/tests/__init__.py +10 -0
- massgen/tests/multi_turn_conversation_design.md +214 -0
- massgen/tests/multiturn_llm_input_analysis.md +189 -0
- massgen/tests/test_case_studies.md +113 -0
- massgen/tests/test_claude_backend.py +310 -0
- massgen/tests/test_grok_backend.py +160 -0
- massgen/tests/test_message_context_building.py +293 -0
- massgen/tests/test_rich_terminal_display.py +378 -0
- massgen/tests/test_v3_3agents.py +117 -0
- massgen/tests/test_v3_simple.py +216 -0
- massgen/tests/test_v3_three_agents.py +272 -0
- massgen/tests/test_v3_two_agents.py +176 -0
- massgen/utils.py +79 -0
- massgen/v1/README.md +330 -0
- massgen/v1/__init__.py +91 -0
- massgen/v1/agent.py +605 -0
- massgen/v1/agents.py +330 -0
- massgen/v1/backends/gemini.py +584 -0
- massgen/v1/backends/grok.py +410 -0
- massgen/v1/backends/oai.py +571 -0
- massgen/v1/cli.py +351 -0
- massgen/v1/config.py +169 -0
- massgen/v1/examples/fast-4o-mini-config.yaml +44 -0
- massgen/v1/examples/fast_config.yaml +44 -0
- massgen/v1/examples/production.yaml +70 -0
- massgen/v1/examples/single_agent.yaml +39 -0
- massgen/v1/logging.py +974 -0
- massgen/v1/main.py +368 -0
- massgen/v1/orchestrator.py +1138 -0
- massgen/v1/streaming_display.py +1190 -0
- massgen/v1/tools.py +160 -0
- massgen/v1/types.py +245 -0
- massgen/v1/utils.py +199 -0
- massgen-0.0.3.dist-info/METADATA +568 -0
- massgen-0.0.3.dist-info/RECORD +76 -0
- massgen-0.0.3.dist-info/WHEEL +5 -0
- massgen-0.0.3.dist-info/entry_points.txt +2 -0
- massgen-0.0.3.dist-info/licenses/LICENSE +204 -0
- massgen-0.0.3.dist-info/top_level.txt +1 -0
massgen/v1/tools.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
import json
|
|
3
|
+
import random
|
|
4
|
+
import subprocess
|
|
5
|
+
import sys
|
|
6
|
+
import time
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from datetime import datetime
|
|
9
|
+
from typing import Any, Union, Optional, Dict, List
|
|
10
|
+
import ast
|
|
11
|
+
import operator
|
|
12
|
+
import math
|
|
13
|
+
|
|
14
|
+
# Global tool registry
|
|
15
|
+
register_tool = {}
|
|
16
|
+
|
|
17
|
+
# Mock functions removed - actual functionality is implemented in agent classes
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def python_interpreter(code: str, timeout: Optional[int] = 10) -> Dict[str, Any]:
|
|
21
|
+
"""
|
|
22
|
+
Execute Python code in an isolated subprocess and return its output.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
code: The Python code string to execute
|
|
26
|
+
timeout: Maximum execution time in seconds (default: 10, Must be less than 60 seconds)
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
A dictionary containing:
|
|
30
|
+
- 'stdout': Standard output from the code execution
|
|
31
|
+
- 'stderr': Standard error from the code execution
|
|
32
|
+
- 'returncode': Exit code of the process (0 for success)
|
|
33
|
+
- 'success': Boolean indicating if execution was successful
|
|
34
|
+
- 'error': Error message if execution failed
|
|
35
|
+
"""
|
|
36
|
+
# Ensure timeout is between 0 and 60 seconds
|
|
37
|
+
timeout = max(min(timeout, 60), 0)
|
|
38
|
+
try:
|
|
39
|
+
# Run the code in a separate Python process
|
|
40
|
+
result = subprocess.run(
|
|
41
|
+
[sys.executable, "-c", code],
|
|
42
|
+
capture_output=True,
|
|
43
|
+
text=True,
|
|
44
|
+
timeout=timeout,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
return json.dumps(
|
|
48
|
+
{
|
|
49
|
+
"stdout": result.stdout,
|
|
50
|
+
"stderr": result.stderr,
|
|
51
|
+
"returncode": result.returncode,
|
|
52
|
+
"success": result.returncode == 0,
|
|
53
|
+
"error": None,
|
|
54
|
+
}
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
except subprocess.TimeoutExpired:
|
|
58
|
+
return json.dumps(
|
|
59
|
+
{
|
|
60
|
+
"stdout": "",
|
|
61
|
+
"stderr": "",
|
|
62
|
+
"returncode": -1,
|
|
63
|
+
"success": False,
|
|
64
|
+
"error": f"Code execution timed out after {timeout} seconds",
|
|
65
|
+
}
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
except Exception as e:
|
|
69
|
+
return json.dumps(
|
|
70
|
+
{
|
|
71
|
+
"stdout": "",
|
|
72
|
+
"stderr": "",
|
|
73
|
+
"returncode": -1,
|
|
74
|
+
"success": False,
|
|
75
|
+
"error": f"Failed to execute code: {str(e)}",
|
|
76
|
+
}
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def calculator(expression: str) -> float:
|
|
81
|
+
"""
|
|
82
|
+
Mathematical expression to evaluate (e.g., '2 + 3 * 4', 'sqrt(16)', 'sin(pi/2)')
|
|
83
|
+
"""
|
|
84
|
+
safe_operators = {
|
|
85
|
+
ast.Add: operator.add,
|
|
86
|
+
ast.Sub: operator.sub,
|
|
87
|
+
ast.Mult: operator.mul,
|
|
88
|
+
ast.Div: operator.truediv,
|
|
89
|
+
ast.Pow: operator.pow,
|
|
90
|
+
ast.USub: operator.neg,
|
|
91
|
+
ast.UAdd: operator.pos,
|
|
92
|
+
ast.Mod: operator.mod,
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
# Safe functions
|
|
96
|
+
safe_functions = {
|
|
97
|
+
"abs": abs,
|
|
98
|
+
"round": round,
|
|
99
|
+
"max": max,
|
|
100
|
+
"min": min,
|
|
101
|
+
"sum": sum,
|
|
102
|
+
"sqrt": math.sqrt,
|
|
103
|
+
"sin": math.sin,
|
|
104
|
+
"cos": math.cos,
|
|
105
|
+
"tan": math.tan,
|
|
106
|
+
"log": math.log,
|
|
107
|
+
"log10": math.log10,
|
|
108
|
+
"exp": math.exp,
|
|
109
|
+
"pi": math.pi,
|
|
110
|
+
"e": math.e,
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
def _safe_eval(node):
|
|
114
|
+
"""Safely evaluate an AST node"""
|
|
115
|
+
if isinstance(node, ast.Constant): # Numbers
|
|
116
|
+
return node.value
|
|
117
|
+
elif isinstance(node, ast.Name): # Variables/constants
|
|
118
|
+
if node.id in safe_functions:
|
|
119
|
+
return safe_functions[node.id]
|
|
120
|
+
else:
|
|
121
|
+
raise ValueError(f"Unknown variable: {node.id}")
|
|
122
|
+
elif isinstance(node, ast.BinOp): # Binary operations
|
|
123
|
+
left = _safe_eval(node.left)
|
|
124
|
+
right = _safe_eval(node.right)
|
|
125
|
+
if type(node.op) in safe_operators:
|
|
126
|
+
return safe_operators[type(node.op)](left, right)
|
|
127
|
+
else:
|
|
128
|
+
raise ValueError(f"Unsupported operation: {type(node.op)}")
|
|
129
|
+
elif isinstance(node, ast.UnaryOp): # Unary operations
|
|
130
|
+
operand = _safe_eval(node.operand)
|
|
131
|
+
if type(node.op) in safe_operators:
|
|
132
|
+
return safe_operators[type(node.op)](operand)
|
|
133
|
+
else:
|
|
134
|
+
raise ValueError(f"Unsupported unary operation: {type(node.op)}")
|
|
135
|
+
elif isinstance(node, ast.Call): # Function calls
|
|
136
|
+
func = _safe_eval(node.func)
|
|
137
|
+
args = [_safe_eval(arg) for arg in node.args]
|
|
138
|
+
return func(*args)
|
|
139
|
+
else:
|
|
140
|
+
raise ValueError(f"Unsupported node type: {type(node)}")
|
|
141
|
+
|
|
142
|
+
try:
|
|
143
|
+
# Parse the expression
|
|
144
|
+
tree = ast.parse(expression, mode="eval")
|
|
145
|
+
|
|
146
|
+
# Evaluate safely
|
|
147
|
+
result = _safe_eval(tree.body)
|
|
148
|
+
|
|
149
|
+
return {"expression": expression, "result": result, "success": True}
|
|
150
|
+
|
|
151
|
+
except Exception as e:
|
|
152
|
+
return {"expression": expression, "error": str(e), "success": False}
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
# Register tools in the global registry
|
|
156
|
+
register_tool["python_interpreter"] = python_interpreter
|
|
157
|
+
register_tool["calculator"] = calculator
|
|
158
|
+
|
|
159
|
+
if __name__ == "__main__":
|
|
160
|
+
print(calculator("24423 + 312 * log(10)"))
|
massgen/v1/types.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
1
|
+
"""
|
|
2
|
+
MassGen System Types
|
|
3
|
+
|
|
4
|
+
This module contains all the core type definitions and dataclasses
|
|
5
|
+
used throughout the MassGen framework.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import time
|
|
9
|
+
from dataclasses import dataclass, field, asdict
|
|
10
|
+
from typing import Any, Dict, List, Optional
|
|
11
|
+
from abc import ABC, abstractmethod
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class AnswerRecord:
|
|
16
|
+
"""Represents a single answer record in an agent's update history."""
|
|
17
|
+
|
|
18
|
+
timestamp: float
|
|
19
|
+
answer: str
|
|
20
|
+
status: str
|
|
21
|
+
|
|
22
|
+
def __post_init__(self):
|
|
23
|
+
"""Ensure timestamp is set if not provided."""
|
|
24
|
+
if not self.timestamp:
|
|
25
|
+
self.timestamp = time.time()
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class VoteRecord:
|
|
30
|
+
"""Records a vote cast by an agent."""
|
|
31
|
+
|
|
32
|
+
voter_id: int
|
|
33
|
+
target_id: int
|
|
34
|
+
reason: str = "" # the full response text that led to this vote
|
|
35
|
+
timestamp: float = 0.0
|
|
36
|
+
|
|
37
|
+
def __post_init__(self):
|
|
38
|
+
"""Ensure timestamp is set if not provided."""
|
|
39
|
+
if not self.timestamp:
|
|
40
|
+
import time
|
|
41
|
+
|
|
42
|
+
self.timestamp = time.time()
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dataclass
|
|
46
|
+
class ModelConfig:
|
|
47
|
+
"""Configuration for agent model parameters."""
|
|
48
|
+
|
|
49
|
+
model: Optional[str] = None
|
|
50
|
+
tools: Optional[List[str]] = None
|
|
51
|
+
max_retries: int = 10 # max retries for each LLM call
|
|
52
|
+
max_rounds: int = 10 # max round for task
|
|
53
|
+
max_tokens: Optional[int] = None
|
|
54
|
+
temperature: Optional[float] = None
|
|
55
|
+
top_p: Optional[float] = None
|
|
56
|
+
inference_timeout: Optional[float] = 180 # seconds
|
|
57
|
+
stream: bool = True # whether to stream the response
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@dataclass
|
|
61
|
+
class TaskInput:
|
|
62
|
+
"""Represents a task to be processed by the MassGen system."""
|
|
63
|
+
|
|
64
|
+
question: str
|
|
65
|
+
context: Dict[str, Any] = field(
|
|
66
|
+
default_factory=dict
|
|
67
|
+
) # may support more information in the future, like images
|
|
68
|
+
task_id: Optional[str] = None
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@dataclass
|
|
72
|
+
class SystemState:
|
|
73
|
+
"""Overall state of the MassGen orchestrator.
|
|
74
|
+
Simplified phases:
|
|
75
|
+
- collaboration: agents are working together to solve the task
|
|
76
|
+
- completed: the representative agent has presented the solution and the task is completed
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
task: Optional[TaskInput] = None
|
|
80
|
+
phase: str = "collaboration" # "collaboration", "debate", "completed"
|
|
81
|
+
start_time: Optional[float] = None
|
|
82
|
+
end_time: Optional[float] = None
|
|
83
|
+
consensus_reached: bool = False
|
|
84
|
+
representative_agent_id: Optional[int] = None
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@dataclass
|
|
88
|
+
class AgentState:
|
|
89
|
+
"""Represents the current state of an agent in the MassGen system."""
|
|
90
|
+
|
|
91
|
+
agent_id: int
|
|
92
|
+
status: str = "working" # "working", "voted", "failed"
|
|
93
|
+
curr_answer: str = "" # the latest answer of the agent's work
|
|
94
|
+
updated_answers: List[AnswerRecord] = field(
|
|
95
|
+
default_factory=list
|
|
96
|
+
) # a list of answer records
|
|
97
|
+
curr_vote: Optional[VoteRecord] = (
|
|
98
|
+
None # Which agent's solution this agent voted for
|
|
99
|
+
)
|
|
100
|
+
cast_votes: List[VoteRecord] = field(default_factory=list) # a list of vote records
|
|
101
|
+
seen_updates_timestamps: Dict[int, float] = field(
|
|
102
|
+
default_factory=dict
|
|
103
|
+
) # agent_id -> last_seen_timestamp
|
|
104
|
+
chat_history: List[Dict[str, Any]] = field(
|
|
105
|
+
default_factory=list
|
|
106
|
+
) # a list of conversation records
|
|
107
|
+
chat_round: int = 0 # the number of chat rounds the agent has participated in
|
|
108
|
+
execution_start_time: Optional[float] = None
|
|
109
|
+
execution_end_time: Optional[float] = None
|
|
110
|
+
|
|
111
|
+
@property
|
|
112
|
+
def execution_time(self) -> Optional[float]:
|
|
113
|
+
"""Calculate execution time if both start and end times are available."""
|
|
114
|
+
if self.execution_start_time and self.execution_end_time:
|
|
115
|
+
return self.execution_end_time - self.execution_start_time
|
|
116
|
+
return None
|
|
117
|
+
|
|
118
|
+
def add_update(self, answer: str, timestamp: Optional[float] = None):
|
|
119
|
+
"""Add an update to the agent's history."""
|
|
120
|
+
if timestamp is None:
|
|
121
|
+
timestamp = time.time()
|
|
122
|
+
|
|
123
|
+
record = AnswerRecord(
|
|
124
|
+
timestamp=timestamp,
|
|
125
|
+
answer=answer,
|
|
126
|
+
status=self.status,
|
|
127
|
+
)
|
|
128
|
+
self.updated_answers.append(record)
|
|
129
|
+
self.curr_answer = answer
|
|
130
|
+
|
|
131
|
+
def mark_updates_seen(self, agent_updates: Dict[int, float]):
|
|
132
|
+
"""Mark updates from other agents as seen."""
|
|
133
|
+
for agent_id, timestamp in agent_updates.items():
|
|
134
|
+
if agent_id != self.agent_id: # Don't track own updates
|
|
135
|
+
self.seen_updates_timestamps[agent_id] = timestamp
|
|
136
|
+
|
|
137
|
+
def has_unseen_updates(self, other_agent_updates: Dict[int, float]) -> bool:
|
|
138
|
+
"""Check if there are unseen updates from other agents."""
|
|
139
|
+
for agent_id, timestamp in other_agent_updates.items():
|
|
140
|
+
if agent_id != self.agent_id: # Don't check own updates
|
|
141
|
+
last_seen = self.seen_updates_timestamps.get(agent_id, 0)
|
|
142
|
+
if timestamp > last_seen:
|
|
143
|
+
return True
|
|
144
|
+
return False
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
@dataclass
|
|
148
|
+
class AgentResponse:
|
|
149
|
+
"""Response from an agent's process_message function."""
|
|
150
|
+
|
|
151
|
+
text: str
|
|
152
|
+
code: List[str] = field(default_factory=list)
|
|
153
|
+
citations: List[Dict[str, Any]] = field(default_factory=list)
|
|
154
|
+
function_calls: List[Dict[str, Any]] = field(default_factory=list)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
@dataclass
|
|
158
|
+
class LogEntry:
|
|
159
|
+
"""Represents a single log entry in the MassGen system."""
|
|
160
|
+
|
|
161
|
+
timestamp: float
|
|
162
|
+
event_type: str # e.g., "agent_answer_update", "voting", "phase_change", etc.
|
|
163
|
+
agent_id: Optional[int]
|
|
164
|
+
phase: str
|
|
165
|
+
data: Dict[str, Any]
|
|
166
|
+
session_id: Optional[str] = None
|
|
167
|
+
|
|
168
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
169
|
+
"""Convert to dictionary for JSON serialization."""
|
|
170
|
+
return asdict(self)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
@dataclass
|
|
174
|
+
class StreamingDisplayConfig:
|
|
175
|
+
"""Configuration for streaming display system."""
|
|
176
|
+
|
|
177
|
+
display_enabled: bool = True
|
|
178
|
+
max_lines: int = 10
|
|
179
|
+
save_logs: bool = True
|
|
180
|
+
stream_callback: Optional[Any] = None # Callable, but avoid circular imports
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
@dataclass
|
|
184
|
+
class LoggingConfig:
|
|
185
|
+
"""Configuration for logging system."""
|
|
186
|
+
|
|
187
|
+
log_dir: str = "logs"
|
|
188
|
+
session_id: Optional[str] = None
|
|
189
|
+
non_blocking: bool = False
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
@dataclass
|
|
193
|
+
class OrchestratorConfig:
|
|
194
|
+
"""Configuration for MassGen orchestrator."""
|
|
195
|
+
|
|
196
|
+
max_duration: int = 600
|
|
197
|
+
consensus_threshold: float = 0.0
|
|
198
|
+
max_debate_rounds: int = 1
|
|
199
|
+
status_check_interval: float = 2.0
|
|
200
|
+
thread_pool_timeout: int = 5
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
@dataclass
|
|
204
|
+
class AgentConfig:
|
|
205
|
+
"""Complete configuration for a single agent."""
|
|
206
|
+
|
|
207
|
+
agent_id: int
|
|
208
|
+
agent_type: str # "openai", "gemini", "grok"
|
|
209
|
+
model_config: ModelConfig
|
|
210
|
+
|
|
211
|
+
def __post_init__(self):
|
|
212
|
+
"""Validate agent configuration."""
|
|
213
|
+
if self.agent_type not in ["openai", "gemini", "grok"]:
|
|
214
|
+
raise ValueError(
|
|
215
|
+
f"Invalid agent_type: {self.agent_type}. Must be one of: openai, gemini, grok"
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
@dataclass
|
|
220
|
+
class MassConfig:
|
|
221
|
+
"""Complete MassGen system configuration."""
|
|
222
|
+
|
|
223
|
+
orchestrator: OrchestratorConfig = field(default_factory=OrchestratorConfig)
|
|
224
|
+
agents: List[AgentConfig] = field(default_factory=list)
|
|
225
|
+
streaming_display: StreamingDisplayConfig = field(
|
|
226
|
+
default_factory=StreamingDisplayConfig
|
|
227
|
+
)
|
|
228
|
+
logging: LoggingConfig = field(default_factory=LoggingConfig)
|
|
229
|
+
task: Optional[Dict[str, Any]] = None # Task-specific configuration
|
|
230
|
+
|
|
231
|
+
def validate(self) -> bool:
|
|
232
|
+
"""Validate the complete configuration."""
|
|
233
|
+
if not self.agents:
|
|
234
|
+
raise ValueError("At least one agent must be configured")
|
|
235
|
+
|
|
236
|
+
# Check for duplicate agent IDs
|
|
237
|
+
agent_ids = [agent.agent_id for agent in self.agents]
|
|
238
|
+
if len(agent_ids) != len(set(agent_ids)):
|
|
239
|
+
raise ValueError("Agent IDs must be unique")
|
|
240
|
+
|
|
241
|
+
# Validate consensus threshold
|
|
242
|
+
if not 0.0 <= self.orchestrator.consensus_threshold <= 1.0:
|
|
243
|
+
raise ValueError("Consensus threshold must be between 0.0 and 1.0")
|
|
244
|
+
|
|
245
|
+
return True
|
massgen/v1/utils.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
import json
|
|
3
|
+
import random
|
|
4
|
+
import subprocess
|
|
5
|
+
import sys
|
|
6
|
+
import time
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from datetime import datetime
|
|
9
|
+
from typing import Any, Union, Optional, Dict, List
|
|
10
|
+
import ast
|
|
11
|
+
import operator
|
|
12
|
+
import math
|
|
13
|
+
|
|
14
|
+
# Model mappings and constants
|
|
15
|
+
MODEL_MAPPINGS = {
|
|
16
|
+
"openai": [
|
|
17
|
+
# GPT-4.1 variants
|
|
18
|
+
"gpt-4.1",
|
|
19
|
+
"gpt-4.1-mini",
|
|
20
|
+
# GPT-4o variants
|
|
21
|
+
"gpt-4o-mini",
|
|
22
|
+
"gpt-4o",
|
|
23
|
+
# o1
|
|
24
|
+
"o1", # -> o1-2024-12-17
|
|
25
|
+
# o3
|
|
26
|
+
"o3",
|
|
27
|
+
"o3-low",
|
|
28
|
+
"o3-medium",
|
|
29
|
+
"o3-high",
|
|
30
|
+
# o3 mini
|
|
31
|
+
"o3-mini",
|
|
32
|
+
"o3-mini-low",
|
|
33
|
+
"o3-mini-medium",
|
|
34
|
+
"o3-mini-high",
|
|
35
|
+
# o4 mini
|
|
36
|
+
"o4-mini",
|
|
37
|
+
"o4-mini-low",
|
|
38
|
+
"o4-mini-medium",
|
|
39
|
+
"o4-mini-high",
|
|
40
|
+
],
|
|
41
|
+
"gemini": [
|
|
42
|
+
"gemini-2.5-flash",
|
|
43
|
+
"gemini-2.5-pro",
|
|
44
|
+
],
|
|
45
|
+
"grok": [
|
|
46
|
+
"grok-3-mini",
|
|
47
|
+
"grok-3",
|
|
48
|
+
"grok-4",
|
|
49
|
+
],
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def get_agent_type_from_model(model: str) -> str:
|
|
54
|
+
"""
|
|
55
|
+
Determine the agent type based on the model name.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
model: The model name (e.g., "gpt-4", "gemini-pro", "grok-1")
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
Agent type string ("openai", "gemini", "grok")
|
|
62
|
+
"""
|
|
63
|
+
if not model:
|
|
64
|
+
return "openai" # Default to OpenAI
|
|
65
|
+
|
|
66
|
+
model_lower = model.lower()
|
|
67
|
+
|
|
68
|
+
for key, models in MODEL_MAPPINGS.items():
|
|
69
|
+
if model_lower in models:
|
|
70
|
+
return key
|
|
71
|
+
raise ValueError(f"Unknown model: {model}")
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def get_available_models() -> list:
|
|
75
|
+
"""Get a flat list of all available model names."""
|
|
76
|
+
all_models = []
|
|
77
|
+
for models in MODEL_MAPPINGS.values():
|
|
78
|
+
all_models.extend(models)
|
|
79
|
+
return all_models
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def generate_random_id(length: int = 24) -> str:
|
|
83
|
+
"""Generate a random ID string."""
|
|
84
|
+
characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
|
|
85
|
+
return "".join(random.choice(characters) for _ in range(length))
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
# Utility functions (originally from util.py)
|
|
89
|
+
def execute_function_calls(function_calls, tool_mapping):
|
|
90
|
+
"""Execute function calls and return formatted outputs for the conversation."""
|
|
91
|
+
function_outputs = []
|
|
92
|
+
for function_call in function_calls:
|
|
93
|
+
try:
|
|
94
|
+
# Get the function from tool mapping
|
|
95
|
+
target_function = None
|
|
96
|
+
function_name = function_call.get("name")
|
|
97
|
+
|
|
98
|
+
# Look up function in tool_mapping
|
|
99
|
+
if function_name in tool_mapping:
|
|
100
|
+
target_function = tool_mapping[function_name]
|
|
101
|
+
else:
|
|
102
|
+
# Handle error case
|
|
103
|
+
error_output = {
|
|
104
|
+
"type": "function_call_output",
|
|
105
|
+
"call_id": function_call.get("call_id"),
|
|
106
|
+
"output": f"Error: Function '{function_name}' not found in tool mapping",
|
|
107
|
+
}
|
|
108
|
+
function_outputs.append(error_output)
|
|
109
|
+
continue
|
|
110
|
+
|
|
111
|
+
# Parse arguments and execute function
|
|
112
|
+
if isinstance(function_call.get("arguments", {}), str):
|
|
113
|
+
arguments = json.loads(function_call.get("arguments", "{}"))
|
|
114
|
+
elif isinstance(function_call.get("arguments", {}), dict):
|
|
115
|
+
arguments = function_call.get("arguments", {})
|
|
116
|
+
else:
|
|
117
|
+
raise ValueError(
|
|
118
|
+
f"Unknown arguments type: {type(function_call.get('arguments', {}))}"
|
|
119
|
+
)
|
|
120
|
+
result = target_function(**arguments)
|
|
121
|
+
|
|
122
|
+
# Format the output according to Responses API requirements
|
|
123
|
+
function_output = {
|
|
124
|
+
"type": "function_call_output",
|
|
125
|
+
"call_id": function_call.get("call_id"),
|
|
126
|
+
"output": str(result),
|
|
127
|
+
}
|
|
128
|
+
function_outputs.append(function_output)
|
|
129
|
+
|
|
130
|
+
# print(f"Executed function: {function_name}({arguments}) -> {result}")
|
|
131
|
+
|
|
132
|
+
except Exception as e:
|
|
133
|
+
# Handle execution errors
|
|
134
|
+
error_output = {
|
|
135
|
+
"type": "function_call_output",
|
|
136
|
+
"call_id": function_call.get("call_id"),
|
|
137
|
+
"output": f"Error executing function: {str(e)}",
|
|
138
|
+
}
|
|
139
|
+
function_outputs.append(error_output)
|
|
140
|
+
# print(f"Error executing function {function_name}: {e}")
|
|
141
|
+
|
|
142
|
+
return function_outputs
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def function_to_json(func) -> dict:
|
|
146
|
+
"""
|
|
147
|
+
Converts a Python function into a JSON-serializable dictionary
|
|
148
|
+
that describes the function's signature, including its name,
|
|
149
|
+
description, and parameters.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
func: The function to be converted.
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
A dictionary representing the function's signature in JSON format.
|
|
156
|
+
"""
|
|
157
|
+
type_map = {
|
|
158
|
+
str: "string",
|
|
159
|
+
int: "integer",
|
|
160
|
+
float: "number",
|
|
161
|
+
bool: "boolean",
|
|
162
|
+
list: "array",
|
|
163
|
+
dict: "object",
|
|
164
|
+
type(None): "null",
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
try:
|
|
168
|
+
signature = inspect.signature(func)
|
|
169
|
+
except ValueError as e:
|
|
170
|
+
raise ValueError(
|
|
171
|
+
f"Failed to get signature for function {func.__name__}: {str(e)}"
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
parameters = {}
|
|
175
|
+
for param in signature.parameters.values():
|
|
176
|
+
try:
|
|
177
|
+
param_type = type_map.get(param.annotation, "string")
|
|
178
|
+
except KeyError as e:
|
|
179
|
+
raise KeyError(
|
|
180
|
+
f"Unknown type annotation {param.annotation} for parameter {param.name}: {str(e)}"
|
|
181
|
+
)
|
|
182
|
+
parameters[param.name] = {"type": param_type}
|
|
183
|
+
|
|
184
|
+
required = [
|
|
185
|
+
param.name
|
|
186
|
+
for param in signature.parameters.values()
|
|
187
|
+
if param.default == inspect._empty
|
|
188
|
+
]
|
|
189
|
+
|
|
190
|
+
return {
|
|
191
|
+
"type": "function",
|
|
192
|
+
"name": func.__name__,
|
|
193
|
+
"description": func.__doc__ or "",
|
|
194
|
+
"parameters": {
|
|
195
|
+
"type": "object",
|
|
196
|
+
"properties": parameters,
|
|
197
|
+
"required": required,
|
|
198
|
+
},
|
|
199
|
+
}
|