memorisdk 1.0.1__py3-none-any.whl → 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of memorisdk might be problematic. Click here for more details.
- memori/__init__.py +24 -8
- memori/agents/conscious_agent.py +252 -414
- memori/agents/memory_agent.py +487 -224
- memori/agents/retrieval_agent.py +416 -60
- memori/config/memory_manager.py +323 -0
- memori/core/conversation.py +393 -0
- memori/core/database.py +386 -371
- memori/core/memory.py +1676 -534
- memori/core/providers.py +217 -0
- memori/database/adapters/__init__.py +10 -0
- memori/database/adapters/mysql_adapter.py +331 -0
- memori/database/adapters/postgresql_adapter.py +291 -0
- memori/database/adapters/sqlite_adapter.py +229 -0
- memori/database/auto_creator.py +320 -0
- memori/database/connection_utils.py +207 -0
- memori/database/connectors/base_connector.py +283 -0
- memori/database/connectors/mysql_connector.py +240 -18
- memori/database/connectors/postgres_connector.py +277 -4
- memori/database/connectors/sqlite_connector.py +178 -3
- memori/database/models.py +400 -0
- memori/database/queries/base_queries.py +1 -1
- memori/database/queries/memory_queries.py +91 -2
- memori/database/query_translator.py +222 -0
- memori/database/schema_generators/__init__.py +7 -0
- memori/database/schema_generators/mysql_schema_generator.py +215 -0
- memori/database/search/__init__.py +8 -0
- memori/database/search/mysql_search_adapter.py +255 -0
- memori/database/search/sqlite_search_adapter.py +180 -0
- memori/database/search_service.py +548 -0
- memori/database/sqlalchemy_manager.py +839 -0
- memori/integrations/__init__.py +36 -11
- memori/integrations/litellm_integration.py +340 -6
- memori/integrations/openai_integration.py +506 -240
- memori/utils/input_validator.py +395 -0
- memori/utils/pydantic_models.py +138 -36
- memori/utils/query_builder.py +530 -0
- memori/utils/security_audit.py +594 -0
- memori/utils/security_integration.py +339 -0
- memori/utils/transaction_manager.py +547 -0
- {memorisdk-1.0.1.dist-info → memorisdk-2.0.0.dist-info}/METADATA +144 -34
- memorisdk-2.0.0.dist-info/RECORD +67 -0
- memorisdk-1.0.1.dist-info/RECORD +0 -44
- memorisdk-1.0.1.dist-info/entry_points.txt +0 -2
- {memorisdk-1.0.1.dist-info → memorisdk-2.0.0.dist-info}/WHEEL +0 -0
- {memorisdk-1.0.1.dist-info → memorisdk-2.0.0.dist-info}/licenses/LICENSE +0 -0
- {memorisdk-1.0.1.dist-info → memorisdk-2.0.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,547 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Transaction management utilities for Memori
|
|
3
|
+
Provides robust transaction handling with proper error recovery
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import time
|
|
7
|
+
from contextlib import contextmanager
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from enum import Enum
|
|
10
|
+
from typing import Any, Callable, Dict, List, Optional
|
|
11
|
+
|
|
12
|
+
from loguru import logger
|
|
13
|
+
|
|
14
|
+
from .exceptions import DatabaseError, ValidationError
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class TransactionState(str, Enum):
|
|
18
|
+
"""Transaction states"""
|
|
19
|
+
|
|
20
|
+
PENDING = "pending"
|
|
21
|
+
ACTIVE = "active"
|
|
22
|
+
COMMITTED = "committed"
|
|
23
|
+
ROLLED_BACK = "rolled_back"
|
|
24
|
+
FAILED = "failed"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class IsolationLevel(str, Enum):
|
|
28
|
+
"""Database isolation levels"""
|
|
29
|
+
|
|
30
|
+
READ_UNCOMMITTED = "READ UNCOMMITTED"
|
|
31
|
+
READ_COMMITTED = "READ COMMITTED"
|
|
32
|
+
REPEATABLE_READ = "REPEATABLE READ"
|
|
33
|
+
SERIALIZABLE = "SERIALIZABLE"
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass
|
|
37
|
+
class TransactionOperation:
|
|
38
|
+
"""Represents a single database operation within a transaction"""
|
|
39
|
+
|
|
40
|
+
query: str
|
|
41
|
+
params: Optional[List[Any]]
|
|
42
|
+
operation_type: str # 'select', 'insert', 'update', 'delete'
|
|
43
|
+
table: Optional[str] = None
|
|
44
|
+
expected_rows: Optional[int] = None # For validation
|
|
45
|
+
rollback_query: Optional[str] = None # Compensation query if needed
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclass
|
|
49
|
+
class TransactionResult:
|
|
50
|
+
"""Result of a transaction execution"""
|
|
51
|
+
|
|
52
|
+
success: bool
|
|
53
|
+
state: TransactionState
|
|
54
|
+
operations_completed: int
|
|
55
|
+
total_operations: int
|
|
56
|
+
error_message: Optional[str] = None
|
|
57
|
+
execution_time: Optional[float] = None
|
|
58
|
+
rollback_performed: bool = False
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class TransactionManager:
|
|
62
|
+
"""Robust transaction manager with error recovery"""
|
|
63
|
+
|
|
64
|
+
def __init__(self, connector, max_retries: int = 3, retry_delay: float = 0.1):
|
|
65
|
+
self.connector = connector
|
|
66
|
+
self.max_retries = max_retries
|
|
67
|
+
self.retry_delay = retry_delay
|
|
68
|
+
self.current_transaction = None
|
|
69
|
+
|
|
70
|
+
@contextmanager
|
|
71
|
+
def transaction(
|
|
72
|
+
self,
|
|
73
|
+
isolation_level: Optional[IsolationLevel] = None,
|
|
74
|
+
timeout: Optional[float] = 30.0,
|
|
75
|
+
readonly: bool = False,
|
|
76
|
+
):
|
|
77
|
+
"""Context manager for database transactions with proper error handling"""
|
|
78
|
+
|
|
79
|
+
transaction_id = f"txn_{int(time.time()*1000)}"
|
|
80
|
+
start_time = time.time()
|
|
81
|
+
|
|
82
|
+
try:
|
|
83
|
+
# Get connection and start transaction
|
|
84
|
+
conn = self.connector.get_connection()
|
|
85
|
+
|
|
86
|
+
# Set isolation level if specified
|
|
87
|
+
if isolation_level:
|
|
88
|
+
self._set_isolation_level(conn, isolation_level)
|
|
89
|
+
|
|
90
|
+
# Set readonly mode if specified
|
|
91
|
+
if readonly:
|
|
92
|
+
self._set_readonly(conn, True)
|
|
93
|
+
|
|
94
|
+
# Begin transaction
|
|
95
|
+
self._begin_transaction(conn)
|
|
96
|
+
|
|
97
|
+
logger.debug(f"Started transaction {transaction_id}")
|
|
98
|
+
|
|
99
|
+
# Store transaction context
|
|
100
|
+
self.current_transaction = {
|
|
101
|
+
"id": transaction_id,
|
|
102
|
+
"connection": conn,
|
|
103
|
+
"start_time": start_time,
|
|
104
|
+
"operations": [],
|
|
105
|
+
"state": TransactionState.ACTIVE,
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
try:
|
|
109
|
+
yield TransactionContext(self, conn, transaction_id)
|
|
110
|
+
|
|
111
|
+
# Check timeout
|
|
112
|
+
if timeout and (time.time() - start_time) > timeout:
|
|
113
|
+
raise DatabaseError(
|
|
114
|
+
f"Transaction {transaction_id} timed out after {timeout}s"
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
# Commit transaction
|
|
118
|
+
conn.commit()
|
|
119
|
+
self.current_transaction["state"] = TransactionState.COMMITTED
|
|
120
|
+
logger.debug(f"Committed transaction {transaction_id}")
|
|
121
|
+
|
|
122
|
+
except Exception as e:
|
|
123
|
+
# Rollback on any error
|
|
124
|
+
try:
|
|
125
|
+
conn.rollback()
|
|
126
|
+
self.current_transaction["state"] = TransactionState.ROLLED_BACK
|
|
127
|
+
logger.warning(f"Rolled back transaction {transaction_id}: {e}")
|
|
128
|
+
except Exception as rollback_error:
|
|
129
|
+
self.current_transaction["state"] = TransactionState.FAILED
|
|
130
|
+
logger.error(
|
|
131
|
+
f"Failed to rollback transaction {transaction_id}: {rollback_error}"
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
raise e
|
|
135
|
+
|
|
136
|
+
except Exception as e:
|
|
137
|
+
logger.error(f"Transaction {transaction_id} failed: {e}")
|
|
138
|
+
raise DatabaseError(f"Transaction failed: {e}")
|
|
139
|
+
|
|
140
|
+
finally:
|
|
141
|
+
# Cleanup
|
|
142
|
+
if self.current_transaction:
|
|
143
|
+
execution_time = time.time() - start_time
|
|
144
|
+
logger.debug(
|
|
145
|
+
f"Transaction {transaction_id} completed in {execution_time:.3f}s"
|
|
146
|
+
)
|
|
147
|
+
self.current_transaction = None
|
|
148
|
+
|
|
149
|
+
# Close connection
|
|
150
|
+
try:
|
|
151
|
+
conn.close()
|
|
152
|
+
except:
|
|
153
|
+
pass
|
|
154
|
+
|
|
155
|
+
def execute_atomic_operations(
|
|
156
|
+
self,
|
|
157
|
+
operations: List[TransactionOperation],
|
|
158
|
+
isolation_level: Optional[IsolationLevel] = None,
|
|
159
|
+
) -> TransactionResult:
|
|
160
|
+
"""Execute multiple operations atomically with validation"""
|
|
161
|
+
|
|
162
|
+
start_time = time.time()
|
|
163
|
+
completed_ops = 0
|
|
164
|
+
|
|
165
|
+
try:
|
|
166
|
+
with self.transaction(isolation_level=isolation_level) as tx:
|
|
167
|
+
for i, operation in enumerate(operations):
|
|
168
|
+
try:
|
|
169
|
+
# Validate operation parameters
|
|
170
|
+
self._validate_operation(operation)
|
|
171
|
+
|
|
172
|
+
# Execute operation
|
|
173
|
+
result = tx.execute(operation.query, operation.params)
|
|
174
|
+
|
|
175
|
+
# Validate result if expected rows specified
|
|
176
|
+
if operation.expected_rows is not None:
|
|
177
|
+
if (
|
|
178
|
+
hasattr(result, "__len__")
|
|
179
|
+
and len(result) != operation.expected_rows
|
|
180
|
+
):
|
|
181
|
+
raise DatabaseError(
|
|
182
|
+
f"Operation {i} affected {len(result)} rows, expected {operation.expected_rows}"
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
completed_ops += 1
|
|
186
|
+
|
|
187
|
+
except Exception as e:
|
|
188
|
+
logger.error(f"Operation {i} failed: {e}")
|
|
189
|
+
raise DatabaseError(
|
|
190
|
+
f"Operation {i} ({operation.operation_type}) failed: {e}"
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
return TransactionResult(
|
|
194
|
+
success=True,
|
|
195
|
+
state=TransactionState.COMMITTED,
|
|
196
|
+
operations_completed=completed_ops,
|
|
197
|
+
total_operations=len(operations),
|
|
198
|
+
execution_time=time.time() - start_time,
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
except Exception as e:
|
|
202
|
+
return TransactionResult(
|
|
203
|
+
success=False,
|
|
204
|
+
state=TransactionState.ROLLED_BACK,
|
|
205
|
+
operations_completed=completed_ops,
|
|
206
|
+
total_operations=len(operations),
|
|
207
|
+
error_message=str(e),
|
|
208
|
+
execution_time=time.time() - start_time,
|
|
209
|
+
rollback_performed=True,
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
def execute_with_retry(
|
|
213
|
+
self,
|
|
214
|
+
operation: Callable[[], Any],
|
|
215
|
+
max_retries: Optional[int] = None,
|
|
216
|
+
retry_delay: Optional[float] = None,
|
|
217
|
+
) -> Any:
|
|
218
|
+
"""Execute operation with automatic retry on transient failures"""
|
|
219
|
+
|
|
220
|
+
retries = max_retries or self.max_retries
|
|
221
|
+
delay = retry_delay or self.retry_delay
|
|
222
|
+
last_error = None
|
|
223
|
+
|
|
224
|
+
for attempt in range(retries + 1):
|
|
225
|
+
try:
|
|
226
|
+
return operation()
|
|
227
|
+
except Exception as e:
|
|
228
|
+
last_error = e
|
|
229
|
+
|
|
230
|
+
# Check if error is retryable
|
|
231
|
+
if not self._is_retryable_error(e):
|
|
232
|
+
logger.debug(f"Non-retryable error: {e}")
|
|
233
|
+
break
|
|
234
|
+
|
|
235
|
+
if attempt < retries:
|
|
236
|
+
logger.warning(
|
|
237
|
+
f"Operation failed (attempt {attempt + 1}/{retries + 1}), retrying in {delay}s: {e}"
|
|
238
|
+
)
|
|
239
|
+
time.sleep(delay)
|
|
240
|
+
delay *= 2 # Exponential backoff
|
|
241
|
+
else:
|
|
242
|
+
logger.error(f"Operation failed after {retries + 1} attempts: {e}")
|
|
243
|
+
|
|
244
|
+
raise DatabaseError(
|
|
245
|
+
f"Operation failed after {retries + 1} attempts: {last_error}"
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
def _validate_operation(self, operation: TransactionOperation):
|
|
249
|
+
"""Validate transaction operation parameters"""
|
|
250
|
+
if not operation.query or not operation.query.strip():
|
|
251
|
+
raise ValidationError("Query cannot be empty")
|
|
252
|
+
|
|
253
|
+
if operation.params is not None and not isinstance(operation.params, list):
|
|
254
|
+
raise ValidationError("Parameters must be a list or None")
|
|
255
|
+
|
|
256
|
+
# Basic SQL injection detection
|
|
257
|
+
query_lower = operation.query.lower().strip()
|
|
258
|
+
dangerous_patterns = [
|
|
259
|
+
";--",
|
|
260
|
+
"; --",
|
|
261
|
+
"/*",
|
|
262
|
+
"*/",
|
|
263
|
+
"xp_",
|
|
264
|
+
"sp_execute",
|
|
265
|
+
"union select",
|
|
266
|
+
"drop table",
|
|
267
|
+
"truncate table",
|
|
268
|
+
]
|
|
269
|
+
|
|
270
|
+
for pattern in dangerous_patterns:
|
|
271
|
+
if pattern in query_lower:
|
|
272
|
+
raise ValidationError(
|
|
273
|
+
f"Potentially dangerous SQL pattern detected: {pattern}"
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
def _set_isolation_level(self, conn, isolation_level: IsolationLevel):
|
|
277
|
+
"""Set transaction isolation level (database-specific)"""
|
|
278
|
+
try:
|
|
279
|
+
if hasattr(conn, "set_isolation_level"):
|
|
280
|
+
# PostgreSQL
|
|
281
|
+
if isolation_level == IsolationLevel.READ_UNCOMMITTED:
|
|
282
|
+
conn.set_isolation_level(1)
|
|
283
|
+
elif isolation_level == IsolationLevel.READ_COMMITTED:
|
|
284
|
+
conn.set_isolation_level(2)
|
|
285
|
+
elif isolation_level == IsolationLevel.REPEATABLE_READ:
|
|
286
|
+
conn.set_isolation_level(3)
|
|
287
|
+
elif isolation_level == IsolationLevel.SERIALIZABLE:
|
|
288
|
+
conn.set_isolation_level(4)
|
|
289
|
+
else:
|
|
290
|
+
# SQLite/MySQL - use SQL commands
|
|
291
|
+
cursor = conn.cursor()
|
|
292
|
+
if isolation_level != IsolationLevel.READ_COMMITTED: # SQLite default
|
|
293
|
+
cursor.execute(
|
|
294
|
+
f"PRAGMA read_uncommitted = {'ON' if isolation_level == IsolationLevel.READ_UNCOMMITTED else 'OFF'}"
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
except Exception as e:
|
|
298
|
+
logger.warning(f"Could not set isolation level: {e}")
|
|
299
|
+
|
|
300
|
+
def _set_readonly(self, conn, readonly: bool):
|
|
301
|
+
"""Set transaction to readonly mode"""
|
|
302
|
+
try:
|
|
303
|
+
cursor = conn.cursor()
|
|
304
|
+
if readonly:
|
|
305
|
+
# Database-specific readonly settings
|
|
306
|
+
cursor.execute("SET TRANSACTION READ ONLY")
|
|
307
|
+
except Exception as e:
|
|
308
|
+
logger.debug(f"Could not set readonly mode: {e}")
|
|
309
|
+
|
|
310
|
+
def _begin_transaction(self, conn):
|
|
311
|
+
"""Begin transaction (database-specific)"""
|
|
312
|
+
try:
|
|
313
|
+
if hasattr(conn, "autocommit"):
|
|
314
|
+
# Ensure autocommit is off
|
|
315
|
+
conn.autocommit = False
|
|
316
|
+
|
|
317
|
+
# Explicitly begin transaction
|
|
318
|
+
cursor = conn.cursor()
|
|
319
|
+
cursor.execute("BEGIN")
|
|
320
|
+
except Exception as e:
|
|
321
|
+
logger.debug(f"Could not explicitly begin transaction: {e}")
|
|
322
|
+
|
|
323
|
+
def _is_retryable_error(self, error: Exception) -> bool:
|
|
324
|
+
"""Determine if an error is retryable"""
|
|
325
|
+
error_str = str(error).lower()
|
|
326
|
+
|
|
327
|
+
# Common retryable error patterns
|
|
328
|
+
retryable_patterns = [
|
|
329
|
+
"timeout",
|
|
330
|
+
"connection",
|
|
331
|
+
"network",
|
|
332
|
+
"temporary",
|
|
333
|
+
"busy",
|
|
334
|
+
"lock",
|
|
335
|
+
"deadlock",
|
|
336
|
+
"serialization",
|
|
337
|
+
]
|
|
338
|
+
|
|
339
|
+
# Non-retryable error patterns
|
|
340
|
+
non_retryable_patterns = [
|
|
341
|
+
"constraint",
|
|
342
|
+
"unique",
|
|
343
|
+
"foreign key",
|
|
344
|
+
"not null",
|
|
345
|
+
"syntax error",
|
|
346
|
+
"permission",
|
|
347
|
+
"access denied",
|
|
348
|
+
]
|
|
349
|
+
|
|
350
|
+
# Check non-retryable first
|
|
351
|
+
for pattern in non_retryable_patterns:
|
|
352
|
+
if pattern in error_str:
|
|
353
|
+
return False
|
|
354
|
+
|
|
355
|
+
# Check retryable patterns
|
|
356
|
+
for pattern in retryable_patterns:
|
|
357
|
+
if pattern in error_str:
|
|
358
|
+
return True
|
|
359
|
+
|
|
360
|
+
# Default to non-retryable for unknown errors
|
|
361
|
+
return False
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
class TransactionContext:
|
|
365
|
+
"""Context for operations within a transaction"""
|
|
366
|
+
|
|
367
|
+
def __init__(self, manager: TransactionManager, connection, transaction_id: str):
|
|
368
|
+
self.manager = manager
|
|
369
|
+
self.connection = connection
|
|
370
|
+
self.transaction_id = transaction_id
|
|
371
|
+
self.operations_count = 0
|
|
372
|
+
|
|
373
|
+
def execute(
|
|
374
|
+
self, query: str, params: Optional[List[Any]] = None
|
|
375
|
+
) -> List[Dict[str, Any]]:
|
|
376
|
+
"""Execute query within the transaction context"""
|
|
377
|
+
try:
|
|
378
|
+
cursor = self.connection.cursor()
|
|
379
|
+
|
|
380
|
+
# Execute query
|
|
381
|
+
if params:
|
|
382
|
+
cursor.execute(query, params)
|
|
383
|
+
else:
|
|
384
|
+
cursor.execute(query)
|
|
385
|
+
|
|
386
|
+
# Get results for SELECT queries
|
|
387
|
+
if query.strip().upper().startswith("SELECT"):
|
|
388
|
+
results = []
|
|
389
|
+
for row in cursor.fetchall():
|
|
390
|
+
if hasattr(row, "keys"):
|
|
391
|
+
# Dictionary-like row
|
|
392
|
+
results.append(dict(row))
|
|
393
|
+
else:
|
|
394
|
+
# Tuple row - convert to dict with column names
|
|
395
|
+
column_names = (
|
|
396
|
+
[desc[0] for desc in cursor.description]
|
|
397
|
+
if cursor.description
|
|
398
|
+
else []
|
|
399
|
+
)
|
|
400
|
+
results.append(dict(zip(column_names, row)))
|
|
401
|
+
return results
|
|
402
|
+
else:
|
|
403
|
+
# For non-SELECT queries, return affected row count
|
|
404
|
+
return [{"affected_rows": cursor.rowcount}]
|
|
405
|
+
|
|
406
|
+
except Exception as e:
|
|
407
|
+
logger.error(
|
|
408
|
+
f"Query execution failed in transaction {self.transaction_id}: {e}"
|
|
409
|
+
)
|
|
410
|
+
raise DatabaseError(f"Query execution failed: {e}")
|
|
411
|
+
finally:
|
|
412
|
+
self.operations_count += 1
|
|
413
|
+
|
|
414
|
+
def execute_many(self, query: str, params_list: List[List[Any]]) -> int:
|
|
415
|
+
"""Execute query with multiple parameter sets"""
|
|
416
|
+
try:
|
|
417
|
+
cursor = self.connection.cursor()
|
|
418
|
+
cursor.executemany(query, params_list)
|
|
419
|
+
return cursor.rowcount
|
|
420
|
+
except Exception as e:
|
|
421
|
+
logger.error(
|
|
422
|
+
f"Batch execution failed in transaction {self.transaction_id}: {e}"
|
|
423
|
+
)
|
|
424
|
+
raise DatabaseError(f"Batch execution failed: {e}")
|
|
425
|
+
finally:
|
|
426
|
+
self.operations_count += 1
|
|
427
|
+
|
|
428
|
+
def execute_script(self, script: str):
|
|
429
|
+
"""Execute SQL script (SQLite specific)"""
|
|
430
|
+
try:
|
|
431
|
+
cursor = self.connection.cursor()
|
|
432
|
+
if hasattr(cursor, "executescript"):
|
|
433
|
+
cursor.executescript(script)
|
|
434
|
+
else:
|
|
435
|
+
# Fallback for other databases - split and execute individually
|
|
436
|
+
statements = script.split(";")
|
|
437
|
+
for statement in statements:
|
|
438
|
+
statement = statement.strip()
|
|
439
|
+
if statement:
|
|
440
|
+
cursor.execute(statement)
|
|
441
|
+
except Exception as e:
|
|
442
|
+
logger.error(
|
|
443
|
+
f"Script execution failed in transaction {self.transaction_id}: {e}"
|
|
444
|
+
)
|
|
445
|
+
raise DatabaseError(f"Script execution failed: {e}")
|
|
446
|
+
finally:
|
|
447
|
+
self.operations_count += 1
|
|
448
|
+
|
|
449
|
+
|
|
450
|
+
class SavepointManager:
|
|
451
|
+
"""Manage savepoints within transactions for fine-grained rollback control"""
|
|
452
|
+
|
|
453
|
+
def __init__(self, transaction_context: TransactionContext):
|
|
454
|
+
self.tx_context = transaction_context
|
|
455
|
+
self.savepoint_counter = 0
|
|
456
|
+
|
|
457
|
+
@contextmanager
|
|
458
|
+
def savepoint(self, name: Optional[str] = None):
|
|
459
|
+
"""Create a savepoint within the current transaction"""
|
|
460
|
+
if not name:
|
|
461
|
+
name = f"sp_{self.savepoint_counter}"
|
|
462
|
+
self.savepoint_counter += 1
|
|
463
|
+
|
|
464
|
+
try:
|
|
465
|
+
# Create savepoint
|
|
466
|
+
self.tx_context.execute(f"SAVEPOINT {name}")
|
|
467
|
+
logger.debug(f"Created savepoint {name}")
|
|
468
|
+
|
|
469
|
+
yield name
|
|
470
|
+
|
|
471
|
+
except Exception as e:
|
|
472
|
+
# Rollback to savepoint
|
|
473
|
+
try:
|
|
474
|
+
self.tx_context.execute(f"ROLLBACK TO SAVEPOINT {name}")
|
|
475
|
+
logger.warning(f"Rolled back to savepoint {name}: {e}")
|
|
476
|
+
except Exception as rollback_error:
|
|
477
|
+
logger.error(
|
|
478
|
+
f"Failed to rollback to savepoint {name}: {rollback_error}"
|
|
479
|
+
)
|
|
480
|
+
|
|
481
|
+
raise e
|
|
482
|
+
|
|
483
|
+
finally:
|
|
484
|
+
# Release savepoint
|
|
485
|
+
try:
|
|
486
|
+
self.tx_context.execute(f"RELEASE SAVEPOINT {name}")
|
|
487
|
+
logger.debug(f"Released savepoint {name}")
|
|
488
|
+
except Exception as e:
|
|
489
|
+
logger.warning(f"Failed to release savepoint {name}: {e}")
|
|
490
|
+
|
|
491
|
+
|
|
492
|
+
# Convenience functions for common transaction patterns
|
|
493
|
+
def atomic_operation(connector):
|
|
494
|
+
"""Decorator for atomic database operations"""
|
|
495
|
+
|
|
496
|
+
def decorator(func):
|
|
497
|
+
def wrapper(*args, **kwargs):
|
|
498
|
+
tm = TransactionManager(connector)
|
|
499
|
+
|
|
500
|
+
def operation():
|
|
501
|
+
return func(*args, **kwargs)
|
|
502
|
+
|
|
503
|
+
return tm.execute_with_retry(operation)
|
|
504
|
+
|
|
505
|
+
return wrapper
|
|
506
|
+
|
|
507
|
+
return decorator
|
|
508
|
+
|
|
509
|
+
|
|
510
|
+
def bulk_insert_transaction(
|
|
511
|
+
connector, table: str, data: List[Dict[str, Any]], batch_size: int = 1000
|
|
512
|
+
) -> TransactionResult:
|
|
513
|
+
"""Perform bulk insert with proper transaction management"""
|
|
514
|
+
from .input_validator import DatabaseInputValidator
|
|
515
|
+
|
|
516
|
+
tm = TransactionManager(connector)
|
|
517
|
+
operations = []
|
|
518
|
+
|
|
519
|
+
# Validate and prepare operations
|
|
520
|
+
for i in range(0, len(data), batch_size):
|
|
521
|
+
batch = data[i : i + batch_size]
|
|
522
|
+
|
|
523
|
+
# Validate batch data
|
|
524
|
+
for row in batch:
|
|
525
|
+
validated_row = DatabaseInputValidator.validate_insert_params(table, row)
|
|
526
|
+
|
|
527
|
+
# Create insert operation
|
|
528
|
+
columns = list(validated_row.keys())
|
|
529
|
+
placeholders = ",".join(
|
|
530
|
+
["?" if connector.database_type.value == "sqlite" else "%s"]
|
|
531
|
+
* len(columns)
|
|
532
|
+
)
|
|
533
|
+
|
|
534
|
+
query = f"INSERT INTO {table} ({','.join(columns)}) VALUES ({placeholders})"
|
|
535
|
+
params = list(validated_row.values())
|
|
536
|
+
|
|
537
|
+
operations.append(
|
|
538
|
+
TransactionOperation(
|
|
539
|
+
query=query,
|
|
540
|
+
params=params,
|
|
541
|
+
operation_type="insert",
|
|
542
|
+
table=table,
|
|
543
|
+
expected_rows=1,
|
|
544
|
+
)
|
|
545
|
+
)
|
|
546
|
+
|
|
547
|
+
return tm.execute_atomic_operations(operations)
|