async-easy-model 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
async_easy_model/model.py CHANGED
@@ -10,6 +10,7 @@ from datetime import datetime, timezone as tz
10
10
  import inspect
11
11
  import json
12
12
  import logging
13
+ import re
13
14
 
14
15
  T = TypeVar("T", bound="EasyModel")
15
16
 
@@ -212,8 +213,9 @@ class EasyModel(SQLModel):
212
213
  @classmethod
213
214
  async def all(
214
215
  cls: Type[T],
215
- include_relationships: bool = True,
216
- order_by: Optional[Union[str, List[str]]] = None
216
+ include_relationships: bool = True,
217
+ order_by: Optional[Union[str, List[str]]] = None,
218
+ max_depth: int = 2
217
219
  ) -> List[T]:
218
220
  """
219
221
  Retrieve all records of this model.
@@ -222,29 +224,20 @@ class EasyModel(SQLModel):
222
224
  include_relationships: If True, eagerly load all relationships
223
225
  order_by: Field(s) to order by. Can be a string or list of strings.
224
226
  Prefix with '-' for descending order (e.g. '-created_at')
227
+ max_depth: Maximum depth for loading nested relationships
225
228
 
226
229
  Returns:
227
230
  A list of all model instances
228
231
  """
229
- async with cls.get_session() as session:
230
- statement = select(cls)
231
-
232
- # Apply ordering
233
- statement = cls._apply_order_by(statement, order_by)
234
-
235
- if include_relationships:
236
- # Get all relationship attributes, including auto-detected ones
237
- for rel_name in cls._get_auto_relationship_fields():
238
- statement = statement.options(selectinload(getattr(cls, rel_name)))
239
-
240
- result = await session.execute(statement)
241
- return result.scalars().all()
232
+ return await cls.select({}, all=True, include_relationships=include_relationships,
233
+ order_by=order_by, max_depth=max_depth)
242
234
 
243
235
  @classmethod
244
236
  async def first(
245
237
  cls: Type[T],
246
- include_relationships: bool = True,
247
- order_by: Optional[Union[str, List[str]]] = None
238
+ include_relationships: bool = True,
239
+ order_by: Optional[Union[str, List[str]]] = None,
240
+ max_depth: int = 2
248
241
  ) -> Optional[T]:
249
242
  """
250
243
  Retrieve the first record of this model.
@@ -253,56 +246,37 @@ class EasyModel(SQLModel):
253
246
  include_relationships: If True, eagerly load all relationships
254
247
  order_by: Field(s) to order by. Can be a string or list of strings.
255
248
  Prefix with '-' for descending order (e.g. '-created_at')
249
+ max_depth: Maximum depth for loading nested relationships
256
250
 
257
251
  Returns:
258
252
  The first model instance or None if no records exist
259
253
  """
260
- async with cls.get_session() as session:
261
- statement = select(cls)
262
-
263
- # Apply ordering
264
- statement = cls._apply_order_by(statement, order_by)
265
-
266
- if include_relationships:
267
- # Get all relationship attributes, including auto-detected ones
268
- for rel_name in cls._get_auto_relationship_fields():
269
- statement = statement.options(selectinload(getattr(cls, rel_name)))
270
-
271
- result = await session.execute(statement)
272
- return result.scalars().first()
254
+ return await cls.select({}, first=True, include_relationships=include_relationships,
255
+ order_by=order_by, max_depth=max_depth)
273
256
 
274
257
  @classmethod
275
258
  async def limit(
276
259
  cls: Type[T],
277
260
  count: int,
278
- include_relationships: bool = True,
279
- order_by: Optional[Union[str, List[str]]] = None
261
+ include_relationships: bool = True,
262
+ order_by: Optional[Union[str, List[str]]] = None,
263
+ max_depth: int = 2
280
264
  ) -> List[T]:
281
265
  """
282
- Retrieve a limited number of records of this model.
266
+ Retrieve a limited number of records.
283
267
 
284
268
  Args:
285
- count: Maximum number of records to retrieve
269
+ count: Maximum number of records to return
286
270
  include_relationships: If True, eagerly load all relationships
287
271
  order_by: Field(s) to order by. Can be a string or list of strings.
288
272
  Prefix with '-' for descending order (e.g. '-created_at')
273
+ max_depth: Maximum depth for loading nested relationships
289
274
 
290
275
  Returns:
291
- A list of model instances up to the specified count
276
+ A list of model instances
292
277
  """
293
- async with cls.get_session() as session:
294
- statement = select(cls).limit(count)
295
-
296
- # Apply ordering
297
- statement = cls._apply_order_by(statement, order_by)
298
-
299
- if include_relationships:
300
- # Get all relationship attributes, including auto-detected ones
301
- for rel_name in cls._get_auto_relationship_fields():
302
- statement = statement.options(selectinload(getattr(cls, rel_name)))
303
-
304
- result = await session.execute(statement)
305
- return result.scalars().all()
278
+ return await cls.select({}, all=True, include_relationships=include_relationships,
279
+ order_by=order_by, limit=count, max_depth=max_depth)
306
280
 
307
281
  @classmethod
308
282
  async def get_by_id(cls: Type[T], id: int, include_relationships: bool = True) -> Optional[T]:
@@ -327,6 +301,20 @@ class EasyModel(SQLModel):
327
301
  else:
328
302
  return await session.get(cls, id)
329
303
 
304
+ @classmethod
305
+ def _get_unique_fields(cls) -> List[str]:
306
+ """
307
+ Get all fields with unique=True constraint
308
+
309
+ Returns:
310
+ List of field names that have unique constraints
311
+ """
312
+ unique_fields = []
313
+ for name, field in cls.__fields__.items():
314
+ if name != 'id' and hasattr(field, "field_info") and field.field_info.extra.get('unique', False):
315
+ unique_fields.append(name)
316
+ return unique_fields
317
+
330
318
  @classmethod
331
319
  async def get_by_attribute(
332
320
  cls: Type[T],
@@ -391,193 +379,76 @@ class EasyModel(SQLModel):
391
379
  return result.scalars().first()
392
380
 
393
381
  @classmethod
394
- async def insert(cls: Type[T], data: Union[Dict[str, Any], List[Dict[str, Any]]], include_relationships: bool = True) -> Union[T, List[T]]:
382
+ async def insert(cls: Type[T], data: Union[Dict[str, Any], List[Dict[str, Any]]], include_relationships: bool = True, max_depth: int = 2) -> Union[T, List[T]]:
395
383
  """
396
384
  Insert one or more records.
397
385
 
398
386
  Args:
399
387
  data: Dictionary of field values or a list of dictionaries for multiple records
400
388
  include_relationships: If True, return the instance(s) with relationships loaded
389
+ max_depth: Maximum depth for loading nested relationships
401
390
 
402
391
  Returns:
403
392
  The created model instance(s)
404
393
  """
405
- # Handle list of records
394
+ if not data:
395
+ return None
396
+
397
+ # Handle single dict or list of dicts
406
398
  if isinstance(data, list):
407
- objects = []
408
- async with cls.get_session() as session:
409
- for item in data:
410
- try:
411
- # Process relationships first
412
- processed_item = await cls._process_relationships_for_insert(session, item)
413
-
414
- # Check if a record with unique constraints already exists
415
- unique_fields = cls._get_unique_fields()
416
- existing_obj = None
417
-
418
- if unique_fields:
419
- unique_criteria = {field: processed_item[field]
420
- for field in unique_fields
421
- if field in processed_item}
422
-
423
- if unique_criteria:
424
- # Try to find existing record with these unique values
425
- statement = select(cls)
426
- for field, value in unique_criteria.items():
427
- statement = statement.where(getattr(cls, field) == value)
428
- result = await session.execute(statement)
429
- existing_obj = result.scalars().first()
430
-
431
- if existing_obj:
432
- # Update existing object with new values
433
- for key, value in processed_item.items():
434
- if key != 'id': # Don't update ID
435
- setattr(existing_obj, key, value)
436
- objects.append(existing_obj)
437
- else:
438
- # Create new object
439
- obj = cls(**processed_item)
440
- session.add(obj)
441
- objects.append(obj)
442
- except Exception as e:
443
- logging.error(f"Error inserting record: {e}")
444
- await session.rollback()
445
- raise
446
-
447
- try:
448
- await session.flush()
449
- await session.commit()
450
-
451
- # Refresh with relationships if requested
452
- if include_relationships:
453
- for obj in objects:
454
- await session.refresh(obj)
455
- except Exception as e:
456
- logging.error(f"Error committing transaction: {e}")
457
- await session.rollback()
458
- raise
459
-
460
- return objects
461
- else:
462
- # Single record case
463
- async with cls.get_session() as session:
464
- try:
465
- # Process relationships first
466
- processed_data = await cls._process_relationships_for_insert(session, data)
467
-
468
- # Check if a record with unique constraints already exists
469
- unique_fields = cls._get_unique_fields()
470
- existing_obj = None
471
-
472
- if unique_fields:
473
- unique_criteria = {field: processed_data[field]
474
- for field in unique_fields
475
- if field in processed_data}
476
-
477
- if unique_criteria:
478
- # Try to find existing record with these unique values
479
- statement = select(cls)
480
- for field, value in unique_criteria.items():
481
- statement = statement.where(getattr(cls, field) == value)
482
- result = await session.execute(statement)
483
- existing_obj = result.scalars().first()
484
-
485
- if existing_obj:
486
- # Update existing object with new values
487
- for key, value in processed_data.items():
488
- if key != 'id': # Don't update ID
489
- setattr(existing_obj, key, value)
490
- obj = existing_obj
491
- else:
492
- # Create new object
493
- obj = cls(**processed_data)
494
- session.add(obj)
495
-
496
- await session.flush() # Flush to get the ID
497
- await session.commit()
498
-
499
- if include_relationships:
500
- # Refresh with relationships
501
- statement = select(cls).where(cls.id == obj.id)
502
- for rel_name in cls._get_auto_relationship_fields():
503
- statement = statement.options(selectinload(getattr(cls, rel_name)))
504
- result = await session.execute(statement)
505
- return result.scalars().first()
506
- else:
507
- await session.refresh(obj)
508
- return obj
509
- except Exception as e:
510
- logging.error(f"Error inserting record: {e}")
511
- await session.rollback()
512
- raise
513
-
514
- @classmethod
515
- async def insert_with_related(
516
- cls: Type[T],
517
- data: Dict[str, Any],
518
- related_data: Dict[str, List[Dict[str, Any]]] = None
519
- ) -> T:
520
- """
521
- Create a model instance with related objects in a single transaction.
399
+ results = []
400
+ for item in data:
401
+ result = await cls.insert(item, include_relationships, max_depth)
402
+ results.append(result)
403
+ return results
522
404
 
523
- Args:
524
- data: Dictionary of field values for the main model
525
- related_data: Dictionary mapping relationship names to lists of data dictionaries
526
- for creating related objects
527
-
528
- Returns:
529
- The created model instance with relationships loaded
530
- """
531
- if related_data is None:
532
- related_data = {}
533
-
405
+ # Store many-to-many relationship data for later processing
406
+ many_to_many_data = {}
407
+ many_to_many_rels = cls._get_many_to_many_relationships()
408
+
409
+ # Extract many-to-many data before processing other relationships
410
+ for rel_name in many_to_many_rels:
411
+ if rel_name in data:
412
+ many_to_many_data[rel_name] = data[rel_name]
413
+
414
+ # Process relationships to convert nested objects to foreign keys
534
415
  async with cls.get_session() as session:
535
- # Create the main object
536
- obj = cls(**data)
537
- session.add(obj)
538
- await session.flush() # Flush to get the ID
539
-
540
- # Create related objects
541
- for rel_name, items_data in related_data.items():
542
- if not hasattr(cls, rel_name):
543
- continue
544
-
545
- rel_attr = getattr(cls, rel_name)
546
- if not hasattr(rel_attr, "property"):
547
- continue
548
-
549
- # Get the related model class and the back reference attribute
550
- related_model = rel_attr.property.mapper.class_
551
- back_populates = getattr(rel_attr.property, "back_populates", None)
416
+ try:
417
+ processed_data = await cls._process_relationships_for_insert(session, data)
552
418
 
553
- # Create each related object
554
- for item_data in items_data:
555
- # Set the back reference if it exists
556
- if back_populates:
557
- item_data[back_populates] = obj
558
-
559
- related_obj = related_model(**item_data)
560
- session.add(related_obj)
561
-
562
- await session.commit()
563
-
564
- # Refresh with relationships
565
- await session.refresh(obj, attribute_names=list(related_data.keys()))
566
- return obj
567
-
568
- @classmethod
569
- def _get_unique_fields(cls) -> List[str]:
570
- """
571
- Get all fields with unique=True constraint
572
-
573
- Returns:
574
- List of field names that have unique constraints
575
- """
576
- unique_fields = []
577
- for name, field in cls.__fields__.items():
578
- if name != 'id' and hasattr(field, 'field_info') and field.field_info.extra.get('unique', False):
579
- unique_fields.append(name)
580
- return unique_fields
419
+ # Create the model instance
420
+ obj = cls(**processed_data)
421
+ session.add(obj)
422
+
423
+ # Flush to get the object ID
424
+ await session.flush()
425
+
426
+ # Now process many-to-many relationships if any
427
+ for rel_name, rel_data in many_to_many_data.items():
428
+ if isinstance(rel_data, list):
429
+ await cls._process_many_to_many_relationship(
430
+ session, obj, rel_name, rel_data
431
+ )
432
+
433
+ # Commit the transaction
434
+ await session.commit()
435
+
436
+ if include_relationships:
437
+ # Reload with relationships
438
+ return await cls._load_relationships_recursively(session, obj, max_depth)
439
+ else:
440
+ return obj
441
+
442
+ except Exception as e:
443
+ await session.rollback()
444
+ logging.error(f"Error inserting {cls.__name__}: {e}")
445
+ if "UNIQUE constraint failed" in str(e):
446
+ field_match = re.search(r"UNIQUE constraint failed: \w+\.(\w+)", str(e))
447
+ if field_match:
448
+ field_name = field_match.group(1)
449
+ value = data.get(field_name)
450
+ raise ValueError(f"A record with {field_name}='{value}' already exists")
451
+ raise
581
452
 
582
453
  @classmethod
583
454
  async def _process_relationships_for_insert(cls: Type[T], session: AsyncSession, data: Dict[str, Any]) -> Dict[str, Any]:
@@ -591,6 +462,15 @@ class EasyModel(SQLModel):
591
462
  "quantity": 2
592
463
  })
593
464
 
465
+ It also handles lists of related objects for one-to-many relationships:
466
+ publisher = await Publisher.insert({
467
+ "name": "Example Publisher",
468
+ "authors": [
469
+ {"name": "Author 1", "email": "author1@example.com"},
470
+ {"name": "Author 2", "email": "author2@example.com"}
471
+ ]
472
+ })
473
+
594
474
  For each nested object:
595
475
  1. Find the target model class
596
476
  2. Check if an object with the same unique fields already exists
@@ -605,175 +485,208 @@ class EasyModel(SQLModel):
605
485
  Returns:
606
486
  Processed data dictionary with nested objects replaced by their foreign key IDs
607
487
  """
608
- import copy
609
- result = copy.deepcopy(data)
488
+ if not data:
489
+ return {}
490
+
491
+ result = dict(data)
610
492
 
611
- # Get all relationship fields for this model
493
+ # Get relationship fields for this model
612
494
  relationship_fields = cls._get_auto_relationship_fields()
613
495
 
614
- # Get foreign key fields
615
- foreign_key_fields = []
616
- for field_name, field_info in cls.__fields__.items():
617
- if field_name.endswith("_id") and hasattr(field_info, "field_info"):
618
- if field_info.field_info.extra.get("foreign_key"):
619
- foreign_key_fields.append(field_name)
620
-
621
- # Handle nested relationship objects
622
- for key, value in data.items():
623
- # Skip if the value is not a dictionary
624
- if not isinstance(value, dict):
496
+ # Get many-to-many relationships separately
497
+ many_to_many_relationships = cls._get_many_to_many_relationships()
498
+
499
+ # Set of field names already processed as many-to-many relationships
500
+ processed_m2m_fields = set()
501
+
502
+ # Process each relationship field in the input data
503
+ for key in list(result.keys()):
504
+ value = result[key]
505
+
506
+ # Skip empty values
507
+ if value is None:
625
508
  continue
626
-
627
- # Check if this is a relationship field (either by name or derived from foreign key)
628
- is_rel_field = key in relationship_fields
629
- related_key = f"{key}_id"
630
- is_derived_rel = related_key in foreign_key_fields
631
509
 
632
- # If it's a relationship field or derived from a foreign key
633
- if is_rel_field or is_derived_rel or related_key in cls.__fields__:
634
- # Find the related model class
510
+ # Check if this is a relationship field
511
+ if key in relationship_fields:
512
+ # Get the related model class
635
513
  related_model = None
636
514
 
637
- # Try to get the related model from the attribute
638
- if hasattr(cls, key) and hasattr(getattr(cls, key), 'property'):
639
- # Get from relationship attribute
515
+ if hasattr(cls, key):
640
516
  rel_attr = getattr(cls, key)
641
- related_model = rel_attr.property.mapper.class_
642
- else:
643
- # Try to find it from foreign key definition
644
- fk_definition = None
645
- for field_name, field_info in cls.__fields__.items():
646
- if field_name == related_key and hasattr(field_info, "field_info"):
647
- fk_definition = field_info.field_info.extra.get("foreign_key")
648
- break
649
-
650
- if fk_definition:
651
- # Parse foreign key definition (e.g. "users.id")
652
- target_table, _ = fk_definition.split(".")
653
- # Try to find the target model
654
- from async_easy_model.auto_relationships import get_model_by_table_name, singularize_name
655
- related_model = get_model_by_table_name(target_table)
656
- if not related_model:
657
- # Try with the singular form
658
- singular_table = singularize_name(target_table)
659
- related_model = get_model_by_table_name(singular_table)
660
- else:
661
- # Try to infer from field name (e.g., "user_id" -> Users)
662
- base_name = related_key[:-3] # Remove "_id"
663
- from async_easy_model.auto_relationships import get_model_by_table_name, singularize_name, pluralize_name
664
-
665
- # Try singular and plural forms
666
- related_model = get_model_by_table_name(base_name)
667
- if not related_model:
668
- plural_table = pluralize_name(base_name)
669
- related_model = get_model_by_table_name(plural_table)
670
- if not related_model:
671
- singular_table = singularize_name(base_name)
672
- related_model = get_model_by_table_name(singular_table)
517
+ if hasattr(rel_attr, 'prop') and hasattr(rel_attr.prop, 'mapper'):
518
+ related_model = rel_attr.prop.mapper.class_
673
519
 
674
520
  if not related_model:
675
- logging.warning(f"Could not find related model for {key} in {cls.__name__}")
521
+ logging.warning(f"Could not determine related model for {key}, skipping")
676
522
  continue
677
523
 
678
- # Look for unique fields in the related model to use for searching
679
- unique_fields = []
680
- for field_name, field_info in related_model.__fields__.items():
681
- if (hasattr(field_info, "field_info") and
682
- field_info.field_info.extra.get('unique', False)):
683
- unique_fields.append(field_name)
684
-
685
- # Create a search dictionary using unique fields
686
- search_dict = {}
687
- for field in unique_fields:
688
- if field in value and value[field] is not None:
689
- search_dict[field] = value[field]
690
-
691
- # If no unique fields found but ID is provided, use it
692
- if not search_dict and 'id' in value and value['id']:
693
- search_dict = {'id': value['id']}
524
+ # Check if this is a many-to-many relationship field
525
+ if key in many_to_many_relationships:
526
+ # Store this separately - we'll handle it after the main object is created
527
+ processed_m2m_fields.add(key)
528
+ continue
694
529
 
695
- # Special case for products without uniqueness constraints
696
- if not search_dict and related_model.__tablename__ == 'products' and 'name' in value:
697
- search_dict = {'name': value['name']}
530
+ # Handle different relationship types based on data type
531
+ if isinstance(value, list):
532
+ # Handle one-to-many relationship (list of dictionaries)
533
+ related_ids = []
534
+ for item in value:
535
+ if isinstance(item, dict):
536
+ related_obj = await cls._process_single_relationship_item(
537
+ session, related_model, item
538
+ )
539
+ if related_obj:
540
+ related_ids.append(related_obj.id)
541
+
542
+ # Update result with list of foreign key IDs
543
+ foreign_key_list_name = f"{key}_ids"
544
+ result[foreign_key_list_name] = related_ids
545
+
546
+ # Remove the relationship list from the result
547
+ if key in result:
548
+ del result[key]
698
549
 
699
- # Try to find an existing record
700
- related_obj = None
701
- if search_dict:
702
- logging.info(f"Searching for existing {related_model.__name__} with {search_dict}")
550
+ elif isinstance(value, dict):
551
+ # Handle one-to-one relationship (single dictionary)
552
+ related_obj = await cls._process_single_relationship_item(
553
+ session, related_model, value
554
+ )
703
555
 
704
- try:
705
- # Create a more appropriate search query based on unique fields
706
- existing_stmt = select(related_model)
707
- for field, field_value in search_dict.items():
708
- existing_stmt = existing_stmt.where(getattr(related_model, field) == field_value)
709
-
710
- existing_result = await session.execute(existing_stmt)
711
- related_obj = existing_result.scalars().first()
556
+ if related_obj:
557
+ # Update the result with the foreign key ID
558
+ foreign_key_name = f"{key}_id"
559
+ result[foreign_key_name] = related_obj.id
712
560
 
713
- if related_obj:
714
- logging.info(f"Found existing {related_model.__name__} with ID: {related_obj.id}")
715
- except Exception as e:
716
- logging.error(f"Error finding existing record: {e}")
561
+ # Remove the relationship dictionary from the result
562
+ if key in result:
563
+ del result[key]
564
+
565
+ # Remove any processed many-to-many fields from the result
566
+ # since we'll handle them separately after the object is created
567
+ for key in processed_m2m_fields:
568
+ if key in result:
569
+ del result[key]
570
+
571
+ return result
572
+
573
+ @classmethod
574
+ async def _process_single_relationship_item(cls, session: AsyncSession, related_model: Type, item_data: Dict[str, Any]) -> Optional[Any]:
575
+ """
576
+ Process a single relationship item (dictionary).
577
+
578
+ This helper method is used by _process_relationships_for_insert to handle
579
+ both singular relationship objects and items within lists of relationships.
580
+
581
+ Args:
582
+ session: The database session to use
583
+ related_model: The related model class
584
+ item_data: Dictionary with field values for the related object
585
+
586
+ Returns:
587
+ The created or found related object, or None if processing failed
588
+ """
589
+ # Look for unique fields in the related model to use for searching
590
+ unique_fields = []
591
+ for name, field in related_model.__fields__.items():
592
+ if (hasattr(field, "field_info") and
593
+ field.field_info.extra.get('unique', False)):
594
+ unique_fields.append(name)
595
+
596
+ # Create a search dictionary using unique fields
597
+ search_dict = {}
598
+ for field in unique_fields:
599
+ if field in item_data and item_data[field] is not None:
600
+ search_dict[field] = item_data[field]
601
+
602
+ # If no unique fields found but ID is provided, use it
603
+ if not search_dict and 'id' in item_data and item_data['id']:
604
+ search_dict = {'id': item_data['id']}
605
+
606
+ # Special case for products without uniqueness constraints
607
+ if not search_dict and related_model.__tablename__ == 'products' and 'name' in item_data:
608
+ search_dict = {'name': item_data['name']}
609
+
610
+ # Try to find an existing record
611
+ related_obj = None
612
+ if search_dict:
613
+ logging.info(f"Searching for existing {related_model.__name__} with {search_dict}")
614
+
615
+ try:
616
+ # Create a more appropriate search query based on unique fields
617
+ existing_stmt = select(related_model)
618
+ for field, field_value in search_dict.items():
619
+ existing_stmt = existing_stmt.where(getattr(related_model, field) == field_value)
620
+
621
+ existing_result = await session.execute(existing_stmt)
622
+ related_obj = existing_result.scalars().first()
717
623
 
718
624
  if related_obj:
719
- # Update the existing record with any non-unique field values
720
- for attr, attr_val in value.items():
721
- # Skip ID field
722
- if attr == 'id':
723
- continue
724
-
725
- # Skip unique fields with different values to avoid constraint violations
726
- if attr in unique_fields and getattr(related_obj, attr) != attr_val:
727
- continue
728
-
729
- # Update non-unique fields
730
- current_val = getattr(related_obj, attr, None)
731
- if current_val != attr_val:
732
- setattr(related_obj, attr, attr_val)
625
+ logging.info(f"Found existing {related_model.__name__} with ID: {related_obj.id}")
626
+ except Exception as e:
627
+ logging.error(f"Error finding existing record: {e}")
628
+
629
+ if related_obj:
630
+ # Update the existing record with any non-unique field values
631
+ for attr, attr_val in item_data.items():
632
+ # Skip ID field
633
+ if attr == 'id':
634
+ continue
733
635
 
734
- # Add the updated object to the session
735
- session.add(related_obj)
736
- logging.info(f"Reusing existing {related_model.__name__} with ID: {related_obj.id}")
737
- else:
738
- # Create a new record
739
- logging.info(f"Creating new {related_model.__name__} for {key}")
740
- related_obj = related_model(**value)
741
- session.add(related_obj)
742
-
743
- # Ensure the object has an ID by flushing
744
- try:
745
- await session.flush()
746
- except Exception as e:
747
- logging.error(f"Error flushing session for {related_model.__name__}: {e}")
636
+ # Skip unique fields with different values to avoid constraint violations
637
+ if attr in unique_fields and getattr(related_obj, attr) != attr_val:
638
+ continue
748
639
 
749
- # If there was a uniqueness error, try again to find the existing record
750
- if "UNIQUE constraint failed" in str(e):
751
- logging.info(f"UNIQUE constraint failed, trying to find existing record again")
752
-
753
- # Try to find by any field provided in the search_dict
754
- existing_stmt = select(related_model)
755
- for field, field_value in search_dict.items():
756
- existing_stmt = existing_stmt.where(getattr(related_model, field) == field_value)
757
-
758
- # Execute the search query
759
- existing_result = await session.execute(existing_stmt)
760
- related_obj = existing_result.scalars().first()
761
-
762
- if not related_obj:
763
- # We couldn't find an existing record, re-raise the exception
764
- raise
765
-
766
- logging.info(f"Found existing {related_model.__name__} with ID: {related_obj.id} after constraint error")
640
+ # Update non-unique fields
641
+ current_val = getattr(related_obj, attr, None)
642
+ if current_val != attr_val:
643
+ setattr(related_obj, attr, attr_val)
644
+
645
+ # Add the updated object to the session
646
+ session.add(related_obj)
647
+ logging.info(f"Reusing existing {related_model.__name__} with ID: {related_obj.id}")
648
+ else:
649
+ # Create a new record
650
+ logging.info(f"Creating new {related_model.__name__}")
651
+
652
+ # Process nested relationships in this item first
653
+ if hasattr(related_model, '_process_relationships_for_insert'):
654
+ # This is a recursive call to process nested relationships
655
+ processed_item_data = await related_model._process_relationships_for_insert(
656
+ session, item_data
657
+ )
658
+ else:
659
+ processed_item_data = item_data
660
+
661
+ related_obj = related_model(**processed_item_data)
662
+ session.add(related_obj)
663
+
664
+ # Ensure the object has an ID by flushing
665
+ try:
666
+ await session.flush()
667
+ except Exception as e:
668
+ logging.error(f"Error flushing session for {related_model.__name__}: {e}")
669
+
670
+ # If there was a uniqueness error, try again to find the existing record
671
+ if "UNIQUE constraint failed" in str(e):
672
+ logging.info(f"UNIQUE constraint failed, trying to find existing record again")
767
673
 
768
- # Update the result with the foreign key ID
769
- foreign_key_name = f"{key}_id"
770
- result[foreign_key_name] = related_obj.id
674
+ # Try to find by any field provided in the search_dict
675
+ existing_stmt = select(related_model)
676
+ for field, field_value in search_dict.items():
677
+ existing_stmt = existing_stmt.where(getattr(related_model, field) == field_value)
771
678
 
772
- # Remove the relationship dictionary from the result
773
- if key in result:
774
- del result[key]
679
+ # Execute the search query
680
+ existing_result = await session.execute(existing_stmt)
681
+ related_obj = existing_result.scalars().first()
682
+
683
+ if not related_obj:
684
+ # We couldn't find an existing record, re-raise the exception
685
+ raise
686
+
687
+ logging.info(f"Found existing {related_model.__name__} with ID: {related_obj.id} after constraint error")
775
688
 
776
- return result
689
+ return related_obj
777
690
 
778
691
  @classmethod
779
692
  async def update(cls: Type[T], data: Dict[str, Any], criteria: Dict[str, Any], include_relationships: bool = True) -> Optional[T]:
@@ -788,6 +701,17 @@ class EasyModel(SQLModel):
788
701
  Returns:
789
702
  The updated model instance
790
703
  """
704
+ # Store many-to-many relationship data for later processing
705
+ many_to_many_data = {}
706
+ many_to_many_rels = cls._get_many_to_many_relationships()
707
+
708
+ # Extract many-to-many data before processing
709
+ for rel_name in many_to_many_rels:
710
+ if rel_name in data:
711
+ many_to_many_data[rel_name] = data[rel_name]
712
+ # Remove from original data
713
+ del data[rel_name]
714
+
791
715
  async with cls.get_session() as session:
792
716
  try:
793
717
  # Find the record(s) to update
@@ -828,6 +752,74 @@ class EasyModel(SQLModel):
828
752
  for key, value in data.items():
829
753
  setattr(record, key, value)
830
754
 
755
+ # Process many-to-many relationships if any
756
+ for rel_name, rel_data in many_to_many_data.items():
757
+ if isinstance(rel_data, list):
758
+ # First, get all existing links for this relation
759
+ junction_model, target_model = many_to_many_rels[rel_name]
760
+
761
+ from async_easy_model.auto_relationships import get_foreign_keys_from_model
762
+ foreign_keys = get_foreign_keys_from_model(junction_model)
763
+
764
+ # Find the foreign key fields for this model and the target model
765
+ this_model_fk = None
766
+ target_model_fk = None
767
+
768
+ for fk_field, fk_target in foreign_keys.items():
769
+ target_table = fk_target.split('.')[0]
770
+ if target_table == cls.__tablename__:
771
+ this_model_fk = fk_field
772
+ elif target_table == target_model.__tablename__:
773
+ target_model_fk = fk_field
774
+
775
+ if not this_model_fk or not target_model_fk:
776
+ logging.warning(f"Could not find foreign key fields for {rel_name} relationship")
777
+ continue
778
+
779
+ # Get all existing junctions for this record
780
+ junction_stmt = select(junction_model).where(
781
+ getattr(junction_model, this_model_fk) == record.id
782
+ )
783
+ junction_result = await session.execute(junction_stmt)
784
+ existing_junctions = junction_result.scalars().all()
785
+
786
+ # Get the target IDs from the existing junctions
787
+ existing_target_ids = [getattr(junction, target_model_fk) for junction in existing_junctions]
788
+
789
+ # Track processed target IDs
790
+ processed_target_ids = set()
791
+
792
+ # Process each item in rel_data
793
+ for item_data in rel_data:
794
+ target_obj = await cls._process_single_relationship_item(
795
+ session, target_model, item_data
796
+ )
797
+
798
+ if not target_obj:
799
+ logging.warning(f"Failed to process {target_model.__name__} item for {rel_name}")
800
+ continue
801
+
802
+ processed_target_ids.add(target_obj.id)
803
+
804
+ # Check if this link already exists
805
+ if target_obj.id not in existing_target_ids:
806
+ # Create new junction
807
+ junction_data = {
808
+ this_model_fk: record.id,
809
+ target_model_fk: target_obj.id
810
+ }
811
+ junction_obj = junction_model(**junction_data)
812
+ session.add(junction_obj)
813
+ logging.info(f"Created junction between {cls.__name__} {record.id} and {target_model.__name__} {target_obj.id}")
814
+
815
+ # Delete junctions for target IDs that weren't in the updated data
816
+ junctions_to_delete = [j for j in existing_junctions
817
+ if getattr(j, target_model_fk) not in processed_target_ids]
818
+
819
+ for junction in junctions_to_delete:
820
+ await session.delete(junction)
821
+ logging.info(f"Deleted junction between {cls.__name__} {record.id} and {target_model.__name__} {getattr(junction, target_model_fk)}")
822
+
831
823
  await session.flush()
832
824
  await session.commit()
833
825
 
@@ -841,9 +833,10 @@ class EasyModel(SQLModel):
841
833
  else:
842
834
  await session.refresh(record)
843
835
  return record
836
+
844
837
  except Exception as e:
845
- logging.error(f"Error updating record: {e}")
846
838
  await session.rollback()
839
+ logging.error(f"Error updating {cls.__name__}: {e}")
847
840
  raise
848
841
 
849
842
  @classmethod
@@ -859,7 +852,7 @@ class EasyModel(SQLModel):
859
852
  """
860
853
  async with cls.get_session() as session:
861
854
  try:
862
- # Find the record(s) to delete
855
+ # Find records to delete
863
856
  statement = select(cls)
864
857
  for field, value in criteria.items():
865
858
  if isinstance(value, str) and '*' in value:
@@ -876,43 +869,50 @@ class EasyModel(SQLModel):
876
869
  logging.warning(f"No records found with criteria: {criteria}")
877
870
  return 0
878
871
 
879
- # Get a list of related tables that might need to be cleared first
880
- # This helps with foreign key constraints
881
- relationship_fields = cls._get_auto_relationship_fields()
882
- to_many_relationships = []
883
-
884
- # Find to-many relationships that need to be handled first
885
- for rel_name in relationship_fields:
886
- rel_attr = getattr(cls, rel_name, None)
887
- if rel_attr and hasattr(rel_attr, 'property'):
888
- # Check if this is a to-many relationship (one-to-many or many-to-many)
889
- if hasattr(rel_attr.property, 'uselist') and rel_attr.property.uselist:
890
- to_many_relationships.append(rel_name)
872
+ # Check if there are many-to-many relationships that need cleanup
873
+ many_to_many_rels = cls._get_many_to_many_relationships()
891
874
 
892
- # For each record, delete related records first (cascade delete)
875
+ # Delete each record and its related many-to-many junction records
876
+ count = 0
893
877
  for record in records:
894
- # First load all related collections
895
- if to_many_relationships:
896
- await session.refresh(record, attribute_names=to_many_relationships)
897
-
898
- # Delete related records in collections
899
- for rel_name in to_many_relationships:
900
- related_collection = getattr(record, rel_name, [])
901
- if related_collection:
902
- for related_item in related_collection:
903
- await session.delete(related_item)
878
+ # Clean up many-to-many junctions first
879
+ for rel_name, (junction_model, _) in many_to_many_rels.items():
880
+ # Get foreign keys from the junction model
881
+ from async_easy_model.auto_relationships import get_foreign_keys_from_model
882
+ foreign_keys = get_foreign_keys_from_model(junction_model)
883
+
884
+ # Find which foreign key refers to this model
885
+ this_model_fk = None
886
+ for fk_field, fk_target in foreign_keys.items():
887
+ target_table = fk_target.split('.')[0]
888
+ if target_table == cls.__tablename__:
889
+ this_model_fk = fk_field
890
+ break
891
+
892
+ if not this_model_fk:
893
+ continue
894
+
895
+ # Delete junction records for this record
896
+ delete_stmt = select(junction_model).where(
897
+ getattr(junction_model, this_model_fk) == record.id
898
+ )
899
+ junction_result = await session.execute(delete_stmt)
900
+ junctions = junction_result.scalars().all()
901
+
902
+ for junction in junctions:
903
+ await session.delete(junction)
904
+ logging.info(f"Deleted junction record for {cls.__name__} id={record.id}")
904
905
 
905
906
  # Now delete the main record
906
907
  await session.delete(record)
908
+ count += 1
907
909
 
908
- # Commit the changes
909
- await session.flush()
910
910
  await session.commit()
911
+ return count
911
912
 
912
- return len(records)
913
913
  except Exception as e:
914
- logging.error(f"Error deleting records: {e}")
915
914
  await session.rollback()
915
+ logging.error(f"Error deleting {cls.__name__}: {e}")
916
916
  raise
917
917
 
918
918
  def to_dict(self, include_relationships: bool = True, max_depth: int = 4) -> Dict[str, Any]:
@@ -991,21 +991,23 @@ class EasyModel(SQLModel):
991
991
  first: bool = False,
992
992
  include_relationships: bool = True,
993
993
  order_by: Optional[Union[str, List[str]]] = None,
994
+ max_depth: int = 2,
994
995
  limit: Optional[int] = None
995
996
  ) -> Union[Optional[T], List[T]]:
996
997
  """
997
- Retrieve record(s) by matching attribute values.
998
+ Select records based on criteria.
998
999
 
999
1000
  Args:
1000
- criteria: Dictionary of field values to filter by (field=value)
1001
- all: If True, return all matching records, otherwise return only the first one
1001
+ criteria: Dictionary of field values to filter by
1002
+ all: If True, return all matching records. If False, return only the first match.
1002
1003
  first: If True, return only the first record (equivalent to all=False)
1003
1004
  include_relationships: If True, eagerly load all relationships
1004
1005
  order_by: Field(s) to order by. Can be a string or list of strings.
1005
1006
  Prefix with '-' for descending order (e.g. '-created_at')
1007
+ max_depth: Maximum depth for loading nested relationships (when include_relationships=True)
1006
1008
  limit: Maximum number of records to retrieve (if all=True)
1007
1009
  If limit > 1, all is automatically set to True
1008
-
1010
+
1009
1011
  Returns:
1010
1012
  A single model instance, a list of instances, or None if not found
1011
1013
  """
@@ -1016,73 +1018,98 @@ class EasyModel(SQLModel):
1016
1018
  # If limit is specified and > 1, set all to True
1017
1019
  if limit is not None and limit > 1:
1018
1020
  all = True
1021
+
1019
1022
  # If first is specified, set all to False (first takes precedence)
1020
1023
  if first:
1021
1024
  all = False
1022
-
1025
+
1023
1026
  async with cls.get_session() as session:
1024
1027
  # Build the query
1025
1028
  statement = select(cls)
1026
1029
 
1027
- # Apply criteria
1028
- for field, value in criteria.items():
1030
+ # Apply criteria filters
1031
+ for key, value in criteria.items():
1029
1032
  if isinstance(value, str) and '*' in value:
1030
1033
  # Handle LIKE queries (convert '*' wildcard to '%')
1031
1034
  like_value = value.replace('*', '%')
1032
- statement = statement.where(getattr(cls, field).like(like_value))
1035
+ statement = statement.where(getattr(cls, key).like(like_value))
1033
1036
  else:
1034
1037
  # Regular equality check
1035
- statement = statement.where(getattr(cls, field) == value)
1038
+ statement = statement.where(getattr(cls, key) == value)
1036
1039
 
1037
1040
  # Apply ordering
1038
1041
  if order_by:
1039
- statement = cls._apply_order_by(statement, order_by)
1042
+ order_clauses = []
1043
+ if isinstance(order_by, str):
1044
+ order_by = [order_by]
1045
+
1046
+ for field_name in order_by:
1047
+ if field_name.startswith("-"):
1048
+ # Descending order
1049
+ field_name = field_name[1:] # Remove the "-" prefix
1050
+ # Handle relationship field ordering with dot notation
1051
+ if "." in field_name:
1052
+ rel_name, attr_name = field_name.split(".", 1)
1053
+ if hasattr(cls, rel_name):
1054
+ rel_model = getattr(cls, rel_name)
1055
+ if hasattr(rel_model, "property"):
1056
+ target_model = rel_model.property.entity.class_
1057
+ if hasattr(target_model, attr_name):
1058
+ order_clauses.append(getattr(target_model, attr_name).desc())
1059
+ else:
1060
+ order_clauses.append(getattr(cls, field_name).desc())
1061
+ else:
1062
+ # Ascending order
1063
+ # Handle relationship field ordering with dot notation
1064
+ if "." in field_name:
1065
+ rel_name, attr_name = field_name.split(".", 1)
1066
+ if hasattr(cls, rel_name):
1067
+ rel_model = getattr(cls, rel_name)
1068
+ if hasattr(rel_model, "property"):
1069
+ target_model = rel_model.property.entity.class_
1070
+ if hasattr(target_model, attr_name):
1071
+ order_clauses.append(getattr(target_model, attr_name).asc())
1072
+ else:
1073
+ order_clauses.append(getattr(cls, field_name).asc())
1074
+
1075
+ if order_clauses:
1076
+ statement = statement.order_by(*order_clauses)
1040
1077
 
1041
1078
  # Apply limit
1042
1079
  if limit:
1043
1080
  statement = statement.limit(limit)
1044
1081
 
1045
- # Include relationships if requested
1082
+ # Load relationships if requested
1046
1083
  if include_relationships:
1047
1084
  for rel_name in cls._get_auto_relationship_fields():
1048
1085
  statement = statement.options(selectinload(getattr(cls, rel_name)))
1049
1086
 
1050
- # Execute the query
1051
1087
  result = await session.execute(statement)
1052
1088
 
1053
1089
  if all:
1054
- # Return all results
1055
- instances = result.scalars().all()
1090
+ objects = result.scalars().all()
1056
1091
 
1057
- # Materialize relationships if requested - this ensures they're fully loaded
1058
- if include_relationships:
1059
- for instance in instances:
1060
- # For each relationship, access it once to ensure it's loaded
1061
- for rel_name in cls._get_auto_relationship_fields():
1062
- try:
1063
- # This will force loading the relationship while session is active
1064
- _ = getattr(instance, rel_name)
1065
- except Exception:
1066
- # Skip if the relationship can't be loaded
1067
- pass
1092
+ # Load nested relationships if requested
1093
+ if include_relationships and objects and max_depth > 1:
1094
+ loaded_objects = []
1095
+ for obj in objects:
1096
+ loaded_obj = await cls._load_relationships_recursively(
1097
+ session, obj, max_depth
1098
+ )
1099
+ loaded_objects.append(loaded_obj)
1100
+ return loaded_objects
1068
1101
 
1069
- return instances
1102
+ return objects
1070
1103
  else:
1071
- # Return only the first result
1072
- instance = result.scalars().first()
1104
+ obj = result.scalars().first()
1073
1105
 
1074
- # Materialize relationships if requested and instance exists
1075
- if include_relationships and instance:
1076
- # For each relationship, access it once to ensure it's loaded
1077
- for rel_name in cls._get_auto_relationship_fields():
1078
- try:
1079
- # This will force loading the relationship while session is active
1080
- _ = getattr(instance, rel_name)
1081
- except Exception:
1082
- # Skip if the relationship can't be loaded
1083
- pass
1084
-
1085
- return instance
1106
+ # Load nested relationships if requested
1107
+ if include_relationships and obj and max_depth > 1:
1108
+ obj = await cls._load_relationships_recursively(
1109
+ session, obj, max_depth
1110
+ )
1111
+
1112
+ return obj
1086
1113
 
1087
1114
  @classmethod
1088
1115
  async def get_or_create(cls: Type[T], search_criteria: Dict[str, Any], defaults: Optional[Dict[str, Any]] = None) -> Tuple[T, bool]:
@@ -1110,6 +1137,226 @@ class EasyModel(SQLModel):
1110
1137
  new_record = await cls.insert(data)
1111
1138
  return new_record, True
1112
1139
 
1140
+ @classmethod
1141
+ async def insert_with_related(
1142
+ cls: Type[T],
1143
+ data: Dict[str, Any],
1144
+ related_data: Dict[str, List[Dict[str, Any]]] = None
1145
+ ) -> T:
1146
+ """
1147
+ Create a model instance with related objects in a single transaction.
1148
+
1149
+ Args:
1150
+ data: Dictionary of field values for the main model
1151
+ related_data: Dictionary mapping relationship names to lists of data dictionaries
1152
+ for creating related objects
1153
+
1154
+ Returns:
1155
+ The created model instance with relationships loaded
1156
+ """
1157
+ if related_data is None:
1158
+ related_data = {}
1159
+
1160
+ # Create a copy of data for modification
1161
+ insert_data = data.copy()
1162
+
1163
+ # Add relationship fields to the data
1164
+ for rel_name, items_data in related_data.items():
1165
+ if items_data:
1166
+ insert_data[rel_name] = items_data
1167
+
1168
+ # Use the enhanced insert method to handle all relationships
1169
+ return await cls.insert(insert_data, include_relationships=True)
1170
+
1171
+ @classmethod
1172
+ def _get_many_to_many_relationships(cls) -> Dict[str, Tuple[Type['EasyModel'], Type['EasyModel']]]:
1173
+ """
1174
+ Get all many-to-many relationships for this model.
1175
+
1176
+ Returns:
1177
+ Dictionary mapping relationship field names to tuples of (junction_model, target_model)
1178
+ """
1179
+ from async_easy_model.auto_relationships import get_model_by_table_name
1180
+
1181
+ many_to_many_relationships = {}
1182
+
1183
+ # Check if this is a class attribute rather than an instance attribute
1184
+ relationship_fields = cls._get_auto_relationship_fields()
1185
+
1186
+ for rel_name in relationship_fields:
1187
+ if not hasattr(cls, rel_name):
1188
+ continue
1189
+
1190
+ rel_attr = getattr(cls, rel_name)
1191
+
1192
+ # Check if this is a many-to-many relationship by looking for secondary table
1193
+ if hasattr(rel_attr, 'prop') and hasattr(rel_attr.prop, 'secondary'):
1194
+ secondary = rel_attr.prop.secondary
1195
+ if isinstance(secondary, str): # For string table names (our implementation)
1196
+ junction_model = get_model_by_table_name(secondary)
1197
+ if junction_model:
1198
+ target_model = rel_attr.prop.mapper.class_
1199
+ many_to_many_relationships[rel_name] = (junction_model, target_model)
1200
+
1201
+ return many_to_many_relationships
1202
+
1203
+ @classmethod
1204
+ async def _process_many_to_many_relationship(
1205
+ cls,
1206
+ session: AsyncSession,
1207
+ parent_obj: 'EasyModel',
1208
+ rel_name: str,
1209
+ items: List[Dict[str, Any]]
1210
+ ) -> None:
1211
+ """
1212
+ Process a many-to-many relationship for an object.
1213
+
1214
+ Args:
1215
+ session: The database session
1216
+ parent_obj: The parent object (e.g., Book)
1217
+ rel_name: The name of the relationship (e.g., 'tags')
1218
+ items: List of data dictionaries for the related items
1219
+
1220
+ Returns:
1221
+ None
1222
+ """
1223
+ # Get information about this many-to-many relationship
1224
+ many_to_many_rels = cls._get_many_to_many_relationships()
1225
+ if rel_name not in many_to_many_rels:
1226
+ logging.warning(f"Relationship {rel_name} is not a many-to-many relationship")
1227
+ return
1228
+
1229
+ junction_model, target_model = many_to_many_rels[rel_name]
1230
+
1231
+ # Get the foreign key fields from the junction model that reference this model and the target model
1232
+ from async_easy_model.auto_relationships import get_foreign_keys_from_model
1233
+ foreign_keys = get_foreign_keys_from_model(junction_model)
1234
+
1235
+ # Find the foreign key fields for this model and the target model
1236
+ this_model_fk = None
1237
+ target_model_fk = None
1238
+
1239
+ for fk_field, fk_target in foreign_keys.items():
1240
+ target_table = fk_target.split('.')[0]
1241
+ if target_table == cls.__tablename__:
1242
+ this_model_fk = fk_field
1243
+ elif target_table == target_model.__tablename__:
1244
+ target_model_fk = fk_field
1245
+
1246
+ if not this_model_fk or not target_model_fk:
1247
+ logging.warning(f"Could not find foreign key fields for {rel_name} relationship")
1248
+ return
1249
+
1250
+ # Process each related item
1251
+ for item_data in items:
1252
+ # First, create or find the target model instance
1253
+ target_obj = await cls._process_single_relationship_item(
1254
+ session, target_model, item_data
1255
+ )
1256
+
1257
+ if not target_obj:
1258
+ logging.warning(f"Failed to process {target_model.__name__} item for {rel_name}")
1259
+ continue
1260
+
1261
+ # Now create a junction record linking the parent to the target
1262
+ # Check if this link already exists
1263
+ junction_stmt = select(junction_model).where(
1264
+ getattr(junction_model, this_model_fk) == parent_obj.id,
1265
+ getattr(junction_model, target_model_fk) == target_obj.id
1266
+ )
1267
+ junction_result = await session.execute(junction_stmt)
1268
+ existing_junction = junction_result.scalars().first()
1269
+
1270
+ if not existing_junction:
1271
+ # Create new junction
1272
+ junction_data = {
1273
+ this_model_fk: parent_obj.id,
1274
+ target_model_fk: target_obj.id
1275
+ }
1276
+ junction_obj = junction_model(**junction_data)
1277
+ session.add(junction_obj)
1278
+ logging.info(f"Created junction between {cls.__name__} {parent_obj.id} and {target_model.__name__} {target_obj.id}")
1279
+
1280
+ @classmethod
1281
+ async def _load_relationships_recursively(cls, session, obj, max_depth=2, current_depth=0, visited_ids=None):
1282
+ """
1283
+ Recursively load all relationships for an object and its related objects.
1284
+
1285
+ Args:
1286
+ session: SQLAlchemy session
1287
+ obj: The object to load relationships for
1288
+ max_depth: Maximum depth to recurse to prevent infinite loops
1289
+ current_depth: Current recursion depth (internal use)
1290
+ visited_ids: Set of already visited object IDs to prevent cycles
1291
+
1292
+ Returns:
1293
+ The object with all relationships loaded
1294
+ """
1295
+ if visited_ids is None:
1296
+ visited_ids = set()
1297
+
1298
+ # Use object ID and class for tracking instead of the object itself (which isn't hashable)
1299
+ obj_key = (obj.__class__.__name__, obj.id)
1300
+
1301
+ # Stop if we've reached max depth or already visited this object
1302
+ if current_depth >= max_depth or obj_key in visited_ids:
1303
+ return obj
1304
+
1305
+ # Mark as visited to prevent cycles
1306
+ visited_ids.add(obj_key)
1307
+
1308
+ # Load all relationship fields for this object
1309
+ obj_class = obj.__class__
1310
+ relationship_fields = obj_class._get_auto_relationship_fields()
1311
+
1312
+ # For each relationship, load it and recurse
1313
+ for rel_name in relationship_fields:
1314
+ try:
1315
+ # Fetch the objects using selectinload
1316
+ stmt = select(obj_class).where(obj_class.id == obj.id)
1317
+ stmt = stmt.options(selectinload(getattr(obj_class, rel_name)))
1318
+ result = await session.execute(stmt)
1319
+ refreshed_obj = result.scalars().first()
1320
+
1321
+ # Get the loaded relationship
1322
+ related_objs = getattr(refreshed_obj, rel_name, None)
1323
+
1324
+ # Update the object's relationship
1325
+ setattr(obj, rel_name, related_objs)
1326
+
1327
+ # Skip if no related objects
1328
+ if related_objs is None:
1329
+ continue
1330
+
1331
+ # Recurse for related objects
1332
+ if isinstance(related_objs, list):
1333
+ for related_obj in related_objs:
1334
+ if hasattr(related_obj.__class__, '_get_auto_relationship_fields'):
1335
+ # Only recurse if the object has an ID (is persistent)
1336
+ if hasattr(related_obj, 'id') and related_obj.id is not None:
1337
+ await cls._load_relationships_recursively(
1338
+ session,
1339
+ related_obj,
1340
+ max_depth,
1341
+ current_depth + 1,
1342
+ visited_ids
1343
+ )
1344
+ else:
1345
+ if hasattr(related_objs.__class__, '_get_auto_relationship_fields'):
1346
+ # Only recurse if the object has an ID (is persistent)
1347
+ if hasattr(related_objs, 'id') and related_objs.id is not None:
1348
+ await cls._load_relationships_recursively(
1349
+ session,
1350
+ related_objs,
1351
+ max_depth,
1352
+ current_depth + 1,
1353
+ visited_ids
1354
+ )
1355
+ except Exception as e:
1356
+ logging.warning(f"Error loading relationship {rel_name}: {e}")
1357
+
1358
+ return obj
1359
+
1113
1360
  # Register an event listener to update 'updated_at' on instance modifications.
1114
1361
  @event.listens_for(Session, "before_flush")
1115
1362
  def _update_updated_at(session, flush_context, instances):