async-easy-model 0.2.3__py3-none-any.whl → 0.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- async_easy_model/__init__.py +1 -1
- async_easy_model/auto_relationships.py +132 -2
- async_easy_model/model.py +516 -367
- {async_easy_model-0.2.3.dist-info → async_easy_model-0.2.4.dist-info}/METADATA +1 -1
- async_easy_model-0.2.4.dist-info/RECORD +10 -0
- async_easy_model-0.2.3.dist-info/RECORD +0 -10
- {async_easy_model-0.2.3.dist-info → async_easy_model-0.2.4.dist-info}/LICENSE +0 -0
- {async_easy_model-0.2.3.dist-info → async_easy_model-0.2.4.dist-info}/WHEEL +0 -0
- {async_easy_model-0.2.3.dist-info → async_easy_model-0.2.4.dist-info}/top_level.txt +0 -0
async_easy_model/model.py
CHANGED
@@ -10,6 +10,7 @@ from datetime import datetime, timezone as tz
|
|
10
10
|
import inspect
|
11
11
|
import json
|
12
12
|
import logging
|
13
|
+
import re
|
13
14
|
|
14
15
|
T = TypeVar("T", bound="EasyModel")
|
15
16
|
|
@@ -212,8 +213,9 @@ class EasyModel(SQLModel):
|
|
212
213
|
@classmethod
|
213
214
|
async def all(
|
214
215
|
cls: Type[T],
|
215
|
-
include_relationships: bool = True,
|
216
|
-
order_by: Optional[Union[str, List[str]]] = None
|
216
|
+
include_relationships: bool = True,
|
217
|
+
order_by: Optional[Union[str, List[str]]] = None,
|
218
|
+
max_depth: int = 2
|
217
219
|
) -> List[T]:
|
218
220
|
"""
|
219
221
|
Retrieve all records of this model.
|
@@ -222,29 +224,20 @@ class EasyModel(SQLModel):
|
|
222
224
|
include_relationships: If True, eagerly load all relationships
|
223
225
|
order_by: Field(s) to order by. Can be a string or list of strings.
|
224
226
|
Prefix with '-' for descending order (e.g. '-created_at')
|
227
|
+
max_depth: Maximum depth for loading nested relationships
|
225
228
|
|
226
229
|
Returns:
|
227
230
|
A list of all model instances
|
228
231
|
"""
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
# Apply ordering
|
233
|
-
statement = cls._apply_order_by(statement, order_by)
|
234
|
-
|
235
|
-
if include_relationships:
|
236
|
-
# Get all relationship attributes, including auto-detected ones
|
237
|
-
for rel_name in cls._get_auto_relationship_fields():
|
238
|
-
statement = statement.options(selectinload(getattr(cls, rel_name)))
|
239
|
-
|
240
|
-
result = await session.execute(statement)
|
241
|
-
return result.scalars().all()
|
232
|
+
return await cls.select({}, all=True, include_relationships=include_relationships,
|
233
|
+
order_by=order_by, max_depth=max_depth)
|
242
234
|
|
243
235
|
@classmethod
|
244
236
|
async def first(
|
245
237
|
cls: Type[T],
|
246
|
-
include_relationships: bool = True,
|
247
|
-
order_by: Optional[Union[str, List[str]]] = None
|
238
|
+
include_relationships: bool = True,
|
239
|
+
order_by: Optional[Union[str, List[str]]] = None,
|
240
|
+
max_depth: int = 2
|
248
241
|
) -> Optional[T]:
|
249
242
|
"""
|
250
243
|
Retrieve the first record of this model.
|
@@ -253,56 +246,37 @@ class EasyModel(SQLModel):
|
|
253
246
|
include_relationships: If True, eagerly load all relationships
|
254
247
|
order_by: Field(s) to order by. Can be a string or list of strings.
|
255
248
|
Prefix with '-' for descending order (e.g. '-created_at')
|
249
|
+
max_depth: Maximum depth for loading nested relationships
|
256
250
|
|
257
251
|
Returns:
|
258
252
|
The first model instance or None if no records exist
|
259
253
|
"""
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
# Apply ordering
|
264
|
-
statement = cls._apply_order_by(statement, order_by)
|
265
|
-
|
266
|
-
if include_relationships:
|
267
|
-
# Get all relationship attributes, including auto-detected ones
|
268
|
-
for rel_name in cls._get_auto_relationship_fields():
|
269
|
-
statement = statement.options(selectinload(getattr(cls, rel_name)))
|
270
|
-
|
271
|
-
result = await session.execute(statement)
|
272
|
-
return result.scalars().first()
|
254
|
+
return await cls.select({}, first=True, include_relationships=include_relationships,
|
255
|
+
order_by=order_by, max_depth=max_depth)
|
273
256
|
|
274
257
|
@classmethod
|
275
258
|
async def limit(
|
276
259
|
cls: Type[T],
|
277
260
|
count: int,
|
278
|
-
include_relationships: bool = True,
|
279
|
-
order_by: Optional[Union[str, List[str]]] = None
|
261
|
+
include_relationships: bool = True,
|
262
|
+
order_by: Optional[Union[str, List[str]]] = None,
|
263
|
+
max_depth: int = 2
|
280
264
|
) -> List[T]:
|
281
265
|
"""
|
282
|
-
Retrieve a limited number of records
|
266
|
+
Retrieve a limited number of records.
|
283
267
|
|
284
268
|
Args:
|
285
|
-
count: Maximum number of records to
|
269
|
+
count: Maximum number of records to return
|
286
270
|
include_relationships: If True, eagerly load all relationships
|
287
271
|
order_by: Field(s) to order by. Can be a string or list of strings.
|
288
272
|
Prefix with '-' for descending order (e.g. '-created_at')
|
273
|
+
max_depth: Maximum depth for loading nested relationships
|
289
274
|
|
290
275
|
Returns:
|
291
|
-
A list of model instances
|
276
|
+
A list of model instances
|
292
277
|
"""
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
# Apply ordering
|
297
|
-
statement = cls._apply_order_by(statement, order_by)
|
298
|
-
|
299
|
-
if include_relationships:
|
300
|
-
# Get all relationship attributes, including auto-detected ones
|
301
|
-
for rel_name in cls._get_auto_relationship_fields():
|
302
|
-
statement = statement.options(selectinload(getattr(cls, rel_name)))
|
303
|
-
|
304
|
-
result = await session.execute(statement)
|
305
|
-
return result.scalars().all()
|
278
|
+
return await cls.select({}, all=True, include_relationships=include_relationships,
|
279
|
+
order_by=order_by, limit=count, max_depth=max_depth)
|
306
280
|
|
307
281
|
@classmethod
|
308
282
|
async def get_by_id(cls: Type[T], id: int, include_relationships: bool = True) -> Optional[T]:
|
@@ -327,6 +301,20 @@ class EasyModel(SQLModel):
|
|
327
301
|
else:
|
328
302
|
return await session.get(cls, id)
|
329
303
|
|
304
|
+
@classmethod
|
305
|
+
def _get_unique_fields(cls) -> List[str]:
|
306
|
+
"""
|
307
|
+
Get all fields with unique=True constraint
|
308
|
+
|
309
|
+
Returns:
|
310
|
+
List of field names that have unique constraints
|
311
|
+
"""
|
312
|
+
unique_fields = []
|
313
|
+
for name, field in cls.__fields__.items():
|
314
|
+
if name != 'id' and hasattr(field, "field_info") and field.field_info.extra.get('unique', False):
|
315
|
+
unique_fields.append(name)
|
316
|
+
return unique_fields
|
317
|
+
|
330
318
|
@classmethod
|
331
319
|
async def get_by_attribute(
|
332
320
|
cls: Type[T],
|
@@ -391,192 +379,76 @@ class EasyModel(SQLModel):
|
|
391
379
|
return result.scalars().first()
|
392
380
|
|
393
381
|
@classmethod
|
394
|
-
async def insert(cls: Type[T], data: Union[Dict[str, Any], List[Dict[str, Any]]], include_relationships: bool = True) -> Union[T, List[T]]:
|
382
|
+
async def insert(cls: Type[T], data: Union[Dict[str, Any], List[Dict[str, Any]]], include_relationships: bool = True, max_depth: int = 2) -> Union[T, List[T]]:
|
395
383
|
"""
|
396
384
|
Insert one or more records.
|
397
385
|
|
398
386
|
Args:
|
399
387
|
data: Dictionary of field values or a list of dictionaries for multiple records
|
400
388
|
include_relationships: If True, return the instance(s) with relationships loaded
|
389
|
+
max_depth: Maximum depth for loading nested relationships
|
401
390
|
|
402
391
|
Returns:
|
403
392
|
The created model instance(s)
|
404
393
|
"""
|
405
|
-
|
394
|
+
if not data:
|
395
|
+
return None
|
396
|
+
|
397
|
+
# Handle single dict or list of dicts
|
406
398
|
if isinstance(data, list):
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
unique_criteria = {field: processed_item[field]
|
427
|
-
for field in unique_fields
|
428
|
-
if field in processed_item}
|
429
|
-
|
430
|
-
if unique_criteria:
|
431
|
-
# Try to find existing record with these unique values
|
432
|
-
statement = select(cls)
|
433
|
-
for field, value in unique_criteria.items():
|
434
|
-
statement = statement.where(getattr(cls, field) == value)
|
435
|
-
result = await session.execute(statement)
|
436
|
-
existing_obj = result.scalars().first()
|
437
|
-
|
438
|
-
if existing_obj:
|
439
|
-
# Update existing object with new values
|
440
|
-
for key, value in processed_item.items():
|
441
|
-
if key != 'id': # Don't update ID
|
442
|
-
setattr(existing_obj, key, value)
|
443
|
-
obj = existing_obj
|
444
|
-
else:
|
445
|
-
# Create new object
|
446
|
-
obj = cls(**processed_item)
|
447
|
-
session.add(obj)
|
448
|
-
|
449
|
-
# Flush to get the ID for this object
|
450
|
-
await session.flush()
|
451
|
-
|
452
|
-
# Now handle any one-to-many relationships
|
453
|
-
for rel_name, related_objects in related_fields.items():
|
454
|
-
# Check if the relationship attribute exists in the class (not the instance)
|
455
|
-
if hasattr(cls, rel_name):
|
456
|
-
# Get the relationship attribute from the class
|
457
|
-
rel_attr = getattr(cls, rel_name)
|
458
|
-
|
459
|
-
# Check if it's a SQLAlchemy relationship
|
460
|
-
if hasattr(rel_attr, 'property') and hasattr(rel_attr.property, 'back_populates'):
|
461
|
-
back_attr = rel_attr.property.back_populates
|
462
|
-
|
463
|
-
# For each related object, set the back reference to this object
|
464
|
-
for related_obj in related_objects:
|
465
|
-
setattr(related_obj, back_attr, obj)
|
466
|
-
# Make sure the related object is in the session
|
467
|
-
session.add(related_obj)
|
468
|
-
|
469
|
-
objects.append(obj)
|
470
|
-
except Exception as e:
|
471
|
-
logging.error(f"Error inserting record: {e}")
|
472
|
-
await session.rollback()
|
473
|
-
raise
|
399
|
+
results = []
|
400
|
+
for item in data:
|
401
|
+
result = await cls.insert(item, include_relationships, max_depth)
|
402
|
+
results.append(result)
|
403
|
+
return results
|
404
|
+
|
405
|
+
# Store many-to-many relationship data for later processing
|
406
|
+
many_to_many_data = {}
|
407
|
+
many_to_many_rels = cls._get_many_to_many_relationships()
|
408
|
+
|
409
|
+
# Extract many-to-many data before processing other relationships
|
410
|
+
for rel_name in many_to_many_rels:
|
411
|
+
if rel_name in data:
|
412
|
+
many_to_many_data[rel_name] = data[rel_name]
|
413
|
+
|
414
|
+
# Process relationships to convert nested objects to foreign keys
|
415
|
+
async with cls.get_session() as session:
|
416
|
+
try:
|
417
|
+
processed_data = await cls._process_relationships_for_insert(session, data)
|
474
418
|
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
#
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
related_fields = {}
|
498
|
-
for key in list(processed_data.keys()):
|
499
|
-
if key.startswith("_related_"):
|
500
|
-
rel_name = key[9:] # Remove "_related_" prefix
|
501
|
-
related_fields[rel_name] = processed_data.pop(key)
|
502
|
-
|
503
|
-
# Check if a record with unique constraints already exists
|
504
|
-
unique_fields = cls._get_unique_fields()
|
505
|
-
existing_obj = None
|
506
|
-
|
507
|
-
if unique_fields:
|
508
|
-
unique_criteria = {field: processed_data[field]
|
509
|
-
for field in unique_fields
|
510
|
-
if field in processed_data}
|
511
|
-
|
512
|
-
if unique_criteria:
|
513
|
-
# Try to find existing record with these unique values
|
514
|
-
statement = select(cls)
|
515
|
-
for field, value in unique_criteria.items():
|
516
|
-
statement = statement.where(getattr(cls, field) == value)
|
517
|
-
result = await session.execute(statement)
|
518
|
-
existing_obj = result.scalars().first()
|
519
|
-
|
520
|
-
if existing_obj:
|
521
|
-
# Update existing object with new values
|
522
|
-
for key, value in processed_data.items():
|
523
|
-
if key != 'id': # Don't update ID
|
524
|
-
setattr(existing_obj, key, value)
|
525
|
-
obj = existing_obj
|
526
|
-
else:
|
527
|
-
# Create new object
|
528
|
-
obj = cls(**processed_data)
|
529
|
-
session.add(obj)
|
530
|
-
|
531
|
-
await session.flush() # Flush to get the ID
|
532
|
-
|
533
|
-
# Now handle any one-to-many relationships
|
534
|
-
for rel_name, related_objects in related_fields.items():
|
535
|
-
# Check if the relationship attribute exists in the class (not the instance)
|
536
|
-
if hasattr(cls, rel_name):
|
537
|
-
# Get the relationship attribute from the class
|
538
|
-
rel_attr = getattr(cls, rel_name)
|
539
|
-
|
540
|
-
# Check if it's a SQLAlchemy relationship
|
541
|
-
if hasattr(rel_attr, 'property') and hasattr(rel_attr.property, 'back_populates'):
|
542
|
-
back_attr = rel_attr.property.back_populates
|
543
|
-
|
544
|
-
# For each related object, set the back reference to this object
|
545
|
-
for related_obj in related_objects:
|
546
|
-
setattr(related_obj, back_attr, obj)
|
547
|
-
# Make sure the related object is in the session
|
548
|
-
session.add(related_obj)
|
549
|
-
|
550
|
-
await session.commit()
|
419
|
+
# Create the model instance
|
420
|
+
obj = cls(**processed_data)
|
421
|
+
session.add(obj)
|
422
|
+
|
423
|
+
# Flush to get the object ID
|
424
|
+
await session.flush()
|
425
|
+
|
426
|
+
# Now process many-to-many relationships if any
|
427
|
+
for rel_name, rel_data in many_to_many_data.items():
|
428
|
+
if isinstance(rel_data, list):
|
429
|
+
await cls._process_many_to_many_relationship(
|
430
|
+
session, obj, rel_name, rel_data
|
431
|
+
)
|
432
|
+
|
433
|
+
# Commit the transaction
|
434
|
+
await session.commit()
|
435
|
+
|
436
|
+
if include_relationships:
|
437
|
+
# Reload with relationships
|
438
|
+
return await cls._load_relationships_recursively(session, obj, max_depth)
|
439
|
+
else:
|
440
|
+
return obj
|
551
441
|
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
|
557
|
-
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
|
562
|
-
except Exception as e:
|
563
|
-
logging.error(f"Error inserting record: {e}")
|
564
|
-
await session.rollback()
|
565
|
-
raise
|
566
|
-
|
567
|
-
@classmethod
|
568
|
-
def _get_unique_fields(cls) -> List[str]:
|
569
|
-
"""
|
570
|
-
Get all fields with unique=True constraint
|
571
|
-
|
572
|
-
Returns:
|
573
|
-
List of field names that have unique constraints
|
574
|
-
"""
|
575
|
-
unique_fields = []
|
576
|
-
for name, field in cls.__fields__.items():
|
577
|
-
if name != 'id' and hasattr(field, 'field_info') and field.field_info.extra.get('unique', False):
|
578
|
-
unique_fields.append(name)
|
579
|
-
return unique_fields
|
442
|
+
except Exception as e:
|
443
|
+
await session.rollback()
|
444
|
+
logging.error(f"Error inserting {cls.__name__}: {e}")
|
445
|
+
if "UNIQUE constraint failed" in str(e):
|
446
|
+
field_match = re.search(r"UNIQUE constraint failed: \w+\.(\w+)", str(e))
|
447
|
+
if field_match:
|
448
|
+
field_name = field_match.group(1)
|
449
|
+
value = data.get(field_name)
|
450
|
+
raise ValueError(f"A record with {field_name}='{value}' already exists")
|
451
|
+
raise
|
580
452
|
|
581
453
|
@classmethod
|
582
454
|
async def _process_relationships_for_insert(cls: Type[T], session: AsyncSession, data: Dict[str, Any]) -> Dict[str, Any]:
|
@@ -613,97 +485,65 @@ class EasyModel(SQLModel):
|
|
613
485
|
Returns:
|
614
486
|
Processed data dictionary with nested objects replaced by their foreign key IDs
|
615
487
|
"""
|
616
|
-
|
617
|
-
|
488
|
+
if not data:
|
489
|
+
return {}
|
490
|
+
|
491
|
+
result = dict(data)
|
618
492
|
|
619
|
-
# Get
|
493
|
+
# Get relationship fields for this model
|
620
494
|
relationship_fields = cls._get_auto_relationship_fields()
|
621
495
|
|
622
|
-
# Get
|
623
|
-
|
624
|
-
for field_name, field_info in cls.__fields__.items():
|
625
|
-
if field_name.endswith("_id") and hasattr(field_info, "field_info"):
|
626
|
-
if field_info.field_info.extra.get("foreign_key"):
|
627
|
-
foreign_key_fields.append(field_name)
|
496
|
+
# Get many-to-many relationships separately
|
497
|
+
many_to_many_relationships = cls._get_many_to_many_relationships()
|
628
498
|
|
629
|
-
#
|
630
|
-
|
631
|
-
|
499
|
+
# Set of field names already processed as many-to-many relationships
|
500
|
+
processed_m2m_fields = set()
|
501
|
+
|
502
|
+
# Process each relationship field in the input data
|
503
|
+
for key in list(result.keys()):
|
504
|
+
value = result[key]
|
505
|
+
|
506
|
+
# Skip empty values
|
632
507
|
if value is None:
|
633
508
|
continue
|
634
|
-
|
635
|
-
# Check if this is a relationship field (either by name or derived from foreign key)
|
636
|
-
is_rel_field = key in relationship_fields
|
637
|
-
related_key = f"{key}_id"
|
638
|
-
is_derived_rel = related_key in foreign_key_fields
|
639
509
|
|
640
|
-
#
|
641
|
-
if
|
642
|
-
#
|
510
|
+
# Check if this is a relationship field
|
511
|
+
if key in relationship_fields:
|
512
|
+
# Get the related model class
|
643
513
|
related_model = None
|
644
514
|
|
645
|
-
|
646
|
-
if hasattr(cls, key) and hasattr(getattr(cls, key), 'property'):
|
647
|
-
# Get from relationship attribute
|
515
|
+
if hasattr(cls, key):
|
648
516
|
rel_attr = getattr(cls, key)
|
649
|
-
|
650
|
-
|
651
|
-
# Try to find it from foreign key definition
|
652
|
-
fk_definition = None
|
653
|
-
for field_name, field_info in cls.__fields__.items():
|
654
|
-
if field_name == related_key and hasattr(field_info, "field_info"):
|
655
|
-
fk_definition = field_info.field_info.extra.get("foreign_key")
|
656
|
-
break
|
657
|
-
|
658
|
-
if fk_definition:
|
659
|
-
# Parse foreign key definition (e.g. "users.id")
|
660
|
-
target_table, _ = fk_definition.split(".")
|
661
|
-
# Try to find the target model
|
662
|
-
from async_easy_model.auto_relationships import get_model_by_table_name, singularize_name
|
663
|
-
related_model = get_model_by_table_name(target_table)
|
664
|
-
if not related_model:
|
665
|
-
# Try with the singular form
|
666
|
-
singular_table = singularize_name(target_table)
|
667
|
-
related_model = get_model_by_table_name(singular_table)
|
668
|
-
else:
|
669
|
-
# Try to infer from field name (e.g., "user_id" -> Users)
|
670
|
-
base_name = related_key[:-3] # Remove "_id"
|
671
|
-
from async_easy_model.auto_relationships import get_model_by_table_name, singularize_name, pluralize_name
|
672
|
-
|
673
|
-
# Try singular and plural forms
|
674
|
-
related_model = get_model_by_table_name(base_name)
|
675
|
-
if not related_model:
|
676
|
-
plural_table = pluralize_name(base_name)
|
677
|
-
related_model = get_model_by_table_name(plural_table)
|
678
|
-
if not related_model:
|
679
|
-
singular_table = singularize_name(base_name)
|
680
|
-
related_model = get_model_by_table_name(singular_table)
|
517
|
+
if hasattr(rel_attr, 'prop') and hasattr(rel_attr.prop, 'mapper'):
|
518
|
+
related_model = rel_attr.prop.mapper.class_
|
681
519
|
|
682
520
|
if not related_model:
|
683
|
-
logging.warning(f"Could not
|
521
|
+
logging.warning(f"Could not determine related model for {key}, skipping")
|
522
|
+
continue
|
523
|
+
|
524
|
+
# Check if this is a many-to-many relationship field
|
525
|
+
if key in many_to_many_relationships:
|
526
|
+
# Store this separately - we'll handle it after the main object is created
|
527
|
+
processed_m2m_fields.add(key)
|
684
528
|
continue
|
685
529
|
|
686
|
-
#
|
530
|
+
# Handle different relationship types based on data type
|
687
531
|
if isinstance(value, list):
|
688
532
|
# Handle one-to-many relationship (list of dictionaries)
|
689
|
-
|
690
|
-
|
533
|
+
related_ids = []
|
691
534
|
for item in value:
|
692
|
-
if
|
693
|
-
|
694
|
-
|
695
|
-
|
696
|
-
|
697
|
-
|
698
|
-
)
|
699
|
-
if related_obj:
|
700
|
-
related_objects.append(related_obj)
|
535
|
+
if isinstance(item, dict):
|
536
|
+
related_obj = await cls._process_single_relationship_item(
|
537
|
+
session, related_model, item
|
538
|
+
)
|
539
|
+
if related_obj:
|
540
|
+
related_ids.append(related_obj.id)
|
701
541
|
|
702
|
-
#
|
703
|
-
|
704
|
-
result[
|
542
|
+
# Update result with list of foreign key IDs
|
543
|
+
foreign_key_list_name = f"{key}_ids"
|
544
|
+
result[foreign_key_list_name] = related_ids
|
705
545
|
|
706
|
-
# Remove the
|
546
|
+
# Remove the relationship list from the result
|
707
547
|
if key in result:
|
708
548
|
del result[key]
|
709
549
|
|
@@ -722,8 +562,14 @@ class EasyModel(SQLModel):
|
|
722
562
|
if key in result:
|
723
563
|
del result[key]
|
724
564
|
|
565
|
+
# Remove any processed many-to-many fields from the result
|
566
|
+
# since we'll handle them separately after the object is created
|
567
|
+
for key in processed_m2m_fields:
|
568
|
+
if key in result:
|
569
|
+
del result[key]
|
570
|
+
|
725
571
|
return result
|
726
|
-
|
572
|
+
|
727
573
|
@classmethod
|
728
574
|
async def _process_single_relationship_item(cls, session: AsyncSession, related_model: Type, item_data: Dict[str, Any]) -> Optional[Any]:
|
729
575
|
"""
|
@@ -742,10 +588,10 @@ class EasyModel(SQLModel):
|
|
742
588
|
"""
|
743
589
|
# Look for unique fields in the related model to use for searching
|
744
590
|
unique_fields = []
|
745
|
-
for
|
746
|
-
if (hasattr(
|
747
|
-
|
748
|
-
unique_fields.append(
|
591
|
+
for name, field in related_model.__fields__.items():
|
592
|
+
if (hasattr(field, "field_info") and
|
593
|
+
field.field_info.extra.get('unique', False)):
|
594
|
+
unique_fields.append(name)
|
749
595
|
|
750
596
|
# Create a search dictionary using unique fields
|
751
597
|
search_dict = {}
|
@@ -855,6 +701,17 @@ class EasyModel(SQLModel):
|
|
855
701
|
Returns:
|
856
702
|
The updated model instance
|
857
703
|
"""
|
704
|
+
# Store many-to-many relationship data for later processing
|
705
|
+
many_to_many_data = {}
|
706
|
+
many_to_many_rels = cls._get_many_to_many_relationships()
|
707
|
+
|
708
|
+
# Extract many-to-many data before processing
|
709
|
+
for rel_name in many_to_many_rels:
|
710
|
+
if rel_name in data:
|
711
|
+
many_to_many_data[rel_name] = data[rel_name]
|
712
|
+
# Remove from original data
|
713
|
+
del data[rel_name]
|
714
|
+
|
858
715
|
async with cls.get_session() as session:
|
859
716
|
try:
|
860
717
|
# Find the record(s) to update
|
@@ -895,6 +752,74 @@ class EasyModel(SQLModel):
|
|
895
752
|
for key, value in data.items():
|
896
753
|
setattr(record, key, value)
|
897
754
|
|
755
|
+
# Process many-to-many relationships if any
|
756
|
+
for rel_name, rel_data in many_to_many_data.items():
|
757
|
+
if isinstance(rel_data, list):
|
758
|
+
# First, get all existing links for this relation
|
759
|
+
junction_model, target_model = many_to_many_rels[rel_name]
|
760
|
+
|
761
|
+
from async_easy_model.auto_relationships import get_foreign_keys_from_model
|
762
|
+
foreign_keys = get_foreign_keys_from_model(junction_model)
|
763
|
+
|
764
|
+
# Find the foreign key fields for this model and the target model
|
765
|
+
this_model_fk = None
|
766
|
+
target_model_fk = None
|
767
|
+
|
768
|
+
for fk_field, fk_target in foreign_keys.items():
|
769
|
+
target_table = fk_target.split('.')[0]
|
770
|
+
if target_table == cls.__tablename__:
|
771
|
+
this_model_fk = fk_field
|
772
|
+
elif target_table == target_model.__tablename__:
|
773
|
+
target_model_fk = fk_field
|
774
|
+
|
775
|
+
if not this_model_fk or not target_model_fk:
|
776
|
+
logging.warning(f"Could not find foreign key fields for {rel_name} relationship")
|
777
|
+
continue
|
778
|
+
|
779
|
+
# Get all existing junctions for this record
|
780
|
+
junction_stmt = select(junction_model).where(
|
781
|
+
getattr(junction_model, this_model_fk) == record.id
|
782
|
+
)
|
783
|
+
junction_result = await session.execute(junction_stmt)
|
784
|
+
existing_junctions = junction_result.scalars().all()
|
785
|
+
|
786
|
+
# Get the target IDs from the existing junctions
|
787
|
+
existing_target_ids = [getattr(junction, target_model_fk) for junction in existing_junctions]
|
788
|
+
|
789
|
+
# Track processed target IDs
|
790
|
+
processed_target_ids = set()
|
791
|
+
|
792
|
+
# Process each item in rel_data
|
793
|
+
for item_data in rel_data:
|
794
|
+
target_obj = await cls._process_single_relationship_item(
|
795
|
+
session, target_model, item_data
|
796
|
+
)
|
797
|
+
|
798
|
+
if not target_obj:
|
799
|
+
logging.warning(f"Failed to process {target_model.__name__} item for {rel_name}")
|
800
|
+
continue
|
801
|
+
|
802
|
+
processed_target_ids.add(target_obj.id)
|
803
|
+
|
804
|
+
# Check if this link already exists
|
805
|
+
if target_obj.id not in existing_target_ids:
|
806
|
+
# Create new junction
|
807
|
+
junction_data = {
|
808
|
+
this_model_fk: record.id,
|
809
|
+
target_model_fk: target_obj.id
|
810
|
+
}
|
811
|
+
junction_obj = junction_model(**junction_data)
|
812
|
+
session.add(junction_obj)
|
813
|
+
logging.info(f"Created junction between {cls.__name__} {record.id} and {target_model.__name__} {target_obj.id}")
|
814
|
+
|
815
|
+
# Delete junctions for target IDs that weren't in the updated data
|
816
|
+
junctions_to_delete = [j for j in existing_junctions
|
817
|
+
if getattr(j, target_model_fk) not in processed_target_ids]
|
818
|
+
|
819
|
+
for junction in junctions_to_delete:
|
820
|
+
await session.delete(junction)
|
821
|
+
logging.info(f"Deleted junction between {cls.__name__} {record.id} and {target_model.__name__} {getattr(junction, target_model_fk)}")
|
822
|
+
|
898
823
|
await session.flush()
|
899
824
|
await session.commit()
|
900
825
|
|
@@ -908,9 +833,10 @@ class EasyModel(SQLModel):
|
|
908
833
|
else:
|
909
834
|
await session.refresh(record)
|
910
835
|
return record
|
836
|
+
|
911
837
|
except Exception as e:
|
912
|
-
logging.error(f"Error updating record: {e}")
|
913
838
|
await session.rollback()
|
839
|
+
logging.error(f"Error updating {cls.__name__}: {e}")
|
914
840
|
raise
|
915
841
|
|
916
842
|
@classmethod
|
@@ -926,7 +852,7 @@ class EasyModel(SQLModel):
|
|
926
852
|
"""
|
927
853
|
async with cls.get_session() as session:
|
928
854
|
try:
|
929
|
-
# Find
|
855
|
+
# Find records to delete
|
930
856
|
statement = select(cls)
|
931
857
|
for field, value in criteria.items():
|
932
858
|
if isinstance(value, str) and '*' in value:
|
@@ -943,43 +869,50 @@ class EasyModel(SQLModel):
|
|
943
869
|
logging.warning(f"No records found with criteria: {criteria}")
|
944
870
|
return 0
|
945
871
|
|
946
|
-
#
|
947
|
-
|
948
|
-
relationship_fields = cls._get_auto_relationship_fields()
|
949
|
-
to_many_relationships = []
|
872
|
+
# Check if there are many-to-many relationships that need cleanup
|
873
|
+
many_to_many_rels = cls._get_many_to_many_relationships()
|
950
874
|
|
951
|
-
#
|
952
|
-
|
953
|
-
rel_attr = getattr(cls, rel_name, None)
|
954
|
-
if rel_attr and hasattr(rel_attr, 'property'):
|
955
|
-
# Check if this is a to-many relationship (one-to-many or many-to-many)
|
956
|
-
if hasattr(rel_attr.property, 'uselist') and rel_attr.property.uselist:
|
957
|
-
to_many_relationships.append(rel_name)
|
958
|
-
|
959
|
-
# For each record, delete related records first (cascade delete)
|
875
|
+
# Delete each record and its related many-to-many junction records
|
876
|
+
count = 0
|
960
877
|
for record in records:
|
961
|
-
#
|
962
|
-
|
963
|
-
|
964
|
-
|
965
|
-
|
966
|
-
|
967
|
-
|
968
|
-
|
969
|
-
|
970
|
-
|
878
|
+
# Clean up many-to-many junctions first
|
879
|
+
for rel_name, (junction_model, _) in many_to_many_rels.items():
|
880
|
+
# Get foreign keys from the junction model
|
881
|
+
from async_easy_model.auto_relationships import get_foreign_keys_from_model
|
882
|
+
foreign_keys = get_foreign_keys_from_model(junction_model)
|
883
|
+
|
884
|
+
# Find which foreign key refers to this model
|
885
|
+
this_model_fk = None
|
886
|
+
for fk_field, fk_target in foreign_keys.items():
|
887
|
+
target_table = fk_target.split('.')[0]
|
888
|
+
if target_table == cls.__tablename__:
|
889
|
+
this_model_fk = fk_field
|
890
|
+
break
|
891
|
+
|
892
|
+
if not this_model_fk:
|
893
|
+
continue
|
894
|
+
|
895
|
+
# Delete junction records for this record
|
896
|
+
delete_stmt = select(junction_model).where(
|
897
|
+
getattr(junction_model, this_model_fk) == record.id
|
898
|
+
)
|
899
|
+
junction_result = await session.execute(delete_stmt)
|
900
|
+
junctions = junction_result.scalars().all()
|
901
|
+
|
902
|
+
for junction in junctions:
|
903
|
+
await session.delete(junction)
|
904
|
+
logging.info(f"Deleted junction record for {cls.__name__} id={record.id}")
|
971
905
|
|
972
906
|
# Now delete the main record
|
973
907
|
await session.delete(record)
|
908
|
+
count += 1
|
974
909
|
|
975
|
-
# Commit the changes
|
976
|
-
await session.flush()
|
977
910
|
await session.commit()
|
911
|
+
return count
|
978
912
|
|
979
|
-
return len(records)
|
980
913
|
except Exception as e:
|
981
|
-
logging.error(f"Error deleting records: {e}")
|
982
914
|
await session.rollback()
|
915
|
+
logging.error(f"Error deleting {cls.__name__}: {e}")
|
983
916
|
raise
|
984
917
|
|
985
918
|
def to_dict(self, include_relationships: bool = True, max_depth: int = 4) -> Dict[str, Any]:
|
@@ -1058,21 +991,23 @@ class EasyModel(SQLModel):
|
|
1058
991
|
first: bool = False,
|
1059
992
|
include_relationships: bool = True,
|
1060
993
|
order_by: Optional[Union[str, List[str]]] = None,
|
994
|
+
max_depth: int = 2,
|
1061
995
|
limit: Optional[int] = None
|
1062
996
|
) -> Union[Optional[T], List[T]]:
|
1063
997
|
"""
|
1064
|
-
|
998
|
+
Select records based on criteria.
|
1065
999
|
|
1066
1000
|
Args:
|
1067
|
-
criteria: Dictionary of
|
1068
|
-
all: If True, return all matching records,
|
1001
|
+
criteria: Dictionary of field values to filter by
|
1002
|
+
all: If True, return all matching records. If False, return only the first match.
|
1069
1003
|
first: If True, return only the first record (equivalent to all=False)
|
1070
1004
|
include_relationships: If True, eagerly load all relationships
|
1071
1005
|
order_by: Field(s) to order by. Can be a string or list of strings.
|
1072
1006
|
Prefix with '-' for descending order (e.g. '-created_at')
|
1007
|
+
max_depth: Maximum depth for loading nested relationships (when include_relationships=True)
|
1073
1008
|
limit: Maximum number of records to retrieve (if all=True)
|
1074
1009
|
If limit > 1, all is automatically set to True
|
1075
|
-
|
1010
|
+
|
1076
1011
|
Returns:
|
1077
1012
|
A single model instance, a list of instances, or None if not found
|
1078
1013
|
"""
|
@@ -1083,73 +1018,98 @@ class EasyModel(SQLModel):
|
|
1083
1018
|
# If limit is specified and > 1, set all to True
|
1084
1019
|
if limit is not None and limit > 1:
|
1085
1020
|
all = True
|
1021
|
+
|
1086
1022
|
# If first is specified, set all to False (first takes precedence)
|
1087
1023
|
if first:
|
1088
1024
|
all = False
|
1089
|
-
|
1025
|
+
|
1090
1026
|
async with cls.get_session() as session:
|
1091
1027
|
# Build the query
|
1092
1028
|
statement = select(cls)
|
1093
1029
|
|
1094
|
-
# Apply criteria
|
1095
|
-
for
|
1030
|
+
# Apply criteria filters
|
1031
|
+
for key, value in criteria.items():
|
1096
1032
|
if isinstance(value, str) and '*' in value:
|
1097
1033
|
# Handle LIKE queries (convert '*' wildcard to '%')
|
1098
1034
|
like_value = value.replace('*', '%')
|
1099
|
-
statement = statement.where(getattr(cls,
|
1035
|
+
statement = statement.where(getattr(cls, key).like(like_value))
|
1100
1036
|
else:
|
1101
1037
|
# Regular equality check
|
1102
|
-
statement = statement.where(getattr(cls,
|
1038
|
+
statement = statement.where(getattr(cls, key) == value)
|
1103
1039
|
|
1104
1040
|
# Apply ordering
|
1105
1041
|
if order_by:
|
1106
|
-
|
1042
|
+
order_clauses = []
|
1043
|
+
if isinstance(order_by, str):
|
1044
|
+
order_by = [order_by]
|
1045
|
+
|
1046
|
+
for field_name in order_by:
|
1047
|
+
if field_name.startswith("-"):
|
1048
|
+
# Descending order
|
1049
|
+
field_name = field_name[1:] # Remove the "-" prefix
|
1050
|
+
# Handle relationship field ordering with dot notation
|
1051
|
+
if "." in field_name:
|
1052
|
+
rel_name, attr_name = field_name.split(".", 1)
|
1053
|
+
if hasattr(cls, rel_name):
|
1054
|
+
rel_model = getattr(cls, rel_name)
|
1055
|
+
if hasattr(rel_model, "property"):
|
1056
|
+
target_model = rel_model.property.entity.class_
|
1057
|
+
if hasattr(target_model, attr_name):
|
1058
|
+
order_clauses.append(getattr(target_model, attr_name).desc())
|
1059
|
+
else:
|
1060
|
+
order_clauses.append(getattr(cls, field_name).desc())
|
1061
|
+
else:
|
1062
|
+
# Ascending order
|
1063
|
+
# Handle relationship field ordering with dot notation
|
1064
|
+
if "." in field_name:
|
1065
|
+
rel_name, attr_name = field_name.split(".", 1)
|
1066
|
+
if hasattr(cls, rel_name):
|
1067
|
+
rel_model = getattr(cls, rel_name)
|
1068
|
+
if hasattr(rel_model, "property"):
|
1069
|
+
target_model = rel_model.property.entity.class_
|
1070
|
+
if hasattr(target_model, attr_name):
|
1071
|
+
order_clauses.append(getattr(target_model, attr_name).asc())
|
1072
|
+
else:
|
1073
|
+
order_clauses.append(getattr(cls, field_name).asc())
|
1074
|
+
|
1075
|
+
if order_clauses:
|
1076
|
+
statement = statement.order_by(*order_clauses)
|
1107
1077
|
|
1108
1078
|
# Apply limit
|
1109
1079
|
if limit:
|
1110
1080
|
statement = statement.limit(limit)
|
1111
1081
|
|
1112
|
-
#
|
1082
|
+
# Load relationships if requested
|
1113
1083
|
if include_relationships:
|
1114
1084
|
for rel_name in cls._get_auto_relationship_fields():
|
1115
1085
|
statement = statement.options(selectinload(getattr(cls, rel_name)))
|
1116
1086
|
|
1117
|
-
# Execute the query
|
1118
1087
|
result = await session.execute(statement)
|
1119
1088
|
|
1120
1089
|
if all:
|
1121
|
-
|
1122
|
-
instances = result.scalars().all()
|
1090
|
+
objects = result.scalars().all()
|
1123
1091
|
|
1124
|
-
#
|
1125
|
-
if include_relationships:
|
1126
|
-
|
1127
|
-
|
1128
|
-
|
1129
|
-
|
1130
|
-
|
1131
|
-
|
1132
|
-
|
1133
|
-
# Skip if the relationship can't be loaded
|
1134
|
-
pass
|
1092
|
+
# Load nested relationships if requested
|
1093
|
+
if include_relationships and objects and max_depth > 1:
|
1094
|
+
loaded_objects = []
|
1095
|
+
for obj in objects:
|
1096
|
+
loaded_obj = await cls._load_relationships_recursively(
|
1097
|
+
session, obj, max_depth
|
1098
|
+
)
|
1099
|
+
loaded_objects.append(loaded_obj)
|
1100
|
+
return loaded_objects
|
1135
1101
|
|
1136
|
-
return
|
1102
|
+
return objects
|
1137
1103
|
else:
|
1138
|
-
|
1139
|
-
instance = result.scalars().first()
|
1104
|
+
obj = result.scalars().first()
|
1140
1105
|
|
1141
|
-
#
|
1142
|
-
if include_relationships and
|
1143
|
-
|
1144
|
-
|
1145
|
-
|
1146
|
-
|
1147
|
-
|
1148
|
-
except Exception:
|
1149
|
-
# Skip if the relationship can't be loaded
|
1150
|
-
pass
|
1151
|
-
|
1152
|
-
return instance
|
1106
|
+
# Load nested relationships if requested
|
1107
|
+
if include_relationships and obj and max_depth > 1:
|
1108
|
+
obj = await cls._load_relationships_recursively(
|
1109
|
+
session, obj, max_depth
|
1110
|
+
)
|
1111
|
+
|
1112
|
+
return obj
|
1153
1113
|
|
1154
1114
|
@classmethod
|
1155
1115
|
async def get_or_create(cls: Type[T], search_criteria: Dict[str, Any], defaults: Optional[Dict[str, Any]] = None) -> Tuple[T, bool]:
|
@@ -1208,6 +1168,195 @@ class EasyModel(SQLModel):
|
|
1208
1168
|
# Use the enhanced insert method to handle all relationships
|
1209
1169
|
return await cls.insert(insert_data, include_relationships=True)
|
1210
1170
|
|
1171
|
+
@classmethod
|
1172
|
+
def _get_many_to_many_relationships(cls) -> Dict[str, Tuple[Type['EasyModel'], Type['EasyModel']]]:
|
1173
|
+
"""
|
1174
|
+
Get all many-to-many relationships for this model.
|
1175
|
+
|
1176
|
+
Returns:
|
1177
|
+
Dictionary mapping relationship field names to tuples of (junction_model, target_model)
|
1178
|
+
"""
|
1179
|
+
from async_easy_model.auto_relationships import get_model_by_table_name
|
1180
|
+
|
1181
|
+
many_to_many_relationships = {}
|
1182
|
+
|
1183
|
+
# Check if this is a class attribute rather than an instance attribute
|
1184
|
+
relationship_fields = cls._get_auto_relationship_fields()
|
1185
|
+
|
1186
|
+
for rel_name in relationship_fields:
|
1187
|
+
if not hasattr(cls, rel_name):
|
1188
|
+
continue
|
1189
|
+
|
1190
|
+
rel_attr = getattr(cls, rel_name)
|
1191
|
+
|
1192
|
+
# Check if this is a many-to-many relationship by looking for secondary table
|
1193
|
+
if hasattr(rel_attr, 'prop') and hasattr(rel_attr.prop, 'secondary'):
|
1194
|
+
secondary = rel_attr.prop.secondary
|
1195
|
+
if isinstance(secondary, str): # For string table names (our implementation)
|
1196
|
+
junction_model = get_model_by_table_name(secondary)
|
1197
|
+
if junction_model:
|
1198
|
+
target_model = rel_attr.prop.mapper.class_
|
1199
|
+
many_to_many_relationships[rel_name] = (junction_model, target_model)
|
1200
|
+
|
1201
|
+
return many_to_many_relationships
|
1202
|
+
|
1203
|
+
@classmethod
|
1204
|
+
async def _process_many_to_many_relationship(
|
1205
|
+
cls,
|
1206
|
+
session: AsyncSession,
|
1207
|
+
parent_obj: 'EasyModel',
|
1208
|
+
rel_name: str,
|
1209
|
+
items: List[Dict[str, Any]]
|
1210
|
+
) -> None:
|
1211
|
+
"""
|
1212
|
+
Process a many-to-many relationship for an object.
|
1213
|
+
|
1214
|
+
Args:
|
1215
|
+
session: The database session
|
1216
|
+
parent_obj: The parent object (e.g., Book)
|
1217
|
+
rel_name: The name of the relationship (e.g., 'tags')
|
1218
|
+
items: List of data dictionaries for the related items
|
1219
|
+
|
1220
|
+
Returns:
|
1221
|
+
None
|
1222
|
+
"""
|
1223
|
+
# Get information about this many-to-many relationship
|
1224
|
+
many_to_many_rels = cls._get_many_to_many_relationships()
|
1225
|
+
if rel_name not in many_to_many_rels:
|
1226
|
+
logging.warning(f"Relationship {rel_name} is not a many-to-many relationship")
|
1227
|
+
return
|
1228
|
+
|
1229
|
+
junction_model, target_model = many_to_many_rels[rel_name]
|
1230
|
+
|
1231
|
+
# Get the foreign key fields from the junction model that reference this model and the target model
|
1232
|
+
from async_easy_model.auto_relationships import get_foreign_keys_from_model
|
1233
|
+
foreign_keys = get_foreign_keys_from_model(junction_model)
|
1234
|
+
|
1235
|
+
# Find the foreign key fields for this model and the target model
|
1236
|
+
this_model_fk = None
|
1237
|
+
target_model_fk = None
|
1238
|
+
|
1239
|
+
for fk_field, fk_target in foreign_keys.items():
|
1240
|
+
target_table = fk_target.split('.')[0]
|
1241
|
+
if target_table == cls.__tablename__:
|
1242
|
+
this_model_fk = fk_field
|
1243
|
+
elif target_table == target_model.__tablename__:
|
1244
|
+
target_model_fk = fk_field
|
1245
|
+
|
1246
|
+
if not this_model_fk or not target_model_fk:
|
1247
|
+
logging.warning(f"Could not find foreign key fields for {rel_name} relationship")
|
1248
|
+
return
|
1249
|
+
|
1250
|
+
# Process each related item
|
1251
|
+
for item_data in items:
|
1252
|
+
# First, create or find the target model instance
|
1253
|
+
target_obj = await cls._process_single_relationship_item(
|
1254
|
+
session, target_model, item_data
|
1255
|
+
)
|
1256
|
+
|
1257
|
+
if not target_obj:
|
1258
|
+
logging.warning(f"Failed to process {target_model.__name__} item for {rel_name}")
|
1259
|
+
continue
|
1260
|
+
|
1261
|
+
# Now create a junction record linking the parent to the target
|
1262
|
+
# Check if this link already exists
|
1263
|
+
junction_stmt = select(junction_model).where(
|
1264
|
+
getattr(junction_model, this_model_fk) == parent_obj.id,
|
1265
|
+
getattr(junction_model, target_model_fk) == target_obj.id
|
1266
|
+
)
|
1267
|
+
junction_result = await session.execute(junction_stmt)
|
1268
|
+
existing_junction = junction_result.scalars().first()
|
1269
|
+
|
1270
|
+
if not existing_junction:
|
1271
|
+
# Create new junction
|
1272
|
+
junction_data = {
|
1273
|
+
this_model_fk: parent_obj.id,
|
1274
|
+
target_model_fk: target_obj.id
|
1275
|
+
}
|
1276
|
+
junction_obj = junction_model(**junction_data)
|
1277
|
+
session.add(junction_obj)
|
1278
|
+
logging.info(f"Created junction between {cls.__name__} {parent_obj.id} and {target_model.__name__} {target_obj.id}")
|
1279
|
+
|
1280
|
+
@classmethod
|
1281
|
+
async def _load_relationships_recursively(cls, session, obj, max_depth=2, current_depth=0, visited_ids=None):
|
1282
|
+
"""
|
1283
|
+
Recursively load all relationships for an object and its related objects.
|
1284
|
+
|
1285
|
+
Args:
|
1286
|
+
session: SQLAlchemy session
|
1287
|
+
obj: The object to load relationships for
|
1288
|
+
max_depth: Maximum depth to recurse to prevent infinite loops
|
1289
|
+
current_depth: Current recursion depth (internal use)
|
1290
|
+
visited_ids: Set of already visited object IDs to prevent cycles
|
1291
|
+
|
1292
|
+
Returns:
|
1293
|
+
The object with all relationships loaded
|
1294
|
+
"""
|
1295
|
+
if visited_ids is None:
|
1296
|
+
visited_ids = set()
|
1297
|
+
|
1298
|
+
# Use object ID and class for tracking instead of the object itself (which isn't hashable)
|
1299
|
+
obj_key = (obj.__class__.__name__, obj.id)
|
1300
|
+
|
1301
|
+
# Stop if we've reached max depth or already visited this object
|
1302
|
+
if current_depth >= max_depth or obj_key in visited_ids:
|
1303
|
+
return obj
|
1304
|
+
|
1305
|
+
# Mark as visited to prevent cycles
|
1306
|
+
visited_ids.add(obj_key)
|
1307
|
+
|
1308
|
+
# Load all relationship fields for this object
|
1309
|
+
obj_class = obj.__class__
|
1310
|
+
relationship_fields = obj_class._get_auto_relationship_fields()
|
1311
|
+
|
1312
|
+
# For each relationship, load it and recurse
|
1313
|
+
for rel_name in relationship_fields:
|
1314
|
+
try:
|
1315
|
+
# Fetch the objects using selectinload
|
1316
|
+
stmt = select(obj_class).where(obj_class.id == obj.id)
|
1317
|
+
stmt = stmt.options(selectinload(getattr(obj_class, rel_name)))
|
1318
|
+
result = await session.execute(stmt)
|
1319
|
+
refreshed_obj = result.scalars().first()
|
1320
|
+
|
1321
|
+
# Get the loaded relationship
|
1322
|
+
related_objs = getattr(refreshed_obj, rel_name, None)
|
1323
|
+
|
1324
|
+
# Update the object's relationship
|
1325
|
+
setattr(obj, rel_name, related_objs)
|
1326
|
+
|
1327
|
+
# Skip if no related objects
|
1328
|
+
if related_objs is None:
|
1329
|
+
continue
|
1330
|
+
|
1331
|
+
# Recurse for related objects
|
1332
|
+
if isinstance(related_objs, list):
|
1333
|
+
for related_obj in related_objs:
|
1334
|
+
if hasattr(related_obj.__class__, '_get_auto_relationship_fields'):
|
1335
|
+
# Only recurse if the object has an ID (is persistent)
|
1336
|
+
if hasattr(related_obj, 'id') and related_obj.id is not None:
|
1337
|
+
await cls._load_relationships_recursively(
|
1338
|
+
session,
|
1339
|
+
related_obj,
|
1340
|
+
max_depth,
|
1341
|
+
current_depth + 1,
|
1342
|
+
visited_ids
|
1343
|
+
)
|
1344
|
+
else:
|
1345
|
+
if hasattr(related_objs.__class__, '_get_auto_relationship_fields'):
|
1346
|
+
# Only recurse if the object has an ID (is persistent)
|
1347
|
+
if hasattr(related_objs, 'id') and related_objs.id is not None:
|
1348
|
+
await cls._load_relationships_recursively(
|
1349
|
+
session,
|
1350
|
+
related_objs,
|
1351
|
+
max_depth,
|
1352
|
+
current_depth + 1,
|
1353
|
+
visited_ids
|
1354
|
+
)
|
1355
|
+
except Exception as e:
|
1356
|
+
logging.warning(f"Error loading relationship {rel_name}: {e}")
|
1357
|
+
|
1358
|
+
return obj
|
1359
|
+
|
1211
1360
|
# Register an event listener to update 'updated_at' on instance modifications.
|
1212
1361
|
@event.listens_for(Session, "before_flush")
|
1213
1362
|
def _update_updated_at(session, flush_context, instances):
|