memorisdk 1.0.2__py3-none-any.whl → 2.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of memorisdk might be problematic. Click here for more details.
- memori/__init__.py +24 -8
- memori/agents/conscious_agent.py +252 -414
- memori/agents/memory_agent.py +487 -224
- memori/agents/retrieval_agent.py +491 -68
- memori/config/memory_manager.py +323 -0
- memori/core/conversation.py +393 -0
- memori/core/database.py +386 -371
- memori/core/memory.py +1683 -532
- memori/core/providers.py +217 -0
- memori/database/adapters/__init__.py +10 -0
- memori/database/adapters/mysql_adapter.py +331 -0
- memori/database/adapters/postgresql_adapter.py +291 -0
- memori/database/adapters/sqlite_adapter.py +229 -0
- memori/database/auto_creator.py +320 -0
- memori/database/connection_utils.py +207 -0
- memori/database/connectors/base_connector.py +283 -0
- memori/database/connectors/mysql_connector.py +240 -18
- memori/database/connectors/postgres_connector.py +277 -4
- memori/database/connectors/sqlite_connector.py +178 -3
- memori/database/models.py +400 -0
- memori/database/queries/base_queries.py +1 -1
- memori/database/queries/memory_queries.py +91 -2
- memori/database/query_translator.py +222 -0
- memori/database/schema_generators/__init__.py +7 -0
- memori/database/schema_generators/mysql_schema_generator.py +215 -0
- memori/database/search/__init__.py +8 -0
- memori/database/search/mysql_search_adapter.py +255 -0
- memori/database/search/sqlite_search_adapter.py +180 -0
- memori/database/search_service.py +700 -0
- memori/database/sqlalchemy_manager.py +888 -0
- memori/integrations/__init__.py +36 -11
- memori/integrations/litellm_integration.py +340 -6
- memori/integrations/openai_integration.py +506 -240
- memori/tools/memory_tool.py +94 -4
- memori/utils/input_validator.py +395 -0
- memori/utils/pydantic_models.py +138 -36
- memori/utils/query_builder.py +530 -0
- memori/utils/security_audit.py +594 -0
- memori/utils/security_integration.py +339 -0
- memori/utils/transaction_manager.py +547 -0
- {memorisdk-1.0.2.dist-info → memorisdk-2.0.1.dist-info}/METADATA +56 -23
- memorisdk-2.0.1.dist-info/RECORD +66 -0
- memori/scripts/llm_text.py +0 -50
- memorisdk-1.0.2.dist-info/RECORD +0 -44
- memorisdk-1.0.2.dist-info/entry_points.txt +0 -2
- {memorisdk-1.0.2.dist-info → memorisdk-2.0.1.dist-info}/WHEEL +0 -0
- {memorisdk-1.0.2.dist-info → memorisdk-2.0.1.dist-info}/licenses/LICENSE +0 -0
- {memorisdk-1.0.2.dist-info → memorisdk-2.0.1.dist-info}/top_level.txt +0 -0
memori/tools/memory_tool.py
CHANGED
|
@@ -73,11 +73,24 @@ class MemoryTool:
|
|
|
73
73
|
|
|
74
74
|
# Use retrieval agent for intelligent search
|
|
75
75
|
try:
|
|
76
|
+
logger.debug(
|
|
77
|
+
f"Attempting to import MemorySearchEngine for query: '{query}'"
|
|
78
|
+
)
|
|
76
79
|
from ..agents.retrieval_agent import MemorySearchEngine
|
|
77
80
|
|
|
81
|
+
logger.debug("Successfully imported MemorySearchEngine")
|
|
82
|
+
|
|
78
83
|
# Create search engine if not already initialized
|
|
79
84
|
if not hasattr(self, "_search_engine"):
|
|
80
|
-
|
|
85
|
+
if (
|
|
86
|
+
hasattr(self.memori, "provider_config")
|
|
87
|
+
and self.memori.provider_config
|
|
88
|
+
):
|
|
89
|
+
self._search_engine = MemorySearchEngine(
|
|
90
|
+
provider_config=self.memori.provider_config
|
|
91
|
+
)
|
|
92
|
+
else:
|
|
93
|
+
self._search_engine = MemorySearchEngine()
|
|
81
94
|
|
|
82
95
|
# Execute search using retrieval agent
|
|
83
96
|
results = self._search_engine.execute_search(
|
|
@@ -88,18 +101,62 @@ class MemoryTool:
|
|
|
88
101
|
)
|
|
89
102
|
|
|
90
103
|
if not results:
|
|
104
|
+
logger.debug(
|
|
105
|
+
f"Primary search returned no results for query: '{query}', trying fallback search"
|
|
106
|
+
)
|
|
107
|
+
# Try fallback direct database search
|
|
108
|
+
try:
|
|
109
|
+
fallback_results = self.memori.db_manager.search_memories(
|
|
110
|
+
query=query, namespace=self.memori.namespace, limit=5
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
if fallback_results:
|
|
114
|
+
logger.debug(
|
|
115
|
+
f"Fallback search found {len(fallback_results)} results"
|
|
116
|
+
)
|
|
117
|
+
results = fallback_results
|
|
118
|
+
else:
|
|
119
|
+
logger.warning(
|
|
120
|
+
f"Both primary and fallback search returned no results for query: '{query}'"
|
|
121
|
+
)
|
|
122
|
+
return f"No relevant memories found for query: '{query}'"
|
|
123
|
+
|
|
124
|
+
except Exception as fallback_e:
|
|
125
|
+
logger.error(
|
|
126
|
+
f"Fallback search also failed for query '{query}': {fallback_e}"
|
|
127
|
+
)
|
|
128
|
+
return f"No relevant memories found for query: '{query}'"
|
|
129
|
+
|
|
130
|
+
# Ensure we have results to format
|
|
131
|
+
if not results:
|
|
132
|
+
logger.warning(
|
|
133
|
+
f"No results available for formatting for query: '{query}'"
|
|
134
|
+
)
|
|
91
135
|
return f"No relevant memories found for query: '{query}'"
|
|
92
136
|
|
|
93
137
|
# Format results as a readable string
|
|
138
|
+
logger.debug(
|
|
139
|
+
f"Starting to format {len(results)} results for query: '{query}'"
|
|
140
|
+
)
|
|
94
141
|
formatted_output = f"🔍 Memory Search Results for: '{query}'\n\n"
|
|
95
142
|
|
|
96
143
|
for i, result in enumerate(results, 1):
|
|
97
144
|
try:
|
|
145
|
+
logger.debug(
|
|
146
|
+
f"Formatting result {i}: type={type(result)}, keys={list(result.keys()) if isinstance(result, dict) else 'not-dict'}"
|
|
147
|
+
)
|
|
148
|
+
|
|
98
149
|
# Try to parse processed data for better formatting
|
|
99
150
|
if "processed_data" in result:
|
|
100
151
|
import json
|
|
101
152
|
|
|
102
|
-
|
|
153
|
+
if isinstance(result["processed_data"], dict):
|
|
154
|
+
processed_data = result["processed_data"]
|
|
155
|
+
elif isinstance(result["processed_data"], str):
|
|
156
|
+
processed_data = json.loads(result["processed_data"])
|
|
157
|
+
else:
|
|
158
|
+
raise ValueError("Error, wrong 'processed_data' format")
|
|
159
|
+
|
|
103
160
|
summary = processed_data.get("summary", "")
|
|
104
161
|
category = processed_data.get("category", {}).get(
|
|
105
162
|
"primary_category", ""
|
|
@@ -124,34 +181,63 @@ class MemoryTool:
|
|
|
124
181
|
|
|
125
182
|
formatted_output += "\n"
|
|
126
183
|
|
|
127
|
-
except Exception:
|
|
184
|
+
except Exception as format_e:
|
|
185
|
+
logger.warning(f"Error formatting result {i}: {format_e}")
|
|
128
186
|
# Fallback formatting
|
|
129
187
|
content = result.get(
|
|
130
188
|
"searchable_content", "Memory content available"
|
|
131
189
|
)[:100]
|
|
132
190
|
formatted_output += f"{i}. {content}...\n\n"
|
|
133
191
|
|
|
192
|
+
logger.debug(
|
|
193
|
+
f"Successfully formatted results, output length: {len(formatted_output)}"
|
|
194
|
+
)
|
|
134
195
|
return formatted_output.strip()
|
|
135
196
|
|
|
136
|
-
except ImportError:
|
|
197
|
+
except ImportError as import_e:
|
|
198
|
+
logger.warning(
|
|
199
|
+
f"Failed to import MemorySearchEngine for query '{query}': {import_e}"
|
|
200
|
+
)
|
|
137
201
|
# Fallback to original search methods if retrieval agent is not available
|
|
202
|
+
logger.debug(
|
|
203
|
+
f"Using ImportError fallback search methods for query: '{query}'"
|
|
204
|
+
)
|
|
205
|
+
|
|
138
206
|
# Try different search strategies based on query content
|
|
139
207
|
if any(word in query.lower() for word in ["name", "who am i", "about me"]):
|
|
208
|
+
logger.debug(
|
|
209
|
+
f"Trying essential conversations for personal query: '{query}'"
|
|
210
|
+
)
|
|
140
211
|
# Personal information query - try essential conversations first
|
|
141
212
|
essential_result = self._get_essential_conversations()
|
|
142
213
|
if essential_result.get("count", 0) > 0:
|
|
214
|
+
logger.debug(
|
|
215
|
+
f"Essential conversations found {essential_result.get('count', 0)} results"
|
|
216
|
+
)
|
|
143
217
|
return self._format_dict_to_string(essential_result)
|
|
144
218
|
|
|
145
219
|
# General search
|
|
220
|
+
logger.debug(f"Trying general search for query: '{query}'")
|
|
146
221
|
search_result = self._search_memories(query=query, limit=10)
|
|
222
|
+
logger.debug(
|
|
223
|
+
f"General search returned results_count: {search_result.get('results_count', 0)}"
|
|
224
|
+
)
|
|
147
225
|
if search_result.get("results_count", 0) > 0:
|
|
148
226
|
return self._format_dict_to_string(search_result)
|
|
149
227
|
|
|
150
228
|
# Fallback to context retrieval
|
|
229
|
+
logger.debug(f"Trying context retrieval fallback for query: '{query}'")
|
|
151
230
|
context_result = self._retrieve_context(query=query, limit=5)
|
|
231
|
+
logger.debug(
|
|
232
|
+
f"Context retrieval returned context_count: {context_result.get('context_count', 0)}"
|
|
233
|
+
)
|
|
152
234
|
return self._format_dict_to_string(context_result)
|
|
153
235
|
|
|
154
236
|
except Exception as e:
|
|
237
|
+
logger.error(
|
|
238
|
+
f"Unexpected error in memory tool execute for query '{query}': {e}",
|
|
239
|
+
exc_info=True,
|
|
240
|
+
)
|
|
155
241
|
return f"Error searching memories: {str(e)}"
|
|
156
242
|
|
|
157
243
|
def _format_dict_to_string(self, result_dict: Dict[str, Any]) -> str:
|
|
@@ -293,6 +379,8 @@ class MemoryTool:
|
|
|
293
379
|
|
|
294
380
|
def _get_stats(self, **kwargs) -> Dict[str, Any]:
|
|
295
381
|
"""Get memory and integration statistics"""
|
|
382
|
+
# kwargs can be used for future filtering options
|
|
383
|
+
_ = kwargs # Mark as intentionally unused
|
|
296
384
|
try:
|
|
297
385
|
memory_stats = self.memori.get_memory_stats()
|
|
298
386
|
integration_stats = self.memori.get_integration_stats()
|
|
@@ -348,6 +436,8 @@ class MemoryTool:
|
|
|
348
436
|
|
|
349
437
|
def _trigger_analysis(self, **kwargs) -> Dict[str, Any]:
|
|
350
438
|
"""Trigger conscious agent analysis"""
|
|
439
|
+
# kwargs can be used for future analysis options
|
|
440
|
+
_ = kwargs # Mark as intentionally unused
|
|
351
441
|
try:
|
|
352
442
|
if hasattr(self.memori, "trigger_conscious_analysis"):
|
|
353
443
|
self.memori.trigger_conscious_analysis()
|
|
@@ -0,0 +1,395 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Input validation and sanitization utilities for Memori
|
|
3
|
+
Provides security-focused validation for all database inputs
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import html
|
|
7
|
+
import json
|
|
8
|
+
import re
|
|
9
|
+
from datetime import datetime
|
|
10
|
+
from typing import Any, Dict, List, Optional, Union
|
|
11
|
+
|
|
12
|
+
from loguru import logger
|
|
13
|
+
|
|
14
|
+
from .exceptions import ValidationError
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class InputValidator:
|
|
18
|
+
"""Comprehensive input validation and sanitization"""
|
|
19
|
+
|
|
20
|
+
# SQL injection patterns to detect and block
|
|
21
|
+
SQL_INJECTION_PATTERNS = [
|
|
22
|
+
r"(\b(SELECT|INSERT|UPDATE|DELETE|DROP|CREATE|ALTER|EXEC|UNION)\b)",
|
|
23
|
+
r"(\b(OR|AND)\s+[\w\s]*=[\w\s]*)",
|
|
24
|
+
r"(;|\|\||&&)",
|
|
25
|
+
r"(\-\-|\#|/\*|\*/)",
|
|
26
|
+
r"(\bxp_cmdshell\b|\bsp_executesql\b)",
|
|
27
|
+
r"(\bINTO\s+OUTFILE\b|\bINTO\s+DUMPFILE\b)",
|
|
28
|
+
]
|
|
29
|
+
|
|
30
|
+
# XSS patterns to detect and sanitize
|
|
31
|
+
XSS_PATTERNS = [
|
|
32
|
+
r"<\s*script[^>]*>.*?</\s*script\s*>",
|
|
33
|
+
r"<\s*iframe[^>]*>.*?</\s*iframe\s*>",
|
|
34
|
+
r"<\s*object[^>]*>.*?</\s*object\s*>",
|
|
35
|
+
r"<\s*embed[^>]*>",
|
|
36
|
+
r"javascript\s*:",
|
|
37
|
+
r"on\w+\s*=",
|
|
38
|
+
]
|
|
39
|
+
|
|
40
|
+
@classmethod
|
|
41
|
+
def validate_and_sanitize_query(cls, query: str, max_length: int = 10000) -> str:
|
|
42
|
+
"""Validate and sanitize search query input"""
|
|
43
|
+
if not isinstance(query, (str, type(None))):
|
|
44
|
+
raise ValidationError("Query must be a string or None")
|
|
45
|
+
|
|
46
|
+
if query is None:
|
|
47
|
+
return ""
|
|
48
|
+
|
|
49
|
+
# Length validation
|
|
50
|
+
if len(query) > max_length:
|
|
51
|
+
raise ValidationError(f"Query too long (max {max_length} characters)")
|
|
52
|
+
|
|
53
|
+
# Check for SQL injection patterns
|
|
54
|
+
query_lower = query.lower()
|
|
55
|
+
for pattern in cls.SQL_INJECTION_PATTERNS:
|
|
56
|
+
if re.search(pattern, query_lower, re.IGNORECASE):
|
|
57
|
+
logger.warning(f"Potential SQL injection attempt blocked: {pattern}")
|
|
58
|
+
raise ValidationError(
|
|
59
|
+
"Invalid query: contains potentially dangerous content"
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
# Check for XSS patterns
|
|
63
|
+
for pattern in cls.XSS_PATTERNS:
|
|
64
|
+
if re.search(pattern, query, re.IGNORECASE):
|
|
65
|
+
logger.warning(f"Potential XSS attempt blocked: {pattern}")
|
|
66
|
+
# Sanitize instead of blocking for XSS
|
|
67
|
+
query = re.sub(pattern, "", query, flags=re.IGNORECASE)
|
|
68
|
+
|
|
69
|
+
# HTML escape for additional safety
|
|
70
|
+
sanitized_query = html.escape(query.strip())
|
|
71
|
+
|
|
72
|
+
return sanitized_query
|
|
73
|
+
|
|
74
|
+
@classmethod
|
|
75
|
+
def validate_namespace(cls, namespace: str) -> str:
|
|
76
|
+
"""Validate and sanitize namespace"""
|
|
77
|
+
if not isinstance(namespace, str):
|
|
78
|
+
raise ValidationError("Namespace must be a string")
|
|
79
|
+
|
|
80
|
+
# Namespace validation rules
|
|
81
|
+
sanitized_namespace = namespace.strip()
|
|
82
|
+
|
|
83
|
+
if not sanitized_namespace:
|
|
84
|
+
sanitized_namespace = "default"
|
|
85
|
+
|
|
86
|
+
# Only allow alphanumeric, underscore, hyphen
|
|
87
|
+
if not re.match(r"^[a-zA-Z0-9_\-]+$", sanitized_namespace):
|
|
88
|
+
raise ValidationError(
|
|
89
|
+
"Namespace contains invalid characters (only alphanumeric, underscore, hyphen allowed)"
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
if len(sanitized_namespace) > 100:
|
|
93
|
+
raise ValidationError("Namespace too long (max 100 characters)")
|
|
94
|
+
|
|
95
|
+
return sanitized_namespace
|
|
96
|
+
|
|
97
|
+
@classmethod
|
|
98
|
+
def validate_category_filter(
|
|
99
|
+
cls, category_filter: Optional[List[str]]
|
|
100
|
+
) -> List[str]:
|
|
101
|
+
"""Validate and sanitize category filter list"""
|
|
102
|
+
if category_filter is None:
|
|
103
|
+
return []
|
|
104
|
+
|
|
105
|
+
if not isinstance(category_filter, list):
|
|
106
|
+
raise ValidationError("Category filter must be a list or None")
|
|
107
|
+
|
|
108
|
+
if len(category_filter) > 50: # Reasonable limit
|
|
109
|
+
raise ValidationError("Too many categories in filter (max 50)")
|
|
110
|
+
|
|
111
|
+
sanitized_categories = []
|
|
112
|
+
for category in category_filter:
|
|
113
|
+
if not isinstance(category, str):
|
|
114
|
+
continue # Skip non-string categories
|
|
115
|
+
|
|
116
|
+
sanitized_category = category.strip()
|
|
117
|
+
if not sanitized_category:
|
|
118
|
+
continue # Skip empty categories
|
|
119
|
+
|
|
120
|
+
# Validate category format
|
|
121
|
+
if not re.match(r"^[a-zA-Z0-9_\-\s]+$", sanitized_category):
|
|
122
|
+
logger.warning(f"Invalid category format: {sanitized_category}")
|
|
123
|
+
continue # Skip invalid categories
|
|
124
|
+
|
|
125
|
+
if len(sanitized_category) > 100:
|
|
126
|
+
sanitized_category = sanitized_category[:100] # Truncate if too long
|
|
127
|
+
|
|
128
|
+
sanitized_categories.append(sanitized_category)
|
|
129
|
+
|
|
130
|
+
return sanitized_categories
|
|
131
|
+
|
|
132
|
+
@classmethod
|
|
133
|
+
def validate_limit(cls, limit: Union[int, str]) -> int:
|
|
134
|
+
"""Validate and sanitize limit parameter"""
|
|
135
|
+
try:
|
|
136
|
+
int_limit = int(limit)
|
|
137
|
+
except (ValueError, TypeError):
|
|
138
|
+
raise ValidationError("Limit must be a valid integer")
|
|
139
|
+
|
|
140
|
+
# Enforce reasonable bounds
|
|
141
|
+
if int_limit < 1:
|
|
142
|
+
return 1
|
|
143
|
+
elif int_limit > 1000: # Maximum reasonable limit
|
|
144
|
+
return 1000
|
|
145
|
+
|
|
146
|
+
return int_limit
|
|
147
|
+
|
|
148
|
+
@classmethod
|
|
149
|
+
def validate_memory_id(cls, memory_id: str) -> str:
|
|
150
|
+
"""Validate memory ID format"""
|
|
151
|
+
if not isinstance(memory_id, str):
|
|
152
|
+
raise ValidationError("Memory ID must be a string")
|
|
153
|
+
|
|
154
|
+
sanitized_id = memory_id.strip()
|
|
155
|
+
|
|
156
|
+
if not sanitized_id:
|
|
157
|
+
raise ValidationError("Memory ID cannot be empty")
|
|
158
|
+
|
|
159
|
+
# UUID-like format validation
|
|
160
|
+
if not re.match(r"^[a-fA-F0-9\-]{36}$", sanitized_id):
|
|
161
|
+
# Also allow shorter alphanumeric IDs for flexibility
|
|
162
|
+
if not re.match(r"^[a-zA-Z0-9_\-]+$", sanitized_id):
|
|
163
|
+
raise ValidationError("Invalid memory ID format")
|
|
164
|
+
|
|
165
|
+
if len(sanitized_id) > 100:
|
|
166
|
+
raise ValidationError("Memory ID too long")
|
|
167
|
+
|
|
168
|
+
return sanitized_id
|
|
169
|
+
|
|
170
|
+
@classmethod
|
|
171
|
+
def validate_json_field(cls, json_data: Any, field_name: str = "data") -> str:
|
|
172
|
+
"""Validate and sanitize JSON data"""
|
|
173
|
+
if json_data is None:
|
|
174
|
+
return "{}"
|
|
175
|
+
|
|
176
|
+
try:
|
|
177
|
+
if isinstance(json_data, str):
|
|
178
|
+
# Validate it's proper JSON
|
|
179
|
+
parsed_data = json.loads(json_data)
|
|
180
|
+
# Re-serialize to ensure clean format
|
|
181
|
+
clean_json = json.dumps(
|
|
182
|
+
parsed_data, ensure_ascii=True, separators=(",", ":")
|
|
183
|
+
)
|
|
184
|
+
else:
|
|
185
|
+
# Serialize Python object to JSON
|
|
186
|
+
clean_json = json.dumps(
|
|
187
|
+
json_data, ensure_ascii=True, separators=(",", ":")
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
# Size limit check (1MB for JSON data)
|
|
191
|
+
if len(clean_json) > 1024 * 1024:
|
|
192
|
+
raise ValidationError(f"{field_name} JSON too large (max 1MB)")
|
|
193
|
+
|
|
194
|
+
return clean_json
|
|
195
|
+
|
|
196
|
+
except (json.JSONDecodeError, TypeError) as e:
|
|
197
|
+
raise ValidationError(f"Invalid JSON in {field_name}: {e}")
|
|
198
|
+
|
|
199
|
+
@classmethod
|
|
200
|
+
def validate_text_content(
|
|
201
|
+
cls, content: str, field_name: str = "content", max_length: int = 100000
|
|
202
|
+
) -> str:
|
|
203
|
+
"""Validate and sanitize text content"""
|
|
204
|
+
if not isinstance(content, str):
|
|
205
|
+
raise ValidationError(f"{field_name} must be a string")
|
|
206
|
+
|
|
207
|
+
# Length check
|
|
208
|
+
if len(content) > max_length:
|
|
209
|
+
raise ValidationError(
|
|
210
|
+
f"{field_name} too long (max {max_length} characters)"
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
# XSS sanitization
|
|
214
|
+
sanitized_content = content
|
|
215
|
+
for pattern in cls.XSS_PATTERNS:
|
|
216
|
+
sanitized_content = re.sub(
|
|
217
|
+
pattern, "", sanitized_content, flags=re.IGNORECASE
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
# Basic HTML escaping for storage
|
|
221
|
+
sanitized_content = html.escape(sanitized_content)
|
|
222
|
+
|
|
223
|
+
return sanitized_content.strip()
|
|
224
|
+
|
|
225
|
+
@classmethod
|
|
226
|
+
def validate_timestamp(cls, timestamp: Union[datetime, str, None]) -> datetime:
|
|
227
|
+
"""Validate and normalize timestamp"""
|
|
228
|
+
if timestamp is None:
|
|
229
|
+
return datetime.now()
|
|
230
|
+
|
|
231
|
+
if isinstance(timestamp, datetime):
|
|
232
|
+
# Make timezone-naive for SQLite compatibility
|
|
233
|
+
return timestamp.replace(tzinfo=None)
|
|
234
|
+
|
|
235
|
+
if isinstance(timestamp, str):
|
|
236
|
+
try:
|
|
237
|
+
# Try to parse ISO format
|
|
238
|
+
parsed_timestamp = datetime.fromisoformat(
|
|
239
|
+
timestamp.replace("Z", "+00:00")
|
|
240
|
+
)
|
|
241
|
+
return parsed_timestamp.replace(tzinfo=None)
|
|
242
|
+
except ValueError:
|
|
243
|
+
raise ValidationError("Invalid timestamp format (use ISO format)")
|
|
244
|
+
|
|
245
|
+
raise ValidationError("Timestamp must be datetime object, ISO string, or None")
|
|
246
|
+
|
|
247
|
+
@classmethod
|
|
248
|
+
def validate_score(
|
|
249
|
+
cls, score: Union[float, int, str], field_name: str = "score"
|
|
250
|
+
) -> float:
|
|
251
|
+
"""Validate and normalize score values (0.0 to 1.0)"""
|
|
252
|
+
try:
|
|
253
|
+
float_score = float(score)
|
|
254
|
+
except (ValueError, TypeError):
|
|
255
|
+
raise ValidationError(f"{field_name} must be a valid number")
|
|
256
|
+
|
|
257
|
+
# Clamp to valid range
|
|
258
|
+
if float_score < 0.0:
|
|
259
|
+
return 0.0
|
|
260
|
+
elif float_score > 1.0:
|
|
261
|
+
return 1.0
|
|
262
|
+
|
|
263
|
+
return float_score
|
|
264
|
+
|
|
265
|
+
@classmethod
|
|
266
|
+
def validate_boolean_field(cls, value: Any, field_name: str = "field") -> bool:
|
|
267
|
+
"""Validate and convert boolean field"""
|
|
268
|
+
if isinstance(value, bool):
|
|
269
|
+
return value
|
|
270
|
+
|
|
271
|
+
if isinstance(value, int):
|
|
272
|
+
return bool(value)
|
|
273
|
+
|
|
274
|
+
if isinstance(value, str):
|
|
275
|
+
return value.lower() in ("true", "1", "yes", "on")
|
|
276
|
+
|
|
277
|
+
return False # Default to False for safety
|
|
278
|
+
|
|
279
|
+
@classmethod
|
|
280
|
+
def sanitize_sql_identifier(cls, identifier: str) -> str:
|
|
281
|
+
"""Sanitize SQL identifiers (table names, column names)"""
|
|
282
|
+
if not isinstance(identifier, str):
|
|
283
|
+
raise ValidationError("SQL identifier must be a string")
|
|
284
|
+
|
|
285
|
+
# Remove dangerous characters and validate format
|
|
286
|
+
sanitized = re.sub(r"[^a-zA-Z0-9_]", "", identifier)
|
|
287
|
+
|
|
288
|
+
if not sanitized or not re.match(r"^[a-zA-Z][a-zA-Z0-9_]*$", sanitized):
|
|
289
|
+
raise ValidationError("Invalid SQL identifier format")
|
|
290
|
+
|
|
291
|
+
if len(sanitized) > 64: # SQL standard limit
|
|
292
|
+
raise ValidationError("SQL identifier too long")
|
|
293
|
+
|
|
294
|
+
# Block reserved words (basic list)
|
|
295
|
+
reserved_words = {
|
|
296
|
+
"SELECT",
|
|
297
|
+
"INSERT",
|
|
298
|
+
"UPDATE",
|
|
299
|
+
"DELETE",
|
|
300
|
+
"DROP",
|
|
301
|
+
"CREATE",
|
|
302
|
+
"ALTER",
|
|
303
|
+
"TABLE",
|
|
304
|
+
"DATABASE",
|
|
305
|
+
"INDEX",
|
|
306
|
+
"VIEW",
|
|
307
|
+
"TRIGGER",
|
|
308
|
+
"PROCEDURE",
|
|
309
|
+
"FUNCTION",
|
|
310
|
+
"EXEC",
|
|
311
|
+
"EXECUTE",
|
|
312
|
+
"UNION",
|
|
313
|
+
"WHERE",
|
|
314
|
+
"FROM",
|
|
315
|
+
"JOIN",
|
|
316
|
+
}
|
|
317
|
+
|
|
318
|
+
if sanitized.upper() in reserved_words:
|
|
319
|
+
raise ValidationError(
|
|
320
|
+
f"Cannot use reserved word as identifier: {sanitized}"
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
return sanitized
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
class DatabaseInputValidator:
|
|
327
|
+
"""Database-specific input validation"""
|
|
328
|
+
|
|
329
|
+
@classmethod
|
|
330
|
+
def validate_insert_params(
|
|
331
|
+
cls, table: str, params: Dict[str, Any]
|
|
332
|
+
) -> Dict[str, Any]:
|
|
333
|
+
"""Validate parameters for database insert operations"""
|
|
334
|
+
sanitized_params = {}
|
|
335
|
+
|
|
336
|
+
# Validate table name
|
|
337
|
+
InputValidator.sanitize_sql_identifier(table)
|
|
338
|
+
|
|
339
|
+
for key, value in params.items():
|
|
340
|
+
# Validate column names
|
|
341
|
+
sanitized_key = InputValidator.sanitize_sql_identifier(key)
|
|
342
|
+
|
|
343
|
+
# Type-specific validation
|
|
344
|
+
if key.endswith("_id"):
|
|
345
|
+
if value is not None:
|
|
346
|
+
sanitized_params[sanitized_key] = InputValidator.validate_memory_id(
|
|
347
|
+
str(value)
|
|
348
|
+
)
|
|
349
|
+
else:
|
|
350
|
+
sanitized_params[sanitized_key] = None
|
|
351
|
+
elif key == "namespace":
|
|
352
|
+
sanitized_params[sanitized_key] = InputValidator.validate_namespace(
|
|
353
|
+
str(value)
|
|
354
|
+
)
|
|
355
|
+
elif key.endswith("_score"):
|
|
356
|
+
sanitized_params[sanitized_key] = InputValidator.validate_score(
|
|
357
|
+
value, key
|
|
358
|
+
)
|
|
359
|
+
elif key.endswith("_at") or key == "timestamp":
|
|
360
|
+
sanitized_params[sanitized_key] = InputValidator.validate_timestamp(
|
|
361
|
+
value
|
|
362
|
+
)
|
|
363
|
+
elif key.endswith("_json") or key == "metadata":
|
|
364
|
+
sanitized_params[sanitized_key] = InputValidator.validate_json_field(
|
|
365
|
+
value, key
|
|
366
|
+
)
|
|
367
|
+
elif isinstance(value, bool) or key.startswith("is_"):
|
|
368
|
+
sanitized_params[sanitized_key] = InputValidator.validate_boolean_field(
|
|
369
|
+
value, key
|
|
370
|
+
)
|
|
371
|
+
elif isinstance(value, str):
|
|
372
|
+
sanitized_params[sanitized_key] = InputValidator.validate_text_content(
|
|
373
|
+
value, key, max_length=50000
|
|
374
|
+
)
|
|
375
|
+
else:
|
|
376
|
+
# Pass through numeric and other safe types
|
|
377
|
+
sanitized_params[sanitized_key] = value
|
|
378
|
+
|
|
379
|
+
return sanitized_params
|
|
380
|
+
|
|
381
|
+
@classmethod
|
|
382
|
+
def validate_search_params(
|
|
383
|
+
cls,
|
|
384
|
+
query: str,
|
|
385
|
+
namespace: str,
|
|
386
|
+
category_filter: Optional[List[str]],
|
|
387
|
+
limit: int,
|
|
388
|
+
) -> Dict[str, Any]:
|
|
389
|
+
"""Validate all search parameters together"""
|
|
390
|
+
return {
|
|
391
|
+
"query": InputValidator.validate_and_sanitize_query(query),
|
|
392
|
+
"namespace": InputValidator.validate_namespace(namespace),
|
|
393
|
+
"category_filter": InputValidator.validate_category_filter(category_filter),
|
|
394
|
+
"limit": InputValidator.validate_limit(limit),
|
|
395
|
+
}
|