async-easy-model 0.1.12__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- async_easy_model/auto_relationships.py +12 -0
- async_easy_model/model.py +635 -84
- async_easy_model-0.2.1.dist-info/METADATA +343 -0
- async_easy_model-0.2.1.dist-info/RECORD +10 -0
- async_easy_model-0.1.12.dist-info/METADATA +0 -533
- async_easy_model-0.1.12.dist-info/RECORD +0 -10
- {async_easy_model-0.1.12.dist-info → async_easy_model-0.2.1.dist-info}/LICENSE +0 -0
- {async_easy_model-0.1.12.dist-info → async_easy_model-0.2.1.dist-info}/WHEEL +0 -0
- {async_easy_model-0.1.12.dist-info → async_easy_model-0.2.1.dist-info}/top_level.txt +0 -0
async_easy_model/model.py
CHANGED
@@ -212,7 +212,7 @@ class EasyModel(SQLModel):
|
|
212
212
|
@classmethod
|
213
213
|
async def all(
|
214
214
|
cls: Type[T],
|
215
|
-
include_relationships: bool =
|
215
|
+
include_relationships: bool = True,
|
216
216
|
order_by: Optional[Union[str, List[str]]] = None
|
217
217
|
) -> List[T]:
|
218
218
|
"""
|
@@ -243,7 +243,7 @@ class EasyModel(SQLModel):
|
|
243
243
|
@classmethod
|
244
244
|
async def first(
|
245
245
|
cls: Type[T],
|
246
|
-
include_relationships: bool =
|
246
|
+
include_relationships: bool = True,
|
247
247
|
order_by: Optional[Union[str, List[str]]] = None
|
248
248
|
) -> Optional[T]:
|
249
249
|
"""
|
@@ -275,7 +275,7 @@ class EasyModel(SQLModel):
|
|
275
275
|
async def limit(
|
276
276
|
cls: Type[T],
|
277
277
|
count: int,
|
278
|
-
include_relationships: bool =
|
278
|
+
include_relationships: bool = True,
|
279
279
|
order_by: Optional[Union[str, List[str]]] = None
|
280
280
|
) -> List[T]:
|
281
281
|
"""
|
@@ -305,7 +305,7 @@ class EasyModel(SQLModel):
|
|
305
305
|
return result.scalars().all()
|
306
306
|
|
307
307
|
@classmethod
|
308
|
-
async def get_by_id(cls: Type[T], id: int, include_relationships: bool =
|
308
|
+
async def get_by_id(cls: Type[T], id: int, include_relationships: bool = True) -> Optional[T]:
|
309
309
|
"""
|
310
310
|
Retrieve a record by its primary key.
|
311
311
|
|
@@ -331,7 +331,7 @@ class EasyModel(SQLModel):
|
|
331
331
|
async def get_by_attribute(
|
332
332
|
cls: Type[T],
|
333
333
|
all: bool = False,
|
334
|
-
include_relationships: bool =
|
334
|
+
include_relationships: bool = True,
|
335
335
|
order_by: Optional[Union[str, List[str]]] = None,
|
336
336
|
**kwargs
|
337
337
|
) -> Union[Optional[T], List[T]]:
|
@@ -391,75 +391,475 @@ class EasyModel(SQLModel):
|
|
391
391
|
return result.scalars().first()
|
392
392
|
|
393
393
|
@classmethod
|
394
|
-
async def insert(cls: Type[T], data: Dict[str, Any], include_relationships: bool =
|
394
|
+
async def insert(cls: Type[T], data: Union[Dict[str, Any], List[Dict[str, Any]]], include_relationships: bool = True) -> Union[T, List[T]]:
|
395
395
|
"""
|
396
|
-
Insert
|
396
|
+
Insert one or more records.
|
397
397
|
|
398
398
|
Args:
|
399
|
-
data: Dictionary of field values
|
400
|
-
include_relationships: If True, return the instance with relationships loaded
|
399
|
+
data: Dictionary of field values or a list of dictionaries for multiple records
|
400
|
+
include_relationships: If True, return the instance(s) with relationships loaded
|
401
401
|
|
402
402
|
Returns:
|
403
|
-
The created model instance
|
403
|
+
The created model instance(s)
|
404
404
|
"""
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
405
|
+
# Handle list of records
|
406
|
+
if isinstance(data, list):
|
407
|
+
objects = []
|
408
|
+
async with cls.get_session() as session:
|
409
|
+
for item in data:
|
410
|
+
try:
|
411
|
+
# Process relationships first
|
412
|
+
processed_item = await cls._process_relationships_for_insert(session, item)
|
413
|
+
|
414
|
+
# Check if a record with unique constraints already exists
|
415
|
+
unique_fields = cls._get_unique_fields()
|
416
|
+
existing_obj = None
|
417
|
+
|
418
|
+
if unique_fields:
|
419
|
+
unique_criteria = {field: processed_item[field]
|
420
|
+
for field in unique_fields
|
421
|
+
if field in processed_item}
|
422
|
+
|
423
|
+
if unique_criteria:
|
424
|
+
# Try to find existing record with these unique values
|
425
|
+
statement = select(cls)
|
426
|
+
for field, value in unique_criteria.items():
|
427
|
+
statement = statement.where(getattr(cls, field) == value)
|
428
|
+
result = await session.execute(statement)
|
429
|
+
existing_obj = result.scalars().first()
|
430
|
+
|
431
|
+
if existing_obj:
|
432
|
+
# Update existing object with new values
|
433
|
+
for key, value in processed_item.items():
|
434
|
+
if key != 'id': # Don't update ID
|
435
|
+
setattr(existing_obj, key, value)
|
436
|
+
objects.append(existing_obj)
|
437
|
+
else:
|
438
|
+
# Create new object
|
439
|
+
obj = cls(**processed_item)
|
440
|
+
session.add(obj)
|
441
|
+
objects.append(obj)
|
442
|
+
except Exception as e:
|
443
|
+
logging.error(f"Error inserting record: {e}")
|
444
|
+
await session.rollback()
|
445
|
+
raise
|
446
|
+
|
447
|
+
try:
|
448
|
+
await session.flush()
|
449
|
+
await session.commit()
|
450
|
+
|
451
|
+
# Refresh with relationships if requested
|
452
|
+
if include_relationships:
|
453
|
+
for obj in objects:
|
454
|
+
await session.refresh(obj)
|
455
|
+
except Exception as e:
|
456
|
+
logging.error(f"Error committing transaction: {e}")
|
457
|
+
await session.rollback()
|
458
|
+
raise
|
459
|
+
|
460
|
+
return objects
|
461
|
+
else:
|
462
|
+
# Single record case
|
463
|
+
async with cls.get_session() as session:
|
464
|
+
try:
|
465
|
+
# Process relationships first
|
466
|
+
processed_data = await cls._process_relationships_for_insert(session, data)
|
467
|
+
|
468
|
+
# Check if a record with unique constraints already exists
|
469
|
+
unique_fields = cls._get_unique_fields()
|
470
|
+
existing_obj = None
|
471
|
+
|
472
|
+
if unique_fields:
|
473
|
+
unique_criteria = {field: processed_data[field]
|
474
|
+
for field in unique_fields
|
475
|
+
if field in processed_data}
|
476
|
+
|
477
|
+
if unique_criteria:
|
478
|
+
# Try to find existing record with these unique values
|
479
|
+
statement = select(cls)
|
480
|
+
for field, value in unique_criteria.items():
|
481
|
+
statement = statement.where(getattr(cls, field) == value)
|
482
|
+
result = await session.execute(statement)
|
483
|
+
existing_obj = result.scalars().first()
|
484
|
+
|
485
|
+
if existing_obj:
|
486
|
+
# Update existing object with new values
|
487
|
+
for key, value in processed_data.items():
|
488
|
+
if key != 'id': # Don't update ID
|
489
|
+
setattr(existing_obj, key, value)
|
490
|
+
obj = existing_obj
|
491
|
+
else:
|
492
|
+
# Create new object
|
493
|
+
obj = cls(**processed_data)
|
494
|
+
session.add(obj)
|
495
|
+
|
496
|
+
await session.flush() # Flush to get the ID
|
497
|
+
await session.commit()
|
498
|
+
|
499
|
+
if include_relationships:
|
500
|
+
# Refresh with relationships
|
501
|
+
statement = select(cls).where(cls.id == obj.id)
|
502
|
+
for rel_name in cls._get_auto_relationship_fields():
|
503
|
+
statement = statement.options(selectinload(getattr(cls, rel_name)))
|
504
|
+
result = await session.execute(statement)
|
505
|
+
return result.scalars().first()
|
506
|
+
else:
|
507
|
+
await session.refresh(obj)
|
508
|
+
return obj
|
509
|
+
except Exception as e:
|
510
|
+
logging.error(f"Error inserting record: {e}")
|
511
|
+
await session.rollback()
|
512
|
+
raise
|
513
|
+
|
514
|
+
@classmethod
|
515
|
+
def _get_unique_fields(cls) -> List[str]:
|
516
|
+
"""
|
517
|
+
Get all fields with unique=True constraint
|
518
|
+
|
519
|
+
Returns:
|
520
|
+
List of field names that have unique constraints
|
521
|
+
"""
|
522
|
+
unique_fields = []
|
523
|
+
for name, field in cls.__fields__.items():
|
524
|
+
if name != 'id' and hasattr(field, 'field_info') and field.field_info.extra.get('unique', False):
|
525
|
+
unique_fields.append(name)
|
526
|
+
return unique_fields
|
420
527
|
|
421
528
|
@classmethod
|
422
|
-
async def
|
423
|
-
cls: Type[T],
|
424
|
-
id: int,
|
425
|
-
data: Dict[str, Any],
|
426
|
-
include_relationships: bool = False
|
427
|
-
) -> Optional[T]:
|
529
|
+
async def _process_relationships_for_insert(cls: Type[T], session: AsyncSession, data: Dict[str, Any]) -> Dict[str, Any]:
|
428
530
|
"""
|
429
|
-
|
531
|
+
Process relationships in input data for insertion.
|
532
|
+
|
533
|
+
This method handles nested objects in the input data, such as:
|
534
|
+
cart = await ShoppingCart.insert({
|
535
|
+
"user": {"username": "john", "email": "john@example.com"},
|
536
|
+
"product": {"name": "Product X", "price": 19.99},
|
537
|
+
"quantity": 2
|
538
|
+
})
|
539
|
+
|
540
|
+
For each nested object:
|
541
|
+
1. Find the target model class
|
542
|
+
2. Check if an object with the same unique fields already exists
|
543
|
+
3. If found, update existing object with non-unique fields
|
544
|
+
4. If not found, create a new object
|
545
|
+
5. Set the foreign key ID in the result data
|
430
546
|
|
431
547
|
Args:
|
432
|
-
|
433
|
-
data:
|
434
|
-
include_relationships: If True, return the instance with relationships loaded
|
548
|
+
session: The database session to use
|
549
|
+
data: Input data dictionary that may contain nested objects
|
435
550
|
|
436
551
|
Returns:
|
437
|
-
|
552
|
+
Processed data dictionary with nested objects replaced by their foreign key IDs
|
438
553
|
"""
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
554
|
+
import copy
|
555
|
+
result = copy.deepcopy(data)
|
556
|
+
|
557
|
+
# Get all relationship fields for this model
|
558
|
+
relationship_fields = cls._get_auto_relationship_fields()
|
559
|
+
|
560
|
+
# Get foreign key fields
|
561
|
+
foreign_key_fields = []
|
562
|
+
for field_name, field_info in cls.__fields__.items():
|
563
|
+
if field_name.endswith("_id") and hasattr(field_info, "field_info"):
|
564
|
+
if field_info.field_info.extra.get("foreign_key"):
|
565
|
+
foreign_key_fields.append(field_name)
|
566
|
+
|
567
|
+
# Handle nested relationship objects
|
568
|
+
for key, value in data.items():
|
569
|
+
# Skip if the value is not a dictionary
|
570
|
+
if not isinstance(value, dict):
|
571
|
+
continue
|
572
|
+
|
573
|
+
# Check if this is a relationship field (either by name or derived from foreign key)
|
574
|
+
is_rel_field = key in relationship_fields
|
575
|
+
related_key = f"{key}_id"
|
576
|
+
is_derived_rel = related_key in foreign_key_fields
|
445
577
|
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
578
|
+
# If it's a relationship field or derived from a foreign key
|
579
|
+
if is_rel_field or is_derived_rel or related_key in cls.__fields__:
|
580
|
+
# Find the related model class
|
581
|
+
related_model = None
|
582
|
+
|
583
|
+
# Try to get the related model from the attribute
|
584
|
+
if hasattr(cls, key) and hasattr(getattr(cls, key), 'property'):
|
585
|
+
# Get from relationship attribute
|
586
|
+
rel_attr = getattr(cls, key)
|
587
|
+
related_model = rel_attr.property.mapper.class_
|
588
|
+
else:
|
589
|
+
# Try to find it from foreign key definition
|
590
|
+
fk_definition = None
|
591
|
+
for field_name, field_info in cls.__fields__.items():
|
592
|
+
if field_name == related_key and hasattr(field_info, "field_info"):
|
593
|
+
fk_definition = field_info.field_info.extra.get("foreign_key")
|
594
|
+
break
|
595
|
+
|
596
|
+
if fk_definition:
|
597
|
+
# Parse foreign key definition (e.g. "users.id")
|
598
|
+
target_table, _ = fk_definition.split(".")
|
599
|
+
# Try to find the target model
|
600
|
+
from async_easy_model.auto_relationships import get_model_by_table_name, singularize_name
|
601
|
+
related_model = get_model_by_table_name(target_table)
|
602
|
+
if not related_model:
|
603
|
+
# Try with the singular form
|
604
|
+
singular_table = singularize_name(target_table)
|
605
|
+
related_model = get_model_by_table_name(singular_table)
|
606
|
+
else:
|
607
|
+
# Try to infer from field name (e.g., "user_id" -> Users)
|
608
|
+
base_name = related_key[:-3] # Remove "_id"
|
609
|
+
from async_easy_model.auto_relationships import get_model_by_table_name, singularize_name, pluralize_name
|
610
|
+
|
611
|
+
# Try singular and plural forms
|
612
|
+
related_model = get_model_by_table_name(base_name)
|
613
|
+
if not related_model:
|
614
|
+
plural_table = pluralize_name(base_name)
|
615
|
+
related_model = get_model_by_table_name(plural_table)
|
616
|
+
if not related_model:
|
617
|
+
singular_table = singularize_name(base_name)
|
618
|
+
related_model = get_model_by_table_name(singular_table)
|
619
|
+
|
620
|
+
if not related_model:
|
621
|
+
logging.warning(f"Could not find related model for {key} in {cls.__name__}")
|
622
|
+
continue
|
623
|
+
|
624
|
+
# Look for unique fields in the related model to use for searching
|
625
|
+
unique_fields = []
|
626
|
+
for field_name, field_info in related_model.__fields__.items():
|
627
|
+
if (hasattr(field_info, "field_info") and
|
628
|
+
field_info.field_info.extra.get('unique', False)):
|
629
|
+
unique_fields.append(field_name)
|
630
|
+
|
631
|
+
# Create a search dictionary using unique fields
|
632
|
+
search_dict = {}
|
633
|
+
for field in unique_fields:
|
634
|
+
if field in value and value[field] is not None:
|
635
|
+
search_dict[field] = value[field]
|
636
|
+
|
637
|
+
# If no unique fields found but ID is provided, use it
|
638
|
+
if not search_dict and 'id' in value and value['id']:
|
639
|
+
search_dict = {'id': value['id']}
|
640
|
+
|
641
|
+
# Special case for products without uniqueness constraints
|
642
|
+
if not search_dict and related_model.__tablename__ == 'products' and 'name' in value:
|
643
|
+
search_dict = {'name': value['name']}
|
644
|
+
|
645
|
+
# Try to find an existing record
|
646
|
+
related_obj = None
|
647
|
+
if search_dict:
|
648
|
+
logging.info(f"Searching for existing {related_model.__name__} with {search_dict}")
|
649
|
+
|
650
|
+
try:
|
651
|
+
# Create a more appropriate search query based on unique fields
|
652
|
+
existing_stmt = select(related_model)
|
653
|
+
for field, field_value in search_dict.items():
|
654
|
+
existing_stmt = existing_stmt.where(getattr(related_model, field) == field_value)
|
655
|
+
|
656
|
+
existing_result = await session.execute(existing_stmt)
|
657
|
+
related_obj = existing_result.scalars().first()
|
658
|
+
|
659
|
+
if related_obj:
|
660
|
+
logging.info(f"Found existing {related_model.__name__} with ID: {related_obj.id}")
|
661
|
+
except Exception as e:
|
662
|
+
logging.error(f"Error finding existing record: {e}")
|
663
|
+
|
664
|
+
if related_obj:
|
665
|
+
# Update the existing record with any non-unique field values
|
666
|
+
for attr, attr_val in value.items():
|
667
|
+
# Skip ID field
|
668
|
+
if attr == 'id':
|
669
|
+
continue
|
670
|
+
|
671
|
+
# Skip unique fields with different values to avoid constraint violations
|
672
|
+
if attr in unique_fields and getattr(related_obj, attr) != attr_val:
|
673
|
+
continue
|
674
|
+
|
675
|
+
# Update non-unique fields
|
676
|
+
current_val = getattr(related_obj, attr, None)
|
677
|
+
if current_val != attr_val:
|
678
|
+
setattr(related_obj, attr, attr_val)
|
679
|
+
|
680
|
+
# Add the updated object to the session
|
681
|
+
session.add(related_obj)
|
682
|
+
logging.info(f"Reusing existing {related_model.__name__} with ID: {related_obj.id}")
|
683
|
+
else:
|
684
|
+
# Create a new record
|
685
|
+
logging.info(f"Creating new {related_model.__name__} for {key}")
|
686
|
+
related_obj = related_model(**value)
|
687
|
+
session.add(related_obj)
|
688
|
+
|
689
|
+
# Ensure the object has an ID by flushing
|
690
|
+
try:
|
691
|
+
await session.flush()
|
692
|
+
except Exception as e:
|
693
|
+
logging.error(f"Error flushing session for {related_model.__name__}: {e}")
|
694
|
+
|
695
|
+
# If there was a uniqueness error, try again to find the existing record
|
696
|
+
if "UNIQUE constraint failed" in str(e):
|
697
|
+
logging.info(f"UNIQUE constraint failed, trying to find existing record again")
|
698
|
+
|
699
|
+
# Try to find by any field provided in the search_dict
|
700
|
+
existing_stmt = select(related_model)
|
701
|
+
for field, field_value in search_dict.items():
|
702
|
+
existing_stmt = existing_stmt.where(getattr(related_model, field) == field_value)
|
703
|
+
|
704
|
+
# Execute the search query
|
705
|
+
existing_result = await session.execute(existing_stmt)
|
706
|
+
related_obj = existing_result.scalars().first()
|
707
|
+
|
708
|
+
if not related_obj:
|
709
|
+
# We couldn't find an existing record, re-raise the exception
|
710
|
+
raise
|
711
|
+
|
712
|
+
logging.info(f"Found existing {related_model.__name__} with ID: {related_obj.id} after constraint error")
|
713
|
+
|
714
|
+
# Update the result with the foreign key ID
|
715
|
+
foreign_key_name = f"{key}_id"
|
716
|
+
result[foreign_key_name] = related_obj.id
|
717
|
+
|
718
|
+
# Remove the relationship dictionary from the result
|
719
|
+
if key in result:
|
720
|
+
del result[key]
|
721
|
+
|
722
|
+
return result
|
723
|
+
|
724
|
+
@classmethod
|
725
|
+
async def update(cls: Type[T], data: Dict[str, Any], criteria: Dict[str, Any], include_relationships: bool = True) -> Optional[T]:
|
726
|
+
"""
|
727
|
+
Update an existing record identified by criteria.
|
728
|
+
|
729
|
+
Args:
|
730
|
+
data: Dictionary of updated field values
|
731
|
+
criteria: Dictionary of field values to identify the record to update
|
732
|
+
include_relationships: If True, return the updated instance with relationships loaded
|
733
|
+
|
734
|
+
Returns:
|
735
|
+
The updated model instance
|
736
|
+
"""
|
737
|
+
async with cls.get_session() as session:
|
738
|
+
try:
|
739
|
+
# Find the record(s) to update
|
740
|
+
statement = select(cls)
|
741
|
+
for field, value in criteria.items():
|
742
|
+
if isinstance(value, str) and '*' in value:
|
743
|
+
# Handle LIKE queries
|
744
|
+
like_value = value.replace('*', '%')
|
745
|
+
statement = statement.where(getattr(cls, field).like(like_value))
|
746
|
+
else:
|
747
|
+
statement = statement.where(getattr(cls, field) == value)
|
748
|
+
|
749
|
+
result = await session.execute(statement)
|
750
|
+
record = result.scalars().first()
|
751
|
+
|
752
|
+
if not record:
|
753
|
+
logging.warning(f"No record found with criteria: {criteria}")
|
754
|
+
return None
|
755
|
+
|
756
|
+
# Check for unique constraints before updating
|
757
|
+
for field_name, new_value in data.items():
|
758
|
+
if field_name != 'id' and hasattr(cls, field_name):
|
759
|
+
field = getattr(cls.__fields__.get(field_name), 'field_info', None)
|
760
|
+
if field and field.extra.get('unique', False):
|
761
|
+
# Check if the new value would conflict with an existing record
|
762
|
+
check_statement = select(cls).where(
|
763
|
+
getattr(cls, field_name) == new_value
|
764
|
+
).where(
|
765
|
+
cls.id != record.id
|
766
|
+
)
|
767
|
+
check_result = await session.execute(check_statement)
|
768
|
+
existing = check_result.scalars().first()
|
769
|
+
|
770
|
+
if existing:
|
771
|
+
raise ValueError(f"Cannot update {field_name} to '{new_value}': value already exists")
|
772
|
+
|
773
|
+
# Apply the updates
|
774
|
+
for key, value in data.items():
|
775
|
+
setattr(record, key, value)
|
776
|
+
|
777
|
+
await session.flush()
|
778
|
+
await session.commit()
|
779
|
+
|
780
|
+
if include_relationships:
|
781
|
+
# Refresh with relationships
|
782
|
+
refresh_statement = select(cls).where(cls.id == record.id)
|
783
|
+
for rel_name in cls._get_auto_relationship_fields():
|
784
|
+
refresh_statement = refresh_statement.options(selectinload(getattr(cls, rel_name)))
|
785
|
+
refresh_result = await session.execute(refresh_statement)
|
786
|
+
return refresh_result.scalars().first()
|
787
|
+
else:
|
788
|
+
await session.refresh(record)
|
789
|
+
return record
|
790
|
+
except Exception as e:
|
791
|
+
logging.error(f"Error updating record: {e}")
|
792
|
+
await session.rollback()
|
793
|
+
raise
|
450
794
|
|
451
795
|
@classmethod
|
452
|
-
async def delete(cls: Type[T],
|
796
|
+
async def delete(cls: Type[T], criteria: Dict[str, Any]) -> int:
|
453
797
|
"""
|
454
|
-
Delete
|
798
|
+
Delete records matching the provided criteria.
|
799
|
+
|
800
|
+
Args:
|
801
|
+
criteria: Dictionary of field values to identify records to delete
|
802
|
+
|
803
|
+
Returns:
|
804
|
+
Number of records deleted
|
455
805
|
"""
|
456
806
|
async with cls.get_session() as session:
|
457
|
-
|
458
|
-
|
459
|
-
|
807
|
+
try:
|
808
|
+
# Find the record(s) to delete
|
809
|
+
statement = select(cls)
|
810
|
+
for field, value in criteria.items():
|
811
|
+
if isinstance(value, str) and '*' in value:
|
812
|
+
# Handle LIKE queries
|
813
|
+
like_value = value.replace('*', '%')
|
814
|
+
statement = statement.where(getattr(cls, field).like(like_value))
|
815
|
+
else:
|
816
|
+
statement = statement.where(getattr(cls, field) == value)
|
817
|
+
|
818
|
+
result = await session.execute(statement)
|
819
|
+
records = result.scalars().all()
|
820
|
+
|
821
|
+
if not records:
|
822
|
+
logging.warning(f"No records found with criteria: {criteria}")
|
823
|
+
return 0
|
824
|
+
|
825
|
+
# Get a list of related tables that might need to be cleared first
|
826
|
+
# This helps with foreign key constraints
|
827
|
+
relationship_fields = cls._get_auto_relationship_fields()
|
828
|
+
to_many_relationships = []
|
829
|
+
|
830
|
+
# Find to-many relationships that need to be handled first
|
831
|
+
for rel_name in relationship_fields:
|
832
|
+
rel_attr = getattr(cls, rel_name, None)
|
833
|
+
if rel_attr and hasattr(rel_attr, 'property'):
|
834
|
+
# Check if this is a to-many relationship (one-to-many or many-to-many)
|
835
|
+
if hasattr(rel_attr.property, 'uselist') and rel_attr.property.uselist:
|
836
|
+
to_many_relationships.append(rel_name)
|
837
|
+
|
838
|
+
# For each record, delete related records first (cascade delete)
|
839
|
+
for record in records:
|
840
|
+
# First load all related collections
|
841
|
+
if to_many_relationships:
|
842
|
+
await session.refresh(record, attribute_names=to_many_relationships)
|
843
|
+
|
844
|
+
# Delete related records in collections
|
845
|
+
for rel_name in to_many_relationships:
|
846
|
+
related_collection = getattr(record, rel_name, [])
|
847
|
+
if related_collection:
|
848
|
+
for related_item in related_collection:
|
849
|
+
await session.delete(related_item)
|
850
|
+
|
851
|
+
# Now delete the main record
|
852
|
+
await session.delete(record)
|
853
|
+
|
854
|
+
# Commit the changes
|
855
|
+
await session.flush()
|
460
856
|
await session.commit()
|
461
|
-
|
462
|
-
|
857
|
+
|
858
|
+
return len(records)
|
859
|
+
except Exception as e:
|
860
|
+
logging.error(f"Error deleting records: {e}")
|
861
|
+
await session.rollback()
|
862
|
+
raise
|
463
863
|
|
464
864
|
@classmethod
|
465
865
|
async def create_with_related(
|
@@ -515,7 +915,7 @@ class EasyModel(SQLModel):
|
|
515
915
|
await session.refresh(obj, attribute_names=list(related_data.keys()))
|
516
916
|
return obj
|
517
917
|
|
518
|
-
def to_dict(self, include_relationships: bool =
|
918
|
+
def to_dict(self, include_relationships: bool = True, max_depth: int = 4) -> Dict[str, Any]:
|
519
919
|
"""
|
520
920
|
Convert the model instance to a dictionary.
|
521
921
|
|
@@ -526,32 +926,47 @@ class EasyModel(SQLModel):
|
|
526
926
|
Returns:
|
527
927
|
Dictionary representation of the model
|
528
928
|
"""
|
529
|
-
if max_depth <= 0:
|
530
|
-
return {}
|
531
|
-
|
532
929
|
# Get basic fields
|
533
930
|
result = self.model_dump()
|
534
931
|
|
535
932
|
# Add relationship fields if requested
|
536
933
|
if include_relationships and max_depth > 0:
|
537
934
|
for rel_name in self.__class__._get_auto_relationship_fields():
|
538
|
-
|
539
|
-
|
540
|
-
|
541
|
-
|
542
|
-
|
543
|
-
|
544
|
-
|
545
|
-
|
546
|
-
|
547
|
-
|
548
|
-
|
549
|
-
#
|
550
|
-
|
551
|
-
include_relationships=True,
|
552
|
-
max_depth=max_depth-1
|
553
|
-
)
|
935
|
+
# Only include relationships that are already loaded to avoid session errors
|
936
|
+
# We check if the relationship is loaded using SQLAlchemy's inspection API
|
937
|
+
is_loaded = False
|
938
|
+
try:
|
939
|
+
# Check if attribute exists and is not a relationship descriptor
|
940
|
+
rel_value = getattr(self, rel_name, None)
|
941
|
+
|
942
|
+
# If it's an attribute that has been loaded or not a relationship at all
|
943
|
+
# (for basic fields that match relationship naming pattern), include it
|
944
|
+
is_loaded = rel_value is not None and not hasattr(rel_value, 'prop')
|
945
|
+
except Exception:
|
946
|
+
# If accessing the attribute raises an exception, it's not loaded
|
947
|
+
is_loaded = False
|
554
948
|
|
949
|
+
if is_loaded:
|
950
|
+
rel_value = getattr(self, rel_name, None)
|
951
|
+
|
952
|
+
if rel_value is None:
|
953
|
+
result[rel_name] = None
|
954
|
+
elif isinstance(rel_value, list):
|
955
|
+
# Handle one-to-many relationships
|
956
|
+
result[rel_name] = [
|
957
|
+
item.to_dict(include_relationships=True, max_depth=max_depth-1)
|
958
|
+
for item in rel_value
|
959
|
+
]
|
960
|
+
else:
|
961
|
+
# Handle many-to-one relationships
|
962
|
+
result[rel_name] = rel_value.to_dict(
|
963
|
+
include_relationships=True,
|
964
|
+
max_depth=max_depth-1
|
965
|
+
)
|
966
|
+
else:
|
967
|
+
# If max_depth is 0, return the basic fields only
|
968
|
+
return result
|
969
|
+
|
555
970
|
return result
|
556
971
|
|
557
972
|
async def load_related(self, *related_fields: str) -> None:
|
@@ -568,6 +983,133 @@ class EasyModel(SQLModel):
|
|
568
983
|
# Refresh the instance with the specified relationships
|
569
984
|
await session.refresh(self, attribute_names=related_fields)
|
570
985
|
|
986
|
+
@classmethod
|
987
|
+
async def select(
|
988
|
+
cls: Type[T],
|
989
|
+
criteria: Dict[str, Any] = None,
|
990
|
+
all: bool = False,
|
991
|
+
first: bool = False,
|
992
|
+
include_relationships: bool = True,
|
993
|
+
order_by: Optional[Union[str, List[str]]] = None,
|
994
|
+
limit: Optional[int] = None
|
995
|
+
) -> Union[Optional[T], List[T]]:
|
996
|
+
"""
|
997
|
+
Retrieve record(s) by matching attribute values.
|
998
|
+
|
999
|
+
Args:
|
1000
|
+
criteria: Dictionary of field values to filter by (field=value)
|
1001
|
+
all: If True, return all matching records, otherwise return only the first one
|
1002
|
+
first: If True, return only the first record (equivalent to all=False)
|
1003
|
+
include_relationships: If True, eagerly load all relationships
|
1004
|
+
order_by: Field(s) to order by. Can be a string or list of strings.
|
1005
|
+
Prefix with '-' for descending order (e.g. '-created_at')
|
1006
|
+
limit: Maximum number of records to retrieve (if all=True)
|
1007
|
+
If limit > 1, all is automatically set to True
|
1008
|
+
|
1009
|
+
Returns:
|
1010
|
+
A single model instance, a list of instances, or None if not found
|
1011
|
+
"""
|
1012
|
+
# Default to empty criteria if None provided
|
1013
|
+
if criteria is None:
|
1014
|
+
criteria = {}
|
1015
|
+
|
1016
|
+
# If limit is specified and > 1, set all to True
|
1017
|
+
if limit is not None and limit > 1:
|
1018
|
+
all = True
|
1019
|
+
# If first is specified, set all to False (first takes precedence)
|
1020
|
+
if first:
|
1021
|
+
all = False
|
1022
|
+
|
1023
|
+
async with cls.get_session() as session:
|
1024
|
+
# Build the query
|
1025
|
+
statement = select(cls)
|
1026
|
+
|
1027
|
+
# Apply criteria
|
1028
|
+
for field, value in criteria.items():
|
1029
|
+
if isinstance(value, str) and '*' in value:
|
1030
|
+
# Handle LIKE queries (convert '*' wildcard to '%')
|
1031
|
+
like_value = value.replace('*', '%')
|
1032
|
+
statement = statement.where(getattr(cls, field).like(like_value))
|
1033
|
+
else:
|
1034
|
+
# Regular equality check
|
1035
|
+
statement = statement.where(getattr(cls, field) == value)
|
1036
|
+
|
1037
|
+
# Apply ordering
|
1038
|
+
if order_by:
|
1039
|
+
statement = cls._apply_order_by(statement, order_by)
|
1040
|
+
|
1041
|
+
# Apply limit
|
1042
|
+
if limit:
|
1043
|
+
statement = statement.limit(limit)
|
1044
|
+
|
1045
|
+
# Include relationships if requested
|
1046
|
+
if include_relationships:
|
1047
|
+
for rel_name in cls._get_auto_relationship_fields():
|
1048
|
+
statement = statement.options(selectinload(getattr(cls, rel_name)))
|
1049
|
+
|
1050
|
+
# Execute the query
|
1051
|
+
result = await session.execute(statement)
|
1052
|
+
|
1053
|
+
if all:
|
1054
|
+
# Return all results
|
1055
|
+
instances = result.scalars().all()
|
1056
|
+
|
1057
|
+
# Materialize relationships if requested - this ensures they're fully loaded
|
1058
|
+
if include_relationships:
|
1059
|
+
for instance in instances:
|
1060
|
+
# For each relationship, access it once to ensure it's loaded
|
1061
|
+
for rel_name in cls._get_auto_relationship_fields():
|
1062
|
+
try:
|
1063
|
+
# This will force loading the relationship while session is active
|
1064
|
+
_ = getattr(instance, rel_name)
|
1065
|
+
except Exception:
|
1066
|
+
# Skip if the relationship can't be loaded
|
1067
|
+
pass
|
1068
|
+
|
1069
|
+
return instances
|
1070
|
+
else:
|
1071
|
+
# Return only the first result
|
1072
|
+
instance = result.scalars().first()
|
1073
|
+
|
1074
|
+
# Materialize relationships if requested and instance exists
|
1075
|
+
if include_relationships and instance:
|
1076
|
+
# For each relationship, access it once to ensure it's loaded
|
1077
|
+
for rel_name in cls._get_auto_relationship_fields():
|
1078
|
+
try:
|
1079
|
+
# This will force loading the relationship while session is active
|
1080
|
+
_ = getattr(instance, rel_name)
|
1081
|
+
except Exception:
|
1082
|
+
# Skip if the relationship can't be loaded
|
1083
|
+
pass
|
1084
|
+
|
1085
|
+
return instance
|
1086
|
+
|
1087
|
+
@classmethod
|
1088
|
+
async def get_or_create(cls: Type[T], search_criteria: Dict[str, Any], defaults: Optional[Dict[str, Any]] = None) -> Tuple[T, bool]:
|
1089
|
+
"""
|
1090
|
+
Get a record by criteria or create it if it doesn't exist.
|
1091
|
+
|
1092
|
+
Args:
|
1093
|
+
search_criteria: Dictionary of search criteria
|
1094
|
+
defaults: Default values to use when creating a new record
|
1095
|
+
|
1096
|
+
Returns:
|
1097
|
+
Tuple of (model instance, created flag)
|
1098
|
+
"""
|
1099
|
+
# Try to find the record
|
1100
|
+
record = await cls.select(criteria=search_criteria, all=False, first=True)
|
1101
|
+
|
1102
|
+
if record:
|
1103
|
+
return record, False
|
1104
|
+
|
1105
|
+
# Record not found, create it
|
1106
|
+
data = {**search_criteria}
|
1107
|
+
if defaults:
|
1108
|
+
data.update(defaults)
|
1109
|
+
|
1110
|
+
new_record = await cls.insert(data)
|
1111
|
+
return new_record, True
|
1112
|
+
|
571
1113
|
# Register an event listener to update 'updated_at' on instance modifications.
|
572
1114
|
@event.listens_for(Session, "before_flush")
|
573
1115
|
def _update_updated_at(session, flush_context, instances):
|
@@ -591,7 +1133,9 @@ async def init_db(migrate: bool = True, model_classes: List[Type[SQLModel]] = No
|
|
591
1133
|
|
592
1134
|
# Import auto_relationships functions with conditional import to avoid circular imports
|
593
1135
|
try:
|
594
|
-
from .auto_relationships import _auto_relationships_enabled, process_auto_relationships
|
1136
|
+
from .auto_relationships import (_auto_relationships_enabled, process_auto_relationships,
|
1137
|
+
enable_auto_relationships, register_model_class,
|
1138
|
+
process_all_models_for_relationships)
|
595
1139
|
has_auto_relationships = True
|
596
1140
|
except ImportError:
|
597
1141
|
has_auto_relationships = False
|
@@ -603,15 +1147,6 @@ async def init_db(migrate: bool = True, model_classes: List[Type[SQLModel]] = No
|
|
603
1147
|
except ImportError:
|
604
1148
|
has_migrations = False
|
605
1149
|
|
606
|
-
# Process auto-relationships before creating tables if enabled
|
607
|
-
if has_auto_relationships and _auto_relationships_enabled:
|
608
|
-
process_auto_relationships()
|
609
|
-
|
610
|
-
# Create async engine and all tables
|
611
|
-
engine = db_config.get_engine()
|
612
|
-
if not engine:
|
613
|
-
raise ValueError("Database configuration is missing. Use db_config.configure_* methods first.")
|
614
|
-
|
615
1150
|
# Get all SQLModel subclasses (our models) if not provided
|
616
1151
|
if model_classes is None:
|
617
1152
|
model_classes = []
|
@@ -622,6 +1157,18 @@ async def init_db(migrate: bool = True, model_classes: List[Type[SQLModel]] = No
|
|
622
1157
|
if isinstance(cls, type) and issubclass(cls, SQLModel) and cls != SQLModel and cls != EasyModel:
|
623
1158
|
model_classes.append(cls)
|
624
1159
|
|
1160
|
+
# Enable auto-relationships and register all models
|
1161
|
+
if has_auto_relationships:
|
1162
|
+
# Enable auto-relationships with patch_metaclass=False
|
1163
|
+
enable_auto_relationships(patch_metaclass=False)
|
1164
|
+
|
1165
|
+
# Register all model classes
|
1166
|
+
for model_cls in model_classes:
|
1167
|
+
register_model_class(model_cls)
|
1168
|
+
|
1169
|
+
# Process relationships for all registered models
|
1170
|
+
process_all_models_for_relationships()
|
1171
|
+
|
625
1172
|
migration_results = {}
|
626
1173
|
|
627
1174
|
# Check for migrations first if the feature is available and enabled
|
@@ -630,7 +1177,11 @@ async def init_db(migrate: bool = True, model_classes: List[Type[SQLModel]] = No
|
|
630
1177
|
if migration_results:
|
631
1178
|
logging.info(f"Applied migrations: {len(migration_results)} models affected")
|
632
1179
|
|
633
|
-
# Create
|
1180
|
+
# Create async engine and all tables
|
1181
|
+
engine = db_config.get_engine()
|
1182
|
+
if not engine:
|
1183
|
+
raise ValueError("Database configuration is missing. Use db_config.configure_* methods first.")
|
1184
|
+
|
634
1185
|
async with engine.begin() as conn:
|
635
1186
|
if has_migrations:
|
636
1187
|
# Use our safe table creation methods if migrations are available
|