iridet-bot 0.1.1a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- iribot/.env.example +4 -0
- iribot/__init__.py +5 -0
- iribot/__main__.py +7 -0
- iribot/ag_ui_protocol.py +247 -0
- iribot/agent.py +155 -0
- iribot/cli.py +33 -0
- iribot/config.py +45 -0
- iribot/executor.py +73 -0
- iribot/main.py +300 -0
- iribot/models.py +79 -0
- iribot/prompt_generator.py +104 -0
- iribot/session_manager.py +194 -0
- iribot/templates/system_prompt.j2 +185 -0
- iribot/tools/__init__.py +27 -0
- iribot/tools/base.py +80 -0
- iribot/tools/execute_command.py +572 -0
- iribot/tools/list_directory.py +49 -0
- iribot/tools/read_file.py +43 -0
- iribot/tools/write_file.py +49 -0
- iridet_bot-0.1.1a1.dist-info/METADATA +369 -0
- iridet_bot-0.1.1a1.dist-info/RECORD +24 -0
- iridet_bot-0.1.1a1.dist-info/WHEEL +5 -0
- iridet_bot-0.1.1a1.dist-info/entry_points.txt +2 -0
- iridet_bot-0.1.1a1.dist-info/top_level.txt +1 -0
iribot/main.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
1
|
+
"""Main FastAPI application"""
|
|
2
|
+
import json
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from fastapi import FastAPI, HTTPException
|
|
5
|
+
from fastapi.staticfiles import StaticFiles
|
|
6
|
+
from fastapi.middleware.cors import CORSMiddleware
|
|
7
|
+
from fastapi.responses import StreamingResponse
|
|
8
|
+
|
|
9
|
+
from .config import settings
|
|
10
|
+
from .models import (
|
|
11
|
+
SessionCreate,
|
|
12
|
+
SystemPromptUpdate,
|
|
13
|
+
ChatRequest,
|
|
14
|
+
MessageRecord,
|
|
15
|
+
ToolCallRecord,
|
|
16
|
+
SystemPromptGenerateRequest,
|
|
17
|
+
SystemPromptGenerateResponse,
|
|
18
|
+
)
|
|
19
|
+
from .session_manager import session_manager
|
|
20
|
+
from .agent import agent
|
|
21
|
+
from .executor import tool_executor
|
|
22
|
+
from .prompt_generator import generate_system_prompt
|
|
23
|
+
|
|
24
|
+
# Initialize FastAPI app
|
|
25
|
+
app = FastAPI(
|
|
26
|
+
title=settings.app_title,
|
|
27
|
+
debug=settings.debug
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
# Add CORS middleware
|
|
31
|
+
app.add_middleware(
|
|
32
|
+
CORSMiddleware,
|
|
33
|
+
allow_origins=settings.cors_origins,
|
|
34
|
+
allow_credentials=True,
|
|
35
|
+
allow_methods=["*"],
|
|
36
|
+
allow_headers=["*"],
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
# Prompt Generation Endpoints
|
|
41
|
+
@app.post("/api/prompt/generate")
|
|
42
|
+
def generate_prompt(request: SystemPromptGenerateRequest):
|
|
43
|
+
"""
|
|
44
|
+
Generate a system prompt with current date/time, safety policies, and tool descriptions
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
request: GeneratePromptRequest with optional custom_instructions
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
Generated system prompt with metadata
|
|
51
|
+
"""
|
|
52
|
+
from .prompt_generator import get_current_datetime_info
|
|
53
|
+
|
|
54
|
+
prompt = generate_system_prompt(
|
|
55
|
+
custom_instructions=request.custom_instructions
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
return SystemPromptGenerateResponse(
|
|
59
|
+
system_prompt=prompt,
|
|
60
|
+
datetime_info=get_current_datetime_info()
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@app.get("/api/prompt/generate")
|
|
65
|
+
def generate_prompt_get(custom_instructions: str = ""):
|
|
66
|
+
"""
|
|
67
|
+
Generate a system prompt (GET endpoint for convenience)
|
|
68
|
+
|
|
69
|
+
Query Parameters:
|
|
70
|
+
custom_instructions: Optional custom instructions
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
Generated system prompt with metadata
|
|
74
|
+
"""
|
|
75
|
+
from .prompt_generator import get_current_datetime_info
|
|
76
|
+
|
|
77
|
+
prompt = generate_system_prompt(
|
|
78
|
+
custom_instructions=custom_instructions
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
return SystemPromptGenerateResponse(
|
|
82
|
+
system_prompt=prompt,
|
|
83
|
+
datetime_info=get_current_datetime_info()
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@app.post("/api/sessions")
|
|
88
|
+
def create_session(request: SessionCreate):
|
|
89
|
+
"""Create a new chat session"""
|
|
90
|
+
# If system_prompt not provided, generate one automatically
|
|
91
|
+
system_prompt = request.system_prompt
|
|
92
|
+
if system_prompt is None:
|
|
93
|
+
system_prompt = generate_system_prompt(custom_instructions="")
|
|
94
|
+
|
|
95
|
+
session = session_manager.create_session(
|
|
96
|
+
title=request.title,
|
|
97
|
+
system_prompt=system_prompt
|
|
98
|
+
)
|
|
99
|
+
return session.model_dump()
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
@app.get("/api/sessions")
|
|
103
|
+
def list_sessions():
|
|
104
|
+
"""List all chat sessions"""
|
|
105
|
+
sessions = session_manager.list_sessions()
|
|
106
|
+
return [s.model_dump() for s in sessions]
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
@app.get("/api/sessions/{session_id}")
|
|
110
|
+
def get_session(session_id: str):
|
|
111
|
+
"""Get a specific session with all records"""
|
|
112
|
+
session = session_manager.get_session(session_id)
|
|
113
|
+
if not session:
|
|
114
|
+
raise HTTPException(status_code=404, detail="Session not found")
|
|
115
|
+
|
|
116
|
+
return session.model_dump()
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
@app.delete("/api/sessions/{session_id}")
|
|
120
|
+
def delete_session(session_id: str):
|
|
121
|
+
"""Delete a session"""
|
|
122
|
+
success = session_manager.delete_session(session_id)
|
|
123
|
+
if not success:
|
|
124
|
+
raise HTTPException(status_code=404, detail="Session not found")
|
|
125
|
+
return {"status": "success"}
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
@app.get("/api/tools/status")
|
|
129
|
+
def get_tools_status():
|
|
130
|
+
"""Get status of all tools"""
|
|
131
|
+
return tool_executor.get_all_tool_statuses()
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
# Streaming chat endpoint
|
|
135
|
+
@app.post("/api/chat/stream")
|
|
136
|
+
def chat_stream(request: ChatRequest):
|
|
137
|
+
"""Send a message and get AI response with streaming support"""
|
|
138
|
+
|
|
139
|
+
def generate():
|
|
140
|
+
# Get session
|
|
141
|
+
session = session_manager.get_session(request.session_id)
|
|
142
|
+
if not session:
|
|
143
|
+
yield f"data: {json.dumps({'type': 'error', 'content': 'Session not found'})}\n\n"
|
|
144
|
+
return
|
|
145
|
+
|
|
146
|
+
# Add user message record
|
|
147
|
+
user_record = MessageRecord(
|
|
148
|
+
role="user",
|
|
149
|
+
content=request.message,
|
|
150
|
+
binary_content=request.binary_content
|
|
151
|
+
)
|
|
152
|
+
session_manager.add_record(request.session_id, user_record.model_dump())
|
|
153
|
+
|
|
154
|
+
# Send user record event
|
|
155
|
+
yield f"data: {json.dumps({'type': 'record', 'record': user_record.model_dump()}, default=str)}\n\n"
|
|
156
|
+
|
|
157
|
+
# Get messages for LLM
|
|
158
|
+
messages = session_manager.get_messages_for_llm(request.session_id)
|
|
159
|
+
|
|
160
|
+
try:
|
|
161
|
+
# Tool calling loop
|
|
162
|
+
max_iterations = 50
|
|
163
|
+
|
|
164
|
+
for iteration in range(max_iterations):
|
|
165
|
+
# Stream AI response
|
|
166
|
+
current_content = ""
|
|
167
|
+
tool_calls = []
|
|
168
|
+
|
|
169
|
+
for chunk in agent.chat_stream(
|
|
170
|
+
messages=messages[1:],
|
|
171
|
+
system_prompt=session.system_prompt,
|
|
172
|
+
images=[bc.get("data") for bc in (request.binary_content or []) if bc.get("data")] if iteration == 0 else None
|
|
173
|
+
):
|
|
174
|
+
if chunk["type"] == "content":
|
|
175
|
+
current_content += chunk["content"]
|
|
176
|
+
yield f"data: {json.dumps({'type': 'content', 'content': chunk['content']})}\n\n"
|
|
177
|
+
elif chunk["type"] == "done":
|
|
178
|
+
tool_calls = chunk.get("tool_calls", [])
|
|
179
|
+
|
|
180
|
+
if not tool_calls:
|
|
181
|
+
# No tool calls, save and finish
|
|
182
|
+
if current_content:
|
|
183
|
+
assistant_record = MessageRecord(
|
|
184
|
+
role="assistant",
|
|
185
|
+
content=current_content
|
|
186
|
+
)
|
|
187
|
+
session_manager.add_record(request.session_id, assistant_record.model_dump())
|
|
188
|
+
yield f"data: {json.dumps({'type': 'record', 'record': assistant_record.model_dump()}, default=str)}\n\n"
|
|
189
|
+
|
|
190
|
+
yield f"data: {json.dumps({'type': 'done'})}\n\n"
|
|
191
|
+
return
|
|
192
|
+
|
|
193
|
+
# Has tool calls - save thinking message if any
|
|
194
|
+
if current_content:
|
|
195
|
+
thinking_record = MessageRecord(
|
|
196
|
+
role="assistant",
|
|
197
|
+
content=current_content
|
|
198
|
+
)
|
|
199
|
+
session_manager.add_record(request.session_id, thinking_record.model_dump())
|
|
200
|
+
yield f"data: {json.dumps({'type': 'record', 'record': thinking_record.model_dump()}, default=str)}\n\n"
|
|
201
|
+
messages.append({"role": "assistant", "content": current_content, "tool_calls": tool_calls})
|
|
202
|
+
else:
|
|
203
|
+
messages.append({"role": "assistant", "content": "", "tool_calls": tool_calls})
|
|
204
|
+
|
|
205
|
+
# Signal tool calls starting
|
|
206
|
+
yield f"data: {json.dumps({'type': 'tool_calls_start', 'tool_calls': tool_calls})}\n\n"
|
|
207
|
+
|
|
208
|
+
# Execute tool calls
|
|
209
|
+
for tool_call in tool_calls:
|
|
210
|
+
tool_name = tool_call["function"]["name"]
|
|
211
|
+
tool_args_str = tool_call["function"]["arguments"]
|
|
212
|
+
tool_id = tool_call["id"]
|
|
213
|
+
|
|
214
|
+
# Parse arguments
|
|
215
|
+
try:
|
|
216
|
+
tool_args = json.loads(tool_args_str) if isinstance(tool_args_str, str) else tool_args_str
|
|
217
|
+
except json.JSONDecodeError:
|
|
218
|
+
tool_args = {"raw_arguments": tool_args_str}
|
|
219
|
+
|
|
220
|
+
# Signal tool execution starting
|
|
221
|
+
yield f"data: {json.dumps({'type': 'tool_start', 'tool_call_id': tool_id, 'tool_name': tool_name, 'arguments': tool_args, 'success': None})}\n\n"
|
|
222
|
+
|
|
223
|
+
# Execute the tool
|
|
224
|
+
result = agent.process_tool_call(
|
|
225
|
+
tool_name,
|
|
226
|
+
tool_args_str,
|
|
227
|
+
context={"session_id": request.session_id}
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
# Create and save tool call record
|
|
231
|
+
tool_record = ToolCallRecord(
|
|
232
|
+
tool_call_id=tool_id,
|
|
233
|
+
tool_name=tool_name,
|
|
234
|
+
arguments=tool_args,
|
|
235
|
+
result=result.get("result") if result.get("success") else result.get("error"),
|
|
236
|
+
success=result.get("success", False)
|
|
237
|
+
)
|
|
238
|
+
session_manager.add_record(request.session_id, tool_record.model_dump())
|
|
239
|
+
|
|
240
|
+
# Send tool result
|
|
241
|
+
yield f"data: {json.dumps({'type': 'tool_result', 'record': tool_record.model_dump()}, default=str)}\n\n"
|
|
242
|
+
|
|
243
|
+
# Add tool result to messages
|
|
244
|
+
messages.append({
|
|
245
|
+
"role": "tool",
|
|
246
|
+
"tool_call_id": tool_id,
|
|
247
|
+
"content": json.dumps(result) if isinstance(result, dict) else str(result)
|
|
248
|
+
})
|
|
249
|
+
|
|
250
|
+
# Max iterations reached
|
|
251
|
+
final_content = "Tool execution reached maximum iterations. Please try again with a simpler request."
|
|
252
|
+
final_record = MessageRecord(role="assistant", content=final_content)
|
|
253
|
+
session_manager.add_record(request.session_id, final_record.model_dump())
|
|
254
|
+
yield f"data: {json.dumps({'type': 'record', 'record': final_record.model_dump()}, default=str)}\n\n"
|
|
255
|
+
yield f"data: {json.dumps({'type': 'done'})}\n\n"
|
|
256
|
+
|
|
257
|
+
except Exception as e:
|
|
258
|
+
error_content = f"Error: {str(e)}"
|
|
259
|
+
error_record = MessageRecord(role="assistant", content=error_content)
|
|
260
|
+
session_manager.add_record(request.session_id, error_record.model_dump())
|
|
261
|
+
yield f"data: {json.dumps({'type': 'error', 'content': error_content})}\n\n"
|
|
262
|
+
|
|
263
|
+
return StreamingResponse(
|
|
264
|
+
generate(),
|
|
265
|
+
media_type="text/event-stream",
|
|
266
|
+
headers={
|
|
267
|
+
"Cache-Control": "no-cache",
|
|
268
|
+
"Connection": "keep-alive",
|
|
269
|
+
"X-Accel-Buffering": "no"
|
|
270
|
+
}
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
# System prompt endpoints
|
|
275
|
+
@app.post("/api/sessions/{session_id}/system-prompt")
|
|
276
|
+
def update_system_prompt(session_id: str, request: SystemPromptUpdate):
|
|
277
|
+
"""Update session system prompt"""
|
|
278
|
+
session = session_manager.update_system_prompt(session_id, request.system_prompt)
|
|
279
|
+
if not session:
|
|
280
|
+
raise HTTPException(status_code=404, detail="Session not found")
|
|
281
|
+
return session.model_dump()
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
# Health check
|
|
285
|
+
@app.get("/api/health")
|
|
286
|
+
def health_check():
|
|
287
|
+
"""Health check endpoint"""
|
|
288
|
+
return {"status": "ok"}
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
# Serve frontend static files
|
|
292
|
+
static_dir = Path(__file__).parent / "static"
|
|
293
|
+
static_dir.mkdir(parents=True, exist_ok=True)
|
|
294
|
+
print(static_dir)
|
|
295
|
+
app.mount("/", StaticFiles(directory=static_dir, html=True), name="static")
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
if __name__ == "__main__":
|
|
299
|
+
import uvicorn
|
|
300
|
+
uvicorn.run(app, host="0.0.0.0", port=8009)
|
iribot/models.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
"""Data models for the Agent application"""
|
|
2
|
+
from typing import Optional, List, Dict, Any, Literal, Union
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from pydantic import BaseModel, Field
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def get_local_now():
|
|
8
|
+
"""Get the current local time."""
|
|
9
|
+
return datetime.now()
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
# ============ Session Record Types ============
|
|
13
|
+
|
|
14
|
+
class MessageRecord(BaseModel):
|
|
15
|
+
"""Message record - represents a message in the conversation"""
|
|
16
|
+
type: Literal["message"] = "message"
|
|
17
|
+
role: Literal["system", "user", "assistant"]
|
|
18
|
+
content: str
|
|
19
|
+
binary_content: Optional[List[Dict[str, Any]]] = None # Images, files, etc.
|
|
20
|
+
timestamp: datetime = Field(default_factory=get_local_now)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ToolCallRecord(BaseModel):
|
|
24
|
+
"""Tool call record - represents a single tool invocation and result"""
|
|
25
|
+
type: Literal["tool_call"] = "tool_call"
|
|
26
|
+
tool_call_id: str
|
|
27
|
+
tool_name: str
|
|
28
|
+
arguments: Dict[str, Any]
|
|
29
|
+
result: Any
|
|
30
|
+
success: bool
|
|
31
|
+
timestamp: datetime = Field(default_factory=get_local_now)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
# Union type for session records
|
|
35
|
+
SessionRecord = Union[MessageRecord, ToolCallRecord]
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
# ============ Session Model ============
|
|
39
|
+
|
|
40
|
+
class Session(BaseModel):
|
|
41
|
+
"""Session model with unified record list"""
|
|
42
|
+
id: str = Field(default_factory=lambda: str(datetime.now().timestamp()))
|
|
43
|
+
title: str
|
|
44
|
+
records: List[Dict[str, Any]] = [] # List of MessageRecord or ToolCallRecord
|
|
45
|
+
created_at: datetime = Field(default_factory=get_local_now)
|
|
46
|
+
updated_at: datetime = Field(default_factory=get_local_now)
|
|
47
|
+
system_prompt: str # Required field, must be provided when creating session
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
# ============ API Request/Response Models ============
|
|
51
|
+
|
|
52
|
+
class ChatRequest(BaseModel):
|
|
53
|
+
"""Chat request model"""
|
|
54
|
+
session_id: str
|
|
55
|
+
message: str
|
|
56
|
+
binary_content: Optional[List[Dict[str, Any]]] = None
|
|
57
|
+
|
|
58
|
+
class SystemPromptUpdate(BaseModel):
|
|
59
|
+
"""System prompt update model"""
|
|
60
|
+
system_prompt: str
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
# ============ Prompt Generation Models ============
|
|
64
|
+
|
|
65
|
+
class SystemPromptGenerateRequest(BaseModel):
|
|
66
|
+
"""Request to generate a system prompt"""
|
|
67
|
+
custom_instructions: str = ""
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class SystemPromptGenerateResponse(BaseModel):
|
|
71
|
+
"""Response containing generated system prompt"""
|
|
72
|
+
system_prompt: str
|
|
73
|
+
datetime_info: dict
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class SessionCreate(BaseModel):
|
|
77
|
+
"""Create new session request"""
|
|
78
|
+
title: str = "New Chat"
|
|
79
|
+
system_prompt: Optional[str] = None
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
"""System Prompt Generator for Agent"""
|
|
2
|
+
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from zoneinfo import ZoneInfo
|
|
5
|
+
from typing import List, Dict, Any
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from jinja2 import Environment, FileSystemLoader
|
|
8
|
+
from .executor import tool_executor
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
# Initialize Jinja2 Environment
|
|
12
|
+
TEMPLATE_DIR = Path(__file__).parent / "templates"
|
|
13
|
+
jinja_env = Environment(
|
|
14
|
+
loader=FileSystemLoader(str(TEMPLATE_DIR)),
|
|
15
|
+
trim_blocks=True,
|
|
16
|
+
lstrip_blocks=True,
|
|
17
|
+
autoescape=False
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def get_current_datetime_info() -> Dict[str, str]:
|
|
22
|
+
"""
|
|
23
|
+
Get current date, time and timezone information
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
Dictionary containing formatted date/time/timezone
|
|
27
|
+
"""
|
|
28
|
+
# Get current UTC time
|
|
29
|
+
utc_now = datetime.now(ZoneInfo("UTC"))
|
|
30
|
+
|
|
31
|
+
# Also get local timezone (system timezone)
|
|
32
|
+
local_tz = datetime.now().astimezone().tzinfo
|
|
33
|
+
local_now = datetime.now(local_tz)
|
|
34
|
+
|
|
35
|
+
return {
|
|
36
|
+
"current_utc": utc_now.strftime("%Y-%m-%d %H:%M:%S UTC"),
|
|
37
|
+
"current_local": local_now.strftime("%Y-%m-%d %H:%M:%S %Z"),
|
|
38
|
+
"timezone": str(local_tz),
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def get_available_tools_description() -> str:
|
|
43
|
+
"""
|
|
44
|
+
Get descriptions of all available tools for the prompt
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
Formatted string describing all available tools
|
|
48
|
+
"""
|
|
49
|
+
tools = tool_executor.get_all_tools()
|
|
50
|
+
|
|
51
|
+
if not tools:
|
|
52
|
+
return "No tools are currently available."
|
|
53
|
+
|
|
54
|
+
tools_description = "## Available Tools\n\n"
|
|
55
|
+
|
|
56
|
+
for tool in tools:
|
|
57
|
+
# Extract tool info from OpenAI function format
|
|
58
|
+
func = tool.get('function', {})
|
|
59
|
+
tool_name = func.get('name', 'Unknown')
|
|
60
|
+
tool_desc = func.get('description', 'No description available')
|
|
61
|
+
params = func.get('parameters', {})
|
|
62
|
+
|
|
63
|
+
tools_description += f"### {tool_name}\n"
|
|
64
|
+
tools_description += f"Description: {tool_desc}\n"
|
|
65
|
+
|
|
66
|
+
if 'properties' in params:
|
|
67
|
+
tools_description += "Parameters:\n"
|
|
68
|
+
for param_name, param_info in params['properties'].items():
|
|
69
|
+
param_type = param_info.get('type', 'unknown')
|
|
70
|
+
param_desc = param_info.get('description', 'No description')
|
|
71
|
+
required = param_name in params.get('required', [])
|
|
72
|
+
req_mark = " (required)" if required else " (optional)"
|
|
73
|
+
tools_description += f" - `{param_name}` ({param_type}){req_mark}: {param_desc}\n"
|
|
74
|
+
|
|
75
|
+
tools_description += "\n"
|
|
76
|
+
|
|
77
|
+
return tools_description
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def generate_system_prompt(custom_instructions: str = "") -> str:
|
|
81
|
+
"""
|
|
82
|
+
Generate a system prompt for the Agent using Jinja2 template
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
custom_instructions: Optional custom instructions to append to the prompt
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
Complete system prompt string
|
|
89
|
+
"""
|
|
90
|
+
datetime_info = get_current_datetime_info()
|
|
91
|
+
tools_desc = get_available_tools_description()
|
|
92
|
+
|
|
93
|
+
# Load and render template
|
|
94
|
+
template = jinja_env.get_template("system_prompt.j2")
|
|
95
|
+
|
|
96
|
+
prompt = template.render(
|
|
97
|
+
current_utc=datetime_info['current_utc'],
|
|
98
|
+
current_local=datetime_info['current_local'],
|
|
99
|
+
timezone=datetime_info['timezone'],
|
|
100
|
+
tools_description=tools_desc,
|
|
101
|
+
custom_instructions=custom_instructions
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
return prompt
|
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
"""Session management for storing and retrieving chat sessions"""
|
|
2
|
+
from typing import Dict, Optional, List, Any
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
import json
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from .models import Session, MessageRecord
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class SessionManager:
|
|
10
|
+
"""Manages chat sessions and persistence"""
|
|
11
|
+
|
|
12
|
+
def __init__(self, storage_path: str = "./sessions"):
|
|
13
|
+
self.storage_path = Path(storage_path)
|
|
14
|
+
self.storage_path.mkdir(exist_ok=True)
|
|
15
|
+
self.sessions: Dict[str, Session] = {}
|
|
16
|
+
self._load_all_sessions()
|
|
17
|
+
|
|
18
|
+
def create_session(self, title: str = "New Chat", system_prompt: Optional[str] = None) -> Session:
|
|
19
|
+
"""Create a new session with system message"""
|
|
20
|
+
if system_prompt is None:
|
|
21
|
+
raise ValueError("system_prompt is required. Please generate a system prompt using the prompt generator first.")
|
|
22
|
+
|
|
23
|
+
session = Session(
|
|
24
|
+
title=title,
|
|
25
|
+
system_prompt=system_prompt
|
|
26
|
+
)
|
|
27
|
+
# Add system message as first record
|
|
28
|
+
system_record = MessageRecord(
|
|
29
|
+
role="system",
|
|
30
|
+
content=session.system_prompt
|
|
31
|
+
)
|
|
32
|
+
session.records.append(system_record.model_dump())
|
|
33
|
+
|
|
34
|
+
self.sessions[session.id] = session
|
|
35
|
+
self._save_session(session)
|
|
36
|
+
return session
|
|
37
|
+
|
|
38
|
+
def get_session(self, session_id: str) -> Optional[Session]:
|
|
39
|
+
"""Get a session by ID"""
|
|
40
|
+
return self.sessions.get(session_id)
|
|
41
|
+
|
|
42
|
+
def list_sessions(self) -> List[Session]:
|
|
43
|
+
"""List all sessions sorted by updated_at descending"""
|
|
44
|
+
return sorted(
|
|
45
|
+
self.sessions.values(),
|
|
46
|
+
key=lambda s: s.updated_at,
|
|
47
|
+
reverse=True
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
def add_record(self, session_id: str, record: Dict[str, Any]) -> Optional[Session]:
|
|
51
|
+
"""Add a record to a session"""
|
|
52
|
+
session = self.sessions.get(session_id)
|
|
53
|
+
if not session:
|
|
54
|
+
return None
|
|
55
|
+
|
|
56
|
+
session.records.append(record)
|
|
57
|
+
session.updated_at = datetime.now()
|
|
58
|
+
self._save_session(session)
|
|
59
|
+
return session
|
|
60
|
+
|
|
61
|
+
def add_records(self, session_id: str, records: List[Dict[str, Any]]) -> Optional[Session]:
|
|
62
|
+
"""Add multiple records to a session"""
|
|
63
|
+
session = self.sessions.get(session_id)
|
|
64
|
+
if not session:
|
|
65
|
+
return None
|
|
66
|
+
|
|
67
|
+
session.records.extend(records)
|
|
68
|
+
session.updated_at = datetime.now()
|
|
69
|
+
self._save_session(session)
|
|
70
|
+
return session
|
|
71
|
+
|
|
72
|
+
def delete_session(self, session_id: str) -> bool:
|
|
73
|
+
"""Delete a session"""
|
|
74
|
+
if session_id not in self.sessions:
|
|
75
|
+
return False
|
|
76
|
+
|
|
77
|
+
del self.sessions[session_id]
|
|
78
|
+
session_file = self.storage_path / f"{session_id}.json"
|
|
79
|
+
if session_file.exists():
|
|
80
|
+
session_file.unlink()
|
|
81
|
+
return True
|
|
82
|
+
|
|
83
|
+
def update_system_prompt(self, session_id: str, system_prompt: str) -> Optional[Session]:
|
|
84
|
+
"""Update session system prompt"""
|
|
85
|
+
session = self.sessions.get(session_id)
|
|
86
|
+
if not session:
|
|
87
|
+
return None
|
|
88
|
+
|
|
89
|
+
session.system_prompt = system_prompt
|
|
90
|
+
# Update the system message record if it exists
|
|
91
|
+
if session.records and session.records[0].get("type") == "message" and session.records[0].get("role") == "system":
|
|
92
|
+
session.records[0]["content"] = system_prompt
|
|
93
|
+
|
|
94
|
+
session.updated_at = datetime.now()
|
|
95
|
+
self._save_session(session)
|
|
96
|
+
return session
|
|
97
|
+
|
|
98
|
+
def get_messages_for_llm(self, session_id: str) -> List[Dict[str, Any]]:
|
|
99
|
+
"""Extract messages in LLM-compatible format from session records"""
|
|
100
|
+
session = self.sessions.get(session_id)
|
|
101
|
+
if not session:
|
|
102
|
+
return []
|
|
103
|
+
|
|
104
|
+
messages = []
|
|
105
|
+
|
|
106
|
+
for record in session.records:
|
|
107
|
+
if record.get("type") == "message":
|
|
108
|
+
role = record.get("role")
|
|
109
|
+
content = record.get("content", "")
|
|
110
|
+
|
|
111
|
+
if role in ["system", "user"]:
|
|
112
|
+
messages.append({"role": role, "content": content})
|
|
113
|
+
elif role == "assistant":
|
|
114
|
+
# Check if this assistant message had tool calls following it
|
|
115
|
+
messages.append({"role": "assistant", "content": content})
|
|
116
|
+
|
|
117
|
+
elif record.get("type") == "tool_call":
|
|
118
|
+
# Add tool call to assistant message and tool result
|
|
119
|
+
tool_call_id = record.get("tool_call_id")
|
|
120
|
+
tool_name = record.get("tool_name")
|
|
121
|
+
arguments = record.get("arguments", {})
|
|
122
|
+
result = record.get("result")
|
|
123
|
+
|
|
124
|
+
# If last message is assistant, add tool_calls to it
|
|
125
|
+
if messages and messages[-1]["role"] == "assistant":
|
|
126
|
+
if "tool_calls" not in messages[-1]:
|
|
127
|
+
messages[-1]["tool_calls"] = []
|
|
128
|
+
messages[-1]["tool_calls"].append({
|
|
129
|
+
"id": tool_call_id,
|
|
130
|
+
"type": "function",
|
|
131
|
+
"function": {
|
|
132
|
+
"name": tool_name,
|
|
133
|
+
"arguments": json.dumps(arguments) if isinstance(arguments, dict) else str(arguments)
|
|
134
|
+
}
|
|
135
|
+
})
|
|
136
|
+
|
|
137
|
+
# Add tool result message
|
|
138
|
+
messages.append({
|
|
139
|
+
"role": "tool",
|
|
140
|
+
"tool_call_id": tool_call_id,
|
|
141
|
+
"content": json.dumps(result) if isinstance(result, dict) else str(result)
|
|
142
|
+
})
|
|
143
|
+
|
|
144
|
+
return messages
|
|
145
|
+
|
|
146
|
+
def _save_session(self, session: Session) -> None:
|
|
147
|
+
"""Save session to disk"""
|
|
148
|
+
session_file = self.storage_path / f"{session.id}.json"
|
|
149
|
+
with open(session_file, 'w', encoding='utf-8') as f:
|
|
150
|
+
json.dump(session.model_dump(), f, indent=2, default=str, ensure_ascii=False)
|
|
151
|
+
|
|
152
|
+
def _load_all_sessions(self) -> None:
|
|
153
|
+
"""Load all sessions from disk"""
|
|
154
|
+
for session_file in self.storage_path.glob("*.json"):
|
|
155
|
+
try:
|
|
156
|
+
with open(session_file, 'r', encoding='utf-8') as f:
|
|
157
|
+
data = json.load(f)
|
|
158
|
+
# Handle migration from old format
|
|
159
|
+
if "messages" in data and "records" not in data:
|
|
160
|
+
data = self._migrate_old_format(data)
|
|
161
|
+
session = Session(**data)
|
|
162
|
+
self.sessions[session.id] = session
|
|
163
|
+
except Exception as e:
|
|
164
|
+
print(f"Error loading session {session_file}: {e}")
|
|
165
|
+
|
|
166
|
+
def _migrate_old_format(self, data: Dict) -> Dict:
|
|
167
|
+
"""Migrate from old message format to new record format"""
|
|
168
|
+
records = []
|
|
169
|
+
|
|
170
|
+
# Add system message
|
|
171
|
+
records.append({
|
|
172
|
+
"type": "message",
|
|
173
|
+
"role": "system",
|
|
174
|
+
"content": data.get("system_prompt", ""),
|
|
175
|
+
"timestamp": data.get("created_at")
|
|
176
|
+
})
|
|
177
|
+
|
|
178
|
+
# Convert old messages to records
|
|
179
|
+
for msg in data.get("messages", []):
|
|
180
|
+
records.append({
|
|
181
|
+
"type": "message",
|
|
182
|
+
"role": msg.get("role"),
|
|
183
|
+
"content": msg.get("content", ""),
|
|
184
|
+
"binary_content": msg.get("binary_content"),
|
|
185
|
+
"timestamp": msg.get("timestamp")
|
|
186
|
+
})
|
|
187
|
+
|
|
188
|
+
data["records"] = records
|
|
189
|
+
del data["messages"]
|
|
190
|
+
return data
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
# Global session manager instance
|
|
194
|
+
session_manager = SessionManager()
|