mem-llm 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mem_llm/__init__.py +98 -0
- mem_llm/api_server.py +595 -0
- mem_llm/base_llm_client.py +201 -0
- mem_llm/builtin_tools.py +311 -0
- mem_llm/cli.py +254 -0
- mem_llm/clients/__init__.py +22 -0
- mem_llm/clients/lmstudio_client.py +393 -0
- mem_llm/clients/ollama_client.py +354 -0
- mem_llm/config.yaml.example +52 -0
- mem_llm/config_from_docs.py +180 -0
- mem_llm/config_manager.py +231 -0
- mem_llm/conversation_summarizer.py +372 -0
- mem_llm/data_export_import.py +640 -0
- mem_llm/dynamic_prompt.py +298 -0
- mem_llm/knowledge_loader.py +88 -0
- mem_llm/llm_client.py +225 -0
- mem_llm/llm_client_factory.py +260 -0
- mem_llm/logger.py +129 -0
- mem_llm/mem_agent.py +1611 -0
- mem_llm/memory_db.py +612 -0
- mem_llm/memory_manager.py +321 -0
- mem_llm/memory_tools.py +253 -0
- mem_llm/prompt_security.py +304 -0
- mem_llm/response_metrics.py +221 -0
- mem_llm/retry_handler.py +193 -0
- mem_llm/thread_safe_db.py +301 -0
- mem_llm/tool_system.py +429 -0
- mem_llm/vector_store.py +278 -0
- mem_llm/web_launcher.py +129 -0
- mem_llm/web_ui/README.md +44 -0
- mem_llm/web_ui/__init__.py +7 -0
- mem_llm/web_ui/index.html +641 -0
- mem_llm/web_ui/memory.html +569 -0
- mem_llm/web_ui/metrics.html +75 -0
- mem_llm-2.0.0.dist-info/METADATA +667 -0
- mem_llm-2.0.0.dist-info/RECORD +39 -0
- mem_llm-2.0.0.dist-info/WHEEL +5 -0
- mem_llm-2.0.0.dist-info/entry_points.txt +3 -0
- mem_llm-2.0.0.dist-info/top_level.txt +1 -0
mem_llm/__init__.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Memory-LLM: Memory-Enabled Mini Assistant
|
|
3
|
+
AI library that remembers user interactions
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from .mem_agent import MemAgent
|
|
7
|
+
from .memory_manager import MemoryManager
|
|
8
|
+
from .llm_client import OllamaClient # Backward compatibility
|
|
9
|
+
from .base_llm_client import BaseLLMClient
|
|
10
|
+
from .llm_client_factory import LLMClientFactory
|
|
11
|
+
|
|
12
|
+
# New multi-backend support (v1.3.0+)
|
|
13
|
+
from .clients import OllamaClient as OllamaClientNew
|
|
14
|
+
from .clients import LMStudioClient
|
|
15
|
+
|
|
16
|
+
# Tools (optional)
|
|
17
|
+
try:
|
|
18
|
+
from .memory_tools import MemoryTools, ToolExecutor
|
|
19
|
+
__all_tools__ = ["MemoryTools", "ToolExecutor"]
|
|
20
|
+
except ImportError:
|
|
21
|
+
__all_tools__ = []
|
|
22
|
+
|
|
23
|
+
# Pro version imports (optional)
|
|
24
|
+
try:
|
|
25
|
+
from .memory_db import SQLMemoryManager
|
|
26
|
+
from .config_manager import get_config
|
|
27
|
+
from .config_from_docs import create_config_from_document
|
|
28
|
+
from .dynamic_prompt import dynamic_prompt_builder
|
|
29
|
+
__all_pro__ = ["SQLMemoryManager", "get_config", "create_config_from_document", "dynamic_prompt_builder"]
|
|
30
|
+
except ImportError:
|
|
31
|
+
__all_pro__ = []
|
|
32
|
+
|
|
33
|
+
# Security features (optional, v1.1.0+)
|
|
34
|
+
try:
|
|
35
|
+
from .prompt_security import (
|
|
36
|
+
PromptInjectionDetector,
|
|
37
|
+
InputSanitizer,
|
|
38
|
+
SecurePromptBuilder
|
|
39
|
+
)
|
|
40
|
+
__all_security__ = ["PromptInjectionDetector", "InputSanitizer", "SecurePromptBuilder"]
|
|
41
|
+
except ImportError:
|
|
42
|
+
__all_security__ = []
|
|
43
|
+
|
|
44
|
+
# Enhanced features (v1.1.0+)
|
|
45
|
+
try:
|
|
46
|
+
from .logger import get_logger, MemLLMLogger
|
|
47
|
+
from .retry_handler import exponential_backoff_retry, SafeExecutor
|
|
48
|
+
__all_enhanced__ = ["get_logger", "MemLLMLogger", "exponential_backoff_retry", "SafeExecutor"]
|
|
49
|
+
except ImportError:
|
|
50
|
+
__all_enhanced__ = []
|
|
51
|
+
|
|
52
|
+
# Conversation Summarization (v1.2.0+)
|
|
53
|
+
try:
|
|
54
|
+
from .conversation_summarizer import ConversationSummarizer, AutoSummarizer
|
|
55
|
+
__all_summarizer__ = ["ConversationSummarizer", "AutoSummarizer"]
|
|
56
|
+
except ImportError:
|
|
57
|
+
__all_summarizer__ = []
|
|
58
|
+
|
|
59
|
+
# Data Export/Import (v1.2.0+)
|
|
60
|
+
try:
|
|
61
|
+
from .data_export_import import DataExporter, DataImporter
|
|
62
|
+
__all_export_import__ = ["DataExporter", "DataImporter"]
|
|
63
|
+
except ImportError:
|
|
64
|
+
__all_export_import__ = []
|
|
65
|
+
|
|
66
|
+
# Response Metrics (v1.3.1+)
|
|
67
|
+
try:
|
|
68
|
+
from .response_metrics import ChatResponse, ResponseMetricsAnalyzer, calculate_confidence
|
|
69
|
+
__all_metrics__ = ["ChatResponse", "ResponseMetricsAnalyzer", "calculate_confidence"]
|
|
70
|
+
except ImportError:
|
|
71
|
+
__all_metrics__ = []
|
|
72
|
+
|
|
73
|
+
__version__ = "2.0.0"
|
|
74
|
+
__author__ = "Cihat Emre Karataş"
|
|
75
|
+
|
|
76
|
+
# Multi-backend LLM support (v1.3.0+)
|
|
77
|
+
__all_llm_backends__ = ["BaseLLMClient", "LLMClientFactory", "OllamaClientNew", "LMStudioClient"]
|
|
78
|
+
|
|
79
|
+
# Tool system (v2.0.0+)
|
|
80
|
+
try:
|
|
81
|
+
from .tool_system import tool, Tool, ToolRegistry
|
|
82
|
+
from .builtin_tools import BUILTIN_TOOLS
|
|
83
|
+
__all_tools__ = ["tool", "Tool", "ToolRegistry", "BUILTIN_TOOLS"]
|
|
84
|
+
except ImportError:
|
|
85
|
+
__all_tools__ = []
|
|
86
|
+
|
|
87
|
+
# CLI
|
|
88
|
+
try:
|
|
89
|
+
from .cli import cli
|
|
90
|
+
__all_cli__ = ["cli"]
|
|
91
|
+
except ImportError:
|
|
92
|
+
__all_cli__ = []
|
|
93
|
+
|
|
94
|
+
__all__ = [
|
|
95
|
+
"MemAgent",
|
|
96
|
+
"MemoryManager",
|
|
97
|
+
"OllamaClient",
|
|
98
|
+
] + __all_llm_backends__ + __all_tools__ + __all_pro__ + __all_cli__ + __all_security__ + __all_enhanced__ + __all_summarizer__ + __all_export_import__ + __all_metrics__
|
mem_llm/api_server.py
ADDED
|
@@ -0,0 +1,595 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Mem-LLM REST API Server
|
|
3
|
+
========================
|
|
4
|
+
|
|
5
|
+
FastAPI-based REST API server for Mem-LLM.
|
|
6
|
+
Provides HTTP endpoints and WebSocket support for streaming responses.
|
|
7
|
+
|
|
8
|
+
Features:
|
|
9
|
+
- RESTful API endpoints
|
|
10
|
+
- WebSocket streaming support
|
|
11
|
+
- Multi-user management
|
|
12
|
+
- Knowledge base operations
|
|
13
|
+
- Memory search and export
|
|
14
|
+
- CORS support for web frontends
|
|
15
|
+
- Auto-generated API documentation (Swagger UI)
|
|
16
|
+
|
|
17
|
+
Usage:
|
|
18
|
+
# Run server
|
|
19
|
+
python -m mem_llm.api_server
|
|
20
|
+
|
|
21
|
+
# Or with uvicorn directly
|
|
22
|
+
uvicorn mem_llm.api_server:app --reload --host 0.0.0.0 --port 8000
|
|
23
|
+
|
|
24
|
+
API Documentation:
|
|
25
|
+
- Swagger UI: http://localhost:8000/docs
|
|
26
|
+
- ReDoc: http://localhost:8000/redoc
|
|
27
|
+
|
|
28
|
+
Author: C. Emre Karataş
|
|
29
|
+
Version: 1.3.3
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, BackgroundTasks
|
|
33
|
+
from fastapi.middleware.cors import CORSMiddleware
|
|
34
|
+
from fastapi.responses import StreamingResponse, FileResponse
|
|
35
|
+
from fastapi.staticfiles import StaticFiles
|
|
36
|
+
from pydantic import BaseModel, Field
|
|
37
|
+
from typing import Optional, List, Dict, Any
|
|
38
|
+
from datetime import datetime
|
|
39
|
+
from pathlib import Path
|
|
40
|
+
import json
|
|
41
|
+
import logging
|
|
42
|
+
import asyncio
|
|
43
|
+
from contextlib import asynccontextmanager
|
|
44
|
+
|
|
45
|
+
# Import Mem-LLM components
|
|
46
|
+
from .mem_agent import MemAgent
|
|
47
|
+
from .response_metrics import ChatResponse
|
|
48
|
+
|
|
49
|
+
# Configure logging
|
|
50
|
+
logging.basicConfig(level=logging.INFO)
|
|
51
|
+
logger = logging.getLogger(__name__)
|
|
52
|
+
|
|
53
|
+
# Store active agents for each user
|
|
54
|
+
agents: Dict[str, MemAgent] = {}
|
|
55
|
+
|
|
56
|
+
# Default agent configuration
|
|
57
|
+
DEFAULT_CONFIG = {
|
|
58
|
+
"model": "granite4:3b",
|
|
59
|
+
"backend": "ollama",
|
|
60
|
+
"base_url": "http://localhost:11434",
|
|
61
|
+
"use_sql": True,
|
|
62
|
+
"load_knowledge_base": True
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
def get_or_create_agent(user_id: str, config: Optional[Dict] = None) -> MemAgent:
|
|
66
|
+
"""
|
|
67
|
+
Get existing agent or create new one for user
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
user_id: User identifier
|
|
71
|
+
config: Optional agent configuration
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
MemAgent instance
|
|
75
|
+
"""
|
|
76
|
+
if user_id not in agents:
|
|
77
|
+
agent_config = DEFAULT_CONFIG.copy()
|
|
78
|
+
if config:
|
|
79
|
+
agent_config.update(config)
|
|
80
|
+
|
|
81
|
+
logger.info(f"Creating new agent for user: {user_id}")
|
|
82
|
+
agent = MemAgent(**agent_config)
|
|
83
|
+
agent.set_user(user_id)
|
|
84
|
+
agents[user_id] = agent
|
|
85
|
+
|
|
86
|
+
return agents[user_id]
|
|
87
|
+
|
|
88
|
+
# Lifespan context manager for startup/shutdown events
|
|
89
|
+
@asynccontextmanager
|
|
90
|
+
async def lifespan(app: FastAPI):
|
|
91
|
+
# Startup
|
|
92
|
+
logger.info("🚀 Mem-LLM API Server starting...")
|
|
93
|
+
logger.info(f"📝 API Documentation: http://localhost:8000/docs")
|
|
94
|
+
logger.info(f"🔌 WebSocket endpoint: ws://localhost:8000/ws/chat/{'{user_id}'}")
|
|
95
|
+
yield
|
|
96
|
+
# Shutdown
|
|
97
|
+
logger.info("🛑 Mem-LLM API Server shutting down...")
|
|
98
|
+
agents.clear()
|
|
99
|
+
|
|
100
|
+
# Create FastAPI app
|
|
101
|
+
app = FastAPI(
|
|
102
|
+
title="Mem-LLM API",
|
|
103
|
+
description="REST API for Mem-LLM - Privacy-first, Memory-enabled AI Assistant (100% Local)",
|
|
104
|
+
version="2.0.0",
|
|
105
|
+
docs_url="/docs",
|
|
106
|
+
redoc_url="/redoc",
|
|
107
|
+
lifespan=lifespan
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
# Add CORS middleware for web frontends
|
|
111
|
+
app.add_middleware(
|
|
112
|
+
CORSMiddleware,
|
|
113
|
+
allow_origins=["*"], # In production, specify exact origins
|
|
114
|
+
allow_credentials=True,
|
|
115
|
+
allow_methods=["*"],
|
|
116
|
+
allow_headers=["*"],
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
# ============================================================================
|
|
120
|
+
# Request/Response Models
|
|
121
|
+
# ============================================================================
|
|
122
|
+
|
|
123
|
+
class ChatRequest(BaseModel):
|
|
124
|
+
"""Chat request model"""
|
|
125
|
+
message: str = Field(..., description="User's message")
|
|
126
|
+
user_id: str = Field(..., description="User identifier")
|
|
127
|
+
metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata")
|
|
128
|
+
stream: bool = Field(False, description="Enable streaming response")
|
|
129
|
+
return_metrics: bool = Field(False, description="Return detailed metrics")
|
|
130
|
+
|
|
131
|
+
class ChatResponse_API(BaseModel):
|
|
132
|
+
"""Chat response model"""
|
|
133
|
+
text: str = Field(..., description="Bot's response text")
|
|
134
|
+
user_id: str = Field(..., description="User identifier")
|
|
135
|
+
confidence: Optional[float] = Field(None, description="Response confidence score (0-1)")
|
|
136
|
+
source: Optional[str] = Field(None, description="Response source (model/kb/hybrid)")
|
|
137
|
+
latency: Optional[float] = Field(None, description="Response latency in milliseconds")
|
|
138
|
+
timestamp: str = Field(..., description="Response timestamp")
|
|
139
|
+
metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata")
|
|
140
|
+
|
|
141
|
+
class KnowledgeEntryRequest(BaseModel):
|
|
142
|
+
"""Knowledge base entry request"""
|
|
143
|
+
category: str = Field(..., description="Entry category")
|
|
144
|
+
question: str = Field(..., description="Question text")
|
|
145
|
+
answer: str = Field(..., description="Answer text")
|
|
146
|
+
|
|
147
|
+
class KnowledgeSearchRequest(BaseModel):
|
|
148
|
+
"""Knowledge base search request"""
|
|
149
|
+
query: str = Field(..., description="Search query")
|
|
150
|
+
limit: int = Field(5, description="Maximum number of results")
|
|
151
|
+
|
|
152
|
+
class UserProfileResponse(BaseModel):
|
|
153
|
+
"""User profile response"""
|
|
154
|
+
user_id: str
|
|
155
|
+
name: Optional[str] = None
|
|
156
|
+
preferences: Optional[Dict[str, Any]] = None
|
|
157
|
+
summary: Optional[str] = None
|
|
158
|
+
interaction_count: int = 0
|
|
159
|
+
|
|
160
|
+
class MemorySearchRequest(BaseModel):
|
|
161
|
+
"""Memory search request"""
|
|
162
|
+
user_id: str = Field(..., description="User identifier")
|
|
163
|
+
query: str = Field(..., description="Search query")
|
|
164
|
+
limit: int = Field(10, description="Maximum number of results")
|
|
165
|
+
|
|
166
|
+
class AgentConfigRequest(BaseModel):
|
|
167
|
+
"""Agent configuration request"""
|
|
168
|
+
model: Optional[str] = Field(None, description="LLM model name")
|
|
169
|
+
backend: Optional[str] = Field(None, description="LLM backend (ollama/lmstudio)")
|
|
170
|
+
base_url: Optional[str] = Field(None, description="Backend base URL")
|
|
171
|
+
temperature: Optional[float] = Field(None, description="Sampling temperature")
|
|
172
|
+
|
|
173
|
+
# ============================================================================
|
|
174
|
+
# Health & Info Endpoints
|
|
175
|
+
# ============================================================================
|
|
176
|
+
|
|
177
|
+
@app.get("/api/v1/info", tags=["General"])
|
|
178
|
+
async def api_info():
|
|
179
|
+
"""API information endpoint"""
|
|
180
|
+
return {
|
|
181
|
+
"name": "Mem-LLM API",
|
|
182
|
+
"version": "2.0.0",
|
|
183
|
+
"status": "running",
|
|
184
|
+
"documentation": "/docs",
|
|
185
|
+
"endpoints": {
|
|
186
|
+
"chat": "/api/v1/chat",
|
|
187
|
+
"websocket": "/ws/chat/{user_id}",
|
|
188
|
+
"knowledge_base": "/api/v1/kb",
|
|
189
|
+
"memory": "/api/v1/memory",
|
|
190
|
+
"users": "/api/v1/users"
|
|
191
|
+
}
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
@app.get("/health", tags=["General"])
|
|
195
|
+
async def health_check():
|
|
196
|
+
"""Health check endpoint"""
|
|
197
|
+
return {
|
|
198
|
+
"status": "healthy",
|
|
199
|
+
"timestamp": datetime.now().isoformat(),
|
|
200
|
+
"active_users": len(agents)
|
|
201
|
+
}
|
|
202
|
+
|
|
203
|
+
# ============================================================================
|
|
204
|
+
# Chat Endpoints
|
|
205
|
+
# ============================================================================
|
|
206
|
+
|
|
207
|
+
@app.post("/api/v1/chat", response_model=ChatResponse_API, tags=["Chat"])
|
|
208
|
+
async def chat(request: ChatRequest):
|
|
209
|
+
"""
|
|
210
|
+
Send a chat message and get response
|
|
211
|
+
|
|
212
|
+
This endpoint supports both regular and streaming responses.
|
|
213
|
+
For streaming, use the WebSocket endpoint instead.
|
|
214
|
+
"""
|
|
215
|
+
try:
|
|
216
|
+
# Get or create agent for user
|
|
217
|
+
agent = get_or_create_agent(request.user_id)
|
|
218
|
+
|
|
219
|
+
# Get response
|
|
220
|
+
if request.return_metrics:
|
|
221
|
+
response = agent.chat(
|
|
222
|
+
message=request.message,
|
|
223
|
+
metadata=request.metadata,
|
|
224
|
+
return_metrics=True
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
return ChatResponse_API(
|
|
228
|
+
text=response.text,
|
|
229
|
+
user_id=request.user_id,
|
|
230
|
+
confidence=response.confidence,
|
|
231
|
+
source=response.source,
|
|
232
|
+
latency=response.latency,
|
|
233
|
+
timestamp=response.timestamp.isoformat(),
|
|
234
|
+
metadata=response.metadata
|
|
235
|
+
)
|
|
236
|
+
else:
|
|
237
|
+
response_text = agent.chat(
|
|
238
|
+
message=request.message,
|
|
239
|
+
metadata=request.metadata
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
return ChatResponse_API(
|
|
243
|
+
text=response_text,
|
|
244
|
+
user_id=request.user_id,
|
|
245
|
+
timestamp=datetime.now().isoformat()
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
except Exception as e:
|
|
249
|
+
logger.error(f"Chat error for user {request.user_id}: {e}")
|
|
250
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
251
|
+
|
|
252
|
+
@app.post("/api/v1/chat/stream", tags=["Chat"])
|
|
253
|
+
async def chat_stream(request: ChatRequest):
|
|
254
|
+
"""
|
|
255
|
+
Send a chat message and get streaming response
|
|
256
|
+
|
|
257
|
+
Returns a Server-Sent Events (SSE) stream.
|
|
258
|
+
"""
|
|
259
|
+
try:
|
|
260
|
+
agent = get_or_create_agent(request.user_id)
|
|
261
|
+
|
|
262
|
+
async def generate():
|
|
263
|
+
"""Generate streaming response"""
|
|
264
|
+
try:
|
|
265
|
+
for chunk in agent.chat_stream(
|
|
266
|
+
message=request.message,
|
|
267
|
+
metadata=request.metadata
|
|
268
|
+
):
|
|
269
|
+
# Send as SSE format
|
|
270
|
+
yield f"data: {json.dumps({'chunk': chunk})}\n\n"
|
|
271
|
+
|
|
272
|
+
# Send completion marker
|
|
273
|
+
yield f"data: {json.dumps({'done': True})}\n\n"
|
|
274
|
+
|
|
275
|
+
except Exception as e:
|
|
276
|
+
logger.error(f"Streaming error: {e}")
|
|
277
|
+
yield f"data: {json.dumps({'error': str(e)})}\n\n"
|
|
278
|
+
|
|
279
|
+
return StreamingResponse(
|
|
280
|
+
generate(),
|
|
281
|
+
media_type="text/event-stream",
|
|
282
|
+
headers={
|
|
283
|
+
"Cache-Control": "no-cache",
|
|
284
|
+
"X-Accel-Buffering": "no"
|
|
285
|
+
}
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
except Exception as e:
|
|
289
|
+
logger.error(f"Chat stream error for user {request.user_id}: {e}")
|
|
290
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
291
|
+
|
|
292
|
+
# ============================================================================
|
|
293
|
+
# WebSocket Endpoint (For Real-time Streaming)
|
|
294
|
+
# ============================================================================
|
|
295
|
+
|
|
296
|
+
@app.websocket("/ws/chat/{user_id}")
|
|
297
|
+
async def websocket_chat(websocket: WebSocket, user_id: str):
|
|
298
|
+
"""
|
|
299
|
+
WebSocket endpoint for real-time streaming chat
|
|
300
|
+
|
|
301
|
+
Client sends: {"message": "Hello", "metadata": {}}
|
|
302
|
+
Server streams: {"type": "chunk", "content": "..."} or {"type": "done"}
|
|
303
|
+
"""
|
|
304
|
+
await websocket.accept()
|
|
305
|
+
logger.info(f"WebSocket connected: {user_id}")
|
|
306
|
+
|
|
307
|
+
try:
|
|
308
|
+
# Get or create agent
|
|
309
|
+
agent = get_or_create_agent(user_id)
|
|
310
|
+
|
|
311
|
+
while True:
|
|
312
|
+
# Receive message from client
|
|
313
|
+
data = await websocket.receive_json()
|
|
314
|
+
message = data.get("message", "")
|
|
315
|
+
metadata = data.get("metadata")
|
|
316
|
+
|
|
317
|
+
if not message:
|
|
318
|
+
await websocket.send_json({"type": "error", "content": "Empty message"})
|
|
319
|
+
continue
|
|
320
|
+
|
|
321
|
+
# Send acknowledgment
|
|
322
|
+
await websocket.send_json({"type": "start"})
|
|
323
|
+
|
|
324
|
+
# Stream response
|
|
325
|
+
try:
|
|
326
|
+
for chunk in agent.chat_stream(message=message, metadata=metadata):
|
|
327
|
+
await websocket.send_json({
|
|
328
|
+
"type": "chunk",
|
|
329
|
+
"content": chunk
|
|
330
|
+
})
|
|
331
|
+
# Small delay to prevent overwhelming the client
|
|
332
|
+
await asyncio.sleep(0.01)
|
|
333
|
+
|
|
334
|
+
# Send completion
|
|
335
|
+
await websocket.send_json({"type": "done"})
|
|
336
|
+
|
|
337
|
+
except Exception as e:
|
|
338
|
+
logger.error(f"Error during streaming: {e}")
|
|
339
|
+
await websocket.send_json({
|
|
340
|
+
"type": "error",
|
|
341
|
+
"content": str(e)
|
|
342
|
+
})
|
|
343
|
+
|
|
344
|
+
except WebSocketDisconnect:
|
|
345
|
+
logger.info(f"WebSocket disconnected: {user_id}")
|
|
346
|
+
except Exception as e:
|
|
347
|
+
logger.error(f"WebSocket error for {user_id}: {e}")
|
|
348
|
+
try:
|
|
349
|
+
await websocket.send_json({"type": "error", "content": str(e)})
|
|
350
|
+
except:
|
|
351
|
+
pass
|
|
352
|
+
|
|
353
|
+
# ============================================================================
|
|
354
|
+
# Knowledge Base Endpoints
|
|
355
|
+
# ============================================================================
|
|
356
|
+
|
|
357
|
+
@app.post("/api/v1/kb/add", tags=["Knowledge Base"])
|
|
358
|
+
async def add_knowledge(entry: KnowledgeEntryRequest, user_id: str = "admin"):
|
|
359
|
+
"""Add entry to knowledge base"""
|
|
360
|
+
try:
|
|
361
|
+
agent = get_or_create_agent(user_id)
|
|
362
|
+
agent.add_kb_entry(
|
|
363
|
+
category=entry.category,
|
|
364
|
+
question=entry.question,
|
|
365
|
+
answer=entry.answer
|
|
366
|
+
)
|
|
367
|
+
return {"status": "success", "message": "Entry added to knowledge base"}
|
|
368
|
+
except Exception as e:
|
|
369
|
+
logger.error(f"KB add error: {e}")
|
|
370
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
371
|
+
|
|
372
|
+
@app.post("/api/v1/kb/search", tags=["Knowledge Base"])
|
|
373
|
+
async def search_knowledge(search: KnowledgeSearchRequest, user_id: str = "admin"):
|
|
374
|
+
"""Search knowledge base"""
|
|
375
|
+
try:
|
|
376
|
+
agent = get_or_create_agent(user_id)
|
|
377
|
+
|
|
378
|
+
if hasattr(agent.memory, 'search_knowledge'):
|
|
379
|
+
results = agent.memory.search_knowledge(
|
|
380
|
+
query=search.query,
|
|
381
|
+
limit=search.limit
|
|
382
|
+
)
|
|
383
|
+
return {"results": results, "count": len(results)}
|
|
384
|
+
else:
|
|
385
|
+
raise HTTPException(
|
|
386
|
+
status_code=400,
|
|
387
|
+
detail="Knowledge base not available. Use use_sql=True"
|
|
388
|
+
)
|
|
389
|
+
except HTTPException:
|
|
390
|
+
raise
|
|
391
|
+
except Exception as e:
|
|
392
|
+
logger.error(f"KB search error: {e}")
|
|
393
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
394
|
+
|
|
395
|
+
@app.get("/api/v1/kb/categories", tags=["Knowledge Base"])
|
|
396
|
+
async def get_kb_categories(user_id: str = "admin"):
|
|
397
|
+
"""Get all knowledge base categories"""
|
|
398
|
+
try:
|
|
399
|
+
agent = get_or_create_agent(user_id)
|
|
400
|
+
|
|
401
|
+
if hasattr(agent.memory, 'get_kb_categories'):
|
|
402
|
+
categories = agent.memory.get_kb_categories()
|
|
403
|
+
return {"categories": categories, "count": len(categories)}
|
|
404
|
+
else:
|
|
405
|
+
return {"categories": [], "count": 0}
|
|
406
|
+
except Exception as e:
|
|
407
|
+
logger.error(f"KB categories error: {e}")
|
|
408
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
409
|
+
|
|
410
|
+
# ============================================================================
|
|
411
|
+
# Memory & User Endpoints
|
|
412
|
+
# ============================================================================
|
|
413
|
+
|
|
414
|
+
@app.get("/api/v1/users/{user_id}/profile", response_model=UserProfileResponse, tags=["Users"])
|
|
415
|
+
async def get_user_profile(user_id: str):
|
|
416
|
+
"""Get user profile"""
|
|
417
|
+
try:
|
|
418
|
+
agent = get_or_create_agent(user_id)
|
|
419
|
+
profile = agent.get_user_profile()
|
|
420
|
+
|
|
421
|
+
return UserProfileResponse(
|
|
422
|
+
user_id=user_id,
|
|
423
|
+
name=profile.get("name"),
|
|
424
|
+
preferences=profile.get("preferences"),
|
|
425
|
+
summary=profile.get("summary"),
|
|
426
|
+
interaction_count=len(profile.get("conversations", []))
|
|
427
|
+
)
|
|
428
|
+
except Exception as e:
|
|
429
|
+
logger.error(f"Profile error for {user_id}: {e}")
|
|
430
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
431
|
+
|
|
432
|
+
@app.post("/api/v1/memory/search", tags=["Memory"])
|
|
433
|
+
async def search_memory(search: MemorySearchRequest):
|
|
434
|
+
"""Search user's memory"""
|
|
435
|
+
try:
|
|
436
|
+
agent = get_or_create_agent(search.user_id)
|
|
437
|
+
# Use memory manager's get_conversation_history instead
|
|
438
|
+
history = agent.memory.get_conversation_history(user_id=search.user_id)
|
|
439
|
+
|
|
440
|
+
# Filter by query if provided
|
|
441
|
+
if search.query:
|
|
442
|
+
results = [
|
|
443
|
+
msg for msg in history
|
|
444
|
+
if search.query.lower() in str(msg).lower()
|
|
445
|
+
][:search.limit]
|
|
446
|
+
else:
|
|
447
|
+
results = history[:search.limit]
|
|
448
|
+
|
|
449
|
+
return {"results": results, "count": len(results)}
|
|
450
|
+
except Exception as e:
|
|
451
|
+
logger.error(f"Memory search error: {e}")
|
|
452
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
453
|
+
|
|
454
|
+
@app.get("/api/v1/memory/stats", tags=["Memory"])
|
|
455
|
+
async def get_memory_stats():
|
|
456
|
+
"""Get memory statistics"""
|
|
457
|
+
try:
|
|
458
|
+
total_memories = 0
|
|
459
|
+
total_users = len(agents)
|
|
460
|
+
|
|
461
|
+
# Count memories from all agents
|
|
462
|
+
for agent in agents.values():
|
|
463
|
+
try:
|
|
464
|
+
if hasattr(agent.memory, 'get_all_users'):
|
|
465
|
+
users = agent.memory.get_all_users()
|
|
466
|
+
total_users = len(users)
|
|
467
|
+
for user_id in users:
|
|
468
|
+
history = agent.memory.get_conversation_history(user_id=user_id)
|
|
469
|
+
total_memories += len(history)
|
|
470
|
+
break # Only need one agent for DB stats
|
|
471
|
+
except:
|
|
472
|
+
pass
|
|
473
|
+
|
|
474
|
+
return {
|
|
475
|
+
"total_users": total_users,
|
|
476
|
+
"total_memories": total_memories,
|
|
477
|
+
"active_agents": len(agents)
|
|
478
|
+
}
|
|
479
|
+
except Exception as e:
|
|
480
|
+
logger.error(f"Memory stats error: {e}")
|
|
481
|
+
# Return empty stats instead of error
|
|
482
|
+
return {
|
|
483
|
+
"total_users": len(agents),
|
|
484
|
+
"total_memories": 0,
|
|
485
|
+
"active_agents": len(agents)
|
|
486
|
+
}
|
|
487
|
+
|
|
488
|
+
@app.delete("/api/v1/users/{user_id}/memory", tags=["Users"])
|
|
489
|
+
async def clear_user_memory(user_id: str):
|
|
490
|
+
"""Clear user's memory"""
|
|
491
|
+
try:
|
|
492
|
+
if user_id in agents:
|
|
493
|
+
agent = agents[user_id]
|
|
494
|
+
# Clear memory (implementation depends on memory backend)
|
|
495
|
+
if hasattr(agent.memory, 'clear_user'):
|
|
496
|
+
agent.memory.clear_user(user_id)
|
|
497
|
+
|
|
498
|
+
# Remove agent from cache
|
|
499
|
+
del agents[user_id]
|
|
500
|
+
|
|
501
|
+
return {"status": "success", "message": f"Memory cleared for user {user_id}"}
|
|
502
|
+
else:
|
|
503
|
+
raise HTTPException(status_code=404, detail="User not found")
|
|
504
|
+
except HTTPException:
|
|
505
|
+
raise
|
|
506
|
+
except Exception as e:
|
|
507
|
+
logger.error(f"Clear memory error: {e}")
|
|
508
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
509
|
+
|
|
510
|
+
# ============================================================================
|
|
511
|
+
# Agent Configuration
|
|
512
|
+
# ============================================================================
|
|
513
|
+
|
|
514
|
+
@app.post("/api/v1/agent/configure/{user_id}", tags=["Agent"])
|
|
515
|
+
async def configure_agent(user_id: str, config: AgentConfigRequest):
|
|
516
|
+
"""Configure agent settings for a user"""
|
|
517
|
+
try:
|
|
518
|
+
# Remove existing agent if exists
|
|
519
|
+
if user_id in agents:
|
|
520
|
+
del agents[user_id]
|
|
521
|
+
|
|
522
|
+
# Create new agent with config
|
|
523
|
+
config_dict = {k: v for k, v in config.dict().items() if v is not None}
|
|
524
|
+
agent = get_or_create_agent(user_id, config_dict)
|
|
525
|
+
|
|
526
|
+
return {
|
|
527
|
+
"status": "success",
|
|
528
|
+
"message": "Agent configured",
|
|
529
|
+
"config": agent.get_info()
|
|
530
|
+
}
|
|
531
|
+
except Exception as e:
|
|
532
|
+
logger.error(f"Configure error: {e}")
|
|
533
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
534
|
+
|
|
535
|
+
@app.get("/api/v1/agent/info/{user_id}", tags=["Agent"])
|
|
536
|
+
async def get_agent_info(user_id: str):
|
|
537
|
+
"""Get agent information"""
|
|
538
|
+
try:
|
|
539
|
+
agent = get_or_create_agent(user_id)
|
|
540
|
+
return agent.get_info()
|
|
541
|
+
except Exception as e:
|
|
542
|
+
logger.error(f"Agent info error: {e}")
|
|
543
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
544
|
+
|
|
545
|
+
# ============================================================================
|
|
546
|
+
# Main Entry Point
|
|
547
|
+
# ============================================================================
|
|
548
|
+
|
|
549
|
+
if __name__ == "__main__":
|
|
550
|
+
import uvicorn
|
|
551
|
+
|
|
552
|
+
print("\n" + "="*60)
|
|
553
|
+
print(" 🚀 Starting Mem-LLM API Server")
|
|
554
|
+
print("="*60)
|
|
555
|
+
print("\n📝 API Documentation: http://localhost:8000/docs")
|
|
556
|
+
print("🔌 WebSocket endpoint: ws://localhost:8000/ws/chat/{user_id}")
|
|
557
|
+
print("\nPress CTRL+C to stop the server\n")
|
|
558
|
+
|
|
559
|
+
uvicorn.run(
|
|
560
|
+
"mem_llm.api_server:app",
|
|
561
|
+
host="0.0.0.0",
|
|
562
|
+
port=8000,
|
|
563
|
+
reload=True,
|
|
564
|
+
log_level="info"
|
|
565
|
+
)
|
|
566
|
+
|
|
567
|
+
# Mount Web UI static files
|
|
568
|
+
web_ui_path = Path(__file__).parent / "web_ui"
|
|
569
|
+
if web_ui_path.exists():
|
|
570
|
+
app.mount("/static", StaticFiles(directory=str(web_ui_path)), name="static")
|
|
571
|
+
|
|
572
|
+
@app.get("/")
|
|
573
|
+
async def root():
|
|
574
|
+
"""Serve Web UI index page"""
|
|
575
|
+
index_path = web_ui_path / "index.html"
|
|
576
|
+
if index_path.exists():
|
|
577
|
+
return FileResponse(str(index_path), media_type="text/html")
|
|
578
|
+
return {"message": "Mem-LLM API Server", "version": "2.0.0"}
|
|
579
|
+
|
|
580
|
+
@app.get("/memory")
|
|
581
|
+
async def memory_page():
|
|
582
|
+
"""Serve memory management page"""
|
|
583
|
+
memory_path = web_ui_path / "memory.html"
|
|
584
|
+
if memory_path.exists():
|
|
585
|
+
return FileResponse(str(memory_path), media_type="text/html")
|
|
586
|
+
return {"error": "Page not found"}
|
|
587
|
+
|
|
588
|
+
@app.get("/metrics")
|
|
589
|
+
async def metrics_page():
|
|
590
|
+
"""Serve metrics dashboard page"""
|
|
591
|
+
metrics_path = web_ui_path / "metrics.html"
|
|
592
|
+
if metrics_path.exists():
|
|
593
|
+
return FileResponse(str(metrics_path), media_type="text/html")
|
|
594
|
+
return {"error": "Page not found"}
|
|
595
|
+
|