mem-llm 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mem_llm/__init__.py +98 -0
- mem_llm/api_server.py +595 -0
- mem_llm/base_llm_client.py +201 -0
- mem_llm/builtin_tools.py +311 -0
- mem_llm/cli.py +254 -0
- mem_llm/clients/__init__.py +22 -0
- mem_llm/clients/lmstudio_client.py +393 -0
- mem_llm/clients/ollama_client.py +354 -0
- mem_llm/config.yaml.example +52 -0
- mem_llm/config_from_docs.py +180 -0
- mem_llm/config_manager.py +231 -0
- mem_llm/conversation_summarizer.py +372 -0
- mem_llm/data_export_import.py +640 -0
- mem_llm/dynamic_prompt.py +298 -0
- mem_llm/knowledge_loader.py +88 -0
- mem_llm/llm_client.py +225 -0
- mem_llm/llm_client_factory.py +260 -0
- mem_llm/logger.py +129 -0
- mem_llm/mem_agent.py +1611 -0
- mem_llm/memory_db.py +612 -0
- mem_llm/memory_manager.py +321 -0
- mem_llm/memory_tools.py +253 -0
- mem_llm/prompt_security.py +304 -0
- mem_llm/response_metrics.py +221 -0
- mem_llm/retry_handler.py +193 -0
- mem_llm/thread_safe_db.py +301 -0
- mem_llm/tool_system.py +429 -0
- mem_llm/vector_store.py +278 -0
- mem_llm/web_launcher.py +129 -0
- mem_llm/web_ui/README.md +44 -0
- mem_llm/web_ui/__init__.py +7 -0
- mem_llm/web_ui/index.html +641 -0
- mem_llm/web_ui/memory.html +569 -0
- mem_llm/web_ui/metrics.html +75 -0
- mem_llm-2.0.0.dist-info/METADATA +667 -0
- mem_llm-2.0.0.dist-info/RECORD +39 -0
- mem_llm-2.0.0.dist-info/WHEEL +5 -0
- mem_llm-2.0.0.dist-info/entry_points.txt +3 -0
- mem_llm-2.0.0.dist-info/top_level.txt +1 -0
mem_llm/tool_system.py
ADDED
|
@@ -0,0 +1,429 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Tool System for Function Calling
|
|
3
|
+
=================================
|
|
4
|
+
|
|
5
|
+
Enables agents to call external functions/tools to perform actions.
|
|
6
|
+
Inspired by OpenAI's function calling and LangChain's tool system.
|
|
7
|
+
|
|
8
|
+
Features:
|
|
9
|
+
- Decorator-based tool definition
|
|
10
|
+
- Automatic schema generation from type hints
|
|
11
|
+
- Tool execution with error handling
|
|
12
|
+
- Tool result formatting
|
|
13
|
+
- Built-in common tools
|
|
14
|
+
|
|
15
|
+
Author: C. Emre Karataş
|
|
16
|
+
Version: 2.0.0
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
import json
|
|
20
|
+
import inspect
|
|
21
|
+
import re
|
|
22
|
+
from typing import Callable, Dict, List, Any, Optional, get_type_hints
|
|
23
|
+
from dataclasses import dataclass, field
|
|
24
|
+
from enum import Enum
|
|
25
|
+
import logging
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class ToolCallStatus(Enum):
|
|
31
|
+
"""Status of tool call execution"""
|
|
32
|
+
SUCCESS = "success"
|
|
33
|
+
ERROR = "error"
|
|
34
|
+
NOT_FOUND = "not_found"
|
|
35
|
+
INVALID_ARGS = "invalid_args"
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@dataclass
|
|
39
|
+
class ToolParameter:
|
|
40
|
+
"""Tool parameter definition"""
|
|
41
|
+
name: str
|
|
42
|
+
type: str
|
|
43
|
+
description: str
|
|
44
|
+
required: bool = True
|
|
45
|
+
default: Any = None
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclass
|
|
49
|
+
class Tool:
|
|
50
|
+
"""Tool definition"""
|
|
51
|
+
name: str
|
|
52
|
+
description: str
|
|
53
|
+
function: Callable
|
|
54
|
+
parameters: List[ToolParameter] = field(default_factory=list)
|
|
55
|
+
return_type: str = "string"
|
|
56
|
+
category: str = "general"
|
|
57
|
+
|
|
58
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
59
|
+
"""Convert tool to dictionary format for LLM"""
|
|
60
|
+
return {
|
|
61
|
+
"name": self.name,
|
|
62
|
+
"description": self.description,
|
|
63
|
+
"category": self.category,
|
|
64
|
+
"parameters": {
|
|
65
|
+
"type": "object",
|
|
66
|
+
"properties": {
|
|
67
|
+
param.name: {
|
|
68
|
+
"type": param.type,
|
|
69
|
+
"description": param.description
|
|
70
|
+
}
|
|
71
|
+
for param in self.parameters
|
|
72
|
+
},
|
|
73
|
+
"required": [p.name for p in self.parameters if p.required]
|
|
74
|
+
},
|
|
75
|
+
"return_type": self.return_type
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
def execute(self, **kwargs) -> Any:
|
|
79
|
+
"""Execute the tool with given arguments"""
|
|
80
|
+
try:
|
|
81
|
+
# Validate required parameters
|
|
82
|
+
required_params = [p.name for p in self.parameters if p.required]
|
|
83
|
+
missing = [p for p in required_params if p not in kwargs]
|
|
84
|
+
if missing:
|
|
85
|
+
raise ValueError(f"Missing required parameters: {missing}")
|
|
86
|
+
|
|
87
|
+
# Execute function
|
|
88
|
+
result = self.function(**kwargs)
|
|
89
|
+
return result
|
|
90
|
+
except Exception as e:
|
|
91
|
+
logger.error(f"Tool execution error ({self.name}): {e}")
|
|
92
|
+
raise
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
@dataclass
|
|
96
|
+
class ToolCallResult:
|
|
97
|
+
"""Result of a tool call"""
|
|
98
|
+
tool_name: str
|
|
99
|
+
status: ToolCallStatus
|
|
100
|
+
result: Any = None
|
|
101
|
+
error: Optional[str] = None
|
|
102
|
+
execution_time: float = 0.0
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def tool(
|
|
106
|
+
name: Optional[str] = None,
|
|
107
|
+
description: Optional[str] = None,
|
|
108
|
+
category: str = "general"
|
|
109
|
+
) -> Callable:
|
|
110
|
+
"""
|
|
111
|
+
Decorator to define a tool/function that the agent can call.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
name: Tool name (defaults to function name)
|
|
115
|
+
description: Tool description (defaults to docstring)
|
|
116
|
+
category: Tool category for organization
|
|
117
|
+
|
|
118
|
+
Example:
|
|
119
|
+
@tool(name="calculate", description="Perform mathematical calculations")
|
|
120
|
+
def calculator(expression: str) -> float:
|
|
121
|
+
'''Evaluate a mathematical expression'''
|
|
122
|
+
return eval(expression)
|
|
123
|
+
"""
|
|
124
|
+
def decorator(func: Callable) -> Tool:
|
|
125
|
+
# Get function metadata
|
|
126
|
+
func_name = name or func.__name__
|
|
127
|
+
func_desc = description or (func.__doc__ or "").strip()
|
|
128
|
+
|
|
129
|
+
# Extract parameters from type hints
|
|
130
|
+
type_hints = get_type_hints(func)
|
|
131
|
+
sig = inspect.signature(func)
|
|
132
|
+
parameters = []
|
|
133
|
+
|
|
134
|
+
for param_name, param in sig.parameters.items():
|
|
135
|
+
if param_name in type_hints:
|
|
136
|
+
param_type = type_hints[param_name]
|
|
137
|
+
# Map Python types to JSON schema types
|
|
138
|
+
type_map = {
|
|
139
|
+
str: "string",
|
|
140
|
+
int: "integer",
|
|
141
|
+
float: "number",
|
|
142
|
+
bool: "boolean",
|
|
143
|
+
list: "array",
|
|
144
|
+
dict: "object"
|
|
145
|
+
}
|
|
146
|
+
json_type = type_map.get(param_type, "string")
|
|
147
|
+
else:
|
|
148
|
+
json_type = "string"
|
|
149
|
+
|
|
150
|
+
param_desc = f"Parameter: {param_name}"
|
|
151
|
+
required = param.default == inspect.Parameter.empty
|
|
152
|
+
|
|
153
|
+
parameters.append(ToolParameter(
|
|
154
|
+
name=param_name,
|
|
155
|
+
type=json_type,
|
|
156
|
+
description=param_desc,
|
|
157
|
+
required=required,
|
|
158
|
+
default=param.default if param.default != inspect.Parameter.empty else None
|
|
159
|
+
))
|
|
160
|
+
|
|
161
|
+
# Get return type
|
|
162
|
+
return_type = "string"
|
|
163
|
+
if "return" in type_hints:
|
|
164
|
+
ret_type = type_hints["return"]
|
|
165
|
+
type_map = {str: "string", int: "integer", float: "number", bool: "boolean"}
|
|
166
|
+
return_type = type_map.get(ret_type, "string")
|
|
167
|
+
|
|
168
|
+
# Create Tool object
|
|
169
|
+
tool_obj = Tool(
|
|
170
|
+
name=func_name,
|
|
171
|
+
description=func_desc,
|
|
172
|
+
function=func,
|
|
173
|
+
parameters=parameters,
|
|
174
|
+
return_type=return_type,
|
|
175
|
+
category=category
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
# Attach tool metadata to function
|
|
179
|
+
func._tool = tool_obj
|
|
180
|
+
return func
|
|
181
|
+
|
|
182
|
+
return decorator
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
class ToolRegistry:
|
|
186
|
+
"""Registry for managing available tools"""
|
|
187
|
+
|
|
188
|
+
def __init__(self):
|
|
189
|
+
self.tools: Dict[str, Tool] = {}
|
|
190
|
+
self._load_builtin_tools()
|
|
191
|
+
|
|
192
|
+
def _load_builtin_tools(self):
|
|
193
|
+
"""Load built-in tools"""
|
|
194
|
+
# Import built-in tools when available
|
|
195
|
+
try:
|
|
196
|
+
from .builtin_tools import BUILTIN_TOOLS
|
|
197
|
+
for tool_func in BUILTIN_TOOLS:
|
|
198
|
+
if hasattr(tool_func, '_tool'):
|
|
199
|
+
self.register(tool_func._tool)
|
|
200
|
+
except ImportError:
|
|
201
|
+
pass
|
|
202
|
+
|
|
203
|
+
def register(self, tool: Tool):
|
|
204
|
+
"""Register a tool"""
|
|
205
|
+
self.tools[tool.name] = tool
|
|
206
|
+
logger.info(f"Registered tool: {tool.name}")
|
|
207
|
+
|
|
208
|
+
def register_function(self, func: Callable):
|
|
209
|
+
"""Register a function as a tool"""
|
|
210
|
+
if hasattr(func, '_tool'):
|
|
211
|
+
self.register(func._tool)
|
|
212
|
+
else:
|
|
213
|
+
# Auto-create tool from function
|
|
214
|
+
tool_obj = tool()(func)
|
|
215
|
+
if hasattr(tool_obj, '_tool'):
|
|
216
|
+
self.register(tool_obj._tool)
|
|
217
|
+
|
|
218
|
+
def get(self, name: str) -> Optional[Tool]:
|
|
219
|
+
"""Get a tool by name"""
|
|
220
|
+
return self.tools.get(name)
|
|
221
|
+
|
|
222
|
+
def list_tools(self, category: Optional[str] = None) -> List[Tool]:
|
|
223
|
+
"""List all tools, optionally filtered by category"""
|
|
224
|
+
tools = list(self.tools.values())
|
|
225
|
+
if category:
|
|
226
|
+
tools = [t for t in tools if t.category == category]
|
|
227
|
+
return tools
|
|
228
|
+
|
|
229
|
+
def get_tools_schema(self) -> List[Dict[str, Any]]:
|
|
230
|
+
"""Get schema for all tools (for LLM prompt)"""
|
|
231
|
+
return [tool.to_dict() for tool in self.tools.values()]
|
|
232
|
+
|
|
233
|
+
def execute(self, tool_name: str, **kwargs) -> ToolCallResult:
|
|
234
|
+
"""Execute a tool by name"""
|
|
235
|
+
import time
|
|
236
|
+
start_time = time.time()
|
|
237
|
+
|
|
238
|
+
# Get tool
|
|
239
|
+
tool = self.get(tool_name)
|
|
240
|
+
if not tool:
|
|
241
|
+
return ToolCallResult(
|
|
242
|
+
tool_name=tool_name,
|
|
243
|
+
status=ToolCallStatus.NOT_FOUND,
|
|
244
|
+
error=f"Tool '{tool_name}' not found",
|
|
245
|
+
execution_time=time.time() - start_time
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
# Execute tool
|
|
249
|
+
try:
|
|
250
|
+
result = tool.execute(**kwargs)
|
|
251
|
+
return ToolCallResult(
|
|
252
|
+
tool_name=tool_name,
|
|
253
|
+
status=ToolCallStatus.SUCCESS,
|
|
254
|
+
result=result,
|
|
255
|
+
execution_time=time.time() - start_time
|
|
256
|
+
)
|
|
257
|
+
except ValueError as e:
|
|
258
|
+
return ToolCallResult(
|
|
259
|
+
tool_name=tool_name,
|
|
260
|
+
status=ToolCallStatus.INVALID_ARGS,
|
|
261
|
+
error=str(e),
|
|
262
|
+
execution_time=time.time() - start_time
|
|
263
|
+
)
|
|
264
|
+
except Exception as e:
|
|
265
|
+
return ToolCallResult(
|
|
266
|
+
tool_name=tool_name,
|
|
267
|
+
status=ToolCallStatus.ERROR,
|
|
268
|
+
error=str(e),
|
|
269
|
+
execution_time=time.time() - start_time
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
class ToolCallParser:
|
|
274
|
+
"""Parse LLM output to detect and extract tool calls"""
|
|
275
|
+
|
|
276
|
+
# Pattern to detect tool calls in LLM output
|
|
277
|
+
# Format: TOOL_CALL: tool_name(arg1="value1", arg2="value2")
|
|
278
|
+
TOOL_CALL_PATTERN = r'TOOL_CALL:\s*(\w+)\((.*?)\)'
|
|
279
|
+
|
|
280
|
+
@staticmethod
|
|
281
|
+
def parse(text: str) -> List[Dict[str, Any]]:
|
|
282
|
+
"""
|
|
283
|
+
Parse text to extract tool calls.
|
|
284
|
+
|
|
285
|
+
Returns:
|
|
286
|
+
List of dicts with 'tool' and 'arguments' keys
|
|
287
|
+
"""
|
|
288
|
+
tool_calls = []
|
|
289
|
+
|
|
290
|
+
# Find all tool call matches
|
|
291
|
+
matches = re.finditer(ToolCallParser.TOOL_CALL_PATTERN, text, re.MULTILINE)
|
|
292
|
+
|
|
293
|
+
for match in matches:
|
|
294
|
+
tool_name = match.group(1)
|
|
295
|
+
args_str = match.group(2)
|
|
296
|
+
|
|
297
|
+
# Parse arguments
|
|
298
|
+
arguments = {}
|
|
299
|
+
if args_str.strip():
|
|
300
|
+
try:
|
|
301
|
+
# Try to parse as Python kwargs
|
|
302
|
+
# Handle both key="value" and positional args
|
|
303
|
+
args_dict = {}
|
|
304
|
+
|
|
305
|
+
# Split by comma, but respect quotes and parentheses
|
|
306
|
+
parts = []
|
|
307
|
+
current = ""
|
|
308
|
+
in_quotes = False
|
|
309
|
+
paren_depth = 0
|
|
310
|
+
quote_char = None
|
|
311
|
+
|
|
312
|
+
for char in args_str:
|
|
313
|
+
if char in ['"', "'"] and quote_char is None:
|
|
314
|
+
quote_char = char
|
|
315
|
+
in_quotes = True
|
|
316
|
+
current += char
|
|
317
|
+
elif char == quote_char:
|
|
318
|
+
in_quotes = False
|
|
319
|
+
quote_char = None
|
|
320
|
+
current += char
|
|
321
|
+
elif char == '(' and not in_quotes:
|
|
322
|
+
paren_depth += 1
|
|
323
|
+
current += char
|
|
324
|
+
elif char == ')' and not in_quotes:
|
|
325
|
+
paren_depth -= 1
|
|
326
|
+
current += char
|
|
327
|
+
elif char == ',' and not in_quotes and paren_depth == 0:
|
|
328
|
+
if current.strip():
|
|
329
|
+
parts.append(current.strip())
|
|
330
|
+
current = ""
|
|
331
|
+
else:
|
|
332
|
+
current += char
|
|
333
|
+
|
|
334
|
+
if current.strip():
|
|
335
|
+
parts.append(current.strip())
|
|
336
|
+
|
|
337
|
+
# Parse each part
|
|
338
|
+
for i, part in enumerate(parts):
|
|
339
|
+
if '=' in part and not part.strip().startswith('"'):
|
|
340
|
+
key, value = part.split('=', 1)
|
|
341
|
+
key = key.strip()
|
|
342
|
+
value = value.strip().strip('"\'')
|
|
343
|
+
args_dict[key] = value
|
|
344
|
+
else:
|
|
345
|
+
# Positional argument - use index as key
|
|
346
|
+
value = part.strip().strip('"\'')
|
|
347
|
+
# Try to infer parameter name from common patterns
|
|
348
|
+
if i == 0 and value:
|
|
349
|
+
# First arg is usually the main parameter
|
|
350
|
+
if tool_name == 'calculate':
|
|
351
|
+
args_dict['expression'] = value
|
|
352
|
+
elif tool_name in ['count_words', 'reverse_text', 'to_uppercase', 'to_lowercase']:
|
|
353
|
+
args_dict['text'] = value
|
|
354
|
+
elif tool_name == 'get_weather':
|
|
355
|
+
args_dict['city'] = value
|
|
356
|
+
elif tool_name in ['read_file', 'write_file']:
|
|
357
|
+
args_dict['filepath'] = value
|
|
358
|
+
else:
|
|
359
|
+
args_dict[f'arg{i}'] = value
|
|
360
|
+
|
|
361
|
+
arguments = args_dict
|
|
362
|
+
except Exception as e:
|
|
363
|
+
logger.warning(f"Failed to parse arguments: {args_str} - Error: {e}")
|
|
364
|
+
|
|
365
|
+
tool_calls.append({
|
|
366
|
+
"tool": tool_name,
|
|
367
|
+
"arguments": arguments
|
|
368
|
+
})
|
|
369
|
+
|
|
370
|
+
return tool_calls
|
|
371
|
+
|
|
372
|
+
@staticmethod
|
|
373
|
+
def has_tool_call(text: str) -> bool:
|
|
374
|
+
"""Check if text contains a tool call"""
|
|
375
|
+
return bool(re.search(ToolCallParser.TOOL_CALL_PATTERN, text))
|
|
376
|
+
|
|
377
|
+
@staticmethod
|
|
378
|
+
def remove_tool_calls(text: str) -> str:
|
|
379
|
+
"""Remove tool call syntax from text, keeping only natural language"""
|
|
380
|
+
return re.sub(ToolCallParser.TOOL_CALL_PATTERN, '', text).strip()
|
|
381
|
+
|
|
382
|
+
|
|
383
|
+
def format_tools_for_prompt(tools: List[Tool]) -> str:
|
|
384
|
+
"""
|
|
385
|
+
Format tools as a string for LLM prompt.
|
|
386
|
+
|
|
387
|
+
Returns:
|
|
388
|
+
Formatted string describing available tools
|
|
389
|
+
"""
|
|
390
|
+
if not tools:
|
|
391
|
+
return ""
|
|
392
|
+
|
|
393
|
+
lines = ["You have access to the following tools:"]
|
|
394
|
+
lines.append("")
|
|
395
|
+
|
|
396
|
+
for tool in tools:
|
|
397
|
+
lines.append(f"- **{tool.name}**: {tool.description}")
|
|
398
|
+
|
|
399
|
+
if tool.parameters:
|
|
400
|
+
lines.append(" Parameters:")
|
|
401
|
+
for param in tool.parameters:
|
|
402
|
+
req = "required" if param.required else "optional"
|
|
403
|
+
lines.append(f" - {param.name} ({param.type}, {req}): {param.description}")
|
|
404
|
+
|
|
405
|
+
lines.append("")
|
|
406
|
+
|
|
407
|
+
lines.append("=" * 80)
|
|
408
|
+
lines.append("TOOL USAGE INSTRUCTIONS:")
|
|
409
|
+
lines.append("-" * 80)
|
|
410
|
+
lines.append("To call a tool, use EXACTLY this format:")
|
|
411
|
+
lines.append(' TOOL_CALL: tool_name(param1="value1", param2="value2")')
|
|
412
|
+
lines.append("")
|
|
413
|
+
lines.append("EXAMPLES:")
|
|
414
|
+
lines.append(' TOOL_CALL: calculate(expression="(25 * 4) + 10")')
|
|
415
|
+
lines.append(' TOOL_CALL: count_words(text="Hello world from AI")')
|
|
416
|
+
lines.append(' TOOL_CALL: get_current_time()')
|
|
417
|
+
lines.append(' TOOL_CALL: read_file(filepath="data.txt")')
|
|
418
|
+
lines.append("")
|
|
419
|
+
lines.append("IMPORTANT RULES:")
|
|
420
|
+
lines.append(" 1. Always use named parameters (param=\"value\")")
|
|
421
|
+
lines.append(" 2. Put ALL parameters inside the parentheses")
|
|
422
|
+
lines.append(" 3. Use double quotes for string values")
|
|
423
|
+
lines.append(" 4. One tool call per line")
|
|
424
|
+
lines.append(" 5. After tool execution, you will receive results to continue your response")
|
|
425
|
+
lines.append("=" * 80)
|
|
426
|
+
lines.append("")
|
|
427
|
+
|
|
428
|
+
return "\n".join(lines)
|
|
429
|
+
|
mem_llm/vector_store.py
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Vector Store Abstraction Layer
|
|
3
|
+
Supports multiple vector databases (Chroma, FAISS, etc.)
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
from typing import List, Dict, Optional, Any
|
|
8
|
+
import logging
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class VectorStore(ABC):
|
|
14
|
+
"""Abstract interface for vector stores"""
|
|
15
|
+
|
|
16
|
+
@abstractmethod
|
|
17
|
+
def add_documents(self, documents: List[Dict[str, Any]]) -> None:
|
|
18
|
+
"""
|
|
19
|
+
Add documents to vector store
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
documents: List of dicts with 'id', 'text', 'metadata'
|
|
23
|
+
"""
|
|
24
|
+
pass
|
|
25
|
+
|
|
26
|
+
@abstractmethod
|
|
27
|
+
def search(self, query: str, limit: int = 5, filter_metadata: Optional[Dict] = None) -> List[Dict[str, Any]]:
|
|
28
|
+
"""
|
|
29
|
+
Search similar documents
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
query: Search query text
|
|
33
|
+
limit: Maximum number of results
|
|
34
|
+
filter_metadata: Optional metadata filters
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
List of similar documents with scores
|
|
38
|
+
"""
|
|
39
|
+
pass
|
|
40
|
+
|
|
41
|
+
@abstractmethod
|
|
42
|
+
def delete_collection(self) -> None:
|
|
43
|
+
"""Delete all vectors in collection"""
|
|
44
|
+
pass
|
|
45
|
+
|
|
46
|
+
@abstractmethod
|
|
47
|
+
def get_stats(self) -> Dict[str, Any]:
|
|
48
|
+
"""Get statistics about the vector store"""
|
|
49
|
+
pass
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
try:
|
|
53
|
+
import chromadb
|
|
54
|
+
from chromadb.config import Settings
|
|
55
|
+
CHROMA_AVAILABLE = True
|
|
56
|
+
except ImportError:
|
|
57
|
+
CHROMA_AVAILABLE = False
|
|
58
|
+
# Don't warn on import, only when actually trying to use it
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class ChromaVectorStore(VectorStore):
|
|
62
|
+
"""ChromaDB implementation of VectorStore"""
|
|
63
|
+
|
|
64
|
+
def __init__(self, collection_name: str = "knowledge_base",
|
|
65
|
+
persist_directory: Optional[str] = None,
|
|
66
|
+
embedding_model: str = "all-MiniLM-L6-v2"):
|
|
67
|
+
"""
|
|
68
|
+
Initialize ChromaDB vector store
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
collection_name: Name of the collection
|
|
72
|
+
persist_directory: Directory to persist data (None = in-memory)
|
|
73
|
+
embedding_model: Embedding model name (sentence-transformers compatible)
|
|
74
|
+
"""
|
|
75
|
+
if not CHROMA_AVAILABLE:
|
|
76
|
+
raise ImportError(
|
|
77
|
+
"ChromaDB is not installed. Install with: pip install chromadb"
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
self.collection_name = collection_name
|
|
81
|
+
self.persist_directory = persist_directory
|
|
82
|
+
self.embedding_model = embedding_model
|
|
83
|
+
|
|
84
|
+
# Initialize Chroma client
|
|
85
|
+
if persist_directory:
|
|
86
|
+
self.client = chromadb.PersistentClient(path=persist_directory)
|
|
87
|
+
else:
|
|
88
|
+
self.client = chromadb.Client()
|
|
89
|
+
|
|
90
|
+
# Lazy load embedding model
|
|
91
|
+
self._embedding_fn = None
|
|
92
|
+
|
|
93
|
+
# Get or create collection with embedding function
|
|
94
|
+
try:
|
|
95
|
+
# Create embedding function
|
|
96
|
+
embedding_fn = self._get_embedding_function()
|
|
97
|
+
|
|
98
|
+
self.collection = self.client.get_or_create_collection(
|
|
99
|
+
name=collection_name,
|
|
100
|
+
embedding_function=embedding_fn,
|
|
101
|
+
metadata={"hnsw:space": "cosine"}
|
|
102
|
+
)
|
|
103
|
+
except Exception as e:
|
|
104
|
+
logger.error(f"Failed to create Chroma collection: {e}")
|
|
105
|
+
raise
|
|
106
|
+
|
|
107
|
+
def _get_embedding_function(self):
|
|
108
|
+
"""Lazy load embedding function"""
|
|
109
|
+
if self._embedding_fn is None:
|
|
110
|
+
try:
|
|
111
|
+
# Try to use ChromaDB's native SentenceTransformerEmbeddingFunction
|
|
112
|
+
try:
|
|
113
|
+
# Try different import paths for ChromaDB embedding functions
|
|
114
|
+
try:
|
|
115
|
+
from chromadb.utils import embedding_functions
|
|
116
|
+
embedding_fn_class = embedding_functions.SentenceTransformerEmbeddingFunction
|
|
117
|
+
except (ImportError, AttributeError):
|
|
118
|
+
try:
|
|
119
|
+
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction as embedding_fn_class
|
|
120
|
+
except ImportError:
|
|
121
|
+
embedding_fn_class = None
|
|
122
|
+
|
|
123
|
+
if embedding_fn_class:
|
|
124
|
+
self._embedding_fn = embedding_fn_class(model_name=self.embedding_model)
|
|
125
|
+
logger.info(f"Loaded embedding model using ChromaDB native function: {self.embedding_model}")
|
|
126
|
+
else:
|
|
127
|
+
raise AttributeError("SentenceTransformerEmbeddingFunction not found")
|
|
128
|
+
|
|
129
|
+
except (ImportError, AttributeError, Exception) as e:
|
|
130
|
+
# Fallback: Custom embedding function wrapper compatible with ChromaDB
|
|
131
|
+
from sentence_transformers import SentenceTransformer
|
|
132
|
+
model = SentenceTransformer(self.embedding_model)
|
|
133
|
+
|
|
134
|
+
class CustomEmbeddingFunction:
|
|
135
|
+
def __init__(self, model, model_name):
|
|
136
|
+
self.model = model
|
|
137
|
+
self.model_name = model_name
|
|
138
|
+
self.name = model_name # ChromaDB may check for 'name' attribute
|
|
139
|
+
|
|
140
|
+
def __call__(self, texts: List[str]) -> List[List[float]]:
|
|
141
|
+
embeddings = self.model.encode(texts, show_progress_bar=False)
|
|
142
|
+
return embeddings.tolist()
|
|
143
|
+
|
|
144
|
+
def encode_queries(self, queries: List[str]) -> List[List[float]]:
|
|
145
|
+
return self.__call__(queries)
|
|
146
|
+
|
|
147
|
+
self._embedding_fn = CustomEmbeddingFunction(model, self.embedding_model)
|
|
148
|
+
logger.info(f"Loaded embedding model using custom wrapper: {self.embedding_model} (fallback: {e})")
|
|
149
|
+
except ImportError:
|
|
150
|
+
raise ImportError(
|
|
151
|
+
"sentence-transformers not installed. "
|
|
152
|
+
"Install with: pip install sentence-transformers"
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
return self._embedding_fn
|
|
156
|
+
|
|
157
|
+
def add_documents(self, documents: List[Dict[str, Any]]) -> None:
|
|
158
|
+
"""Add documents to ChromaDB"""
|
|
159
|
+
if not documents:
|
|
160
|
+
return
|
|
161
|
+
|
|
162
|
+
# Prepare data
|
|
163
|
+
ids = []
|
|
164
|
+
texts = []
|
|
165
|
+
metadatas = []
|
|
166
|
+
|
|
167
|
+
for doc in documents:
|
|
168
|
+
doc_id = str(doc.get('id', doc.get('text', ''))[:100])
|
|
169
|
+
# Ensure unique IDs
|
|
170
|
+
if doc_id in ids:
|
|
171
|
+
doc_id = f"{doc_id}_{len(ids)}"
|
|
172
|
+
ids.append(doc_id)
|
|
173
|
+
texts.append(doc['text'])
|
|
174
|
+
# Ensure metadata values are JSON-serializable
|
|
175
|
+
metadata = doc.get('metadata', {})
|
|
176
|
+
clean_metadata = {}
|
|
177
|
+
for k, v in metadata.items():
|
|
178
|
+
if isinstance(v, (str, int, float, bool)) or v is None:
|
|
179
|
+
clean_metadata[k] = v
|
|
180
|
+
else:
|
|
181
|
+
clean_metadata[k] = str(v)
|
|
182
|
+
metadatas.append(clean_metadata)
|
|
183
|
+
|
|
184
|
+
# Add to collection (Chroma will use embedding function automatically)
|
|
185
|
+
try:
|
|
186
|
+
self.collection.add(
|
|
187
|
+
ids=ids,
|
|
188
|
+
documents=texts,
|
|
189
|
+
metadatas=metadatas
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
logger.debug(f"Added {len(documents)} documents to Chroma")
|
|
193
|
+
except Exception as e:
|
|
194
|
+
logger.error(f"Error adding documents to Chroma: {e}")
|
|
195
|
+
raise
|
|
196
|
+
|
|
197
|
+
def search(self, query: str, limit: int = 5,
|
|
198
|
+
filter_metadata: Optional[Dict] = None) -> List[Dict[str, Any]]:
|
|
199
|
+
"""Search in ChromaDB"""
|
|
200
|
+
try:
|
|
201
|
+
# Build where clause for metadata filtering
|
|
202
|
+
where = None
|
|
203
|
+
if filter_metadata:
|
|
204
|
+
where = filter_metadata
|
|
205
|
+
|
|
206
|
+
# Search (Chroma will use embedding function automatically)
|
|
207
|
+
results = self.collection.query(
|
|
208
|
+
query_texts=[query],
|
|
209
|
+
n_results=limit,
|
|
210
|
+
where=where
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
# Format results
|
|
214
|
+
formatted_results = []
|
|
215
|
+
if results.get('documents') and len(results['documents']) > 0 and len(results['documents'][0]) > 0:
|
|
216
|
+
num_results = len(results['documents'][0])
|
|
217
|
+
distances = results.get('distances', [[0.0] * num_results])
|
|
218
|
+
|
|
219
|
+
for i in range(num_results):
|
|
220
|
+
# ChromaDB uses cosine distance (0 = identical, 1 = opposite)
|
|
221
|
+
# Convert to similarity score (1 = identical, 0 = opposite)
|
|
222
|
+
distance = distances[0][i] if distances and len(distances[0]) > i else 0.0
|
|
223
|
+
similarity = 1.0 - distance if distance <= 1.0 else max(0.0, 1.0 / (1.0 + distance))
|
|
224
|
+
|
|
225
|
+
formatted_results.append({
|
|
226
|
+
'id': results['ids'][0][i] if results.get('ids') and len(results['ids'][0]) > i else f"doc_{i}",
|
|
227
|
+
'text': results['documents'][0][i],
|
|
228
|
+
'metadata': results['metadatas'][0][i] if results.get('metadatas') and len(results['metadatas'][0]) > i else {},
|
|
229
|
+
'score': similarity
|
|
230
|
+
})
|
|
231
|
+
|
|
232
|
+
return formatted_results
|
|
233
|
+
except Exception as e:
|
|
234
|
+
logger.error(f"Error searching Chroma: {e}")
|
|
235
|
+
return []
|
|
236
|
+
|
|
237
|
+
def delete_collection(self) -> None:
|
|
238
|
+
"""Delete collection"""
|
|
239
|
+
try:
|
|
240
|
+
self.client.delete_collection(self.collection_name)
|
|
241
|
+
logger.info(f"Deleted Chroma collection: {self.collection_name}")
|
|
242
|
+
except Exception as e:
|
|
243
|
+
logger.error(f"Error deleting collection: {e}")
|
|
244
|
+
|
|
245
|
+
def get_stats(self) -> Dict[str, Any]:
|
|
246
|
+
"""Get collection statistics"""
|
|
247
|
+
try:
|
|
248
|
+
count = self.collection.count()
|
|
249
|
+
return {
|
|
250
|
+
'total_documents': count,
|
|
251
|
+
'collection_name': self.collection_name,
|
|
252
|
+
'embedding_model': self.embedding_model
|
|
253
|
+
}
|
|
254
|
+
except Exception as e:
|
|
255
|
+
logger.error(f"Error getting stats: {e}")
|
|
256
|
+
return {'total_documents': 0}
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def create_vector_store(store_type: str = "chroma", **kwargs) -> Optional[VectorStore]:
|
|
260
|
+
"""
|
|
261
|
+
Factory function to create vector store
|
|
262
|
+
|
|
263
|
+
Args:
|
|
264
|
+
store_type: Type of vector store ('chroma', 'faiss', etc.)
|
|
265
|
+
**kwargs: Store-specific parameters
|
|
266
|
+
|
|
267
|
+
Returns:
|
|
268
|
+
VectorStore instance or None if not available
|
|
269
|
+
"""
|
|
270
|
+
if store_type == "chroma":
|
|
271
|
+
if not CHROMA_AVAILABLE:
|
|
272
|
+
logger.info("ℹ️ Vector search disabled (ChromaDB not installed). For semantic search, run: pip install chromadb sentence-transformers")
|
|
273
|
+
return None
|
|
274
|
+
return ChromaVectorStore(**kwargs)
|
|
275
|
+
else:
|
|
276
|
+
logger.warning(f"Unknown vector store type: {store_type}")
|
|
277
|
+
return None
|
|
278
|
+
|