esuls 0.1.15__tar.gz → 0.1.17__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {esuls-0.1.15/src/esuls.egg-info → esuls-0.1.17}/PKG-INFO +1 -1
- {esuls-0.1.15 → esuls-0.1.17}/pyproject.toml +1 -1
- {esuls-0.1.15 → esuls-0.1.17}/src/esuls/db_cli.py +145 -65
- {esuls-0.1.15 → esuls-0.1.17}/src/esuls/request_cli.py +9 -0
- {esuls-0.1.15 → esuls-0.1.17}/src/esuls/tests/test_db_concurrent.py +4 -0
- esuls-0.1.17/src/esuls/tests/test_db_fixes.py +443 -0
- {esuls-0.1.15 → esuls-0.1.17/src/esuls.egg-info}/PKG-INFO +1 -1
- {esuls-0.1.15 → esuls-0.1.17}/src/esuls.egg-info/SOURCES.txt +2 -1
- {esuls-0.1.15 → esuls-0.1.17}/LICENSE +0 -0
- {esuls-0.1.15 → esuls-0.1.17}/README.md +0 -0
- {esuls-0.1.15 → esuls-0.1.17}/setup.cfg +0 -0
- {esuls-0.1.15 → esuls-0.1.17}/src/esuls/__init__.py +0 -0
- {esuls-0.1.15 → esuls-0.1.17}/src/esuls/download_icon.py +0 -0
- {esuls-0.1.15 → esuls-0.1.17}/src/esuls/utils.py +0 -0
- {esuls-0.1.15 → esuls-0.1.17}/src/esuls.egg-info/dependency_links.txt +0 -0
- {esuls-0.1.15 → esuls-0.1.17}/src/esuls.egg-info/requires.txt +0 -0
- {esuls-0.1.15 → esuls-0.1.17}/src/esuls.egg-info/top_level.txt +0 -0
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "esuls"
|
|
7
|
-
version = "0.1.
|
|
7
|
+
version = "0.1.17"
|
|
8
8
|
description = "Utility library for async database operations, HTTP requests, and parallel execution"
|
|
9
9
|
readme = "README.md"
|
|
10
10
|
requires-python = ">=3.14"
|
|
@@ -1,6 +1,10 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import aiosqlite
|
|
3
|
+
import ast
|
|
3
4
|
import json
|
|
5
|
+
import re
|
|
6
|
+
import threading
|
|
7
|
+
import dataclasses
|
|
4
8
|
from datetime import datetime
|
|
5
9
|
from pathlib import Path
|
|
6
10
|
from typing import Any, Dict, List, Optional, TypeVar, Generic, Type, get_type_hints, Union, Tuple
|
|
@@ -11,6 +15,13 @@ import contextlib
|
|
|
11
15
|
import enum
|
|
12
16
|
from loguru import logger
|
|
13
17
|
|
|
18
|
+
_VALID_IDENTIFIER = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$')
|
|
19
|
+
|
|
20
|
+
def _validate_identifier(name: str) -> str:
|
|
21
|
+
if not _VALID_IDENTIFIER.match(name):
|
|
22
|
+
raise ValueError(f"Invalid SQL identifier: {name!r}")
|
|
23
|
+
return name
|
|
24
|
+
|
|
14
25
|
T = TypeVar('T')
|
|
15
26
|
SchemaType = TypeVar('SchemaType', bound='BaseModel')
|
|
16
27
|
|
|
@@ -33,6 +44,8 @@ class AsyncDB(Generic[SchemaType]):
|
|
|
33
44
|
_db_locks: dict[str, asyncio.Lock] = {}
|
|
34
45
|
# Lock for schema initialization (class-level)
|
|
35
46
|
_schema_init_lock: asyncio.Lock = None
|
|
47
|
+
# Threading lock to guard class-level dict mutations
|
|
48
|
+
_db_locks_guard = threading.Lock()
|
|
36
49
|
|
|
37
50
|
def __init__(self, db_path: Union[str, Path], table_name: str, schema_class: Type[SchemaType]):
|
|
38
51
|
"""Initialize AsyncDB with a path and schema dataclass."""
|
|
@@ -41,34 +54,49 @@ class AsyncDB(Generic[SchemaType]):
|
|
|
41
54
|
|
|
42
55
|
self.db_path = Path(db_path).resolve()
|
|
43
56
|
self.schema_class = schema_class
|
|
44
|
-
self.table_name = table_name
|
|
57
|
+
self.table_name = _validate_identifier(table_name)
|
|
45
58
|
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
46
59
|
|
|
60
|
+
# Validate all field names upfront
|
|
61
|
+
for f in fields(schema_class):
|
|
62
|
+
_validate_identifier(f.name)
|
|
63
|
+
|
|
47
64
|
# Make schema initialization unique per instance
|
|
48
65
|
self._db_key = f"{str(self.db_path)}:{self.table_name}:{self.schema_class.__name__}"
|
|
49
66
|
|
|
50
67
|
# Use shared lock per database file (not per instance)
|
|
51
68
|
db_path_str = str(self.db_path)
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
69
|
+
with AsyncDB._db_locks_guard:
|
|
70
|
+
if db_path_str not in AsyncDB._db_locks:
|
|
71
|
+
AsyncDB._db_locks[db_path_str] = asyncio.Lock()
|
|
72
|
+
self._write_lock = AsyncDB._db_locks[db_path_str]
|
|
55
73
|
|
|
56
74
|
self._type_hints = get_type_hints(schema_class)
|
|
57
|
-
|
|
75
|
+
|
|
76
|
+
# Persistent connection (lazy init)
|
|
77
|
+
self._connection: Optional[aiosqlite.Connection] = None
|
|
78
|
+
|
|
58
79
|
# Use a class-level set to track initialized schemas
|
|
59
80
|
if not hasattr(AsyncDB, '_initialized_schemas'):
|
|
60
81
|
AsyncDB._initialized_schemas = set()
|
|
61
82
|
|
|
62
|
-
async def
|
|
63
|
-
"""
|
|
83
|
+
async def _ensure_connection(self, max_retries: int = 5) -> aiosqlite.Connection:
|
|
84
|
+
"""Return the persistent connection, creating it on first call with retry logic."""
|
|
85
|
+
if self._connection is not None:
|
|
86
|
+
return self._connection
|
|
87
|
+
|
|
64
88
|
# Ensure schema init lock exists (lazy init for asyncio compatibility)
|
|
65
|
-
|
|
66
|
-
AsyncDB._schema_init_lock
|
|
89
|
+
with AsyncDB._db_locks_guard:
|
|
90
|
+
if AsyncDB._schema_init_lock is None:
|
|
91
|
+
AsyncDB._schema_init_lock = asyncio.Lock()
|
|
67
92
|
|
|
68
93
|
last_error = None
|
|
69
94
|
for attempt in range(max_retries):
|
|
70
95
|
try:
|
|
71
|
-
db =
|
|
96
|
+
db = aiosqlite.connect(self.db_path, timeout=30.0)
|
|
97
|
+
# Mark aiosqlite's thread as daemon so it won't block process exit
|
|
98
|
+
db.daemon = True
|
|
99
|
+
db = await db
|
|
72
100
|
# Fast WAL mode with minimal sync
|
|
73
101
|
await db.execute("PRAGMA journal_mode=WAL")
|
|
74
102
|
await db.execute("PRAGMA synchronous=NORMAL")
|
|
@@ -83,6 +111,7 @@ class AsyncDB(Generic[SchemaType]):
|
|
|
83
111
|
await self._init_schema(db)
|
|
84
112
|
AsyncDB._initialized_schemas.add(self._db_key)
|
|
85
113
|
|
|
114
|
+
self._connection = db
|
|
86
115
|
return db
|
|
87
116
|
except Exception as e:
|
|
88
117
|
last_error = e
|
|
@@ -93,7 +122,22 @@ class AsyncDB(Generic[SchemaType]):
|
|
|
93
122
|
continue
|
|
94
123
|
raise
|
|
95
124
|
raise last_error
|
|
96
|
-
|
|
125
|
+
|
|
126
|
+
async def __aenter__(self):
|
|
127
|
+
return self
|
|
128
|
+
|
|
129
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
130
|
+
await self.close()
|
|
131
|
+
|
|
132
|
+
async def close(self) -> None:
|
|
133
|
+
"""Explicitly close the persistent connection."""
|
|
134
|
+
if self._connection is not None:
|
|
135
|
+
try:
|
|
136
|
+
await self._connection.close()
|
|
137
|
+
except Exception:
|
|
138
|
+
pass
|
|
139
|
+
self._connection = None
|
|
140
|
+
|
|
97
141
|
async def _init_schema(self, db: aiosqlite.Connection) -> None:
|
|
98
142
|
"""Generate schema from dataclass structure with support for field additions."""
|
|
99
143
|
logger.debug(f"Initializing schema for {self.schema_class.__name__} in table {self.table_name}")
|
|
@@ -146,7 +190,7 @@ class AsyncDB(Generic[SchemaType]):
|
|
|
146
190
|
constraints.append("PRIMARY KEY")
|
|
147
191
|
if f.metadata.get('unique'):
|
|
148
192
|
constraints.append("UNIQUE")
|
|
149
|
-
if
|
|
193
|
+
if f.default is dataclasses.MISSING and f.default_factory is dataclasses.MISSING and f.metadata.get('required', True):
|
|
150
194
|
constraints.append("NOT NULL")
|
|
151
195
|
|
|
152
196
|
field_def = f"{field_name} {sql_type} {' '.join(constraints)}"
|
|
@@ -193,16 +237,27 @@ class AsyncDB(Generic[SchemaType]):
|
|
|
193
237
|
|
|
194
238
|
@contextlib.asynccontextmanager
|
|
195
239
|
async def transaction(self):
|
|
196
|
-
"""Run operations in a transaction with reliable cleanup."""
|
|
197
|
-
db = await self.
|
|
240
|
+
"""Run operations in a transaction with reliable cleanup and auto-reconnect."""
|
|
241
|
+
db = await self._ensure_connection()
|
|
198
242
|
try:
|
|
199
243
|
yield db
|
|
200
244
|
await db.commit()
|
|
201
|
-
except Exception:
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
245
|
+
except Exception as e:
|
|
246
|
+
try:
|
|
247
|
+
await db.rollback()
|
|
248
|
+
except Exception:
|
|
249
|
+
pass
|
|
250
|
+
if "closed" in str(e).lower() or "no active connection" in str(e).lower():
|
|
251
|
+
self._connection = None
|
|
252
|
+
db = await self._ensure_connection()
|
|
253
|
+
try:
|
|
254
|
+
yield db
|
|
255
|
+
await db.commit()
|
|
256
|
+
except Exception:
|
|
257
|
+
await db.rollback()
|
|
258
|
+
raise
|
|
259
|
+
else:
|
|
260
|
+
raise
|
|
206
261
|
|
|
207
262
|
# @lru_cache(maxsize=128)
|
|
208
263
|
def _serialize_value(self, value: Any) -> Any:
|
|
@@ -232,12 +287,15 @@ class AsyncDB(Generic[SchemaType]):
|
|
|
232
287
|
return value
|
|
233
288
|
# If somehow stored as string, convert back
|
|
234
289
|
if isinstance(value, str):
|
|
235
|
-
import ast
|
|
236
290
|
try:
|
|
237
291
|
return ast.literal_eval(value)
|
|
238
|
-
except:
|
|
292
|
+
except (ValueError, SyntaxError):
|
|
239
293
|
return value.encode('utf-8')
|
|
240
294
|
|
|
295
|
+
# Handle bool fields - SQLite stores as INTEGER, need to convert back
|
|
296
|
+
if field_type is bool:
|
|
297
|
+
return bool(value)
|
|
298
|
+
|
|
241
299
|
# Handle string fields - ensure phone numbers are strings
|
|
242
300
|
if field_type is str or (hasattr(field_type, '__origin__') and field_type.__origin__ is Union and str in getattr(field_type, '__args__', ())):
|
|
243
301
|
return str(value)
|
|
@@ -250,12 +308,12 @@ class AsyncDB(Generic[SchemaType]):
|
|
|
250
308
|
# Handle Optional[EnumType] case
|
|
251
309
|
args = getattr(field_type, '__args__', ())
|
|
252
310
|
for arg in args:
|
|
253
|
-
if arg is not type(None) and
|
|
311
|
+
if arg is not type(None) and isinstance(arg, type) and issubclass(arg, enum.Enum):
|
|
254
312
|
try:
|
|
255
313
|
return arg(value)
|
|
256
314
|
except (ValueError, TypeError):
|
|
257
315
|
pass
|
|
258
|
-
elif
|
|
316
|
+
elif isinstance(field_type, type) and issubclass(field_type, enum.Enum):
|
|
259
317
|
# Handle direct enum types
|
|
260
318
|
try:
|
|
261
319
|
return field_type(value)
|
|
@@ -276,11 +334,27 @@ class AsyncDB(Generic[SchemaType]):
|
|
|
276
334
|
columns = ','.join(field_names)
|
|
277
335
|
placeholders = ','.join('?' for _ in field_names)
|
|
278
336
|
|
|
337
|
+
set_clause = ','.join(f'{col}=excluded.{col}' for col in field_names if col != 'created_at')
|
|
279
338
|
return f"""
|
|
280
|
-
INSERT
|
|
339
|
+
INSERT INTO {self.table_name} ({columns},id)
|
|
281
340
|
VALUES ({placeholders},?)
|
|
341
|
+
ON CONFLICT(id) DO UPDATE SET {set_clause}
|
|
282
342
|
"""
|
|
283
343
|
|
|
344
|
+
def _prepare_item(self, item: SchemaType) -> Tuple[str, List[Any]]:
|
|
345
|
+
"""Prepare an item for saving. Returns (sql, values)."""
|
|
346
|
+
data = asdict(item)
|
|
347
|
+
item_id = data.pop('id', None) or str(uuid.uuid4())
|
|
348
|
+
now = datetime.now()
|
|
349
|
+
if not data.get('created_at'):
|
|
350
|
+
data['created_at'] = now
|
|
351
|
+
data['updated_at'] = now
|
|
352
|
+
field_names = tuple(sorted(data.keys()))
|
|
353
|
+
sql = self._generate_save_sql(field_names)
|
|
354
|
+
values = [self._serialize_value(data[name]) for name in field_names]
|
|
355
|
+
values.append(item_id)
|
|
356
|
+
return sql, values
|
|
357
|
+
|
|
284
358
|
async def save_batch(self, items: List[SchemaType], skip_errors: bool = True) -> int:
|
|
285
359
|
"""Save multiple items in a single transaction for better performance.
|
|
286
360
|
|
|
@@ -299,6 +373,7 @@ class AsyncDB(Generic[SchemaType]):
|
|
|
299
373
|
max_retries = 3
|
|
300
374
|
for attempt in range(max_retries):
|
|
301
375
|
try:
|
|
376
|
+
saved_count = 0
|
|
302
377
|
async with self._write_lock:
|
|
303
378
|
async with self.transaction() as db:
|
|
304
379
|
for item in items:
|
|
@@ -308,23 +383,7 @@ class AsyncDB(Generic[SchemaType]):
|
|
|
308
383
|
raise TypeError(f"Expected {self.schema_class.__name__}, got {type(item).__name__}")
|
|
309
384
|
continue
|
|
310
385
|
|
|
311
|
-
|
|
312
|
-
data = asdict(item)
|
|
313
|
-
item_id = data.pop('id', None) or str(uuid.uuid4())
|
|
314
|
-
|
|
315
|
-
# Ensure created_at and updated_at are set
|
|
316
|
-
now = datetime.now()
|
|
317
|
-
if not data.get('created_at'):
|
|
318
|
-
data['created_at'] = now
|
|
319
|
-
data['updated_at'] = now
|
|
320
|
-
|
|
321
|
-
# Prepare SQL and values
|
|
322
|
-
field_names = tuple(sorted(data.keys()))
|
|
323
|
-
sql = self._generate_save_sql(field_names)
|
|
324
|
-
values = [self._serialize_value(data[name]) for name in field_names]
|
|
325
|
-
values.append(item_id)
|
|
326
|
-
|
|
327
|
-
# Execute save
|
|
386
|
+
sql, values = self._prepare_item(item)
|
|
328
387
|
await db.execute(sql, values)
|
|
329
388
|
saved_count += 1
|
|
330
389
|
|
|
@@ -360,21 +419,7 @@ class AsyncDB(Generic[SchemaType]):
|
|
|
360
419
|
return False
|
|
361
420
|
raise TypeError(f"Expected {self.schema_class.__name__}, got {type(item).__name__}")
|
|
362
421
|
|
|
363
|
-
|
|
364
|
-
data = asdict(item)
|
|
365
|
-
item_id = data.pop('id', None) or str(uuid.uuid4())
|
|
366
|
-
|
|
367
|
-
# Ensure created_at and updated_at are set
|
|
368
|
-
now = datetime.now()
|
|
369
|
-
if not data.get('created_at'):
|
|
370
|
-
data['created_at'] = now
|
|
371
|
-
data['updated_at'] = now
|
|
372
|
-
|
|
373
|
-
# Prepare SQL and values
|
|
374
|
-
field_names = tuple(sorted(data.keys()))
|
|
375
|
-
sql = self._generate_save_sql(field_names)
|
|
376
|
-
values = [self._serialize_value(data[name]) for name in field_names]
|
|
377
|
-
values.append(item_id)
|
|
422
|
+
sql, values = self._prepare_item(item)
|
|
378
423
|
|
|
379
424
|
# Perform save with reliable transaction (retry on "database is locked")
|
|
380
425
|
max_retries = 3
|
|
@@ -425,10 +470,6 @@ class AsyncDB(Generic[SchemaType]):
|
|
|
425
470
|
values = []
|
|
426
471
|
|
|
427
472
|
for key, value in filters.items():
|
|
428
|
-
# Handle special values
|
|
429
|
-
if value == 'now':
|
|
430
|
-
value = datetime.now()
|
|
431
|
-
|
|
432
473
|
# Parse field and operator
|
|
433
474
|
parts = key.split('__', 1)
|
|
434
475
|
field = parts[0]
|
|
@@ -451,22 +492,32 @@ class AsyncDB(Generic[SchemaType]):
|
|
|
451
492
|
|
|
452
493
|
return f"WHERE {' AND '.join(conditions)}", values
|
|
453
494
|
|
|
454
|
-
async def find(self, order_by=None, **filters) -> List[SchemaType]:
|
|
495
|
+
async def find(self, order_by=None, limit: int = None, offset: int = None, **filters) -> List[SchemaType]:
|
|
455
496
|
"""Query items with reliable connection handling."""
|
|
456
497
|
where_clause, values = self._build_where_clause(filters)
|
|
457
|
-
|
|
498
|
+
|
|
458
499
|
# Build query
|
|
459
500
|
query = f"SELECT * FROM {self.table_name} {where_clause}"
|
|
460
|
-
|
|
501
|
+
|
|
461
502
|
# Add ORDER BY clause if specified
|
|
462
503
|
if order_by:
|
|
463
504
|
order_fields = [order_by] if isinstance(order_by, str) else order_by
|
|
464
505
|
order_clauses = [
|
|
465
|
-
f"{field[1:]} DESC" if field.startswith('-') else f"{field} ASC"
|
|
506
|
+
f"{field[1:]} DESC" if field.startswith('-') else f"{field} ASC"
|
|
466
507
|
for field in order_fields
|
|
467
508
|
]
|
|
468
509
|
query += f" ORDER BY {', '.join(order_clauses)}"
|
|
469
|
-
|
|
510
|
+
|
|
511
|
+
# Add LIMIT/OFFSET (SQLite requires LIMIT before OFFSET)
|
|
512
|
+
if limit is not None:
|
|
513
|
+
query += " LIMIT ?"
|
|
514
|
+
values.append(limit)
|
|
515
|
+
elif offset is not None:
|
|
516
|
+
query += " LIMIT -1"
|
|
517
|
+
if offset is not None:
|
|
518
|
+
query += " OFFSET ?"
|
|
519
|
+
values.append(offset)
|
|
520
|
+
|
|
470
521
|
# Execute query with reliable transaction
|
|
471
522
|
async with self.transaction() as db:
|
|
472
523
|
cursor = await db.execute(query, values)
|
|
@@ -504,4 +555,33 @@ class AsyncDB(Generic[SchemaType]):
|
|
|
504
555
|
async with self._write_lock:
|
|
505
556
|
async with self.transaction() as db:
|
|
506
557
|
cursor = await db.execute(f"DELETE FROM {self.table_name} WHERE id = ?", (id,))
|
|
558
|
+
return cursor.rowcount > 0
|
|
559
|
+
|
|
560
|
+
async def exists(self, **filters) -> bool:
|
|
561
|
+
"""Check if any record matches the filters without fetching data."""
|
|
562
|
+
return await self.count(**filters) > 0
|
|
563
|
+
|
|
564
|
+
async def delete_many(self, **filters) -> int:
|
|
565
|
+
"""Delete all items matching filters. Returns count of deleted rows."""
|
|
566
|
+
if not filters:
|
|
567
|
+
raise ValueError("delete_many() requires at least one filter to prevent accidental full table delete")
|
|
568
|
+
where_clause, values = self._build_where_clause(filters)
|
|
569
|
+
async with self._write_lock:
|
|
570
|
+
async with self.transaction() as db:
|
|
571
|
+
cursor = await db.execute(f"DELETE FROM {self.table_name} {where_clause}", values)
|
|
572
|
+
return cursor.rowcount
|
|
573
|
+
|
|
574
|
+
async def update_fields(self, id: str, **fields) -> bool:
|
|
575
|
+
"""Update specific fields on a record by ID without fetching the full record."""
|
|
576
|
+
if not fields:
|
|
577
|
+
return False
|
|
578
|
+
fields['updated_at'] = datetime.now()
|
|
579
|
+
set_clause = ', '.join(f"{_validate_identifier(k)} = ?" for k in fields)
|
|
580
|
+
values = [self._serialize_value(v) for v in fields.values()]
|
|
581
|
+
values.append(id)
|
|
582
|
+
async with self._write_lock:
|
|
583
|
+
async with self.transaction() as db:
|
|
584
|
+
cursor = await db.execute(
|
|
585
|
+
f"UPDATE {self.table_name} SET {set_clause} WHERE id = ?", values
|
|
586
|
+
)
|
|
507
587
|
return cursor.rowcount > 0
|
|
@@ -256,6 +256,15 @@ async def close_shared_client() -> None:
|
|
|
256
256
|
_domain_clients.clear()
|
|
257
257
|
|
|
258
258
|
|
|
259
|
+
async def cleanup_all() -> None:
|
|
260
|
+
"""Close all global HTTP resources (domain clients + cffi session)."""
|
|
261
|
+
await close_shared_client()
|
|
262
|
+
if _get_session_cffi.cache_info().currsize > 0:
|
|
263
|
+
cffi_session = _get_session_cffi()
|
|
264
|
+
await cffi_session.close()
|
|
265
|
+
_get_session_cffi.cache_clear()
|
|
266
|
+
|
|
267
|
+
|
|
259
268
|
async def close_domain_client(url: str, http2: Optional[bool] = None) -> None:
|
|
260
269
|
"""Close HTTP client for a specific domain. If http2 is None, closes both h1 and h2 clients."""
|
|
261
270
|
domain = _extract_domain(url)
|
|
@@ -33,6 +33,7 @@ async def test_concurrent_reads(temp_db):
|
|
|
33
33
|
|
|
34
34
|
# All reads should succeed and return same data
|
|
35
35
|
assert all(len(r) == 10 for r in results)
|
|
36
|
+
await db.close()
|
|
36
37
|
print(f"✓ 100 concurrent reads completed successfully")
|
|
37
38
|
|
|
38
39
|
|
|
@@ -53,6 +54,7 @@ async def test_concurrent_writes(temp_db):
|
|
|
53
54
|
# Verify all items were saved
|
|
54
55
|
items = await db.find()
|
|
55
56
|
assert len(items) == 50
|
|
57
|
+
await db.close()
|
|
56
58
|
print(f"✓ 50 concurrent writes completed successfully")
|
|
57
59
|
|
|
58
60
|
|
|
@@ -92,6 +94,7 @@ async def test_concurrent_mixed_operations(temp_db):
|
|
|
92
94
|
# Verify final state
|
|
93
95
|
items = await db.find()
|
|
94
96
|
assert len(items) == 55 # 5 seed + 50 writes
|
|
97
|
+
await db.close()
|
|
95
98
|
print(f"✓ 200 concurrent mixed operations completed successfully")
|
|
96
99
|
|
|
97
100
|
|
|
@@ -124,6 +127,7 @@ async def test_stress_concurrent_access(temp_db):
|
|
|
124
127
|
|
|
125
128
|
# Should have very few or no failures with retry logic
|
|
126
129
|
assert len(exceptions) == 0, f"{len(exceptions)} operations failed"
|
|
130
|
+
await db.close()
|
|
127
131
|
print(f"✓ 500 concurrent stress operations completed successfully")
|
|
128
132
|
|
|
129
133
|
|
|
@@ -0,0 +1,443 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Tests for specific fixes applied to AsyncDB.
|
|
3
|
+
Covers: limit/offset, save_batch reset, enum deserialization,
|
|
4
|
+
NOT NULL schema logic, identifier validation, connection pooling.
|
|
5
|
+
"""
|
|
6
|
+
import asyncio
|
|
7
|
+
import enum
|
|
8
|
+
import tempfile
|
|
9
|
+
from dataclasses import dataclass, field
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import Optional
|
|
12
|
+
|
|
13
|
+
from esuls.db_cli import AsyncDB, BaseModel, _validate_identifier
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
# --- Test models ---
|
|
17
|
+
|
|
18
|
+
class Color(enum.Enum):
|
|
19
|
+
RED = "red"
|
|
20
|
+
GREEN = "green"
|
|
21
|
+
BLUE = "blue"
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class Priority(enum.IntEnum):
|
|
25
|
+
LOW = 1
|
|
26
|
+
MEDIUM = 2
|
|
27
|
+
HIGH = 3
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class Status(str, enum.Enum):
|
|
31
|
+
ACTIVE = "active"
|
|
32
|
+
INACTIVE = "inactive"
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass
|
|
36
|
+
class EnumItem(BaseModel):
|
|
37
|
+
color: Optional[Color] = None
|
|
38
|
+
priority: Optional[Priority] = None
|
|
39
|
+
status: Optional[Status] = None
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclass
|
|
43
|
+
class DefaultsItem(BaseModel):
|
|
44
|
+
name: str = ""
|
|
45
|
+
count: int = 0
|
|
46
|
+
flag: bool = False
|
|
47
|
+
score: float = 0.0
|
|
48
|
+
required_field: str = field(default="", metadata={"required": False})
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@dataclass
|
|
52
|
+
class TestItem(BaseModel):
|
|
53
|
+
name: str = ""
|
|
54
|
+
value: int = 0
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
# --- Tests ---
|
|
58
|
+
|
|
59
|
+
async def test_find_limit_offset(temp_db):
|
|
60
|
+
"""Test limit and offset parameters on find()."""
|
|
61
|
+
db = AsyncDB(temp_db, "items", TestItem)
|
|
62
|
+
|
|
63
|
+
try:
|
|
64
|
+
# Insert 20 items with predictable ordering
|
|
65
|
+
for i in range(20):
|
|
66
|
+
await db.save(TestItem(name=f"item_{i:02d}", value=i))
|
|
67
|
+
|
|
68
|
+
# Test limit
|
|
69
|
+
results = await db.find(order_by="value", limit=5)
|
|
70
|
+
assert len(results) == 5
|
|
71
|
+
assert results[0].value == 0
|
|
72
|
+
assert results[4].value == 4
|
|
73
|
+
|
|
74
|
+
# Test offset
|
|
75
|
+
results = await db.find(order_by="value", limit=5, offset=10)
|
|
76
|
+
assert len(results) == 5
|
|
77
|
+
assert results[0].value == 10
|
|
78
|
+
assert results[4].value == 14
|
|
79
|
+
|
|
80
|
+
# Test offset beyond data
|
|
81
|
+
results = await db.find(order_by="value", limit=5, offset=18)
|
|
82
|
+
assert len(results) == 2
|
|
83
|
+
assert results[0].value == 18
|
|
84
|
+
|
|
85
|
+
# Test limit without offset
|
|
86
|
+
results = await db.find(order_by="value", limit=3)
|
|
87
|
+
assert len(results) == 3
|
|
88
|
+
|
|
89
|
+
# Test offset without limit (should return from offset to end)
|
|
90
|
+
results = await db.find(order_by="value", offset=17)
|
|
91
|
+
assert len(results) == 3
|
|
92
|
+
|
|
93
|
+
# Test no limit/offset (all results)
|
|
94
|
+
results = await db.find()
|
|
95
|
+
assert len(results) == 20
|
|
96
|
+
|
|
97
|
+
print("✓ find() limit/offset works correctly")
|
|
98
|
+
finally:
|
|
99
|
+
await db.close()
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
async def test_save_batch_count_reset(temp_db):
|
|
103
|
+
"""Test that saved_count resets properly on retry in save_batch()."""
|
|
104
|
+
db = AsyncDB(temp_db, "items", TestItem)
|
|
105
|
+
|
|
106
|
+
try:
|
|
107
|
+
items = [TestItem(name=f"batch_{i}", value=i) for i in range(10)]
|
|
108
|
+
count = await db.save_batch(items)
|
|
109
|
+
assert count == 10, f"Expected 10, got {count}"
|
|
110
|
+
|
|
111
|
+
# Save again (upsert) — count should still be exactly 10
|
|
112
|
+
count2 = await db.save_batch(items)
|
|
113
|
+
assert count2 == 10, f"Expected 10 on re-save, got {count2}"
|
|
114
|
+
|
|
115
|
+
print("✓ save_batch() count is accurate")
|
|
116
|
+
finally:
|
|
117
|
+
await db.close()
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
async def test_enum_deserialization_subclasses(temp_db):
|
|
121
|
+
"""Test enum deserialization with IntEnum, StrEnum, and regular Enum."""
|
|
122
|
+
db = AsyncDB(temp_db, "enums", EnumItem)
|
|
123
|
+
|
|
124
|
+
try:
|
|
125
|
+
# Save with all enum types
|
|
126
|
+
item = EnumItem(
|
|
127
|
+
color=Color.GREEN,
|
|
128
|
+
priority=Priority.HIGH,
|
|
129
|
+
status=Status.ACTIVE,
|
|
130
|
+
)
|
|
131
|
+
await db.save(item)
|
|
132
|
+
|
|
133
|
+
# Read back
|
|
134
|
+
results = await db.find()
|
|
135
|
+
assert len(results) == 1
|
|
136
|
+
loaded = results[0]
|
|
137
|
+
|
|
138
|
+
assert loaded.color == Color.GREEN, f"Expected Color.GREEN, got {loaded.color!r}"
|
|
139
|
+
assert loaded.priority == Priority.HIGH, f"Expected Priority.HIGH, got {loaded.priority!r}"
|
|
140
|
+
assert loaded.status == Status.ACTIVE, f"Expected Status.ACTIVE, got {loaded.status!r}"
|
|
141
|
+
|
|
142
|
+
# Test None values
|
|
143
|
+
item2 = EnumItem()
|
|
144
|
+
await db.save(item2)
|
|
145
|
+
results = await db.find()
|
|
146
|
+
none_items = [r for r in results if r.color is None]
|
|
147
|
+
assert len(none_items) == 1
|
|
148
|
+
assert none_items[0].priority is None
|
|
149
|
+
assert none_items[0].status is None
|
|
150
|
+
|
|
151
|
+
print("✓ Enum deserialization works for Enum, IntEnum, StrEnum")
|
|
152
|
+
finally:
|
|
153
|
+
await db.close()
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
async def test_not_null_with_falsy_defaults(temp_db):
|
|
157
|
+
"""Test that fields with falsy defaults (0, '', False) don't get NOT NULL."""
|
|
158
|
+
db = AsyncDB(temp_db, "defaults_test", DefaultsItem)
|
|
159
|
+
|
|
160
|
+
try:
|
|
161
|
+
# If the schema was created correctly, saving an item with default values should work
|
|
162
|
+
item = DefaultsItem()
|
|
163
|
+
result = await db.save(item)
|
|
164
|
+
assert result is True
|
|
165
|
+
|
|
166
|
+
# Verify roundtrip
|
|
167
|
+
items = await db.find()
|
|
168
|
+
assert len(items) == 1
|
|
169
|
+
loaded = items[0]
|
|
170
|
+
assert loaded.name == ""
|
|
171
|
+
assert loaded.count == 0
|
|
172
|
+
assert loaded.flag is False
|
|
173
|
+
|
|
174
|
+
print("✓ NOT NULL logic correctly handles falsy defaults")
|
|
175
|
+
finally:
|
|
176
|
+
await db.close()
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
async def test_identifier_validation():
|
|
180
|
+
"""Test that invalid SQL identifiers are rejected."""
|
|
181
|
+
# Valid identifiers
|
|
182
|
+
assert _validate_identifier("items") == "items"
|
|
183
|
+
assert _validate_identifier("my_table") == "my_table"
|
|
184
|
+
assert _validate_identifier("_private") == "_private"
|
|
185
|
+
assert _validate_identifier("Table123") == "Table123"
|
|
186
|
+
|
|
187
|
+
# Invalid identifiers
|
|
188
|
+
invalid_names = [
|
|
189
|
+
"Robert'; DROP TABLE students;--",
|
|
190
|
+
"my table",
|
|
191
|
+
"123abc",
|
|
192
|
+
"my-table",
|
|
193
|
+
"",
|
|
194
|
+
"table.name",
|
|
195
|
+
"col(umn)",
|
|
196
|
+
]
|
|
197
|
+
for name in invalid_names:
|
|
198
|
+
try:
|
|
199
|
+
_validate_identifier(name)
|
|
200
|
+
raise AssertionError(f"Should have rejected: {name!r}")
|
|
201
|
+
except ValueError:
|
|
202
|
+
pass # Expected
|
|
203
|
+
|
|
204
|
+
print("✓ Identifier validation rejects injection attempts")
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
async def test_identifier_validation_in_constructor(temp_db):
|
|
208
|
+
"""Test that AsyncDB constructor rejects invalid table names."""
|
|
209
|
+
try:
|
|
210
|
+
AsyncDB(temp_db, "valid_table; DROP TABLE x", TestItem)
|
|
211
|
+
raise AssertionError("Should have rejected invalid table name")
|
|
212
|
+
except ValueError:
|
|
213
|
+
pass
|
|
214
|
+
|
|
215
|
+
print("✓ Constructor rejects invalid table names")
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
async def test_connection_pooling(temp_db):
|
|
219
|
+
"""Test that persistent connection is reused across operations."""
|
|
220
|
+
db = AsyncDB(temp_db, "items", TestItem)
|
|
221
|
+
|
|
222
|
+
# First operation creates the connection
|
|
223
|
+
await db.save(TestItem(name="first", value=1))
|
|
224
|
+
conn1 = db._connection
|
|
225
|
+
assert conn1 is not None
|
|
226
|
+
|
|
227
|
+
# Subsequent operations reuse it
|
|
228
|
+
await db.find()
|
|
229
|
+
conn2 = db._connection
|
|
230
|
+
assert conn2 is conn1, "Connection should be reused"
|
|
231
|
+
|
|
232
|
+
await db.count()
|
|
233
|
+
conn3 = db._connection
|
|
234
|
+
assert conn3 is conn1, "Connection should still be reused"
|
|
235
|
+
|
|
236
|
+
await db.save(TestItem(name="second", value=2))
|
|
237
|
+
conn4 = db._connection
|
|
238
|
+
assert conn4 is conn1, "Connection should persist across saves"
|
|
239
|
+
|
|
240
|
+
# Explicit close
|
|
241
|
+
await db.close()
|
|
242
|
+
assert db._connection is None
|
|
243
|
+
|
|
244
|
+
# Next operation creates a new connection
|
|
245
|
+
await db.find()
|
|
246
|
+
conn5 = db._connection
|
|
247
|
+
assert conn5 is not None
|
|
248
|
+
assert conn5 is not conn1, "Should be a new connection after close"
|
|
249
|
+
|
|
250
|
+
await db.close()
|
|
251
|
+
print("✓ Connection pooling reuses connections correctly")
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
async def test_close_idempotent(temp_db):
|
|
255
|
+
"""Test that close() can be called multiple times safely."""
|
|
256
|
+
db = AsyncDB(temp_db, "items", TestItem)
|
|
257
|
+
await db.save(TestItem(name="test", value=1))
|
|
258
|
+
|
|
259
|
+
await db.close()
|
|
260
|
+
await db.close() # Should not raise
|
|
261
|
+
await db.close() # Should not raise
|
|
262
|
+
|
|
263
|
+
# Should still work after close
|
|
264
|
+
items = await db.find()
|
|
265
|
+
assert len(items) == 1
|
|
266
|
+
|
|
267
|
+
await db.close()
|
|
268
|
+
print("✓ close() is idempotent")
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
async def test_exists(temp_db):
|
|
272
|
+
"""Test exists() returns bool without fetching full records."""
|
|
273
|
+
db = AsyncDB(temp_db, "items", TestItem)
|
|
274
|
+
|
|
275
|
+
try:
|
|
276
|
+
assert await db.exists(name="nope") is False
|
|
277
|
+
|
|
278
|
+
await db.save(TestItem(name="hello", value=1))
|
|
279
|
+
assert await db.exists(name="hello") is True
|
|
280
|
+
assert await db.exists(name="nope") is False
|
|
281
|
+
assert await db.exists(value=1) is True
|
|
282
|
+
assert await db.exists(value=999) is False
|
|
283
|
+
|
|
284
|
+
print("✓ exists() works correctly")
|
|
285
|
+
finally:
|
|
286
|
+
await db.close()
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
async def test_delete_many(temp_db):
|
|
290
|
+
"""Test delete_many() deletes matching records and returns count."""
|
|
291
|
+
db = AsyncDB(temp_db, "items", TestItem)
|
|
292
|
+
|
|
293
|
+
try:
|
|
294
|
+
for i in range(10):
|
|
295
|
+
await db.save(TestItem(name="group_a" if i < 6 else "group_b", value=i))
|
|
296
|
+
|
|
297
|
+
# Delete group_a
|
|
298
|
+
deleted = await db.delete_many(name="group_a")
|
|
299
|
+
assert deleted == 6, f"Expected 6 deleted, got {deleted}"
|
|
300
|
+
|
|
301
|
+
remaining = await db.count()
|
|
302
|
+
assert remaining == 4, f"Expected 4 remaining, got {remaining}"
|
|
303
|
+
|
|
304
|
+
# Delete non-existent
|
|
305
|
+
deleted = await db.delete_many(name="group_c")
|
|
306
|
+
assert deleted == 0
|
|
307
|
+
|
|
308
|
+
# Must raise on empty filters
|
|
309
|
+
try:
|
|
310
|
+
await db.delete_many()
|
|
311
|
+
raise AssertionError("Should have raised ValueError")
|
|
312
|
+
except ValueError:
|
|
313
|
+
pass
|
|
314
|
+
|
|
315
|
+
print("✓ delete_many() works correctly")
|
|
316
|
+
finally:
|
|
317
|
+
await db.close()
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
async def test_update_fields(temp_db):
|
|
321
|
+
"""Test update_fields() updates specific fields without full fetch."""
|
|
322
|
+
db = AsyncDB(temp_db, "items", TestItem)
|
|
323
|
+
|
|
324
|
+
try:
|
|
325
|
+
item = TestItem(name="original", value=10)
|
|
326
|
+
await db.save(item)
|
|
327
|
+
|
|
328
|
+
# Update single field
|
|
329
|
+
result = await db.update_fields(item.id, name="updated")
|
|
330
|
+
assert result is True
|
|
331
|
+
|
|
332
|
+
loaded = await db.get_by_id(item.id)
|
|
333
|
+
assert loaded.name == "updated"
|
|
334
|
+
assert loaded.value == 10 # unchanged
|
|
335
|
+
|
|
336
|
+
# Update multiple fields
|
|
337
|
+
result = await db.update_fields(item.id, name="final", value=99)
|
|
338
|
+
assert result is True
|
|
339
|
+
loaded = await db.get_by_id(item.id)
|
|
340
|
+
assert loaded.name == "final"
|
|
341
|
+
assert loaded.value == 99
|
|
342
|
+
|
|
343
|
+
# Non-existent ID
|
|
344
|
+
result = await db.update_fields("nonexistent-id", name="x")
|
|
345
|
+
assert result is False
|
|
346
|
+
|
|
347
|
+
# Empty fields
|
|
348
|
+
result = await db.update_fields(item.id)
|
|
349
|
+
assert result is False
|
|
350
|
+
|
|
351
|
+
print("✓ update_fields() works correctly")
|
|
352
|
+
finally:
|
|
353
|
+
await db.close()
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
async def test_context_manager(temp_db):
|
|
357
|
+
"""Test async context manager for automatic cleanup."""
|
|
358
|
+
async with AsyncDB(temp_db, "items", TestItem) as db:
|
|
359
|
+
await db.save(TestItem(name="ctx", value=42))
|
|
360
|
+
items = await db.find()
|
|
361
|
+
assert len(items) == 1
|
|
362
|
+
assert items[0].name == "ctx"
|
|
363
|
+
|
|
364
|
+
# Connection should be closed after exiting context
|
|
365
|
+
assert db._connection is None
|
|
366
|
+
|
|
367
|
+
# Should still work when reopened
|
|
368
|
+
async with AsyncDB(temp_db, "items", TestItem) as db:
|
|
369
|
+
items = await db.find()
|
|
370
|
+
assert len(items) == 1
|
|
371
|
+
|
|
372
|
+
print("✓ Context manager works correctly")
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
async def test_prepare_item_dedup(temp_db):
|
|
376
|
+
"""Test that _prepare_item produces correct SQL and values for save."""
|
|
377
|
+
db = AsyncDB(temp_db, "items", TestItem)
|
|
378
|
+
|
|
379
|
+
try:
|
|
380
|
+
item = TestItem(name="test", value=5)
|
|
381
|
+
sql, values = db._prepare_item(item)
|
|
382
|
+
assert "ON CONFLICT(id) DO UPDATE SET" in sql
|
|
383
|
+
assert "test" in values
|
|
384
|
+
assert 5 in values
|
|
385
|
+
|
|
386
|
+
# Verify save still works end-to-end
|
|
387
|
+
await db.save(item)
|
|
388
|
+
loaded = await db.get_by_id(item.id)
|
|
389
|
+
assert loaded.name == "test"
|
|
390
|
+
assert loaded.value == 5
|
|
391
|
+
|
|
392
|
+
# Verify batch save still works
|
|
393
|
+
items = [TestItem(name=f"batch_{i}", value=i) for i in range(5)]
|
|
394
|
+
count = await db.save_batch(items)
|
|
395
|
+
assert count == 5
|
|
396
|
+
|
|
397
|
+
print("✓ _prepare_item() deduplication works correctly")
|
|
398
|
+
finally:
|
|
399
|
+
await db.close()
|
|
400
|
+
|
|
401
|
+
|
|
402
|
+
if __name__ == "__main__":
|
|
403
|
+
async def run_all_tests():
|
|
404
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
405
|
+
print("\n" + "=" * 60)
|
|
406
|
+
print("ASYNCDB FIX VERIFICATION TESTS")
|
|
407
|
+
print("=" * 60)
|
|
408
|
+
|
|
409
|
+
test_num = 0
|
|
410
|
+
|
|
411
|
+
async def run_test(name, coro):
|
|
412
|
+
nonlocal test_num
|
|
413
|
+
test_num += 1
|
|
414
|
+
# Use a fresh db for each test
|
|
415
|
+
db_path = Path(tmpdir) / f"test_fix_{test_num}.db"
|
|
416
|
+
print(f"\n[Test {test_num}] {name}...")
|
|
417
|
+
await coro(db_path)
|
|
418
|
+
|
|
419
|
+
async def run_test_no_db(name, coro):
|
|
420
|
+
nonlocal test_num
|
|
421
|
+
test_num += 1
|
|
422
|
+
print(f"\n[Test {test_num}] {name}...")
|
|
423
|
+
await coro()
|
|
424
|
+
|
|
425
|
+
await run_test("Limit/offset in find()", test_find_limit_offset)
|
|
426
|
+
await run_test("save_batch count reset", test_save_batch_count_reset)
|
|
427
|
+
await run_test("Enum subclass deserialization", test_enum_deserialization_subclasses)
|
|
428
|
+
await run_test("NOT NULL with falsy defaults", test_not_null_with_falsy_defaults)
|
|
429
|
+
await run_test_no_db("Identifier validation", test_identifier_validation)
|
|
430
|
+
await run_test("Constructor identifier validation", test_identifier_validation_in_constructor)
|
|
431
|
+
await run_test("Connection pooling", test_connection_pooling)
|
|
432
|
+
await run_test("Close idempotent", test_close_idempotent)
|
|
433
|
+
await run_test("exists()", test_exists)
|
|
434
|
+
await run_test("delete_many()", test_delete_many)
|
|
435
|
+
await run_test("update_fields()", test_update_fields)
|
|
436
|
+
await run_test("Context manager", test_context_manager)
|
|
437
|
+
await run_test("_prepare_item dedup", test_prepare_item_dedup)
|
|
438
|
+
|
|
439
|
+
print("\n" + "=" * 60)
|
|
440
|
+
print("ALL TESTS PASSED!")
|
|
441
|
+
print("=" * 60)
|
|
442
|
+
|
|
443
|
+
asyncio.run(run_all_tests())
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|