async-easy-model 0.2.3__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
async_easy_model/model.py CHANGED
@@ -10,6 +10,7 @@ from datetime import datetime, timezone as tz
10
10
  import inspect
11
11
  import json
12
12
  import logging
13
+ import re
13
14
 
14
15
  T = TypeVar("T", bound="EasyModel")
15
16
 
@@ -212,8 +213,9 @@ class EasyModel(SQLModel):
212
213
  @classmethod
213
214
  async def all(
214
215
  cls: Type[T],
215
- include_relationships: bool = True,
216
- order_by: Optional[Union[str, List[str]]] = None
216
+ include_relationships: bool = True,
217
+ order_by: Optional[Union[str, List[str]]] = None,
218
+ max_depth: int = 2
217
219
  ) -> List[T]:
218
220
  """
219
221
  Retrieve all records of this model.
@@ -222,29 +224,20 @@ class EasyModel(SQLModel):
222
224
  include_relationships: If True, eagerly load all relationships
223
225
  order_by: Field(s) to order by. Can be a string or list of strings.
224
226
  Prefix with '-' for descending order (e.g. '-created_at')
227
+ max_depth: Maximum depth for loading nested relationships
225
228
 
226
229
  Returns:
227
230
  A list of all model instances
228
231
  """
229
- async with cls.get_session() as session:
230
- statement = select(cls)
231
-
232
- # Apply ordering
233
- statement = cls._apply_order_by(statement, order_by)
234
-
235
- if include_relationships:
236
- # Get all relationship attributes, including auto-detected ones
237
- for rel_name in cls._get_auto_relationship_fields():
238
- statement = statement.options(selectinload(getattr(cls, rel_name)))
239
-
240
- result = await session.execute(statement)
241
- return result.scalars().all()
232
+ return await cls.select({}, all=True, include_relationships=include_relationships,
233
+ order_by=order_by, max_depth=max_depth)
242
234
 
243
235
  @classmethod
244
236
  async def first(
245
237
  cls: Type[T],
246
- include_relationships: bool = True,
247
- order_by: Optional[Union[str, List[str]]] = None
238
+ include_relationships: bool = True,
239
+ order_by: Optional[Union[str, List[str]]] = None,
240
+ max_depth: int = 2
248
241
  ) -> Optional[T]:
249
242
  """
250
243
  Retrieve the first record of this model.
@@ -253,56 +246,37 @@ class EasyModel(SQLModel):
253
246
  include_relationships: If True, eagerly load all relationships
254
247
  order_by: Field(s) to order by. Can be a string or list of strings.
255
248
  Prefix with '-' for descending order (e.g. '-created_at')
249
+ max_depth: Maximum depth for loading nested relationships
256
250
 
257
251
  Returns:
258
252
  The first model instance or None if no records exist
259
253
  """
260
- async with cls.get_session() as session:
261
- statement = select(cls)
262
-
263
- # Apply ordering
264
- statement = cls._apply_order_by(statement, order_by)
265
-
266
- if include_relationships:
267
- # Get all relationship attributes, including auto-detected ones
268
- for rel_name in cls._get_auto_relationship_fields():
269
- statement = statement.options(selectinload(getattr(cls, rel_name)))
270
-
271
- result = await session.execute(statement)
272
- return result.scalars().first()
254
+ return await cls.select({}, first=True, include_relationships=include_relationships,
255
+ order_by=order_by, max_depth=max_depth)
273
256
 
274
257
  @classmethod
275
258
  async def limit(
276
259
  cls: Type[T],
277
260
  count: int,
278
- include_relationships: bool = True,
279
- order_by: Optional[Union[str, List[str]]] = None
261
+ include_relationships: bool = True,
262
+ order_by: Optional[Union[str, List[str]]] = None,
263
+ max_depth: int = 2
280
264
  ) -> List[T]:
281
265
  """
282
- Retrieve a limited number of records of this model.
266
+ Retrieve a limited number of records.
283
267
 
284
268
  Args:
285
- count: Maximum number of records to retrieve
269
+ count: Maximum number of records to return
286
270
  include_relationships: If True, eagerly load all relationships
287
271
  order_by: Field(s) to order by. Can be a string or list of strings.
288
272
  Prefix with '-' for descending order (e.g. '-created_at')
273
+ max_depth: Maximum depth for loading nested relationships
289
274
 
290
275
  Returns:
291
- A list of model instances up to the specified count
276
+ A list of model instances
292
277
  """
293
- async with cls.get_session() as session:
294
- statement = select(cls).limit(count)
295
-
296
- # Apply ordering
297
- statement = cls._apply_order_by(statement, order_by)
298
-
299
- if include_relationships:
300
- # Get all relationship attributes, including auto-detected ones
301
- for rel_name in cls._get_auto_relationship_fields():
302
- statement = statement.options(selectinload(getattr(cls, rel_name)))
303
-
304
- result = await session.execute(statement)
305
- return result.scalars().all()
278
+ return await cls.select({}, all=True, include_relationships=include_relationships,
279
+ order_by=order_by, limit=count, max_depth=max_depth)
306
280
 
307
281
  @classmethod
308
282
  async def get_by_id(cls: Type[T], id: int, include_relationships: bool = True) -> Optional[T]:
@@ -327,6 +301,20 @@ class EasyModel(SQLModel):
327
301
  else:
328
302
  return await session.get(cls, id)
329
303
 
304
+ @classmethod
305
+ def _get_unique_fields(cls) -> List[str]:
306
+ """
307
+ Get all fields with unique=True constraint
308
+
309
+ Returns:
310
+ List of field names that have unique constraints
311
+ """
312
+ unique_fields = []
313
+ for name, field in cls.__fields__.items():
314
+ if name != 'id' and hasattr(field, "field_info") and field.field_info.extra.get('unique', False):
315
+ unique_fields.append(name)
316
+ return unique_fields
317
+
330
318
  @classmethod
331
319
  async def get_by_attribute(
332
320
  cls: Type[T],
@@ -391,192 +379,76 @@ class EasyModel(SQLModel):
391
379
  return result.scalars().first()
392
380
 
393
381
  @classmethod
394
- async def insert(cls: Type[T], data: Union[Dict[str, Any], List[Dict[str, Any]]], include_relationships: bool = True) -> Union[T, List[T]]:
382
+ async def insert(cls: Type[T], data: Union[Dict[str, Any], List[Dict[str, Any]]], include_relationships: bool = True, max_depth: int = 2) -> Union[T, List[T]]:
395
383
  """
396
384
  Insert one or more records.
397
385
 
398
386
  Args:
399
387
  data: Dictionary of field values or a list of dictionaries for multiple records
400
388
  include_relationships: If True, return the instance(s) with relationships loaded
389
+ max_depth: Maximum depth for loading nested relationships
401
390
 
402
391
  Returns:
403
392
  The created model instance(s)
404
393
  """
405
- # Handle list of records
394
+ if not data:
395
+ return None
396
+
397
+ # Handle single dict or list of dicts
406
398
  if isinstance(data, list):
407
- objects = []
408
- async with cls.get_session() as session:
409
- for item in data:
410
- try:
411
- # Process relationships first
412
- processed_item = await cls._process_relationships_for_insert(session, item)
413
-
414
- # Extract special _related_* fields for post-processing
415
- related_fields = {}
416
- for key in list(processed_item.keys()):
417
- if key.startswith("_related_"):
418
- rel_name = key[9:] # Remove "_related_" prefix
419
- related_fields[rel_name] = processed_item.pop(key)
420
-
421
- # Check if a record with unique constraints already exists
422
- unique_fields = cls._get_unique_fields()
423
- existing_obj = None
424
-
425
- if unique_fields:
426
- unique_criteria = {field: processed_item[field]
427
- for field in unique_fields
428
- if field in processed_item}
429
-
430
- if unique_criteria:
431
- # Try to find existing record with these unique values
432
- statement = select(cls)
433
- for field, value in unique_criteria.items():
434
- statement = statement.where(getattr(cls, field) == value)
435
- result = await session.execute(statement)
436
- existing_obj = result.scalars().first()
437
-
438
- if existing_obj:
439
- # Update existing object with new values
440
- for key, value in processed_item.items():
441
- if key != 'id': # Don't update ID
442
- setattr(existing_obj, key, value)
443
- obj = existing_obj
444
- else:
445
- # Create new object
446
- obj = cls(**processed_item)
447
- session.add(obj)
448
-
449
- # Flush to get the ID for this object
450
- await session.flush()
451
-
452
- # Now handle any one-to-many relationships
453
- for rel_name, related_objects in related_fields.items():
454
- # Check if the relationship attribute exists in the class (not the instance)
455
- if hasattr(cls, rel_name):
456
- # Get the relationship attribute from the class
457
- rel_attr = getattr(cls, rel_name)
458
-
459
- # Check if it's a SQLAlchemy relationship
460
- if hasattr(rel_attr, 'property') and hasattr(rel_attr.property, 'back_populates'):
461
- back_attr = rel_attr.property.back_populates
462
-
463
- # For each related object, set the back reference to this object
464
- for related_obj in related_objects:
465
- setattr(related_obj, back_attr, obj)
466
- # Make sure the related object is in the session
467
- session.add(related_obj)
468
-
469
- objects.append(obj)
470
- except Exception as e:
471
- logging.error(f"Error inserting record: {e}")
472
- await session.rollback()
473
- raise
399
+ results = []
400
+ for item in data:
401
+ result = await cls.insert(item, include_relationships, max_depth)
402
+ results.append(result)
403
+ return results
404
+
405
+ # Store many-to-many relationship data for later processing
406
+ many_to_many_data = {}
407
+ many_to_many_rels = cls._get_many_to_many_relationships()
408
+
409
+ # Extract many-to-many data before processing other relationships
410
+ for rel_name in many_to_many_rels:
411
+ if rel_name in data:
412
+ many_to_many_data[rel_name] = data[rel_name]
413
+
414
+ # Process relationships to convert nested objects to foreign keys
415
+ async with cls.get_session() as session:
416
+ try:
417
+ processed_data = await cls._process_relationships_for_insert(session, data)
474
418
 
475
- try:
476
- await session.flush()
477
- await session.commit()
478
-
479
- # Refresh with relationships if requested
480
- if include_relationships:
481
- for obj in objects:
482
- await session.refresh(obj)
483
- except Exception as e:
484
- logging.error(f"Error committing transaction: {e}")
485
- await session.rollback()
486
- raise
487
-
488
- return objects
489
- else:
490
- # Single record case
491
- async with cls.get_session() as session:
492
- try:
493
- # Process relationships first
494
- processed_data = await cls._process_relationships_for_insert(session, data)
495
-
496
- # Extract special _related_* fields for post-processing
497
- related_fields = {}
498
- for key in list(processed_data.keys()):
499
- if key.startswith("_related_"):
500
- rel_name = key[9:] # Remove "_related_" prefix
501
- related_fields[rel_name] = processed_data.pop(key)
502
-
503
- # Check if a record with unique constraints already exists
504
- unique_fields = cls._get_unique_fields()
505
- existing_obj = None
506
-
507
- if unique_fields:
508
- unique_criteria = {field: processed_data[field]
509
- for field in unique_fields
510
- if field in processed_data}
511
-
512
- if unique_criteria:
513
- # Try to find existing record with these unique values
514
- statement = select(cls)
515
- for field, value in unique_criteria.items():
516
- statement = statement.where(getattr(cls, field) == value)
517
- result = await session.execute(statement)
518
- existing_obj = result.scalars().first()
519
-
520
- if existing_obj:
521
- # Update existing object with new values
522
- for key, value in processed_data.items():
523
- if key != 'id': # Don't update ID
524
- setattr(existing_obj, key, value)
525
- obj = existing_obj
526
- else:
527
- # Create new object
528
- obj = cls(**processed_data)
529
- session.add(obj)
530
-
531
- await session.flush() # Flush to get the ID
532
-
533
- # Now handle any one-to-many relationships
534
- for rel_name, related_objects in related_fields.items():
535
- # Check if the relationship attribute exists in the class (not the instance)
536
- if hasattr(cls, rel_name):
537
- # Get the relationship attribute from the class
538
- rel_attr = getattr(cls, rel_name)
539
-
540
- # Check if it's a SQLAlchemy relationship
541
- if hasattr(rel_attr, 'property') and hasattr(rel_attr.property, 'back_populates'):
542
- back_attr = rel_attr.property.back_populates
543
-
544
- # For each related object, set the back reference to this object
545
- for related_obj in related_objects:
546
- setattr(related_obj, back_attr, obj)
547
- # Make sure the related object is in the session
548
- session.add(related_obj)
549
-
550
- await session.commit()
419
+ # Create the model instance
420
+ obj = cls(**processed_data)
421
+ session.add(obj)
422
+
423
+ # Flush to get the object ID
424
+ await session.flush()
425
+
426
+ # Now process many-to-many relationships if any
427
+ for rel_name, rel_data in many_to_many_data.items():
428
+ if isinstance(rel_data, list):
429
+ await cls._process_many_to_many_relationship(
430
+ session, obj, rel_name, rel_data
431
+ )
432
+
433
+ # Commit the transaction
434
+ await session.commit()
435
+
436
+ if include_relationships:
437
+ # Reload with relationships
438
+ return await cls._load_relationships_recursively(session, obj, max_depth)
439
+ else:
440
+ return obj
551
441
 
552
- if include_relationships:
553
- # Refresh with relationships
554
- statement = select(cls).where(cls.id == obj.id)
555
- for rel_name in cls._get_auto_relationship_fields():
556
- statement = statement.options(selectinload(getattr(cls, rel_name)))
557
- result = await session.execute(statement)
558
- return result.scalars().first()
559
- else:
560
- await session.refresh(obj)
561
- return obj
562
- except Exception as e:
563
- logging.error(f"Error inserting record: {e}")
564
- await session.rollback()
565
- raise
566
-
567
- @classmethod
568
- def _get_unique_fields(cls) -> List[str]:
569
- """
570
- Get all fields with unique=True constraint
571
-
572
- Returns:
573
- List of field names that have unique constraints
574
- """
575
- unique_fields = []
576
- for name, field in cls.__fields__.items():
577
- if name != 'id' and hasattr(field, 'field_info') and field.field_info.extra.get('unique', False):
578
- unique_fields.append(name)
579
- return unique_fields
442
+ except Exception as e:
443
+ await session.rollback()
444
+ logging.error(f"Error inserting {cls.__name__}: {e}")
445
+ if "UNIQUE constraint failed" in str(e):
446
+ field_match = re.search(r"UNIQUE constraint failed: \w+\.(\w+)", str(e))
447
+ if field_match:
448
+ field_name = field_match.group(1)
449
+ value = data.get(field_name)
450
+ raise ValueError(f"A record with {field_name}='{value}' already exists")
451
+ raise
580
452
 
581
453
  @classmethod
582
454
  async def _process_relationships_for_insert(cls: Type[T], session: AsyncSession, data: Dict[str, Any]) -> Dict[str, Any]:
@@ -613,97 +485,65 @@ class EasyModel(SQLModel):
613
485
  Returns:
614
486
  Processed data dictionary with nested objects replaced by their foreign key IDs
615
487
  """
616
- import copy
617
- result = copy.deepcopy(data)
488
+ if not data:
489
+ return {}
490
+
491
+ result = dict(data)
618
492
 
619
- # Get all relationship fields for this model
493
+ # Get relationship fields for this model
620
494
  relationship_fields = cls._get_auto_relationship_fields()
621
495
 
622
- # Get foreign key fields
623
- foreign_key_fields = []
624
- for field_name, field_info in cls.__fields__.items():
625
- if field_name.endswith("_id") and hasattr(field_info, "field_info"):
626
- if field_info.field_info.extra.get("foreign_key"):
627
- foreign_key_fields.append(field_name)
496
+ # Get many-to-many relationships separately
497
+ many_to_many_relationships = cls._get_many_to_many_relationships()
628
498
 
629
- # Handle nested relationship objects
630
- for key, value in data.items():
631
- # Skip if None
499
+ # Set of field names already processed as many-to-many relationships
500
+ processed_m2m_fields = set()
501
+
502
+ # Process each relationship field in the input data
503
+ for key in list(result.keys()):
504
+ value = result[key]
505
+
506
+ # Skip empty values
632
507
  if value is None:
633
508
  continue
634
-
635
- # Check if this is a relationship field (either by name or derived from foreign key)
636
- is_rel_field = key in relationship_fields
637
- related_key = f"{key}_id"
638
- is_derived_rel = related_key in foreign_key_fields
639
509
 
640
- # If it's a relationship field or derived from a foreign key
641
- if is_rel_field or is_derived_rel or related_key in cls.__fields__:
642
- # Find the related model class
510
+ # Check if this is a relationship field
511
+ if key in relationship_fields:
512
+ # Get the related model class
643
513
  related_model = None
644
514
 
645
- # Try to get the related model from the attribute
646
- if hasattr(cls, key) and hasattr(getattr(cls, key), 'property'):
647
- # Get from relationship attribute
515
+ if hasattr(cls, key):
648
516
  rel_attr = getattr(cls, key)
649
- related_model = rel_attr.property.mapper.class_
650
- else:
651
- # Try to find it from foreign key definition
652
- fk_definition = None
653
- for field_name, field_info in cls.__fields__.items():
654
- if field_name == related_key and hasattr(field_info, "field_info"):
655
- fk_definition = field_info.field_info.extra.get("foreign_key")
656
- break
657
-
658
- if fk_definition:
659
- # Parse foreign key definition (e.g. "users.id")
660
- target_table, _ = fk_definition.split(".")
661
- # Try to find the target model
662
- from async_easy_model.auto_relationships import get_model_by_table_name, singularize_name
663
- related_model = get_model_by_table_name(target_table)
664
- if not related_model:
665
- # Try with the singular form
666
- singular_table = singularize_name(target_table)
667
- related_model = get_model_by_table_name(singular_table)
668
- else:
669
- # Try to infer from field name (e.g., "user_id" -> Users)
670
- base_name = related_key[:-3] # Remove "_id"
671
- from async_easy_model.auto_relationships import get_model_by_table_name, singularize_name, pluralize_name
672
-
673
- # Try singular and plural forms
674
- related_model = get_model_by_table_name(base_name)
675
- if not related_model:
676
- plural_table = pluralize_name(base_name)
677
- related_model = get_model_by_table_name(plural_table)
678
- if not related_model:
679
- singular_table = singularize_name(base_name)
680
- related_model = get_model_by_table_name(singular_table)
517
+ if hasattr(rel_attr, 'prop') and hasattr(rel_attr.prop, 'mapper'):
518
+ related_model = rel_attr.prop.mapper.class_
681
519
 
682
520
  if not related_model:
683
- logging.warning(f"Could not find related model for {key} in {cls.__name__}")
521
+ logging.warning(f"Could not determine related model for {key}, skipping")
522
+ continue
523
+
524
+ # Check if this is a many-to-many relationship field
525
+ if key in many_to_many_relationships:
526
+ # Store this separately - we'll handle it after the main object is created
527
+ processed_m2m_fields.add(key)
684
528
  continue
685
529
 
686
- # Check if the value is a list (one-to-many) or dict (one-to-one)
530
+ # Handle different relationship types based on data type
687
531
  if isinstance(value, list):
688
532
  # Handle one-to-many relationship (list of dictionaries)
689
- related_objects = []
690
-
533
+ related_ids = []
691
534
  for item in value:
692
- if not isinstance(item, dict):
693
- logging.warning(f"Skipping non-dict item in list for {key}")
694
- continue
695
-
696
- related_obj = await cls._process_single_relationship_item(
697
- session, related_model, item
698
- )
699
- if related_obj:
700
- related_objects.append(related_obj)
535
+ if isinstance(item, dict):
536
+ related_obj = await cls._process_single_relationship_item(
537
+ session, related_model, item
538
+ )
539
+ if related_obj:
540
+ related_ids.append(related_obj.id)
701
541
 
702
- # For one-to-many, we need to keep a list of related objects to be attached later
703
- # We'll store them in a special field that will be removed before creating the model
704
- result[f"_related_{key}"] = related_objects
542
+ # Update result with list of foreign key IDs
543
+ foreign_key_list_name = f"{key}_ids"
544
+ result[foreign_key_list_name] = related_ids
705
545
 
706
- # Remove the original field from the result
546
+ # Remove the relationship list from the result
707
547
  if key in result:
708
548
  del result[key]
709
549
 
@@ -722,8 +562,14 @@ class EasyModel(SQLModel):
722
562
  if key in result:
723
563
  del result[key]
724
564
 
565
+ # Remove any processed many-to-many fields from the result
566
+ # since we'll handle them separately after the object is created
567
+ for key in processed_m2m_fields:
568
+ if key in result:
569
+ del result[key]
570
+
725
571
  return result
726
-
572
+
727
573
  @classmethod
728
574
  async def _process_single_relationship_item(cls, session: AsyncSession, related_model: Type, item_data: Dict[str, Any]) -> Optional[Any]:
729
575
  """
@@ -742,10 +588,10 @@ class EasyModel(SQLModel):
742
588
  """
743
589
  # Look for unique fields in the related model to use for searching
744
590
  unique_fields = []
745
- for field_name, field_info in related_model.__fields__.items():
746
- if (hasattr(field_info, "field_info") and
747
- field_info.field_info.extra.get('unique', False)):
748
- unique_fields.append(field_name)
591
+ for name, field in related_model.__fields__.items():
592
+ if (hasattr(field, "field_info") and
593
+ field.field_info.extra.get('unique', False)):
594
+ unique_fields.append(name)
749
595
 
750
596
  # Create a search dictionary using unique fields
751
597
  search_dict = {}
@@ -855,6 +701,17 @@ class EasyModel(SQLModel):
855
701
  Returns:
856
702
  The updated model instance
857
703
  """
704
+ # Store many-to-many relationship data for later processing
705
+ many_to_many_data = {}
706
+ many_to_many_rels = cls._get_many_to_many_relationships()
707
+
708
+ # Extract many-to-many data before processing
709
+ for rel_name in many_to_many_rels:
710
+ if rel_name in data:
711
+ many_to_many_data[rel_name] = data[rel_name]
712
+ # Remove from original data
713
+ del data[rel_name]
714
+
858
715
  async with cls.get_session() as session:
859
716
  try:
860
717
  # Find the record(s) to update
@@ -895,6 +752,74 @@ class EasyModel(SQLModel):
895
752
  for key, value in data.items():
896
753
  setattr(record, key, value)
897
754
 
755
+ # Process many-to-many relationships if any
756
+ for rel_name, rel_data in many_to_many_data.items():
757
+ if isinstance(rel_data, list):
758
+ # First, get all existing links for this relation
759
+ junction_model, target_model = many_to_many_rels[rel_name]
760
+
761
+ from async_easy_model.auto_relationships import get_foreign_keys_from_model
762
+ foreign_keys = get_foreign_keys_from_model(junction_model)
763
+
764
+ # Find the foreign key fields for this model and the target model
765
+ this_model_fk = None
766
+ target_model_fk = None
767
+
768
+ for fk_field, fk_target in foreign_keys.items():
769
+ target_table = fk_target.split('.')[0]
770
+ if target_table == cls.__tablename__:
771
+ this_model_fk = fk_field
772
+ elif target_table == target_model.__tablename__:
773
+ target_model_fk = fk_field
774
+
775
+ if not this_model_fk or not target_model_fk:
776
+ logging.warning(f"Could not find foreign key fields for {rel_name} relationship")
777
+ continue
778
+
779
+ # Get all existing junctions for this record
780
+ junction_stmt = select(junction_model).where(
781
+ getattr(junction_model, this_model_fk) == record.id
782
+ )
783
+ junction_result = await session.execute(junction_stmt)
784
+ existing_junctions = junction_result.scalars().all()
785
+
786
+ # Get the target IDs from the existing junctions
787
+ existing_target_ids = [getattr(junction, target_model_fk) for junction in existing_junctions]
788
+
789
+ # Track processed target IDs
790
+ processed_target_ids = set()
791
+
792
+ # Process each item in rel_data
793
+ for item_data in rel_data:
794
+ target_obj = await cls._process_single_relationship_item(
795
+ session, target_model, item_data
796
+ )
797
+
798
+ if not target_obj:
799
+ logging.warning(f"Failed to process {target_model.__name__} item for {rel_name}")
800
+ continue
801
+
802
+ processed_target_ids.add(target_obj.id)
803
+
804
+ # Check if this link already exists
805
+ if target_obj.id not in existing_target_ids:
806
+ # Create new junction
807
+ junction_data = {
808
+ this_model_fk: record.id,
809
+ target_model_fk: target_obj.id
810
+ }
811
+ junction_obj = junction_model(**junction_data)
812
+ session.add(junction_obj)
813
+ logging.info(f"Created junction between {cls.__name__} {record.id} and {target_model.__name__} {target_obj.id}")
814
+
815
+ # Delete junctions for target IDs that weren't in the updated data
816
+ junctions_to_delete = [j for j in existing_junctions
817
+ if getattr(j, target_model_fk) not in processed_target_ids]
818
+
819
+ for junction in junctions_to_delete:
820
+ await session.delete(junction)
821
+ logging.info(f"Deleted junction between {cls.__name__} {record.id} and {target_model.__name__} {getattr(junction, target_model_fk)}")
822
+
898
823
  await session.flush()
899
824
  await session.commit()
900
825
 
@@ -908,9 +833,10 @@ class EasyModel(SQLModel):
908
833
  else:
909
834
  await session.refresh(record)
910
835
  return record
836
+
911
837
  except Exception as e:
912
- logging.error(f"Error updating record: {e}")
913
838
  await session.rollback()
839
+ logging.error(f"Error updating {cls.__name__}: {e}")
914
840
  raise
915
841
 
916
842
  @classmethod
@@ -926,7 +852,7 @@ class EasyModel(SQLModel):
926
852
  """
927
853
  async with cls.get_session() as session:
928
854
  try:
929
- # Find the record(s) to delete
855
+ # Find records to delete
930
856
  statement = select(cls)
931
857
  for field, value in criteria.items():
932
858
  if isinstance(value, str) and '*' in value:
@@ -943,43 +869,50 @@ class EasyModel(SQLModel):
943
869
  logging.warning(f"No records found with criteria: {criteria}")
944
870
  return 0
945
871
 
946
- # Get a list of related tables that might need to be cleared first
947
- # This helps with foreign key constraints
948
- relationship_fields = cls._get_auto_relationship_fields()
949
- to_many_relationships = []
872
+ # Check if there are many-to-many relationships that need cleanup
873
+ many_to_many_rels = cls._get_many_to_many_relationships()
950
874
 
951
- # Find to-many relationships that need to be handled first
952
- for rel_name in relationship_fields:
953
- rel_attr = getattr(cls, rel_name, None)
954
- if rel_attr and hasattr(rel_attr, 'property'):
955
- # Check if this is a to-many relationship (one-to-many or many-to-many)
956
- if hasattr(rel_attr.property, 'uselist') and rel_attr.property.uselist:
957
- to_many_relationships.append(rel_name)
958
-
959
- # For each record, delete related records first (cascade delete)
875
+ # Delete each record and its related many-to-many junction records
876
+ count = 0
960
877
  for record in records:
961
- # First load all related collections
962
- if to_many_relationships:
963
- await session.refresh(record, attribute_names=to_many_relationships)
964
-
965
- # Delete related records in collections
966
- for rel_name in to_many_relationships:
967
- related_collection = getattr(record, rel_name, [])
968
- if related_collection:
969
- for related_item in related_collection:
970
- await session.delete(related_item)
878
+ # Clean up many-to-many junctions first
879
+ for rel_name, (junction_model, _) in many_to_many_rels.items():
880
+ # Get foreign keys from the junction model
881
+ from async_easy_model.auto_relationships import get_foreign_keys_from_model
882
+ foreign_keys = get_foreign_keys_from_model(junction_model)
883
+
884
+ # Find which foreign key refers to this model
885
+ this_model_fk = None
886
+ for fk_field, fk_target in foreign_keys.items():
887
+ target_table = fk_target.split('.')[0]
888
+ if target_table == cls.__tablename__:
889
+ this_model_fk = fk_field
890
+ break
891
+
892
+ if not this_model_fk:
893
+ continue
894
+
895
+ # Delete junction records for this record
896
+ delete_stmt = select(junction_model).where(
897
+ getattr(junction_model, this_model_fk) == record.id
898
+ )
899
+ junction_result = await session.execute(delete_stmt)
900
+ junctions = junction_result.scalars().all()
901
+
902
+ for junction in junctions:
903
+ await session.delete(junction)
904
+ logging.info(f"Deleted junction record for {cls.__name__} id={record.id}")
971
905
 
972
906
  # Now delete the main record
973
907
  await session.delete(record)
908
+ count += 1
974
909
 
975
- # Commit the changes
976
- await session.flush()
977
910
  await session.commit()
911
+ return count
978
912
 
979
- return len(records)
980
913
  except Exception as e:
981
- logging.error(f"Error deleting records: {e}")
982
914
  await session.rollback()
915
+ logging.error(f"Error deleting {cls.__name__}: {e}")
983
916
  raise
984
917
 
985
918
  def to_dict(self, include_relationships: bool = True, max_depth: int = 4) -> Dict[str, Any]:
@@ -1058,21 +991,23 @@ class EasyModel(SQLModel):
1058
991
  first: bool = False,
1059
992
  include_relationships: bool = True,
1060
993
  order_by: Optional[Union[str, List[str]]] = None,
994
+ max_depth: int = 2,
1061
995
  limit: Optional[int] = None
1062
996
  ) -> Union[Optional[T], List[T]]:
1063
997
  """
1064
- Retrieve record(s) by matching attribute values.
998
+ Select records based on criteria.
1065
999
 
1066
1000
  Args:
1067
- criteria: Dictionary of search criteria
1068
- all: If True, return all matching records, otherwise return only the first one
1001
+ criteria: Dictionary of field values to filter by
1002
+ all: If True, return all matching records. If False, return only the first match.
1069
1003
  first: If True, return only the first record (equivalent to all=False)
1070
1004
  include_relationships: If True, eagerly load all relationships
1071
1005
  order_by: Field(s) to order by. Can be a string or list of strings.
1072
1006
  Prefix with '-' for descending order (e.g. '-created_at')
1007
+ max_depth: Maximum depth for loading nested relationships (when include_relationships=True)
1073
1008
  limit: Maximum number of records to retrieve (if all=True)
1074
1009
  If limit > 1, all is automatically set to True
1075
-
1010
+
1076
1011
  Returns:
1077
1012
  A single model instance, a list of instances, or None if not found
1078
1013
  """
@@ -1083,73 +1018,98 @@ class EasyModel(SQLModel):
1083
1018
  # If limit is specified and > 1, set all to True
1084
1019
  if limit is not None and limit > 1:
1085
1020
  all = True
1021
+
1086
1022
  # If first is specified, set all to False (first takes precedence)
1087
1023
  if first:
1088
1024
  all = False
1089
-
1025
+
1090
1026
  async with cls.get_session() as session:
1091
1027
  # Build the query
1092
1028
  statement = select(cls)
1093
1029
 
1094
- # Apply criteria
1095
- for field, value in criteria.items():
1030
+ # Apply criteria filters
1031
+ for key, value in criteria.items():
1096
1032
  if isinstance(value, str) and '*' in value:
1097
1033
  # Handle LIKE queries (convert '*' wildcard to '%')
1098
1034
  like_value = value.replace('*', '%')
1099
- statement = statement.where(getattr(cls, field).like(like_value))
1035
+ statement = statement.where(getattr(cls, key).like(like_value))
1100
1036
  else:
1101
1037
  # Regular equality check
1102
- statement = statement.where(getattr(cls, field) == value)
1038
+ statement = statement.where(getattr(cls, key) == value)
1103
1039
 
1104
1040
  # Apply ordering
1105
1041
  if order_by:
1106
- statement = cls._apply_order_by(statement, order_by)
1042
+ order_clauses = []
1043
+ if isinstance(order_by, str):
1044
+ order_by = [order_by]
1045
+
1046
+ for field_name in order_by:
1047
+ if field_name.startswith("-"):
1048
+ # Descending order
1049
+ field_name = field_name[1:] # Remove the "-" prefix
1050
+ # Handle relationship field ordering with dot notation
1051
+ if "." in field_name:
1052
+ rel_name, attr_name = field_name.split(".", 1)
1053
+ if hasattr(cls, rel_name):
1054
+ rel_model = getattr(cls, rel_name)
1055
+ if hasattr(rel_model, "property"):
1056
+ target_model = rel_model.property.entity.class_
1057
+ if hasattr(target_model, attr_name):
1058
+ order_clauses.append(getattr(target_model, attr_name).desc())
1059
+ else:
1060
+ order_clauses.append(getattr(cls, field_name).desc())
1061
+ else:
1062
+ # Ascending order
1063
+ # Handle relationship field ordering with dot notation
1064
+ if "." in field_name:
1065
+ rel_name, attr_name = field_name.split(".", 1)
1066
+ if hasattr(cls, rel_name):
1067
+ rel_model = getattr(cls, rel_name)
1068
+ if hasattr(rel_model, "property"):
1069
+ target_model = rel_model.property.entity.class_
1070
+ if hasattr(target_model, attr_name):
1071
+ order_clauses.append(getattr(target_model, attr_name).asc())
1072
+ else:
1073
+ order_clauses.append(getattr(cls, field_name).asc())
1074
+
1075
+ if order_clauses:
1076
+ statement = statement.order_by(*order_clauses)
1107
1077
 
1108
1078
  # Apply limit
1109
1079
  if limit:
1110
1080
  statement = statement.limit(limit)
1111
1081
 
1112
- # Include relationships if requested
1082
+ # Load relationships if requested
1113
1083
  if include_relationships:
1114
1084
  for rel_name in cls._get_auto_relationship_fields():
1115
1085
  statement = statement.options(selectinload(getattr(cls, rel_name)))
1116
1086
 
1117
- # Execute the query
1118
1087
  result = await session.execute(statement)
1119
1088
 
1120
1089
  if all:
1121
- # Return all results
1122
- instances = result.scalars().all()
1090
+ objects = result.scalars().all()
1123
1091
 
1124
- # Materialize relationships if requested - this ensures they're fully loaded
1125
- if include_relationships:
1126
- for instance in instances:
1127
- # For each relationship, access it once to ensure it's loaded
1128
- for rel_name in cls._get_auto_relationship_fields():
1129
- try:
1130
- # This will force loading the relationship while session is active
1131
- _ = getattr(instance, rel_name)
1132
- except Exception:
1133
- # Skip if the relationship can't be loaded
1134
- pass
1092
+ # Load nested relationships if requested
1093
+ if include_relationships and objects and max_depth > 1:
1094
+ loaded_objects = []
1095
+ for obj in objects:
1096
+ loaded_obj = await cls._load_relationships_recursively(
1097
+ session, obj, max_depth
1098
+ )
1099
+ loaded_objects.append(loaded_obj)
1100
+ return loaded_objects
1135
1101
 
1136
- return instances
1102
+ return objects
1137
1103
  else:
1138
- # Return only the first result
1139
- instance = result.scalars().first()
1104
+ obj = result.scalars().first()
1140
1105
 
1141
- # Materialize relationships if requested and instance exists
1142
- if include_relationships and instance:
1143
- # For each relationship, access it once to ensure it's loaded
1144
- for rel_name in cls._get_auto_relationship_fields():
1145
- try:
1146
- # This will force loading the relationship while session is active
1147
- _ = getattr(instance, rel_name)
1148
- except Exception:
1149
- # Skip if the relationship can't be loaded
1150
- pass
1151
-
1152
- return instance
1106
+ # Load nested relationships if requested
1107
+ if include_relationships and obj and max_depth > 1:
1108
+ obj = await cls._load_relationships_recursively(
1109
+ session, obj, max_depth
1110
+ )
1111
+
1112
+ return obj
1153
1113
 
1154
1114
  @classmethod
1155
1115
  async def get_or_create(cls: Type[T], search_criteria: Dict[str, Any], defaults: Optional[Dict[str, Any]] = None) -> Tuple[T, bool]:
@@ -1208,6 +1168,195 @@ class EasyModel(SQLModel):
1208
1168
  # Use the enhanced insert method to handle all relationships
1209
1169
  return await cls.insert(insert_data, include_relationships=True)
1210
1170
 
1171
+ @classmethod
1172
+ def _get_many_to_many_relationships(cls) -> Dict[str, Tuple[Type['EasyModel'], Type['EasyModel']]]:
1173
+ """
1174
+ Get all many-to-many relationships for this model.
1175
+
1176
+ Returns:
1177
+ Dictionary mapping relationship field names to tuples of (junction_model, target_model)
1178
+ """
1179
+ from async_easy_model.auto_relationships import get_model_by_table_name
1180
+
1181
+ many_to_many_relationships = {}
1182
+
1183
+ # Check if this is a class attribute rather than an instance attribute
1184
+ relationship_fields = cls._get_auto_relationship_fields()
1185
+
1186
+ for rel_name in relationship_fields:
1187
+ if not hasattr(cls, rel_name):
1188
+ continue
1189
+
1190
+ rel_attr = getattr(cls, rel_name)
1191
+
1192
+ # Check if this is a many-to-many relationship by looking for secondary table
1193
+ if hasattr(rel_attr, 'prop') and hasattr(rel_attr.prop, 'secondary'):
1194
+ secondary = rel_attr.prop.secondary
1195
+ if isinstance(secondary, str): # For string table names (our implementation)
1196
+ junction_model = get_model_by_table_name(secondary)
1197
+ if junction_model:
1198
+ target_model = rel_attr.prop.mapper.class_
1199
+ many_to_many_relationships[rel_name] = (junction_model, target_model)
1200
+
1201
+ return many_to_many_relationships
1202
+
1203
+ @classmethod
1204
+ async def _process_many_to_many_relationship(
1205
+ cls,
1206
+ session: AsyncSession,
1207
+ parent_obj: 'EasyModel',
1208
+ rel_name: str,
1209
+ items: List[Dict[str, Any]]
1210
+ ) -> None:
1211
+ """
1212
+ Process a many-to-many relationship for an object.
1213
+
1214
+ Args:
1215
+ session: The database session
1216
+ parent_obj: The parent object (e.g., Book)
1217
+ rel_name: The name of the relationship (e.g., 'tags')
1218
+ items: List of data dictionaries for the related items
1219
+
1220
+ Returns:
1221
+ None
1222
+ """
1223
+ # Get information about this many-to-many relationship
1224
+ many_to_many_rels = cls._get_many_to_many_relationships()
1225
+ if rel_name not in many_to_many_rels:
1226
+ logging.warning(f"Relationship {rel_name} is not a many-to-many relationship")
1227
+ return
1228
+
1229
+ junction_model, target_model = many_to_many_rels[rel_name]
1230
+
1231
+ # Get the foreign key fields from the junction model that reference this model and the target model
1232
+ from async_easy_model.auto_relationships import get_foreign_keys_from_model
1233
+ foreign_keys = get_foreign_keys_from_model(junction_model)
1234
+
1235
+ # Find the foreign key fields for this model and the target model
1236
+ this_model_fk = None
1237
+ target_model_fk = None
1238
+
1239
+ for fk_field, fk_target in foreign_keys.items():
1240
+ target_table = fk_target.split('.')[0]
1241
+ if target_table == cls.__tablename__:
1242
+ this_model_fk = fk_field
1243
+ elif target_table == target_model.__tablename__:
1244
+ target_model_fk = fk_field
1245
+
1246
+ if not this_model_fk or not target_model_fk:
1247
+ logging.warning(f"Could not find foreign key fields for {rel_name} relationship")
1248
+ return
1249
+
1250
+ # Process each related item
1251
+ for item_data in items:
1252
+ # First, create or find the target model instance
1253
+ target_obj = await cls._process_single_relationship_item(
1254
+ session, target_model, item_data
1255
+ )
1256
+
1257
+ if not target_obj:
1258
+ logging.warning(f"Failed to process {target_model.__name__} item for {rel_name}")
1259
+ continue
1260
+
1261
+ # Now create a junction record linking the parent to the target
1262
+ # Check if this link already exists
1263
+ junction_stmt = select(junction_model).where(
1264
+ getattr(junction_model, this_model_fk) == parent_obj.id,
1265
+ getattr(junction_model, target_model_fk) == target_obj.id
1266
+ )
1267
+ junction_result = await session.execute(junction_stmt)
1268
+ existing_junction = junction_result.scalars().first()
1269
+
1270
+ if not existing_junction:
1271
+ # Create new junction
1272
+ junction_data = {
1273
+ this_model_fk: parent_obj.id,
1274
+ target_model_fk: target_obj.id
1275
+ }
1276
+ junction_obj = junction_model(**junction_data)
1277
+ session.add(junction_obj)
1278
+ logging.info(f"Created junction between {cls.__name__} {parent_obj.id} and {target_model.__name__} {target_obj.id}")
1279
+
1280
+ @classmethod
1281
+ async def _load_relationships_recursively(cls, session, obj, max_depth=2, current_depth=0, visited_ids=None):
1282
+ """
1283
+ Recursively load all relationships for an object and its related objects.
1284
+
1285
+ Args:
1286
+ session: SQLAlchemy session
1287
+ obj: The object to load relationships for
1288
+ max_depth: Maximum depth to recurse to prevent infinite loops
1289
+ current_depth: Current recursion depth (internal use)
1290
+ visited_ids: Set of already visited object IDs to prevent cycles
1291
+
1292
+ Returns:
1293
+ The object with all relationships loaded
1294
+ """
1295
+ if visited_ids is None:
1296
+ visited_ids = set()
1297
+
1298
+ # Use object ID and class for tracking instead of the object itself (which isn't hashable)
1299
+ obj_key = (obj.__class__.__name__, obj.id)
1300
+
1301
+ # Stop if we've reached max depth or already visited this object
1302
+ if current_depth >= max_depth or obj_key in visited_ids:
1303
+ return obj
1304
+
1305
+ # Mark as visited to prevent cycles
1306
+ visited_ids.add(obj_key)
1307
+
1308
+ # Load all relationship fields for this object
1309
+ obj_class = obj.__class__
1310
+ relationship_fields = obj_class._get_auto_relationship_fields()
1311
+
1312
+ # For each relationship, load it and recurse
1313
+ for rel_name in relationship_fields:
1314
+ try:
1315
+ # Fetch the objects using selectinload
1316
+ stmt = select(obj_class).where(obj_class.id == obj.id)
1317
+ stmt = stmt.options(selectinload(getattr(obj_class, rel_name)))
1318
+ result = await session.execute(stmt)
1319
+ refreshed_obj = result.scalars().first()
1320
+
1321
+ # Get the loaded relationship
1322
+ related_objs = getattr(refreshed_obj, rel_name, None)
1323
+
1324
+ # Update the object's relationship
1325
+ setattr(obj, rel_name, related_objs)
1326
+
1327
+ # Skip if no related objects
1328
+ if related_objs is None:
1329
+ continue
1330
+
1331
+ # Recurse for related objects
1332
+ if isinstance(related_objs, list):
1333
+ for related_obj in related_objs:
1334
+ if hasattr(related_obj.__class__, '_get_auto_relationship_fields'):
1335
+ # Only recurse if the object has an ID (is persistent)
1336
+ if hasattr(related_obj, 'id') and related_obj.id is not None:
1337
+ await cls._load_relationships_recursively(
1338
+ session,
1339
+ related_obj,
1340
+ max_depth,
1341
+ current_depth + 1,
1342
+ visited_ids
1343
+ )
1344
+ else:
1345
+ if hasattr(related_objs.__class__, '_get_auto_relationship_fields'):
1346
+ # Only recurse if the object has an ID (is persistent)
1347
+ if hasattr(related_objs, 'id') and related_objs.id is not None:
1348
+ await cls._load_relationships_recursively(
1349
+ session,
1350
+ related_objs,
1351
+ max_depth,
1352
+ current_depth + 1,
1353
+ visited_ids
1354
+ )
1355
+ except Exception as e:
1356
+ logging.warning(f"Error loading relationship {rel_name}: {e}")
1357
+
1358
+ return obj
1359
+
1211
1360
  # Register an event listener to update 'updated_at' on instance modifications.
1212
1361
  @event.listens_for(Session, "before_flush")
1213
1362
  def _update_updated_at(session, flush_context, instances):