esuls 0.1.15__tar.gz → 0.1.16__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: esuls
3
- Version: 0.1.15
3
+ Version: 0.1.16
4
4
  Summary: Utility library for async database operations, HTTP requests, and parallel execution
5
5
  Author-email: IperGiove <ipergiove@gmail.com>
6
6
  License: MIT
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "esuls"
7
- version = "0.1.15"
7
+ version = "0.1.16"
8
8
  description = "Utility library for async database operations, HTTP requests, and parallel execution"
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.14"
@@ -1,6 +1,10 @@
1
1
  import asyncio
2
2
  import aiosqlite
3
+ import ast
3
4
  import json
5
+ import re
6
+ import threading
7
+ import dataclasses
4
8
  from datetime import datetime
5
9
  from pathlib import Path
6
10
  from typing import Any, Dict, List, Optional, TypeVar, Generic, Type, get_type_hints, Union, Tuple
@@ -11,6 +15,13 @@ import contextlib
11
15
  import enum
12
16
  from loguru import logger
13
17
 
18
+ _VALID_IDENTIFIER = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$')
19
+
20
+ def _validate_identifier(name: str) -> str:
21
+ if not _VALID_IDENTIFIER.match(name):
22
+ raise ValueError(f"Invalid SQL identifier: {name!r}")
23
+ return name
24
+
14
25
  T = TypeVar('T')
15
26
  SchemaType = TypeVar('SchemaType', bound='BaseModel')
16
27
 
@@ -33,6 +44,8 @@ class AsyncDB(Generic[SchemaType]):
33
44
  _db_locks: dict[str, asyncio.Lock] = {}
34
45
  # Lock for schema initialization (class-level)
35
46
  _schema_init_lock: asyncio.Lock = None
47
+ # Threading lock to guard class-level dict mutations
48
+ _db_locks_guard = threading.Lock()
36
49
 
37
50
  def __init__(self, db_path: Union[str, Path], table_name: str, schema_class: Type[SchemaType]):
38
51
  """Initialize AsyncDB with a path and schema dataclass."""
@@ -41,34 +54,49 @@ class AsyncDB(Generic[SchemaType]):
41
54
 
42
55
  self.db_path = Path(db_path).resolve()
43
56
  self.schema_class = schema_class
44
- self.table_name = table_name
57
+ self.table_name = _validate_identifier(table_name)
45
58
  self.db_path.parent.mkdir(parents=True, exist_ok=True)
46
59
 
60
+ # Validate all field names upfront
61
+ for f in fields(schema_class):
62
+ _validate_identifier(f.name)
63
+
47
64
  # Make schema initialization unique per instance
48
65
  self._db_key = f"{str(self.db_path)}:{self.table_name}:{self.schema_class.__name__}"
49
66
 
50
67
  # Use shared lock per database file (not per instance)
51
68
  db_path_str = str(self.db_path)
52
- if db_path_str not in AsyncDB._db_locks:
53
- AsyncDB._db_locks[db_path_str] = asyncio.Lock()
54
- self._write_lock = AsyncDB._db_locks[db_path_str]
69
+ with AsyncDB._db_locks_guard:
70
+ if db_path_str not in AsyncDB._db_locks:
71
+ AsyncDB._db_locks[db_path_str] = asyncio.Lock()
72
+ self._write_lock = AsyncDB._db_locks[db_path_str]
55
73
 
56
74
  self._type_hints = get_type_hints(schema_class)
57
-
75
+
76
+ # Persistent connection (lazy init)
77
+ self._connection: Optional[aiosqlite.Connection] = None
78
+
58
79
  # Use a class-level set to track initialized schemas
59
80
  if not hasattr(AsyncDB, '_initialized_schemas'):
60
81
  AsyncDB._initialized_schemas = set()
61
82
 
62
- async def _get_connection(self, max_retries: int = 5) -> aiosqlite.Connection:
63
- """Create a new optimized connection with retry logic for concurrent access."""
83
+ async def _ensure_connection(self, max_retries: int = 5) -> aiosqlite.Connection:
84
+ """Return the persistent connection, creating it on first call with retry logic."""
85
+ if self._connection is not None:
86
+ return self._connection
87
+
64
88
  # Ensure schema init lock exists (lazy init for asyncio compatibility)
65
- if AsyncDB._schema_init_lock is None:
66
- AsyncDB._schema_init_lock = asyncio.Lock()
89
+ with AsyncDB._db_locks_guard:
90
+ if AsyncDB._schema_init_lock is None:
91
+ AsyncDB._schema_init_lock = asyncio.Lock()
67
92
 
68
93
  last_error = None
69
94
  for attempt in range(max_retries):
70
95
  try:
71
- db = await aiosqlite.connect(self.db_path, timeout=30.0)
96
+ db = aiosqlite.connect(self.db_path, timeout=30.0)
97
+ # Mark aiosqlite's thread as daemon so it won't block process exit
98
+ db.daemon = True
99
+ db = await db
72
100
  # Fast WAL mode with minimal sync
73
101
  await db.execute("PRAGMA journal_mode=WAL")
74
102
  await db.execute("PRAGMA synchronous=NORMAL")
@@ -83,6 +111,7 @@ class AsyncDB(Generic[SchemaType]):
83
111
  await self._init_schema(db)
84
112
  AsyncDB._initialized_schemas.add(self._db_key)
85
113
 
114
+ self._connection = db
86
115
  return db
87
116
  except Exception as e:
88
117
  last_error = e
@@ -93,7 +122,22 @@ class AsyncDB(Generic[SchemaType]):
93
122
  continue
94
123
  raise
95
124
  raise last_error
96
-
125
+
126
+ async def __aenter__(self):
127
+ return self
128
+
129
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
130
+ await self.close()
131
+
132
+ async def close(self) -> None:
133
+ """Explicitly close the persistent connection."""
134
+ if self._connection is not None:
135
+ try:
136
+ await self._connection.close()
137
+ except Exception:
138
+ pass
139
+ self._connection = None
140
+
97
141
  async def _init_schema(self, db: aiosqlite.Connection) -> None:
98
142
  """Generate schema from dataclass structure with support for field additions."""
99
143
  logger.debug(f"Initializing schema for {self.schema_class.__name__} in table {self.table_name}")
@@ -146,7 +190,7 @@ class AsyncDB(Generic[SchemaType]):
146
190
  constraints.append("PRIMARY KEY")
147
191
  if f.metadata.get('unique'):
148
192
  constraints.append("UNIQUE")
149
- if not f.default and not f.default_factory and f.metadata.get('required', True):
193
+ if f.default is dataclasses.MISSING and f.default_factory is dataclasses.MISSING and f.metadata.get('required', True):
150
194
  constraints.append("NOT NULL")
151
195
 
152
196
  field_def = f"{field_name} {sql_type} {' '.join(constraints)}"
@@ -193,16 +237,27 @@ class AsyncDB(Generic[SchemaType]):
193
237
 
194
238
  @contextlib.asynccontextmanager
195
239
  async def transaction(self):
196
- """Run operations in a transaction with reliable cleanup."""
197
- db = await self._get_connection()
240
+ """Run operations in a transaction with reliable cleanup and auto-reconnect."""
241
+ db = await self._ensure_connection()
198
242
  try:
199
243
  yield db
200
244
  await db.commit()
201
- except Exception:
202
- await db.rollback()
203
- raise
204
- finally:
205
- await db.close()
245
+ except Exception as e:
246
+ try:
247
+ await db.rollback()
248
+ except Exception:
249
+ pass
250
+ if "closed" in str(e).lower() or "no active connection" in str(e).lower():
251
+ self._connection = None
252
+ db = await self._ensure_connection()
253
+ try:
254
+ yield db
255
+ await db.commit()
256
+ except Exception:
257
+ await db.rollback()
258
+ raise
259
+ else:
260
+ raise
206
261
 
207
262
  # @lru_cache(maxsize=128)
208
263
  def _serialize_value(self, value: Any) -> Any:
@@ -232,12 +287,15 @@ class AsyncDB(Generic[SchemaType]):
232
287
  return value
233
288
  # If somehow stored as string, convert back
234
289
  if isinstance(value, str):
235
- import ast
236
290
  try:
237
291
  return ast.literal_eval(value)
238
- except:
292
+ except (ValueError, SyntaxError):
239
293
  return value.encode('utf-8')
240
294
 
295
+ # Handle bool fields - SQLite stores as INTEGER, need to convert back
296
+ if field_type is bool:
297
+ return bool(value)
298
+
241
299
  # Handle string fields - ensure phone numbers are strings
242
300
  if field_type is str or (hasattr(field_type, '__origin__') and field_type.__origin__ is Union and str in getattr(field_type, '__args__', ())):
243
301
  return str(value)
@@ -250,12 +308,12 @@ class AsyncDB(Generic[SchemaType]):
250
308
  # Handle Optional[EnumType] case
251
309
  args = getattr(field_type, '__args__', ())
252
310
  for arg in args:
253
- if arg is not type(None) and hasattr(arg, '__bases__') and enum.Enum in arg.__bases__:
311
+ if arg is not type(None) and isinstance(arg, type) and issubclass(arg, enum.Enum):
254
312
  try:
255
313
  return arg(value)
256
314
  except (ValueError, TypeError):
257
315
  pass
258
- elif hasattr(field_type, '__bases__') and enum.Enum in field_type.__bases__:
316
+ elif isinstance(field_type, type) and issubclass(field_type, enum.Enum):
259
317
  # Handle direct enum types
260
318
  try:
261
319
  return field_type(value)
@@ -281,6 +339,20 @@ class AsyncDB(Generic[SchemaType]):
281
339
  VALUES ({placeholders},?)
282
340
  """
283
341
 
342
+ def _prepare_item(self, item: SchemaType) -> Tuple[str, List[Any]]:
343
+ """Prepare an item for saving. Returns (sql, values)."""
344
+ data = asdict(item)
345
+ item_id = data.pop('id', None) or str(uuid.uuid4())
346
+ now = datetime.now()
347
+ if not data.get('created_at'):
348
+ data['created_at'] = now
349
+ data['updated_at'] = now
350
+ field_names = tuple(sorted(data.keys()))
351
+ sql = self._generate_save_sql(field_names)
352
+ values = [self._serialize_value(data[name]) for name in field_names]
353
+ values.append(item_id)
354
+ return sql, values
355
+
284
356
  async def save_batch(self, items: List[SchemaType], skip_errors: bool = True) -> int:
285
357
  """Save multiple items in a single transaction for better performance.
286
358
 
@@ -299,6 +371,7 @@ class AsyncDB(Generic[SchemaType]):
299
371
  max_retries = 3
300
372
  for attempt in range(max_retries):
301
373
  try:
374
+ saved_count = 0
302
375
  async with self._write_lock:
303
376
  async with self.transaction() as db:
304
377
  for item in items:
@@ -308,23 +381,7 @@ class AsyncDB(Generic[SchemaType]):
308
381
  raise TypeError(f"Expected {self.schema_class.__name__}, got {type(item).__name__}")
309
382
  continue
310
383
 
311
- # Extract and process data
312
- data = asdict(item)
313
- item_id = data.pop('id', None) or str(uuid.uuid4())
314
-
315
- # Ensure created_at and updated_at are set
316
- now = datetime.now()
317
- if not data.get('created_at'):
318
- data['created_at'] = now
319
- data['updated_at'] = now
320
-
321
- # Prepare SQL and values
322
- field_names = tuple(sorted(data.keys()))
323
- sql = self._generate_save_sql(field_names)
324
- values = [self._serialize_value(data[name]) for name in field_names]
325
- values.append(item_id)
326
-
327
- # Execute save
384
+ sql, values = self._prepare_item(item)
328
385
  await db.execute(sql, values)
329
386
  saved_count += 1
330
387
 
@@ -360,21 +417,7 @@ class AsyncDB(Generic[SchemaType]):
360
417
  return False
361
418
  raise TypeError(f"Expected {self.schema_class.__name__}, got {type(item).__name__}")
362
419
 
363
- # Extract and process data
364
- data = asdict(item)
365
- item_id = data.pop('id', None) or str(uuid.uuid4())
366
-
367
- # Ensure created_at and updated_at are set
368
- now = datetime.now()
369
- if not data.get('created_at'):
370
- data['created_at'] = now
371
- data['updated_at'] = now
372
-
373
- # Prepare SQL and values
374
- field_names = tuple(sorted(data.keys()))
375
- sql = self._generate_save_sql(field_names)
376
- values = [self._serialize_value(data[name]) for name in field_names]
377
- values.append(item_id)
420
+ sql, values = self._prepare_item(item)
378
421
 
379
422
  # Perform save with reliable transaction (retry on "database is locked")
380
423
  max_retries = 3
@@ -425,10 +468,6 @@ class AsyncDB(Generic[SchemaType]):
425
468
  values = []
426
469
 
427
470
  for key, value in filters.items():
428
- # Handle special values
429
- if value == 'now':
430
- value = datetime.now()
431
-
432
471
  # Parse field and operator
433
472
  parts = key.split('__', 1)
434
473
  field = parts[0]
@@ -451,22 +490,32 @@ class AsyncDB(Generic[SchemaType]):
451
490
 
452
491
  return f"WHERE {' AND '.join(conditions)}", values
453
492
 
454
- async def find(self, order_by=None, **filters) -> List[SchemaType]:
493
+ async def find(self, order_by=None, limit: int = None, offset: int = None, **filters) -> List[SchemaType]:
455
494
  """Query items with reliable connection handling."""
456
495
  where_clause, values = self._build_where_clause(filters)
457
-
496
+
458
497
  # Build query
459
498
  query = f"SELECT * FROM {self.table_name} {where_clause}"
460
-
499
+
461
500
  # Add ORDER BY clause if specified
462
501
  if order_by:
463
502
  order_fields = [order_by] if isinstance(order_by, str) else order_by
464
503
  order_clauses = [
465
- f"{field[1:]} DESC" if field.startswith('-') else f"{field} ASC"
504
+ f"{field[1:]} DESC" if field.startswith('-') else f"{field} ASC"
466
505
  for field in order_fields
467
506
  ]
468
507
  query += f" ORDER BY {', '.join(order_clauses)}"
469
-
508
+
509
+ # Add LIMIT/OFFSET (SQLite requires LIMIT before OFFSET)
510
+ if limit is not None:
511
+ query += " LIMIT ?"
512
+ values.append(limit)
513
+ elif offset is not None:
514
+ query += " LIMIT -1"
515
+ if offset is not None:
516
+ query += " OFFSET ?"
517
+ values.append(offset)
518
+
470
519
  # Execute query with reliable transaction
471
520
  async with self.transaction() as db:
472
521
  cursor = await db.execute(query, values)
@@ -504,4 +553,33 @@ class AsyncDB(Generic[SchemaType]):
504
553
  async with self._write_lock:
505
554
  async with self.transaction() as db:
506
555
  cursor = await db.execute(f"DELETE FROM {self.table_name} WHERE id = ?", (id,))
556
+ return cursor.rowcount > 0
557
+
558
+ async def exists(self, **filters) -> bool:
559
+ """Check if any record matches the filters without fetching data."""
560
+ return await self.count(**filters) > 0
561
+
562
+ async def delete_many(self, **filters) -> int:
563
+ """Delete all items matching filters. Returns count of deleted rows."""
564
+ if not filters:
565
+ raise ValueError("delete_many() requires at least one filter to prevent accidental full table delete")
566
+ where_clause, values = self._build_where_clause(filters)
567
+ async with self._write_lock:
568
+ async with self.transaction() as db:
569
+ cursor = await db.execute(f"DELETE FROM {self.table_name} {where_clause}", values)
570
+ return cursor.rowcount
571
+
572
+ async def update_fields(self, id: str, **fields) -> bool:
573
+ """Update specific fields on a record by ID without fetching the full record."""
574
+ if not fields:
575
+ return False
576
+ fields['updated_at'] = datetime.now()
577
+ set_clause = ', '.join(f"{_validate_identifier(k)} = ?" for k in fields)
578
+ values = [self._serialize_value(v) for v in fields.values()]
579
+ values.append(id)
580
+ async with self._write_lock:
581
+ async with self.transaction() as db:
582
+ cursor = await db.execute(
583
+ f"UPDATE {self.table_name} SET {set_clause} WHERE id = ?", values
584
+ )
507
585
  return cursor.rowcount > 0
@@ -33,6 +33,7 @@ async def test_concurrent_reads(temp_db):
33
33
 
34
34
  # All reads should succeed and return same data
35
35
  assert all(len(r) == 10 for r in results)
36
+ await db.close()
36
37
  print(f"✓ 100 concurrent reads completed successfully")
37
38
 
38
39
 
@@ -53,6 +54,7 @@ async def test_concurrent_writes(temp_db):
53
54
  # Verify all items were saved
54
55
  items = await db.find()
55
56
  assert len(items) == 50
57
+ await db.close()
56
58
  print(f"✓ 50 concurrent writes completed successfully")
57
59
 
58
60
 
@@ -92,6 +94,7 @@ async def test_concurrent_mixed_operations(temp_db):
92
94
  # Verify final state
93
95
  items = await db.find()
94
96
  assert len(items) == 55 # 5 seed + 50 writes
97
+ await db.close()
95
98
  print(f"✓ 200 concurrent mixed operations completed successfully")
96
99
 
97
100
 
@@ -124,6 +127,7 @@ async def test_stress_concurrent_access(temp_db):
124
127
 
125
128
  # Should have very few or no failures with retry logic
126
129
  assert len(exceptions) == 0, f"{len(exceptions)} operations failed"
130
+ await db.close()
127
131
  print(f"✓ 500 concurrent stress operations completed successfully")
128
132
 
129
133
 
@@ -0,0 +1,443 @@
1
+ """
2
+ Tests for specific fixes applied to AsyncDB.
3
+ Covers: limit/offset, save_batch reset, enum deserialization,
4
+ NOT NULL schema logic, identifier validation, connection pooling.
5
+ """
6
+ import asyncio
7
+ import enum
8
+ import tempfile
9
+ from dataclasses import dataclass, field
10
+ from pathlib import Path
11
+ from typing import Optional
12
+
13
+ from esuls.db_cli import AsyncDB, BaseModel, _validate_identifier
14
+
15
+
16
+ # --- Test models ---
17
+
18
+ class Color(enum.Enum):
19
+ RED = "red"
20
+ GREEN = "green"
21
+ BLUE = "blue"
22
+
23
+
24
+ class Priority(enum.IntEnum):
25
+ LOW = 1
26
+ MEDIUM = 2
27
+ HIGH = 3
28
+
29
+
30
+ class Status(str, enum.Enum):
31
+ ACTIVE = "active"
32
+ INACTIVE = "inactive"
33
+
34
+
35
+ @dataclass
36
+ class EnumItem(BaseModel):
37
+ color: Optional[Color] = None
38
+ priority: Optional[Priority] = None
39
+ status: Optional[Status] = None
40
+
41
+
42
+ @dataclass
43
+ class DefaultsItem(BaseModel):
44
+ name: str = ""
45
+ count: int = 0
46
+ flag: bool = False
47
+ score: float = 0.0
48
+ required_field: str = field(default="", metadata={"required": False})
49
+
50
+
51
+ @dataclass
52
+ class TestItem(BaseModel):
53
+ name: str = ""
54
+ value: int = 0
55
+
56
+
57
+ # --- Tests ---
58
+
59
+ async def test_find_limit_offset(temp_db):
60
+ """Test limit and offset parameters on find()."""
61
+ db = AsyncDB(temp_db, "items", TestItem)
62
+
63
+ try:
64
+ # Insert 20 items with predictable ordering
65
+ for i in range(20):
66
+ await db.save(TestItem(name=f"item_{i:02d}", value=i))
67
+
68
+ # Test limit
69
+ results = await db.find(order_by="value", limit=5)
70
+ assert len(results) == 5
71
+ assert results[0].value == 0
72
+ assert results[4].value == 4
73
+
74
+ # Test offset
75
+ results = await db.find(order_by="value", limit=5, offset=10)
76
+ assert len(results) == 5
77
+ assert results[0].value == 10
78
+ assert results[4].value == 14
79
+
80
+ # Test offset beyond data
81
+ results = await db.find(order_by="value", limit=5, offset=18)
82
+ assert len(results) == 2
83
+ assert results[0].value == 18
84
+
85
+ # Test limit without offset
86
+ results = await db.find(order_by="value", limit=3)
87
+ assert len(results) == 3
88
+
89
+ # Test offset without limit (should return from offset to end)
90
+ results = await db.find(order_by="value", offset=17)
91
+ assert len(results) == 3
92
+
93
+ # Test no limit/offset (all results)
94
+ results = await db.find()
95
+ assert len(results) == 20
96
+
97
+ print("✓ find() limit/offset works correctly")
98
+ finally:
99
+ await db.close()
100
+
101
+
102
+ async def test_save_batch_count_reset(temp_db):
103
+ """Test that saved_count resets properly on retry in save_batch()."""
104
+ db = AsyncDB(temp_db, "items", TestItem)
105
+
106
+ try:
107
+ items = [TestItem(name=f"batch_{i}", value=i) for i in range(10)]
108
+ count = await db.save_batch(items)
109
+ assert count == 10, f"Expected 10, got {count}"
110
+
111
+ # Save again (upsert) — count should still be exactly 10
112
+ count2 = await db.save_batch(items)
113
+ assert count2 == 10, f"Expected 10 on re-save, got {count2}"
114
+
115
+ print("✓ save_batch() count is accurate")
116
+ finally:
117
+ await db.close()
118
+
119
+
120
+ async def test_enum_deserialization_subclasses(temp_db):
121
+ """Test enum deserialization with IntEnum, StrEnum, and regular Enum."""
122
+ db = AsyncDB(temp_db, "enums", EnumItem)
123
+
124
+ try:
125
+ # Save with all enum types
126
+ item = EnumItem(
127
+ color=Color.GREEN,
128
+ priority=Priority.HIGH,
129
+ status=Status.ACTIVE,
130
+ )
131
+ await db.save(item)
132
+
133
+ # Read back
134
+ results = await db.find()
135
+ assert len(results) == 1
136
+ loaded = results[0]
137
+
138
+ assert loaded.color == Color.GREEN, f"Expected Color.GREEN, got {loaded.color!r}"
139
+ assert loaded.priority == Priority.HIGH, f"Expected Priority.HIGH, got {loaded.priority!r}"
140
+ assert loaded.status == Status.ACTIVE, f"Expected Status.ACTIVE, got {loaded.status!r}"
141
+
142
+ # Test None values
143
+ item2 = EnumItem()
144
+ await db.save(item2)
145
+ results = await db.find()
146
+ none_items = [r for r in results if r.color is None]
147
+ assert len(none_items) == 1
148
+ assert none_items[0].priority is None
149
+ assert none_items[0].status is None
150
+
151
+ print("✓ Enum deserialization works for Enum, IntEnum, StrEnum")
152
+ finally:
153
+ await db.close()
154
+
155
+
156
+ async def test_not_null_with_falsy_defaults(temp_db):
157
+ """Test that fields with falsy defaults (0, '', False) don't get NOT NULL."""
158
+ db = AsyncDB(temp_db, "defaults_test", DefaultsItem)
159
+
160
+ try:
161
+ # If the schema was created correctly, saving an item with default values should work
162
+ item = DefaultsItem()
163
+ result = await db.save(item)
164
+ assert result is True
165
+
166
+ # Verify roundtrip
167
+ items = await db.find()
168
+ assert len(items) == 1
169
+ loaded = items[0]
170
+ assert loaded.name == ""
171
+ assert loaded.count == 0
172
+ assert loaded.flag is False
173
+
174
+ print("✓ NOT NULL logic correctly handles falsy defaults")
175
+ finally:
176
+ await db.close()
177
+
178
+
179
+ async def test_identifier_validation():
180
+ """Test that invalid SQL identifiers are rejected."""
181
+ # Valid identifiers
182
+ assert _validate_identifier("items") == "items"
183
+ assert _validate_identifier("my_table") == "my_table"
184
+ assert _validate_identifier("_private") == "_private"
185
+ assert _validate_identifier("Table123") == "Table123"
186
+
187
+ # Invalid identifiers
188
+ invalid_names = [
189
+ "Robert'; DROP TABLE students;--",
190
+ "my table",
191
+ "123abc",
192
+ "my-table",
193
+ "",
194
+ "table.name",
195
+ "col(umn)",
196
+ ]
197
+ for name in invalid_names:
198
+ try:
199
+ _validate_identifier(name)
200
+ raise AssertionError(f"Should have rejected: {name!r}")
201
+ except ValueError:
202
+ pass # Expected
203
+
204
+ print("✓ Identifier validation rejects injection attempts")
205
+
206
+
207
+ async def test_identifier_validation_in_constructor(temp_db):
208
+ """Test that AsyncDB constructor rejects invalid table names."""
209
+ try:
210
+ AsyncDB(temp_db, "valid_table; DROP TABLE x", TestItem)
211
+ raise AssertionError("Should have rejected invalid table name")
212
+ except ValueError:
213
+ pass
214
+
215
+ print("✓ Constructor rejects invalid table names")
216
+
217
+
218
+ async def test_connection_pooling(temp_db):
219
+ """Test that persistent connection is reused across operations."""
220
+ db = AsyncDB(temp_db, "items", TestItem)
221
+
222
+ # First operation creates the connection
223
+ await db.save(TestItem(name="first", value=1))
224
+ conn1 = db._connection
225
+ assert conn1 is not None
226
+
227
+ # Subsequent operations reuse it
228
+ await db.find()
229
+ conn2 = db._connection
230
+ assert conn2 is conn1, "Connection should be reused"
231
+
232
+ await db.count()
233
+ conn3 = db._connection
234
+ assert conn3 is conn1, "Connection should still be reused"
235
+
236
+ await db.save(TestItem(name="second", value=2))
237
+ conn4 = db._connection
238
+ assert conn4 is conn1, "Connection should persist across saves"
239
+
240
+ # Explicit close
241
+ await db.close()
242
+ assert db._connection is None
243
+
244
+ # Next operation creates a new connection
245
+ await db.find()
246
+ conn5 = db._connection
247
+ assert conn5 is not None
248
+ assert conn5 is not conn1, "Should be a new connection after close"
249
+
250
+ await db.close()
251
+ print("✓ Connection pooling reuses connections correctly")
252
+
253
+
254
+ async def test_close_idempotent(temp_db):
255
+ """Test that close() can be called multiple times safely."""
256
+ db = AsyncDB(temp_db, "items", TestItem)
257
+ await db.save(TestItem(name="test", value=1))
258
+
259
+ await db.close()
260
+ await db.close() # Should not raise
261
+ await db.close() # Should not raise
262
+
263
+ # Should still work after close
264
+ items = await db.find()
265
+ assert len(items) == 1
266
+
267
+ await db.close()
268
+ print("✓ close() is idempotent")
269
+
270
+
271
+ async def test_exists(temp_db):
272
+ """Test exists() returns bool without fetching full records."""
273
+ db = AsyncDB(temp_db, "items", TestItem)
274
+
275
+ try:
276
+ assert await db.exists(name="nope") is False
277
+
278
+ await db.save(TestItem(name="hello", value=1))
279
+ assert await db.exists(name="hello") is True
280
+ assert await db.exists(name="nope") is False
281
+ assert await db.exists(value=1) is True
282
+ assert await db.exists(value=999) is False
283
+
284
+ print("✓ exists() works correctly")
285
+ finally:
286
+ await db.close()
287
+
288
+
289
+ async def test_delete_many(temp_db):
290
+ """Test delete_many() deletes matching records and returns count."""
291
+ db = AsyncDB(temp_db, "items", TestItem)
292
+
293
+ try:
294
+ for i in range(10):
295
+ await db.save(TestItem(name="group_a" if i < 6 else "group_b", value=i))
296
+
297
+ # Delete group_a
298
+ deleted = await db.delete_many(name="group_a")
299
+ assert deleted == 6, f"Expected 6 deleted, got {deleted}"
300
+
301
+ remaining = await db.count()
302
+ assert remaining == 4, f"Expected 4 remaining, got {remaining}"
303
+
304
+ # Delete non-existent
305
+ deleted = await db.delete_many(name="group_c")
306
+ assert deleted == 0
307
+
308
+ # Must raise on empty filters
309
+ try:
310
+ await db.delete_many()
311
+ raise AssertionError("Should have raised ValueError")
312
+ except ValueError:
313
+ pass
314
+
315
+ print("✓ delete_many() works correctly")
316
+ finally:
317
+ await db.close()
318
+
319
+
320
+ async def test_update_fields(temp_db):
321
+ """Test update_fields() updates specific fields without full fetch."""
322
+ db = AsyncDB(temp_db, "items", TestItem)
323
+
324
+ try:
325
+ item = TestItem(name="original", value=10)
326
+ await db.save(item)
327
+
328
+ # Update single field
329
+ result = await db.update_fields(item.id, name="updated")
330
+ assert result is True
331
+
332
+ loaded = await db.get_by_id(item.id)
333
+ assert loaded.name == "updated"
334
+ assert loaded.value == 10 # unchanged
335
+
336
+ # Update multiple fields
337
+ result = await db.update_fields(item.id, name="final", value=99)
338
+ assert result is True
339
+ loaded = await db.get_by_id(item.id)
340
+ assert loaded.name == "final"
341
+ assert loaded.value == 99
342
+
343
+ # Non-existent ID
344
+ result = await db.update_fields("nonexistent-id", name="x")
345
+ assert result is False
346
+
347
+ # Empty fields
348
+ result = await db.update_fields(item.id)
349
+ assert result is False
350
+
351
+ print("✓ update_fields() works correctly")
352
+ finally:
353
+ await db.close()
354
+
355
+
356
+ async def test_context_manager(temp_db):
357
+ """Test async context manager for automatic cleanup."""
358
+ async with AsyncDB(temp_db, "items", TestItem) as db:
359
+ await db.save(TestItem(name="ctx", value=42))
360
+ items = await db.find()
361
+ assert len(items) == 1
362
+ assert items[0].name == "ctx"
363
+
364
+ # Connection should be closed after exiting context
365
+ assert db._connection is None
366
+
367
+ # Should still work when reopened
368
+ async with AsyncDB(temp_db, "items", TestItem) as db:
369
+ items = await db.find()
370
+ assert len(items) == 1
371
+
372
+ print("✓ Context manager works correctly")
373
+
374
+
375
+ async def test_prepare_item_dedup(temp_db):
376
+ """Test that _prepare_item produces correct SQL and values for save."""
377
+ db = AsyncDB(temp_db, "items", TestItem)
378
+
379
+ try:
380
+ item = TestItem(name="test", value=5)
381
+ sql, values = db._prepare_item(item)
382
+ assert "INSERT OR REPLACE" in sql
383
+ assert "test" in values
384
+ assert 5 in values
385
+
386
+ # Verify save still works end-to-end
387
+ await db.save(item)
388
+ loaded = await db.get_by_id(item.id)
389
+ assert loaded.name == "test"
390
+ assert loaded.value == 5
391
+
392
+ # Verify batch save still works
393
+ items = [TestItem(name=f"batch_{i}", value=i) for i in range(5)]
394
+ count = await db.save_batch(items)
395
+ assert count == 5
396
+
397
+ print("✓ _prepare_item() deduplication works correctly")
398
+ finally:
399
+ await db.close()
400
+
401
+
402
+ if __name__ == "__main__":
403
+ async def run_all_tests():
404
+ with tempfile.TemporaryDirectory() as tmpdir:
405
+ print("\n" + "=" * 60)
406
+ print("ASYNCDB FIX VERIFICATION TESTS")
407
+ print("=" * 60)
408
+
409
+ test_num = 0
410
+
411
+ async def run_test(name, coro):
412
+ nonlocal test_num
413
+ test_num += 1
414
+ # Use a fresh db for each test
415
+ db_path = Path(tmpdir) / f"test_fix_{test_num}.db"
416
+ print(f"\n[Test {test_num}] {name}...")
417
+ await coro(db_path)
418
+
419
+ async def run_test_no_db(name, coro):
420
+ nonlocal test_num
421
+ test_num += 1
422
+ print(f"\n[Test {test_num}] {name}...")
423
+ await coro()
424
+
425
+ await run_test("Limit/offset in find()", test_find_limit_offset)
426
+ await run_test("save_batch count reset", test_save_batch_count_reset)
427
+ await run_test("Enum subclass deserialization", test_enum_deserialization_subclasses)
428
+ await run_test("NOT NULL with falsy defaults", test_not_null_with_falsy_defaults)
429
+ await run_test_no_db("Identifier validation", test_identifier_validation)
430
+ await run_test("Constructor identifier validation", test_identifier_validation_in_constructor)
431
+ await run_test("Connection pooling", test_connection_pooling)
432
+ await run_test("Close idempotent", test_close_idempotent)
433
+ await run_test("exists()", test_exists)
434
+ await run_test("delete_many()", test_delete_many)
435
+ await run_test("update_fields()", test_update_fields)
436
+ await run_test("Context manager", test_context_manager)
437
+ await run_test("_prepare_item dedup", test_prepare_item_dedup)
438
+
439
+ print("\n" + "=" * 60)
440
+ print("ALL TESTS PASSED!")
441
+ print("=" * 60)
442
+
443
+ asyncio.run(run_all_tests())
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: esuls
3
- Version: 0.1.15
3
+ Version: 0.1.16
4
4
  Summary: Utility library for async database operations, HTTP requests, and parallel execution
5
5
  Author-email: IperGiove <ipergiove@gmail.com>
6
6
  License: MIT
@@ -11,4 +11,5 @@ src/esuls.egg-info/SOURCES.txt
11
11
  src/esuls.egg-info/dependency_links.txt
12
12
  src/esuls.egg-info/requires.txt
13
13
  src/esuls.egg-info/top_level.txt
14
- src/esuls/tests/test_db_concurrent.py
14
+ src/esuls/tests/test_db_concurrent.py
15
+ src/esuls/tests/test_db_fixes.py
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes