async-easy-model 0.1.12__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
async_easy_model/model.py CHANGED
@@ -212,7 +212,7 @@ class EasyModel(SQLModel):
212
212
  @classmethod
213
213
  async def all(
214
214
  cls: Type[T],
215
- include_relationships: bool = False,
215
+ include_relationships: bool = True,
216
216
  order_by: Optional[Union[str, List[str]]] = None
217
217
  ) -> List[T]:
218
218
  """
@@ -243,7 +243,7 @@ class EasyModel(SQLModel):
243
243
  @classmethod
244
244
  async def first(
245
245
  cls: Type[T],
246
- include_relationships: bool = False,
246
+ include_relationships: bool = True,
247
247
  order_by: Optional[Union[str, List[str]]] = None
248
248
  ) -> Optional[T]:
249
249
  """
@@ -275,7 +275,7 @@ class EasyModel(SQLModel):
275
275
  async def limit(
276
276
  cls: Type[T],
277
277
  count: int,
278
- include_relationships: bool = False,
278
+ include_relationships: bool = True,
279
279
  order_by: Optional[Union[str, List[str]]] = None
280
280
  ) -> List[T]:
281
281
  """
@@ -305,7 +305,7 @@ class EasyModel(SQLModel):
305
305
  return result.scalars().all()
306
306
 
307
307
  @classmethod
308
- async def get_by_id(cls: Type[T], id: int, include_relationships: bool = False) -> Optional[T]:
308
+ async def get_by_id(cls: Type[T], id: int, include_relationships: bool = True) -> Optional[T]:
309
309
  """
310
310
  Retrieve a record by its primary key.
311
311
 
@@ -331,7 +331,7 @@ class EasyModel(SQLModel):
331
331
  async def get_by_attribute(
332
332
  cls: Type[T],
333
333
  all: bool = False,
334
- include_relationships: bool = False,
334
+ include_relationships: bool = True,
335
335
  order_by: Optional[Union[str, List[str]]] = None,
336
336
  **kwargs
337
337
  ) -> Union[Optional[T], List[T]]:
@@ -391,78 +391,128 @@ class EasyModel(SQLModel):
391
391
  return result.scalars().first()
392
392
 
393
393
  @classmethod
394
- async def insert(cls: Type[T], data: Dict[str, Any], include_relationships: bool = False) -> T:
394
+ async def insert(cls: Type[T], data: Union[Dict[str, Any], List[Dict[str, Any]]], include_relationships: bool = True) -> Union[T, List[T]]:
395
395
  """
396
- Insert a new record.
396
+ Insert one or more records.
397
397
 
398
398
  Args:
399
- data: Dictionary of field values
400
- include_relationships: If True, return the instance with relationships loaded
399
+ data: Dictionary of field values or a list of dictionaries for multiple records
400
+ include_relationships: If True, return the instance(s) with relationships loaded
401
401
 
402
402
  Returns:
403
- The created model instance
403
+ The created model instance(s)
404
404
  """
405
- async with cls.get_session() as session:
406
- obj = cls(**data)
407
- session.add(obj)
408
- await session.commit()
409
-
410
- if include_relationships:
411
- # Refresh with relationships
412
- statement = select(cls).where(cls.id == obj.id)
413
- for rel_name in cls._get_auto_relationship_fields():
414
- statement = statement.options(selectinload(getattr(cls, rel_name)))
415
- result = await session.execute(statement)
416
- return result.scalars().first()
417
- else:
418
- await session.refresh(obj)
419
- return obj
420
-
421
- @classmethod
422
- async def update(
423
- cls: Type[T],
424
- id: int,
425
- data: Dict[str, Any],
426
- include_relationships: bool = False
427
- ) -> Optional[T]:
428
- """
429
- Update an existing record by its ID.
430
-
431
- Args:
432
- id: The primary key value
433
- data: Dictionary of field values to update
434
- include_relationships: If True, return the instance with relationships loaded
435
-
436
- Returns:
437
- The updated model instance or None if not found
438
- """
439
- async with cls.get_session() as session:
440
- # Explicitly update updated_at since bulk updates bypass ORM events.
441
- data["updated_at"] = datetime.now(tz.utc)
442
- statement = sqlalchemy_update(cls).where(cls.id == id).values(**data).execution_options(synchronize_session="fetch")
443
- await session.execute(statement)
444
- await session.commit()
445
-
446
- if include_relationships:
447
- return await cls.get_with_related(id, *cls._get_auto_relationship_fields())
448
- else:
449
- return await cls.get_by_id(id)
450
-
451
- @classmethod
452
- async def delete(cls: Type[T], id: int) -> bool:
453
- """
454
- Delete a record by its ID.
455
- """
456
- async with cls.get_session() as session:
457
- obj = await session.get(cls, id)
458
- if obj:
459
- await session.delete(obj)
460
- await session.commit()
461
- return True
462
- return False
463
-
405
+ # Handle list of records
406
+ if isinstance(data, list):
407
+ objects = []
408
+ async with cls.get_session() as session:
409
+ for item in data:
410
+ try:
411
+ # Process relationships first
412
+ processed_item = await cls._process_relationships_for_insert(session, item)
413
+
414
+ # Check if a record with unique constraints already exists
415
+ unique_fields = cls._get_unique_fields()
416
+ existing_obj = None
417
+
418
+ if unique_fields:
419
+ unique_criteria = {field: processed_item[field]
420
+ for field in unique_fields
421
+ if field in processed_item}
422
+
423
+ if unique_criteria:
424
+ # Try to find existing record with these unique values
425
+ statement = select(cls)
426
+ for field, value in unique_criteria.items():
427
+ statement = statement.where(getattr(cls, field) == value)
428
+ result = await session.execute(statement)
429
+ existing_obj = result.scalars().first()
430
+
431
+ if existing_obj:
432
+ # Update existing object with new values
433
+ for key, value in processed_item.items():
434
+ if key != 'id': # Don't update ID
435
+ setattr(existing_obj, key, value)
436
+ objects.append(existing_obj)
437
+ else:
438
+ # Create new object
439
+ obj = cls(**processed_item)
440
+ session.add(obj)
441
+ objects.append(obj)
442
+ except Exception as e:
443
+ logging.error(f"Error inserting record: {e}")
444
+ await session.rollback()
445
+ raise
446
+
447
+ try:
448
+ await session.flush()
449
+ await session.commit()
450
+
451
+ # Refresh with relationships if requested
452
+ if include_relationships:
453
+ for obj in objects:
454
+ await session.refresh(obj)
455
+ except Exception as e:
456
+ logging.error(f"Error committing transaction: {e}")
457
+ await session.rollback()
458
+ raise
459
+
460
+ return objects
461
+ else:
462
+ # Single record case
463
+ async with cls.get_session() as session:
464
+ try:
465
+ # Process relationships first
466
+ processed_data = await cls._process_relationships_for_insert(session, data)
467
+
468
+ # Check if a record with unique constraints already exists
469
+ unique_fields = cls._get_unique_fields()
470
+ existing_obj = None
471
+
472
+ if unique_fields:
473
+ unique_criteria = {field: processed_data[field]
474
+ for field in unique_fields
475
+ if field in processed_data}
476
+
477
+ if unique_criteria:
478
+ # Try to find existing record with these unique values
479
+ statement = select(cls)
480
+ for field, value in unique_criteria.items():
481
+ statement = statement.where(getattr(cls, field) == value)
482
+ result = await session.execute(statement)
483
+ existing_obj = result.scalars().first()
484
+
485
+ if existing_obj:
486
+ # Update existing object with new values
487
+ for key, value in processed_data.items():
488
+ if key != 'id': # Don't update ID
489
+ setattr(existing_obj, key, value)
490
+ obj = existing_obj
491
+ else:
492
+ # Create new object
493
+ obj = cls(**processed_data)
494
+ session.add(obj)
495
+
496
+ await session.flush() # Flush to get the ID
497
+ await session.commit()
498
+
499
+ if include_relationships:
500
+ # Refresh with relationships
501
+ statement = select(cls).where(cls.id == obj.id)
502
+ for rel_name in cls._get_auto_relationship_fields():
503
+ statement = statement.options(selectinload(getattr(cls, rel_name)))
504
+ result = await session.execute(statement)
505
+ return result.scalars().first()
506
+ else:
507
+ await session.refresh(obj)
508
+ return obj
509
+ except Exception as e:
510
+ logging.error(f"Error inserting record: {e}")
511
+ await session.rollback()
512
+ raise
513
+
464
514
  @classmethod
465
- async def create_with_related(
515
+ async def insert_with_related(
466
516
  cls: Type[T],
467
517
  data: Dict[str, Any],
468
518
  related_data: Dict[str, List[Dict[str, Any]]] = None
@@ -515,7 +565,357 @@ class EasyModel(SQLModel):
515
565
  await session.refresh(obj, attribute_names=list(related_data.keys()))
516
566
  return obj
517
567
 
518
- def to_dict(self, include_relationships: bool = False, max_depth: int = 1) -> Dict[str, Any]:
568
+ @classmethod
569
+ def _get_unique_fields(cls) -> List[str]:
570
+ """
571
+ Get all fields with unique=True constraint
572
+
573
+ Returns:
574
+ List of field names that have unique constraints
575
+ """
576
+ unique_fields = []
577
+ for name, field in cls.__fields__.items():
578
+ if name != 'id' and hasattr(field, 'field_info') and field.field_info.extra.get('unique', False):
579
+ unique_fields.append(name)
580
+ return unique_fields
581
+
582
+ @classmethod
583
+ async def _process_relationships_for_insert(cls: Type[T], session: AsyncSession, data: Dict[str, Any]) -> Dict[str, Any]:
584
+ """
585
+ Process relationships in input data for insertion.
586
+
587
+ This method handles nested objects in the input data, such as:
588
+ cart = await ShoppingCart.insert({
589
+ "user": {"username": "john", "email": "john@example.com"},
590
+ "product": {"name": "Product X", "price": 19.99},
591
+ "quantity": 2
592
+ })
593
+
594
+ For each nested object:
595
+ 1. Find the target model class
596
+ 2. Check if an object with the same unique fields already exists
597
+ 3. If found, update existing object with non-unique fields
598
+ 4. If not found, create a new object
599
+ 5. Set the foreign key ID in the result data
600
+
601
+ Args:
602
+ session: The database session to use
603
+ data: Input data dictionary that may contain nested objects
604
+
605
+ Returns:
606
+ Processed data dictionary with nested objects replaced by their foreign key IDs
607
+ """
608
+ import copy
609
+ result = copy.deepcopy(data)
610
+
611
+ # Get all relationship fields for this model
612
+ relationship_fields = cls._get_auto_relationship_fields()
613
+
614
+ # Get foreign key fields
615
+ foreign_key_fields = []
616
+ for field_name, field_info in cls.__fields__.items():
617
+ if field_name.endswith("_id") and hasattr(field_info, "field_info"):
618
+ if field_info.field_info.extra.get("foreign_key"):
619
+ foreign_key_fields.append(field_name)
620
+
621
+ # Handle nested relationship objects
622
+ for key, value in data.items():
623
+ # Skip if the value is not a dictionary
624
+ if not isinstance(value, dict):
625
+ continue
626
+
627
+ # Check if this is a relationship field (either by name or derived from foreign key)
628
+ is_rel_field = key in relationship_fields
629
+ related_key = f"{key}_id"
630
+ is_derived_rel = related_key in foreign_key_fields
631
+
632
+ # If it's a relationship field or derived from a foreign key
633
+ if is_rel_field or is_derived_rel or related_key in cls.__fields__:
634
+ # Find the related model class
635
+ related_model = None
636
+
637
+ # Try to get the related model from the attribute
638
+ if hasattr(cls, key) and hasattr(getattr(cls, key), 'property'):
639
+ # Get from relationship attribute
640
+ rel_attr = getattr(cls, key)
641
+ related_model = rel_attr.property.mapper.class_
642
+ else:
643
+ # Try to find it from foreign key definition
644
+ fk_definition = None
645
+ for field_name, field_info in cls.__fields__.items():
646
+ if field_name == related_key and hasattr(field_info, "field_info"):
647
+ fk_definition = field_info.field_info.extra.get("foreign_key")
648
+ break
649
+
650
+ if fk_definition:
651
+ # Parse foreign key definition (e.g. "users.id")
652
+ target_table, _ = fk_definition.split(".")
653
+ # Try to find the target model
654
+ from async_easy_model.auto_relationships import get_model_by_table_name, singularize_name
655
+ related_model = get_model_by_table_name(target_table)
656
+ if not related_model:
657
+ # Try with the singular form
658
+ singular_table = singularize_name(target_table)
659
+ related_model = get_model_by_table_name(singular_table)
660
+ else:
661
+ # Try to infer from field name (e.g., "user_id" -> Users)
662
+ base_name = related_key[:-3] # Remove "_id"
663
+ from async_easy_model.auto_relationships import get_model_by_table_name, singularize_name, pluralize_name
664
+
665
+ # Try singular and plural forms
666
+ related_model = get_model_by_table_name(base_name)
667
+ if not related_model:
668
+ plural_table = pluralize_name(base_name)
669
+ related_model = get_model_by_table_name(plural_table)
670
+ if not related_model:
671
+ singular_table = singularize_name(base_name)
672
+ related_model = get_model_by_table_name(singular_table)
673
+
674
+ if not related_model:
675
+ logging.warning(f"Could not find related model for {key} in {cls.__name__}")
676
+ continue
677
+
678
+ # Look for unique fields in the related model to use for searching
679
+ unique_fields = []
680
+ for field_name, field_info in related_model.__fields__.items():
681
+ if (hasattr(field_info, "field_info") and
682
+ field_info.field_info.extra.get('unique', False)):
683
+ unique_fields.append(field_name)
684
+
685
+ # Create a search dictionary using unique fields
686
+ search_dict = {}
687
+ for field in unique_fields:
688
+ if field in value and value[field] is not None:
689
+ search_dict[field] = value[field]
690
+
691
+ # If no unique fields found but ID is provided, use it
692
+ if not search_dict and 'id' in value and value['id']:
693
+ search_dict = {'id': value['id']}
694
+
695
+ # Special case for products without uniqueness constraints
696
+ if not search_dict and related_model.__tablename__ == 'products' and 'name' in value:
697
+ search_dict = {'name': value['name']}
698
+
699
+ # Try to find an existing record
700
+ related_obj = None
701
+ if search_dict:
702
+ logging.info(f"Searching for existing {related_model.__name__} with {search_dict}")
703
+
704
+ try:
705
+ # Create a more appropriate search query based on unique fields
706
+ existing_stmt = select(related_model)
707
+ for field, field_value in search_dict.items():
708
+ existing_stmt = existing_stmt.where(getattr(related_model, field) == field_value)
709
+
710
+ existing_result = await session.execute(existing_stmt)
711
+ related_obj = existing_result.scalars().first()
712
+
713
+ if related_obj:
714
+ logging.info(f"Found existing {related_model.__name__} with ID: {related_obj.id}")
715
+ except Exception as e:
716
+ logging.error(f"Error finding existing record: {e}")
717
+
718
+ if related_obj:
719
+ # Update the existing record with any non-unique field values
720
+ for attr, attr_val in value.items():
721
+ # Skip ID field
722
+ if attr == 'id':
723
+ continue
724
+
725
+ # Skip unique fields with different values to avoid constraint violations
726
+ if attr in unique_fields and getattr(related_obj, attr) != attr_val:
727
+ continue
728
+
729
+ # Update non-unique fields
730
+ current_val = getattr(related_obj, attr, None)
731
+ if current_val != attr_val:
732
+ setattr(related_obj, attr, attr_val)
733
+
734
+ # Add the updated object to the session
735
+ session.add(related_obj)
736
+ logging.info(f"Reusing existing {related_model.__name__} with ID: {related_obj.id}")
737
+ else:
738
+ # Create a new record
739
+ logging.info(f"Creating new {related_model.__name__} for {key}")
740
+ related_obj = related_model(**value)
741
+ session.add(related_obj)
742
+
743
+ # Ensure the object has an ID by flushing
744
+ try:
745
+ await session.flush()
746
+ except Exception as e:
747
+ logging.error(f"Error flushing session for {related_model.__name__}: {e}")
748
+
749
+ # If there was a uniqueness error, try again to find the existing record
750
+ if "UNIQUE constraint failed" in str(e):
751
+ logging.info(f"UNIQUE constraint failed, trying to find existing record again")
752
+
753
+ # Try to find by any field provided in the search_dict
754
+ existing_stmt = select(related_model)
755
+ for field, field_value in search_dict.items():
756
+ existing_stmt = existing_stmt.where(getattr(related_model, field) == field_value)
757
+
758
+ # Execute the search query
759
+ existing_result = await session.execute(existing_stmt)
760
+ related_obj = existing_result.scalars().first()
761
+
762
+ if not related_obj:
763
+ # We couldn't find an existing record, re-raise the exception
764
+ raise
765
+
766
+ logging.info(f"Found existing {related_model.__name__} with ID: {related_obj.id} after constraint error")
767
+
768
+ # Update the result with the foreign key ID
769
+ foreign_key_name = f"{key}_id"
770
+ result[foreign_key_name] = related_obj.id
771
+
772
+ # Remove the relationship dictionary from the result
773
+ if key in result:
774
+ del result[key]
775
+
776
+ return result
777
+
778
+ @classmethod
779
+ async def update(cls: Type[T], data: Dict[str, Any], criteria: Dict[str, Any], include_relationships: bool = True) -> Optional[T]:
780
+ """
781
+ Update an existing record identified by criteria.
782
+
783
+ Args:
784
+ data: Dictionary of updated field values
785
+ criteria: Dictionary of field values to identify the record to update
786
+ include_relationships: If True, return the updated instance with relationships loaded
787
+
788
+ Returns:
789
+ The updated model instance
790
+ """
791
+ async with cls.get_session() as session:
792
+ try:
793
+ # Find the record(s) to update
794
+ statement = select(cls)
795
+ for field, value in criteria.items():
796
+ if isinstance(value, str) and '*' in value:
797
+ # Handle LIKE queries
798
+ like_value = value.replace('*', '%')
799
+ statement = statement.where(getattr(cls, field).like(like_value))
800
+ else:
801
+ statement = statement.where(getattr(cls, field) == value)
802
+
803
+ result = await session.execute(statement)
804
+ record = result.scalars().first()
805
+
806
+ if not record:
807
+ logging.warning(f"No record found with criteria: {criteria}")
808
+ return None
809
+
810
+ # Check for unique constraints before updating
811
+ for field_name, new_value in data.items():
812
+ if field_name != 'id' and hasattr(cls, field_name):
813
+ field = getattr(cls.__fields__.get(field_name), 'field_info', None)
814
+ if field and field.extra.get('unique', False):
815
+ # Check if the new value would conflict with an existing record
816
+ check_statement = select(cls).where(
817
+ getattr(cls, field_name) == new_value
818
+ ).where(
819
+ cls.id != record.id
820
+ )
821
+ check_result = await session.execute(check_statement)
822
+ existing = check_result.scalars().first()
823
+
824
+ if existing:
825
+ raise ValueError(f"Cannot update {field_name} to '{new_value}': value already exists")
826
+
827
+ # Apply the updates
828
+ for key, value in data.items():
829
+ setattr(record, key, value)
830
+
831
+ await session.flush()
832
+ await session.commit()
833
+
834
+ if include_relationships:
835
+ # Refresh with relationships
836
+ refresh_statement = select(cls).where(cls.id == record.id)
837
+ for rel_name in cls._get_auto_relationship_fields():
838
+ refresh_statement = refresh_statement.options(selectinload(getattr(cls, rel_name)))
839
+ refresh_result = await session.execute(refresh_statement)
840
+ return refresh_result.scalars().first()
841
+ else:
842
+ await session.refresh(record)
843
+ return record
844
+ except Exception as e:
845
+ logging.error(f"Error updating record: {e}")
846
+ await session.rollback()
847
+ raise
848
+
849
+ @classmethod
850
+ async def delete(cls: Type[T], criteria: Dict[str, Any]) -> int:
851
+ """
852
+ Delete records matching the provided criteria.
853
+
854
+ Args:
855
+ criteria: Dictionary of field values to identify records to delete
856
+
857
+ Returns:
858
+ Number of records deleted
859
+ """
860
+ async with cls.get_session() as session:
861
+ try:
862
+ # Find the record(s) to delete
863
+ statement = select(cls)
864
+ for field, value in criteria.items():
865
+ if isinstance(value, str) and '*' in value:
866
+ # Handle LIKE queries
867
+ like_value = value.replace('*', '%')
868
+ statement = statement.where(getattr(cls, field).like(like_value))
869
+ else:
870
+ statement = statement.where(getattr(cls, field) == value)
871
+
872
+ result = await session.execute(statement)
873
+ records = result.scalars().all()
874
+
875
+ if not records:
876
+ logging.warning(f"No records found with criteria: {criteria}")
877
+ return 0
878
+
879
+ # Get a list of related tables that might need to be cleared first
880
+ # This helps with foreign key constraints
881
+ relationship_fields = cls._get_auto_relationship_fields()
882
+ to_many_relationships = []
883
+
884
+ # Find to-many relationships that need to be handled first
885
+ for rel_name in relationship_fields:
886
+ rel_attr = getattr(cls, rel_name, None)
887
+ if rel_attr and hasattr(rel_attr, 'property'):
888
+ # Check if this is a to-many relationship (one-to-many or many-to-many)
889
+ if hasattr(rel_attr.property, 'uselist') and rel_attr.property.uselist:
890
+ to_many_relationships.append(rel_name)
891
+
892
+ # For each record, delete related records first (cascade delete)
893
+ for record in records:
894
+ # First load all related collections
895
+ if to_many_relationships:
896
+ await session.refresh(record, attribute_names=to_many_relationships)
897
+
898
+ # Delete related records in collections
899
+ for rel_name in to_many_relationships:
900
+ related_collection = getattr(record, rel_name, [])
901
+ if related_collection:
902
+ for related_item in related_collection:
903
+ await session.delete(related_item)
904
+
905
+ # Now delete the main record
906
+ await session.delete(record)
907
+
908
+ # Commit the changes
909
+ await session.flush()
910
+ await session.commit()
911
+
912
+ return len(records)
913
+ except Exception as e:
914
+ logging.error(f"Error deleting records: {e}")
915
+ await session.rollback()
916
+ raise
917
+
918
+ def to_dict(self, include_relationships: bool = True, max_depth: int = 4) -> Dict[str, Any]:
519
919
  """
520
920
  Convert the model instance to a dictionary.
521
921
 
@@ -526,32 +926,47 @@ class EasyModel(SQLModel):
526
926
  Returns:
527
927
  Dictionary representation of the model
528
928
  """
529
- if max_depth <= 0:
530
- return {}
531
-
532
929
  # Get basic fields
533
930
  result = self.model_dump()
534
931
 
535
932
  # Add relationship fields if requested
536
933
  if include_relationships and max_depth > 0:
537
934
  for rel_name in self.__class__._get_auto_relationship_fields():
538
- rel_value = getattr(self, rel_name, None)
539
-
540
- if rel_value is None:
541
- result[rel_name] = None
542
- elif isinstance(rel_value, list):
543
- # Handle one-to-many relationships
544
- result[rel_name] = [
545
- item.to_dict(include_relationships=True, max_depth=max_depth-1)
546
- for item in rel_value
547
- ]
548
- else:
549
- # Handle many-to-one relationships
550
- result[rel_name] = rel_value.to_dict(
551
- include_relationships=True,
552
- max_depth=max_depth-1
553
- )
935
+ # Only include relationships that are already loaded to avoid session errors
936
+ # We check if the relationship is loaded using SQLAlchemy's inspection API
937
+ is_loaded = False
938
+ try:
939
+ # Check if attribute exists and is not a relationship descriptor
940
+ rel_value = getattr(self, rel_name, None)
941
+
942
+ # If it's an attribute that has been loaded or not a relationship at all
943
+ # (for basic fields that match relationship naming pattern), include it
944
+ is_loaded = rel_value is not None and not hasattr(rel_value, 'prop')
945
+ except Exception:
946
+ # If accessing the attribute raises an exception, it's not loaded
947
+ is_loaded = False
948
+
949
+ if is_loaded:
950
+ rel_value = getattr(self, rel_name, None)
554
951
 
952
+ if rel_value is None:
953
+ result[rel_name] = None
954
+ elif isinstance(rel_value, list):
955
+ # Handle one-to-many relationships
956
+ result[rel_name] = [
957
+ item.to_dict(include_relationships=True, max_depth=max_depth-1)
958
+ for item in rel_value
959
+ ]
960
+ else:
961
+ # Handle many-to-one relationships
962
+ result[rel_name] = rel_value.to_dict(
963
+ include_relationships=True,
964
+ max_depth=max_depth-1
965
+ )
966
+ else:
967
+ # If max_depth is 0, return the basic fields only
968
+ return result
969
+
555
970
  return result
556
971
 
557
972
  async def load_related(self, *related_fields: str) -> None:
@@ -568,6 +983,133 @@ class EasyModel(SQLModel):
568
983
  # Refresh the instance with the specified relationships
569
984
  await session.refresh(self, attribute_names=related_fields)
570
985
 
986
+ @classmethod
987
+ async def select(
988
+ cls: Type[T],
989
+ criteria: Dict[str, Any] = None,
990
+ all: bool = False,
991
+ first: bool = False,
992
+ include_relationships: bool = True,
993
+ order_by: Optional[Union[str, List[str]]] = None,
994
+ limit: Optional[int] = None
995
+ ) -> Union[Optional[T], List[T]]:
996
+ """
997
+ Retrieve record(s) by matching attribute values.
998
+
999
+ Args:
1000
+ criteria: Dictionary of field values to filter by (field=value)
1001
+ all: If True, return all matching records, otherwise return only the first one
1002
+ first: If True, return only the first record (equivalent to all=False)
1003
+ include_relationships: If True, eagerly load all relationships
1004
+ order_by: Field(s) to order by. Can be a string or list of strings.
1005
+ Prefix with '-' for descending order (e.g. '-created_at')
1006
+ limit: Maximum number of records to retrieve (if all=True)
1007
+ If limit > 1, all is automatically set to True
1008
+
1009
+ Returns:
1010
+ A single model instance, a list of instances, or None if not found
1011
+ """
1012
+ # Default to empty criteria if None provided
1013
+ if criteria is None:
1014
+ criteria = {}
1015
+
1016
+ # If limit is specified and > 1, set all to True
1017
+ if limit is not None and limit > 1:
1018
+ all = True
1019
+ # If first is specified, set all to False (first takes precedence)
1020
+ if first:
1021
+ all = False
1022
+
1023
+ async with cls.get_session() as session:
1024
+ # Build the query
1025
+ statement = select(cls)
1026
+
1027
+ # Apply criteria
1028
+ for field, value in criteria.items():
1029
+ if isinstance(value, str) and '*' in value:
1030
+ # Handle LIKE queries (convert '*' wildcard to '%')
1031
+ like_value = value.replace('*', '%')
1032
+ statement = statement.where(getattr(cls, field).like(like_value))
1033
+ else:
1034
+ # Regular equality check
1035
+ statement = statement.where(getattr(cls, field) == value)
1036
+
1037
+ # Apply ordering
1038
+ if order_by:
1039
+ statement = cls._apply_order_by(statement, order_by)
1040
+
1041
+ # Apply limit
1042
+ if limit:
1043
+ statement = statement.limit(limit)
1044
+
1045
+ # Include relationships if requested
1046
+ if include_relationships:
1047
+ for rel_name in cls._get_auto_relationship_fields():
1048
+ statement = statement.options(selectinload(getattr(cls, rel_name)))
1049
+
1050
+ # Execute the query
1051
+ result = await session.execute(statement)
1052
+
1053
+ if all:
1054
+ # Return all results
1055
+ instances = result.scalars().all()
1056
+
1057
+ # Materialize relationships if requested - this ensures they're fully loaded
1058
+ if include_relationships:
1059
+ for instance in instances:
1060
+ # For each relationship, access it once to ensure it's loaded
1061
+ for rel_name in cls._get_auto_relationship_fields():
1062
+ try:
1063
+ # This will force loading the relationship while session is active
1064
+ _ = getattr(instance, rel_name)
1065
+ except Exception:
1066
+ # Skip if the relationship can't be loaded
1067
+ pass
1068
+
1069
+ return instances
1070
+ else:
1071
+ # Return only the first result
1072
+ instance = result.scalars().first()
1073
+
1074
+ # Materialize relationships if requested and instance exists
1075
+ if include_relationships and instance:
1076
+ # For each relationship, access it once to ensure it's loaded
1077
+ for rel_name in cls._get_auto_relationship_fields():
1078
+ try:
1079
+ # This will force loading the relationship while session is active
1080
+ _ = getattr(instance, rel_name)
1081
+ except Exception:
1082
+ # Skip if the relationship can't be loaded
1083
+ pass
1084
+
1085
+ return instance
1086
+
1087
+ @classmethod
1088
+ async def get_or_create(cls: Type[T], search_criteria: Dict[str, Any], defaults: Optional[Dict[str, Any]] = None) -> Tuple[T, bool]:
1089
+ """
1090
+ Get a record by criteria or create it if it doesn't exist.
1091
+
1092
+ Args:
1093
+ search_criteria: Dictionary of search criteria
1094
+ defaults: Default values to use when creating a new record
1095
+
1096
+ Returns:
1097
+ Tuple of (model instance, created flag)
1098
+ """
1099
+ # Try to find the record
1100
+ record = await cls.select(criteria=search_criteria, all=False, first=True)
1101
+
1102
+ if record:
1103
+ return record, False
1104
+
1105
+ # Record not found, create it
1106
+ data = {**search_criteria}
1107
+ if defaults:
1108
+ data.update(defaults)
1109
+
1110
+ new_record = await cls.insert(data)
1111
+ return new_record, True
1112
+
571
1113
  # Register an event listener to update 'updated_at' on instance modifications.
572
1114
  @event.listens_for(Session, "before_flush")
573
1115
  def _update_updated_at(session, flush_context, instances):
@@ -591,7 +1133,9 @@ async def init_db(migrate: bool = True, model_classes: List[Type[SQLModel]] = No
591
1133
 
592
1134
  # Import auto_relationships functions with conditional import to avoid circular imports
593
1135
  try:
594
- from .auto_relationships import _auto_relationships_enabled, process_auto_relationships
1136
+ from .auto_relationships import (_auto_relationships_enabled, process_auto_relationships,
1137
+ enable_auto_relationships, register_model_class,
1138
+ process_all_models_for_relationships)
595
1139
  has_auto_relationships = True
596
1140
  except ImportError:
597
1141
  has_auto_relationships = False
@@ -603,15 +1147,6 @@ async def init_db(migrate: bool = True, model_classes: List[Type[SQLModel]] = No
603
1147
  except ImportError:
604
1148
  has_migrations = False
605
1149
 
606
- # Process auto-relationships before creating tables if enabled
607
- if has_auto_relationships and _auto_relationships_enabled:
608
- process_auto_relationships()
609
-
610
- # Create async engine and all tables
611
- engine = db_config.get_engine()
612
- if not engine:
613
- raise ValueError("Database configuration is missing. Use db_config.configure_* methods first.")
614
-
615
1150
  # Get all SQLModel subclasses (our models) if not provided
616
1151
  if model_classes is None:
617
1152
  model_classes = []
@@ -622,6 +1157,18 @@ async def init_db(migrate: bool = True, model_classes: List[Type[SQLModel]] = No
622
1157
  if isinstance(cls, type) and issubclass(cls, SQLModel) and cls != SQLModel and cls != EasyModel:
623
1158
  model_classes.append(cls)
624
1159
 
1160
+ # Enable auto-relationships and register all models
1161
+ if has_auto_relationships:
1162
+ # Enable auto-relationships with patch_metaclass=False
1163
+ enable_auto_relationships(patch_metaclass=False)
1164
+
1165
+ # Register all model classes
1166
+ for model_cls in model_classes:
1167
+ register_model_class(model_cls)
1168
+
1169
+ # Process relationships for all registered models
1170
+ process_all_models_for_relationships()
1171
+
625
1172
  migration_results = {}
626
1173
 
627
1174
  # Check for migrations first if the feature is available and enabled
@@ -630,7 +1177,11 @@ async def init_db(migrate: bool = True, model_classes: List[Type[SQLModel]] = No
630
1177
  if migration_results:
631
1178
  logging.info(f"Applied migrations: {len(migration_results)} models affected")
632
1179
 
633
- # Create tables that don't exist yet - using safe index creation
1180
+ # Create async engine and all tables
1181
+ engine = db_config.get_engine()
1182
+ if not engine:
1183
+ raise ValueError("Database configuration is missing. Use db_config.configure_* methods first.")
1184
+
634
1185
  async with engine.begin() as conn:
635
1186
  if has_migrations:
636
1187
  # Use our safe table creation methods if migrations are available