async-easy-model 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- async_easy_model/__init__.py +1 -1
- async_easy_model/auto_relationships.py +132 -2
- async_easy_model/model.py +683 -436
- {async_easy_model-0.2.2.dist-info → async_easy_model-0.2.4.dist-info}/METADATA +42 -12
- async_easy_model-0.2.4.dist-info/RECORD +10 -0
- async_easy_model-0.2.2.dist-info/RECORD +0 -10
- {async_easy_model-0.2.2.dist-info → async_easy_model-0.2.4.dist-info}/LICENSE +0 -0
- {async_easy_model-0.2.2.dist-info → async_easy_model-0.2.4.dist-info}/WHEEL +0 -0
- {async_easy_model-0.2.2.dist-info → async_easy_model-0.2.4.dist-info}/top_level.txt +0 -0
async_easy_model/model.py
CHANGED
@@ -10,6 +10,7 @@ from datetime import datetime, timezone as tz
|
|
10
10
|
import inspect
|
11
11
|
import json
|
12
12
|
import logging
|
13
|
+
import re
|
13
14
|
|
14
15
|
T = TypeVar("T", bound="EasyModel")
|
15
16
|
|
@@ -212,8 +213,9 @@ class EasyModel(SQLModel):
|
|
212
213
|
@classmethod
|
213
214
|
async def all(
|
214
215
|
cls: Type[T],
|
215
|
-
include_relationships: bool = True,
|
216
|
-
order_by: Optional[Union[str, List[str]]] = None
|
216
|
+
include_relationships: bool = True,
|
217
|
+
order_by: Optional[Union[str, List[str]]] = None,
|
218
|
+
max_depth: int = 2
|
217
219
|
) -> List[T]:
|
218
220
|
"""
|
219
221
|
Retrieve all records of this model.
|
@@ -222,29 +224,20 @@ class EasyModel(SQLModel):
|
|
222
224
|
include_relationships: If True, eagerly load all relationships
|
223
225
|
order_by: Field(s) to order by. Can be a string or list of strings.
|
224
226
|
Prefix with '-' for descending order (e.g. '-created_at')
|
227
|
+
max_depth: Maximum depth for loading nested relationships
|
225
228
|
|
226
229
|
Returns:
|
227
230
|
A list of all model instances
|
228
231
|
"""
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
# Apply ordering
|
233
|
-
statement = cls._apply_order_by(statement, order_by)
|
234
|
-
|
235
|
-
if include_relationships:
|
236
|
-
# Get all relationship attributes, including auto-detected ones
|
237
|
-
for rel_name in cls._get_auto_relationship_fields():
|
238
|
-
statement = statement.options(selectinload(getattr(cls, rel_name)))
|
239
|
-
|
240
|
-
result = await session.execute(statement)
|
241
|
-
return result.scalars().all()
|
232
|
+
return await cls.select({}, all=True, include_relationships=include_relationships,
|
233
|
+
order_by=order_by, max_depth=max_depth)
|
242
234
|
|
243
235
|
@classmethod
|
244
236
|
async def first(
|
245
237
|
cls: Type[T],
|
246
|
-
include_relationships: bool = True,
|
247
|
-
order_by: Optional[Union[str, List[str]]] = None
|
238
|
+
include_relationships: bool = True,
|
239
|
+
order_by: Optional[Union[str, List[str]]] = None,
|
240
|
+
max_depth: int = 2
|
248
241
|
) -> Optional[T]:
|
249
242
|
"""
|
250
243
|
Retrieve the first record of this model.
|
@@ -253,56 +246,37 @@ class EasyModel(SQLModel):
|
|
253
246
|
include_relationships: If True, eagerly load all relationships
|
254
247
|
order_by: Field(s) to order by. Can be a string or list of strings.
|
255
248
|
Prefix with '-' for descending order (e.g. '-created_at')
|
249
|
+
max_depth: Maximum depth for loading nested relationships
|
256
250
|
|
257
251
|
Returns:
|
258
252
|
The first model instance or None if no records exist
|
259
253
|
"""
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
# Apply ordering
|
264
|
-
statement = cls._apply_order_by(statement, order_by)
|
265
|
-
|
266
|
-
if include_relationships:
|
267
|
-
# Get all relationship attributes, including auto-detected ones
|
268
|
-
for rel_name in cls._get_auto_relationship_fields():
|
269
|
-
statement = statement.options(selectinload(getattr(cls, rel_name)))
|
270
|
-
|
271
|
-
result = await session.execute(statement)
|
272
|
-
return result.scalars().first()
|
254
|
+
return await cls.select({}, first=True, include_relationships=include_relationships,
|
255
|
+
order_by=order_by, max_depth=max_depth)
|
273
256
|
|
274
257
|
@classmethod
|
275
258
|
async def limit(
|
276
259
|
cls: Type[T],
|
277
260
|
count: int,
|
278
|
-
include_relationships: bool = True,
|
279
|
-
order_by: Optional[Union[str, List[str]]] = None
|
261
|
+
include_relationships: bool = True,
|
262
|
+
order_by: Optional[Union[str, List[str]]] = None,
|
263
|
+
max_depth: int = 2
|
280
264
|
) -> List[T]:
|
281
265
|
"""
|
282
|
-
Retrieve a limited number of records
|
266
|
+
Retrieve a limited number of records.
|
283
267
|
|
284
268
|
Args:
|
285
|
-
count: Maximum number of records to
|
269
|
+
count: Maximum number of records to return
|
286
270
|
include_relationships: If True, eagerly load all relationships
|
287
271
|
order_by: Field(s) to order by. Can be a string or list of strings.
|
288
272
|
Prefix with '-' for descending order (e.g. '-created_at')
|
273
|
+
max_depth: Maximum depth for loading nested relationships
|
289
274
|
|
290
275
|
Returns:
|
291
|
-
A list of model instances
|
276
|
+
A list of model instances
|
292
277
|
"""
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
# Apply ordering
|
297
|
-
statement = cls._apply_order_by(statement, order_by)
|
298
|
-
|
299
|
-
if include_relationships:
|
300
|
-
# Get all relationship attributes, including auto-detected ones
|
301
|
-
for rel_name in cls._get_auto_relationship_fields():
|
302
|
-
statement = statement.options(selectinload(getattr(cls, rel_name)))
|
303
|
-
|
304
|
-
result = await session.execute(statement)
|
305
|
-
return result.scalars().all()
|
278
|
+
return await cls.select({}, all=True, include_relationships=include_relationships,
|
279
|
+
order_by=order_by, limit=count, max_depth=max_depth)
|
306
280
|
|
307
281
|
@classmethod
|
308
282
|
async def get_by_id(cls: Type[T], id: int, include_relationships: bool = True) -> Optional[T]:
|
@@ -327,6 +301,20 @@ class EasyModel(SQLModel):
|
|
327
301
|
else:
|
328
302
|
return await session.get(cls, id)
|
329
303
|
|
304
|
+
@classmethod
|
305
|
+
def _get_unique_fields(cls) -> List[str]:
|
306
|
+
"""
|
307
|
+
Get all fields with unique=True constraint
|
308
|
+
|
309
|
+
Returns:
|
310
|
+
List of field names that have unique constraints
|
311
|
+
"""
|
312
|
+
unique_fields = []
|
313
|
+
for name, field in cls.__fields__.items():
|
314
|
+
if name != 'id' and hasattr(field, "field_info") and field.field_info.extra.get('unique', False):
|
315
|
+
unique_fields.append(name)
|
316
|
+
return unique_fields
|
317
|
+
|
330
318
|
@classmethod
|
331
319
|
async def get_by_attribute(
|
332
320
|
cls: Type[T],
|
@@ -391,193 +379,76 @@ class EasyModel(SQLModel):
|
|
391
379
|
return result.scalars().first()
|
392
380
|
|
393
381
|
@classmethod
|
394
|
-
async def insert(cls: Type[T], data: Union[Dict[str, Any], List[Dict[str, Any]]], include_relationships: bool = True) -> Union[T, List[T]]:
|
382
|
+
async def insert(cls: Type[T], data: Union[Dict[str, Any], List[Dict[str, Any]]], include_relationships: bool = True, max_depth: int = 2) -> Union[T, List[T]]:
|
395
383
|
"""
|
396
384
|
Insert one or more records.
|
397
385
|
|
398
386
|
Args:
|
399
387
|
data: Dictionary of field values or a list of dictionaries for multiple records
|
400
388
|
include_relationships: If True, return the instance(s) with relationships loaded
|
389
|
+
max_depth: Maximum depth for loading nested relationships
|
401
390
|
|
402
391
|
Returns:
|
403
392
|
The created model instance(s)
|
404
393
|
"""
|
405
|
-
|
394
|
+
if not data:
|
395
|
+
return None
|
396
|
+
|
397
|
+
# Handle single dict or list of dicts
|
406
398
|
if isinstance(data, list):
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
processed_item = await cls._process_relationships_for_insert(session, item)
|
413
|
-
|
414
|
-
# Check if a record with unique constraints already exists
|
415
|
-
unique_fields = cls._get_unique_fields()
|
416
|
-
existing_obj = None
|
417
|
-
|
418
|
-
if unique_fields:
|
419
|
-
unique_criteria = {field: processed_item[field]
|
420
|
-
for field in unique_fields
|
421
|
-
if field in processed_item}
|
422
|
-
|
423
|
-
if unique_criteria:
|
424
|
-
# Try to find existing record with these unique values
|
425
|
-
statement = select(cls)
|
426
|
-
for field, value in unique_criteria.items():
|
427
|
-
statement = statement.where(getattr(cls, field) == value)
|
428
|
-
result = await session.execute(statement)
|
429
|
-
existing_obj = result.scalars().first()
|
430
|
-
|
431
|
-
if existing_obj:
|
432
|
-
# Update existing object with new values
|
433
|
-
for key, value in processed_item.items():
|
434
|
-
if key != 'id': # Don't update ID
|
435
|
-
setattr(existing_obj, key, value)
|
436
|
-
objects.append(existing_obj)
|
437
|
-
else:
|
438
|
-
# Create new object
|
439
|
-
obj = cls(**processed_item)
|
440
|
-
session.add(obj)
|
441
|
-
objects.append(obj)
|
442
|
-
except Exception as e:
|
443
|
-
logging.error(f"Error inserting record: {e}")
|
444
|
-
await session.rollback()
|
445
|
-
raise
|
446
|
-
|
447
|
-
try:
|
448
|
-
await session.flush()
|
449
|
-
await session.commit()
|
450
|
-
|
451
|
-
# Refresh with relationships if requested
|
452
|
-
if include_relationships:
|
453
|
-
for obj in objects:
|
454
|
-
await session.refresh(obj)
|
455
|
-
except Exception as e:
|
456
|
-
logging.error(f"Error committing transaction: {e}")
|
457
|
-
await session.rollback()
|
458
|
-
raise
|
459
|
-
|
460
|
-
return objects
|
461
|
-
else:
|
462
|
-
# Single record case
|
463
|
-
async with cls.get_session() as session:
|
464
|
-
try:
|
465
|
-
# Process relationships first
|
466
|
-
processed_data = await cls._process_relationships_for_insert(session, data)
|
467
|
-
|
468
|
-
# Check if a record with unique constraints already exists
|
469
|
-
unique_fields = cls._get_unique_fields()
|
470
|
-
existing_obj = None
|
471
|
-
|
472
|
-
if unique_fields:
|
473
|
-
unique_criteria = {field: processed_data[field]
|
474
|
-
for field in unique_fields
|
475
|
-
if field in processed_data}
|
476
|
-
|
477
|
-
if unique_criteria:
|
478
|
-
# Try to find existing record with these unique values
|
479
|
-
statement = select(cls)
|
480
|
-
for field, value in unique_criteria.items():
|
481
|
-
statement = statement.where(getattr(cls, field) == value)
|
482
|
-
result = await session.execute(statement)
|
483
|
-
existing_obj = result.scalars().first()
|
484
|
-
|
485
|
-
if existing_obj:
|
486
|
-
# Update existing object with new values
|
487
|
-
for key, value in processed_data.items():
|
488
|
-
if key != 'id': # Don't update ID
|
489
|
-
setattr(existing_obj, key, value)
|
490
|
-
obj = existing_obj
|
491
|
-
else:
|
492
|
-
# Create new object
|
493
|
-
obj = cls(**processed_data)
|
494
|
-
session.add(obj)
|
495
|
-
|
496
|
-
await session.flush() # Flush to get the ID
|
497
|
-
await session.commit()
|
498
|
-
|
499
|
-
if include_relationships:
|
500
|
-
# Refresh with relationships
|
501
|
-
statement = select(cls).where(cls.id == obj.id)
|
502
|
-
for rel_name in cls._get_auto_relationship_fields():
|
503
|
-
statement = statement.options(selectinload(getattr(cls, rel_name)))
|
504
|
-
result = await session.execute(statement)
|
505
|
-
return result.scalars().first()
|
506
|
-
else:
|
507
|
-
await session.refresh(obj)
|
508
|
-
return obj
|
509
|
-
except Exception as e:
|
510
|
-
logging.error(f"Error inserting record: {e}")
|
511
|
-
await session.rollback()
|
512
|
-
raise
|
513
|
-
|
514
|
-
@classmethod
|
515
|
-
async def insert_with_related(
|
516
|
-
cls: Type[T],
|
517
|
-
data: Dict[str, Any],
|
518
|
-
related_data: Dict[str, List[Dict[str, Any]]] = None
|
519
|
-
) -> T:
|
520
|
-
"""
|
521
|
-
Create a model instance with related objects in a single transaction.
|
399
|
+
results = []
|
400
|
+
for item in data:
|
401
|
+
result = await cls.insert(item, include_relationships, max_depth)
|
402
|
+
results.append(result)
|
403
|
+
return results
|
522
404
|
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
|
531
|
-
|
532
|
-
|
533
|
-
|
405
|
+
# Store many-to-many relationship data for later processing
|
406
|
+
many_to_many_data = {}
|
407
|
+
many_to_many_rels = cls._get_many_to_many_relationships()
|
408
|
+
|
409
|
+
# Extract many-to-many data before processing other relationships
|
410
|
+
for rel_name in many_to_many_rels:
|
411
|
+
if rel_name in data:
|
412
|
+
many_to_many_data[rel_name] = data[rel_name]
|
413
|
+
|
414
|
+
# Process relationships to convert nested objects to foreign keys
|
534
415
|
async with cls.get_session() as session:
|
535
|
-
|
536
|
-
|
537
|
-
session.add(obj)
|
538
|
-
await session.flush() # Flush to get the ID
|
539
|
-
|
540
|
-
# Create related objects
|
541
|
-
for rel_name, items_data in related_data.items():
|
542
|
-
if not hasattr(cls, rel_name):
|
543
|
-
continue
|
544
|
-
|
545
|
-
rel_attr = getattr(cls, rel_name)
|
546
|
-
if not hasattr(rel_attr, "property"):
|
547
|
-
continue
|
548
|
-
|
549
|
-
# Get the related model class and the back reference attribute
|
550
|
-
related_model = rel_attr.property.mapper.class_
|
551
|
-
back_populates = getattr(rel_attr.property, "back_populates", None)
|
416
|
+
try:
|
417
|
+
processed_data = await cls._process_relationships_for_insert(session, data)
|
552
418
|
|
553
|
-
# Create
|
554
|
-
|
555
|
-
|
556
|
-
|
557
|
-
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
572
|
-
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
|
419
|
+
# Create the model instance
|
420
|
+
obj = cls(**processed_data)
|
421
|
+
session.add(obj)
|
422
|
+
|
423
|
+
# Flush to get the object ID
|
424
|
+
await session.flush()
|
425
|
+
|
426
|
+
# Now process many-to-many relationships if any
|
427
|
+
for rel_name, rel_data in many_to_many_data.items():
|
428
|
+
if isinstance(rel_data, list):
|
429
|
+
await cls._process_many_to_many_relationship(
|
430
|
+
session, obj, rel_name, rel_data
|
431
|
+
)
|
432
|
+
|
433
|
+
# Commit the transaction
|
434
|
+
await session.commit()
|
435
|
+
|
436
|
+
if include_relationships:
|
437
|
+
# Reload with relationships
|
438
|
+
return await cls._load_relationships_recursively(session, obj, max_depth)
|
439
|
+
else:
|
440
|
+
return obj
|
441
|
+
|
442
|
+
except Exception as e:
|
443
|
+
await session.rollback()
|
444
|
+
logging.error(f"Error inserting {cls.__name__}: {e}")
|
445
|
+
if "UNIQUE constraint failed" in str(e):
|
446
|
+
field_match = re.search(r"UNIQUE constraint failed: \w+\.(\w+)", str(e))
|
447
|
+
if field_match:
|
448
|
+
field_name = field_match.group(1)
|
449
|
+
value = data.get(field_name)
|
450
|
+
raise ValueError(f"A record with {field_name}='{value}' already exists")
|
451
|
+
raise
|
581
452
|
|
582
453
|
@classmethod
|
583
454
|
async def _process_relationships_for_insert(cls: Type[T], session: AsyncSession, data: Dict[str, Any]) -> Dict[str, Any]:
|
@@ -591,6 +462,15 @@ class EasyModel(SQLModel):
|
|
591
462
|
"quantity": 2
|
592
463
|
})
|
593
464
|
|
465
|
+
It also handles lists of related objects for one-to-many relationships:
|
466
|
+
publisher = await Publisher.insert({
|
467
|
+
"name": "Example Publisher",
|
468
|
+
"authors": [
|
469
|
+
{"name": "Author 1", "email": "author1@example.com"},
|
470
|
+
{"name": "Author 2", "email": "author2@example.com"}
|
471
|
+
]
|
472
|
+
})
|
473
|
+
|
594
474
|
For each nested object:
|
595
475
|
1. Find the target model class
|
596
476
|
2. Check if an object with the same unique fields already exists
|
@@ -605,175 +485,208 @@ class EasyModel(SQLModel):
|
|
605
485
|
Returns:
|
606
486
|
Processed data dictionary with nested objects replaced by their foreign key IDs
|
607
487
|
"""
|
608
|
-
|
609
|
-
|
488
|
+
if not data:
|
489
|
+
return {}
|
490
|
+
|
491
|
+
result = dict(data)
|
610
492
|
|
611
|
-
# Get
|
493
|
+
# Get relationship fields for this model
|
612
494
|
relationship_fields = cls._get_auto_relationship_fields()
|
613
495
|
|
614
|
-
# Get
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
496
|
+
# Get many-to-many relationships separately
|
497
|
+
many_to_many_relationships = cls._get_many_to_many_relationships()
|
498
|
+
|
499
|
+
# Set of field names already processed as many-to-many relationships
|
500
|
+
processed_m2m_fields = set()
|
501
|
+
|
502
|
+
# Process each relationship field in the input data
|
503
|
+
for key in list(result.keys()):
|
504
|
+
value = result[key]
|
505
|
+
|
506
|
+
# Skip empty values
|
507
|
+
if value is None:
|
625
508
|
continue
|
626
|
-
|
627
|
-
# Check if this is a relationship field (either by name or derived from foreign key)
|
628
|
-
is_rel_field = key in relationship_fields
|
629
|
-
related_key = f"{key}_id"
|
630
|
-
is_derived_rel = related_key in foreign_key_fields
|
631
509
|
|
632
|
-
#
|
633
|
-
if
|
634
|
-
#
|
510
|
+
# Check if this is a relationship field
|
511
|
+
if key in relationship_fields:
|
512
|
+
# Get the related model class
|
635
513
|
related_model = None
|
636
514
|
|
637
|
-
|
638
|
-
if hasattr(cls, key) and hasattr(getattr(cls, key), 'property'):
|
639
|
-
# Get from relationship attribute
|
515
|
+
if hasattr(cls, key):
|
640
516
|
rel_attr = getattr(cls, key)
|
641
|
-
|
642
|
-
|
643
|
-
# Try to find it from foreign key definition
|
644
|
-
fk_definition = None
|
645
|
-
for field_name, field_info in cls.__fields__.items():
|
646
|
-
if field_name == related_key and hasattr(field_info, "field_info"):
|
647
|
-
fk_definition = field_info.field_info.extra.get("foreign_key")
|
648
|
-
break
|
649
|
-
|
650
|
-
if fk_definition:
|
651
|
-
# Parse foreign key definition (e.g. "users.id")
|
652
|
-
target_table, _ = fk_definition.split(".")
|
653
|
-
# Try to find the target model
|
654
|
-
from async_easy_model.auto_relationships import get_model_by_table_name, singularize_name
|
655
|
-
related_model = get_model_by_table_name(target_table)
|
656
|
-
if not related_model:
|
657
|
-
# Try with the singular form
|
658
|
-
singular_table = singularize_name(target_table)
|
659
|
-
related_model = get_model_by_table_name(singular_table)
|
660
|
-
else:
|
661
|
-
# Try to infer from field name (e.g., "user_id" -> Users)
|
662
|
-
base_name = related_key[:-3] # Remove "_id"
|
663
|
-
from async_easy_model.auto_relationships import get_model_by_table_name, singularize_name, pluralize_name
|
664
|
-
|
665
|
-
# Try singular and plural forms
|
666
|
-
related_model = get_model_by_table_name(base_name)
|
667
|
-
if not related_model:
|
668
|
-
plural_table = pluralize_name(base_name)
|
669
|
-
related_model = get_model_by_table_name(plural_table)
|
670
|
-
if not related_model:
|
671
|
-
singular_table = singularize_name(base_name)
|
672
|
-
related_model = get_model_by_table_name(singular_table)
|
517
|
+
if hasattr(rel_attr, 'prop') and hasattr(rel_attr.prop, 'mapper'):
|
518
|
+
related_model = rel_attr.prop.mapper.class_
|
673
519
|
|
674
520
|
if not related_model:
|
675
|
-
logging.warning(f"Could not
|
521
|
+
logging.warning(f"Could not determine related model for {key}, skipping")
|
676
522
|
continue
|
677
523
|
|
678
|
-
#
|
679
|
-
|
680
|
-
|
681
|
-
|
682
|
-
|
683
|
-
unique_fields.append(field_name)
|
684
|
-
|
685
|
-
# Create a search dictionary using unique fields
|
686
|
-
search_dict = {}
|
687
|
-
for field in unique_fields:
|
688
|
-
if field in value and value[field] is not None:
|
689
|
-
search_dict[field] = value[field]
|
690
|
-
|
691
|
-
# If no unique fields found but ID is provided, use it
|
692
|
-
if not search_dict and 'id' in value and value['id']:
|
693
|
-
search_dict = {'id': value['id']}
|
524
|
+
# Check if this is a many-to-many relationship field
|
525
|
+
if key in many_to_many_relationships:
|
526
|
+
# Store this separately - we'll handle it after the main object is created
|
527
|
+
processed_m2m_fields.add(key)
|
528
|
+
continue
|
694
529
|
|
695
|
-
#
|
696
|
-
if
|
697
|
-
|
530
|
+
# Handle different relationship types based on data type
|
531
|
+
if isinstance(value, list):
|
532
|
+
# Handle one-to-many relationship (list of dictionaries)
|
533
|
+
related_ids = []
|
534
|
+
for item in value:
|
535
|
+
if isinstance(item, dict):
|
536
|
+
related_obj = await cls._process_single_relationship_item(
|
537
|
+
session, related_model, item
|
538
|
+
)
|
539
|
+
if related_obj:
|
540
|
+
related_ids.append(related_obj.id)
|
541
|
+
|
542
|
+
# Update result with list of foreign key IDs
|
543
|
+
foreign_key_list_name = f"{key}_ids"
|
544
|
+
result[foreign_key_list_name] = related_ids
|
545
|
+
|
546
|
+
# Remove the relationship list from the result
|
547
|
+
if key in result:
|
548
|
+
del result[key]
|
698
549
|
|
699
|
-
|
700
|
-
|
701
|
-
|
702
|
-
|
550
|
+
elif isinstance(value, dict):
|
551
|
+
# Handle one-to-one relationship (single dictionary)
|
552
|
+
related_obj = await cls._process_single_relationship_item(
|
553
|
+
session, related_model, value
|
554
|
+
)
|
703
555
|
|
704
|
-
|
705
|
-
#
|
706
|
-
|
707
|
-
|
708
|
-
existing_stmt = existing_stmt.where(getattr(related_model, field) == field_value)
|
709
|
-
|
710
|
-
existing_result = await session.execute(existing_stmt)
|
711
|
-
related_obj = existing_result.scalars().first()
|
556
|
+
if related_obj:
|
557
|
+
# Update the result with the foreign key ID
|
558
|
+
foreign_key_name = f"{key}_id"
|
559
|
+
result[foreign_key_name] = related_obj.id
|
712
560
|
|
713
|
-
|
714
|
-
|
715
|
-
|
716
|
-
|
561
|
+
# Remove the relationship dictionary from the result
|
562
|
+
if key in result:
|
563
|
+
del result[key]
|
564
|
+
|
565
|
+
# Remove any processed many-to-many fields from the result
|
566
|
+
# since we'll handle them separately after the object is created
|
567
|
+
for key in processed_m2m_fields:
|
568
|
+
if key in result:
|
569
|
+
del result[key]
|
570
|
+
|
571
|
+
return result
|
572
|
+
|
573
|
+
@classmethod
|
574
|
+
async def _process_single_relationship_item(cls, session: AsyncSession, related_model: Type, item_data: Dict[str, Any]) -> Optional[Any]:
|
575
|
+
"""
|
576
|
+
Process a single relationship item (dictionary).
|
577
|
+
|
578
|
+
This helper method is used by _process_relationships_for_insert to handle
|
579
|
+
both singular relationship objects and items within lists of relationships.
|
580
|
+
|
581
|
+
Args:
|
582
|
+
session: The database session to use
|
583
|
+
related_model: The related model class
|
584
|
+
item_data: Dictionary with field values for the related object
|
585
|
+
|
586
|
+
Returns:
|
587
|
+
The created or found related object, or None if processing failed
|
588
|
+
"""
|
589
|
+
# Look for unique fields in the related model to use for searching
|
590
|
+
unique_fields = []
|
591
|
+
for name, field in related_model.__fields__.items():
|
592
|
+
if (hasattr(field, "field_info") and
|
593
|
+
field.field_info.extra.get('unique', False)):
|
594
|
+
unique_fields.append(name)
|
595
|
+
|
596
|
+
# Create a search dictionary using unique fields
|
597
|
+
search_dict = {}
|
598
|
+
for field in unique_fields:
|
599
|
+
if field in item_data and item_data[field] is not None:
|
600
|
+
search_dict[field] = item_data[field]
|
601
|
+
|
602
|
+
# If no unique fields found but ID is provided, use it
|
603
|
+
if not search_dict and 'id' in item_data and item_data['id']:
|
604
|
+
search_dict = {'id': item_data['id']}
|
605
|
+
|
606
|
+
# Special case for products without uniqueness constraints
|
607
|
+
if not search_dict and related_model.__tablename__ == 'products' and 'name' in item_data:
|
608
|
+
search_dict = {'name': item_data['name']}
|
609
|
+
|
610
|
+
# Try to find an existing record
|
611
|
+
related_obj = None
|
612
|
+
if search_dict:
|
613
|
+
logging.info(f"Searching for existing {related_model.__name__} with {search_dict}")
|
614
|
+
|
615
|
+
try:
|
616
|
+
# Create a more appropriate search query based on unique fields
|
617
|
+
existing_stmt = select(related_model)
|
618
|
+
for field, field_value in search_dict.items():
|
619
|
+
existing_stmt = existing_stmt.where(getattr(related_model, field) == field_value)
|
620
|
+
|
621
|
+
existing_result = await session.execute(existing_stmt)
|
622
|
+
related_obj = existing_result.scalars().first()
|
717
623
|
|
718
624
|
if related_obj:
|
719
|
-
|
720
|
-
|
721
|
-
|
722
|
-
|
723
|
-
|
724
|
-
|
725
|
-
|
726
|
-
|
727
|
-
|
728
|
-
|
729
|
-
# Update non-unique fields
|
730
|
-
current_val = getattr(related_obj, attr, None)
|
731
|
-
if current_val != attr_val:
|
732
|
-
setattr(related_obj, attr, attr_val)
|
625
|
+
logging.info(f"Found existing {related_model.__name__} with ID: {related_obj.id}")
|
626
|
+
except Exception as e:
|
627
|
+
logging.error(f"Error finding existing record: {e}")
|
628
|
+
|
629
|
+
if related_obj:
|
630
|
+
# Update the existing record with any non-unique field values
|
631
|
+
for attr, attr_val in item_data.items():
|
632
|
+
# Skip ID field
|
633
|
+
if attr == 'id':
|
634
|
+
continue
|
733
635
|
|
734
|
-
|
735
|
-
|
736
|
-
|
737
|
-
else:
|
738
|
-
# Create a new record
|
739
|
-
logging.info(f"Creating new {related_model.__name__} for {key}")
|
740
|
-
related_obj = related_model(**value)
|
741
|
-
session.add(related_obj)
|
742
|
-
|
743
|
-
# Ensure the object has an ID by flushing
|
744
|
-
try:
|
745
|
-
await session.flush()
|
746
|
-
except Exception as e:
|
747
|
-
logging.error(f"Error flushing session for {related_model.__name__}: {e}")
|
636
|
+
# Skip unique fields with different values to avoid constraint violations
|
637
|
+
if attr in unique_fields and getattr(related_obj, attr) != attr_val:
|
638
|
+
continue
|
748
639
|
|
749
|
-
|
750
|
-
|
751
|
-
|
752
|
-
|
753
|
-
|
754
|
-
|
755
|
-
|
756
|
-
|
757
|
-
|
758
|
-
|
759
|
-
|
760
|
-
|
761
|
-
|
762
|
-
|
763
|
-
|
764
|
-
|
765
|
-
|
766
|
-
|
640
|
+
# Update non-unique fields
|
641
|
+
current_val = getattr(related_obj, attr, None)
|
642
|
+
if current_val != attr_val:
|
643
|
+
setattr(related_obj, attr, attr_val)
|
644
|
+
|
645
|
+
# Add the updated object to the session
|
646
|
+
session.add(related_obj)
|
647
|
+
logging.info(f"Reusing existing {related_model.__name__} with ID: {related_obj.id}")
|
648
|
+
else:
|
649
|
+
# Create a new record
|
650
|
+
logging.info(f"Creating new {related_model.__name__}")
|
651
|
+
|
652
|
+
# Process nested relationships in this item first
|
653
|
+
if hasattr(related_model, '_process_relationships_for_insert'):
|
654
|
+
# This is a recursive call to process nested relationships
|
655
|
+
processed_item_data = await related_model._process_relationships_for_insert(
|
656
|
+
session, item_data
|
657
|
+
)
|
658
|
+
else:
|
659
|
+
processed_item_data = item_data
|
660
|
+
|
661
|
+
related_obj = related_model(**processed_item_data)
|
662
|
+
session.add(related_obj)
|
663
|
+
|
664
|
+
# Ensure the object has an ID by flushing
|
665
|
+
try:
|
666
|
+
await session.flush()
|
667
|
+
except Exception as e:
|
668
|
+
logging.error(f"Error flushing session for {related_model.__name__}: {e}")
|
669
|
+
|
670
|
+
# If there was a uniqueness error, try again to find the existing record
|
671
|
+
if "UNIQUE constraint failed" in str(e):
|
672
|
+
logging.info(f"UNIQUE constraint failed, trying to find existing record again")
|
767
673
|
|
768
|
-
#
|
769
|
-
|
770
|
-
|
674
|
+
# Try to find by any field provided in the search_dict
|
675
|
+
existing_stmt = select(related_model)
|
676
|
+
for field, field_value in search_dict.items():
|
677
|
+
existing_stmt = existing_stmt.where(getattr(related_model, field) == field_value)
|
771
678
|
|
772
|
-
#
|
773
|
-
|
774
|
-
|
679
|
+
# Execute the search query
|
680
|
+
existing_result = await session.execute(existing_stmt)
|
681
|
+
related_obj = existing_result.scalars().first()
|
682
|
+
|
683
|
+
if not related_obj:
|
684
|
+
# We couldn't find an existing record, re-raise the exception
|
685
|
+
raise
|
686
|
+
|
687
|
+
logging.info(f"Found existing {related_model.__name__} with ID: {related_obj.id} after constraint error")
|
775
688
|
|
776
|
-
return
|
689
|
+
return related_obj
|
777
690
|
|
778
691
|
@classmethod
|
779
692
|
async def update(cls: Type[T], data: Dict[str, Any], criteria: Dict[str, Any], include_relationships: bool = True) -> Optional[T]:
|
@@ -788,6 +701,17 @@ class EasyModel(SQLModel):
|
|
788
701
|
Returns:
|
789
702
|
The updated model instance
|
790
703
|
"""
|
704
|
+
# Store many-to-many relationship data for later processing
|
705
|
+
many_to_many_data = {}
|
706
|
+
many_to_many_rels = cls._get_many_to_many_relationships()
|
707
|
+
|
708
|
+
# Extract many-to-many data before processing
|
709
|
+
for rel_name in many_to_many_rels:
|
710
|
+
if rel_name in data:
|
711
|
+
many_to_many_data[rel_name] = data[rel_name]
|
712
|
+
# Remove from original data
|
713
|
+
del data[rel_name]
|
714
|
+
|
791
715
|
async with cls.get_session() as session:
|
792
716
|
try:
|
793
717
|
# Find the record(s) to update
|
@@ -828,6 +752,74 @@ class EasyModel(SQLModel):
|
|
828
752
|
for key, value in data.items():
|
829
753
|
setattr(record, key, value)
|
830
754
|
|
755
|
+
# Process many-to-many relationships if any
|
756
|
+
for rel_name, rel_data in many_to_many_data.items():
|
757
|
+
if isinstance(rel_data, list):
|
758
|
+
# First, get all existing links for this relation
|
759
|
+
junction_model, target_model = many_to_many_rels[rel_name]
|
760
|
+
|
761
|
+
from async_easy_model.auto_relationships import get_foreign_keys_from_model
|
762
|
+
foreign_keys = get_foreign_keys_from_model(junction_model)
|
763
|
+
|
764
|
+
# Find the foreign key fields for this model and the target model
|
765
|
+
this_model_fk = None
|
766
|
+
target_model_fk = None
|
767
|
+
|
768
|
+
for fk_field, fk_target in foreign_keys.items():
|
769
|
+
target_table = fk_target.split('.')[0]
|
770
|
+
if target_table == cls.__tablename__:
|
771
|
+
this_model_fk = fk_field
|
772
|
+
elif target_table == target_model.__tablename__:
|
773
|
+
target_model_fk = fk_field
|
774
|
+
|
775
|
+
if not this_model_fk or not target_model_fk:
|
776
|
+
logging.warning(f"Could not find foreign key fields for {rel_name} relationship")
|
777
|
+
continue
|
778
|
+
|
779
|
+
# Get all existing junctions for this record
|
780
|
+
junction_stmt = select(junction_model).where(
|
781
|
+
getattr(junction_model, this_model_fk) == record.id
|
782
|
+
)
|
783
|
+
junction_result = await session.execute(junction_stmt)
|
784
|
+
existing_junctions = junction_result.scalars().all()
|
785
|
+
|
786
|
+
# Get the target IDs from the existing junctions
|
787
|
+
existing_target_ids = [getattr(junction, target_model_fk) for junction in existing_junctions]
|
788
|
+
|
789
|
+
# Track processed target IDs
|
790
|
+
processed_target_ids = set()
|
791
|
+
|
792
|
+
# Process each item in rel_data
|
793
|
+
for item_data in rel_data:
|
794
|
+
target_obj = await cls._process_single_relationship_item(
|
795
|
+
session, target_model, item_data
|
796
|
+
)
|
797
|
+
|
798
|
+
if not target_obj:
|
799
|
+
logging.warning(f"Failed to process {target_model.__name__} item for {rel_name}")
|
800
|
+
continue
|
801
|
+
|
802
|
+
processed_target_ids.add(target_obj.id)
|
803
|
+
|
804
|
+
# Check if this link already exists
|
805
|
+
if target_obj.id not in existing_target_ids:
|
806
|
+
# Create new junction
|
807
|
+
junction_data = {
|
808
|
+
this_model_fk: record.id,
|
809
|
+
target_model_fk: target_obj.id
|
810
|
+
}
|
811
|
+
junction_obj = junction_model(**junction_data)
|
812
|
+
session.add(junction_obj)
|
813
|
+
logging.info(f"Created junction between {cls.__name__} {record.id} and {target_model.__name__} {target_obj.id}")
|
814
|
+
|
815
|
+
# Delete junctions for target IDs that weren't in the updated data
|
816
|
+
junctions_to_delete = [j for j in existing_junctions
|
817
|
+
if getattr(j, target_model_fk) not in processed_target_ids]
|
818
|
+
|
819
|
+
for junction in junctions_to_delete:
|
820
|
+
await session.delete(junction)
|
821
|
+
logging.info(f"Deleted junction between {cls.__name__} {record.id} and {target_model.__name__} {getattr(junction, target_model_fk)}")
|
822
|
+
|
831
823
|
await session.flush()
|
832
824
|
await session.commit()
|
833
825
|
|
@@ -841,9 +833,10 @@ class EasyModel(SQLModel):
|
|
841
833
|
else:
|
842
834
|
await session.refresh(record)
|
843
835
|
return record
|
836
|
+
|
844
837
|
except Exception as e:
|
845
|
-
logging.error(f"Error updating record: {e}")
|
846
838
|
await session.rollback()
|
839
|
+
logging.error(f"Error updating {cls.__name__}: {e}")
|
847
840
|
raise
|
848
841
|
|
849
842
|
@classmethod
|
@@ -859,7 +852,7 @@ class EasyModel(SQLModel):
|
|
859
852
|
"""
|
860
853
|
async with cls.get_session() as session:
|
861
854
|
try:
|
862
|
-
# Find
|
855
|
+
# Find records to delete
|
863
856
|
statement = select(cls)
|
864
857
|
for field, value in criteria.items():
|
865
858
|
if isinstance(value, str) and '*' in value:
|
@@ -876,43 +869,50 @@ class EasyModel(SQLModel):
|
|
876
869
|
logging.warning(f"No records found with criteria: {criteria}")
|
877
870
|
return 0
|
878
871
|
|
879
|
-
#
|
880
|
-
|
881
|
-
relationship_fields = cls._get_auto_relationship_fields()
|
882
|
-
to_many_relationships = []
|
883
|
-
|
884
|
-
# Find to-many relationships that need to be handled first
|
885
|
-
for rel_name in relationship_fields:
|
886
|
-
rel_attr = getattr(cls, rel_name, None)
|
887
|
-
if rel_attr and hasattr(rel_attr, 'property'):
|
888
|
-
# Check if this is a to-many relationship (one-to-many or many-to-many)
|
889
|
-
if hasattr(rel_attr.property, 'uselist') and rel_attr.property.uselist:
|
890
|
-
to_many_relationships.append(rel_name)
|
872
|
+
# Check if there are many-to-many relationships that need cleanup
|
873
|
+
many_to_many_rels = cls._get_many_to_many_relationships()
|
891
874
|
|
892
|
-
#
|
875
|
+
# Delete each record and its related many-to-many junction records
|
876
|
+
count = 0
|
893
877
|
for record in records:
|
894
|
-
#
|
895
|
-
|
896
|
-
|
897
|
-
|
898
|
-
|
899
|
-
|
900
|
-
|
901
|
-
|
902
|
-
|
903
|
-
|
878
|
+
# Clean up many-to-many junctions first
|
879
|
+
for rel_name, (junction_model, _) in many_to_many_rels.items():
|
880
|
+
# Get foreign keys from the junction model
|
881
|
+
from async_easy_model.auto_relationships import get_foreign_keys_from_model
|
882
|
+
foreign_keys = get_foreign_keys_from_model(junction_model)
|
883
|
+
|
884
|
+
# Find which foreign key refers to this model
|
885
|
+
this_model_fk = None
|
886
|
+
for fk_field, fk_target in foreign_keys.items():
|
887
|
+
target_table = fk_target.split('.')[0]
|
888
|
+
if target_table == cls.__tablename__:
|
889
|
+
this_model_fk = fk_field
|
890
|
+
break
|
891
|
+
|
892
|
+
if not this_model_fk:
|
893
|
+
continue
|
894
|
+
|
895
|
+
# Delete junction records for this record
|
896
|
+
delete_stmt = select(junction_model).where(
|
897
|
+
getattr(junction_model, this_model_fk) == record.id
|
898
|
+
)
|
899
|
+
junction_result = await session.execute(delete_stmt)
|
900
|
+
junctions = junction_result.scalars().all()
|
901
|
+
|
902
|
+
for junction in junctions:
|
903
|
+
await session.delete(junction)
|
904
|
+
logging.info(f"Deleted junction record for {cls.__name__} id={record.id}")
|
904
905
|
|
905
906
|
# Now delete the main record
|
906
907
|
await session.delete(record)
|
908
|
+
count += 1
|
907
909
|
|
908
|
-
# Commit the changes
|
909
|
-
await session.flush()
|
910
910
|
await session.commit()
|
911
|
+
return count
|
911
912
|
|
912
|
-
return len(records)
|
913
913
|
except Exception as e:
|
914
|
-
logging.error(f"Error deleting records: {e}")
|
915
914
|
await session.rollback()
|
915
|
+
logging.error(f"Error deleting {cls.__name__}: {e}")
|
916
916
|
raise
|
917
917
|
|
918
918
|
def to_dict(self, include_relationships: bool = True, max_depth: int = 4) -> Dict[str, Any]:
|
@@ -991,21 +991,23 @@ class EasyModel(SQLModel):
|
|
991
991
|
first: bool = False,
|
992
992
|
include_relationships: bool = True,
|
993
993
|
order_by: Optional[Union[str, List[str]]] = None,
|
994
|
+
max_depth: int = 2,
|
994
995
|
limit: Optional[int] = None
|
995
996
|
) -> Union[Optional[T], List[T]]:
|
996
997
|
"""
|
997
|
-
|
998
|
+
Select records based on criteria.
|
998
999
|
|
999
1000
|
Args:
|
1000
|
-
criteria: Dictionary of field values to filter by
|
1001
|
-
all: If True, return all matching records,
|
1001
|
+
criteria: Dictionary of field values to filter by
|
1002
|
+
all: If True, return all matching records. If False, return only the first match.
|
1002
1003
|
first: If True, return only the first record (equivalent to all=False)
|
1003
1004
|
include_relationships: If True, eagerly load all relationships
|
1004
1005
|
order_by: Field(s) to order by. Can be a string or list of strings.
|
1005
1006
|
Prefix with '-' for descending order (e.g. '-created_at')
|
1007
|
+
max_depth: Maximum depth for loading nested relationships (when include_relationships=True)
|
1006
1008
|
limit: Maximum number of records to retrieve (if all=True)
|
1007
1009
|
If limit > 1, all is automatically set to True
|
1008
|
-
|
1010
|
+
|
1009
1011
|
Returns:
|
1010
1012
|
A single model instance, a list of instances, or None if not found
|
1011
1013
|
"""
|
@@ -1016,73 +1018,98 @@ class EasyModel(SQLModel):
|
|
1016
1018
|
# If limit is specified and > 1, set all to True
|
1017
1019
|
if limit is not None and limit > 1:
|
1018
1020
|
all = True
|
1021
|
+
|
1019
1022
|
# If first is specified, set all to False (first takes precedence)
|
1020
1023
|
if first:
|
1021
1024
|
all = False
|
1022
|
-
|
1025
|
+
|
1023
1026
|
async with cls.get_session() as session:
|
1024
1027
|
# Build the query
|
1025
1028
|
statement = select(cls)
|
1026
1029
|
|
1027
|
-
# Apply criteria
|
1028
|
-
for
|
1030
|
+
# Apply criteria filters
|
1031
|
+
for key, value in criteria.items():
|
1029
1032
|
if isinstance(value, str) and '*' in value:
|
1030
1033
|
# Handle LIKE queries (convert '*' wildcard to '%')
|
1031
1034
|
like_value = value.replace('*', '%')
|
1032
|
-
statement = statement.where(getattr(cls,
|
1035
|
+
statement = statement.where(getattr(cls, key).like(like_value))
|
1033
1036
|
else:
|
1034
1037
|
# Regular equality check
|
1035
|
-
statement = statement.where(getattr(cls,
|
1038
|
+
statement = statement.where(getattr(cls, key) == value)
|
1036
1039
|
|
1037
1040
|
# Apply ordering
|
1038
1041
|
if order_by:
|
1039
|
-
|
1042
|
+
order_clauses = []
|
1043
|
+
if isinstance(order_by, str):
|
1044
|
+
order_by = [order_by]
|
1045
|
+
|
1046
|
+
for field_name in order_by:
|
1047
|
+
if field_name.startswith("-"):
|
1048
|
+
# Descending order
|
1049
|
+
field_name = field_name[1:] # Remove the "-" prefix
|
1050
|
+
# Handle relationship field ordering with dot notation
|
1051
|
+
if "." in field_name:
|
1052
|
+
rel_name, attr_name = field_name.split(".", 1)
|
1053
|
+
if hasattr(cls, rel_name):
|
1054
|
+
rel_model = getattr(cls, rel_name)
|
1055
|
+
if hasattr(rel_model, "property"):
|
1056
|
+
target_model = rel_model.property.entity.class_
|
1057
|
+
if hasattr(target_model, attr_name):
|
1058
|
+
order_clauses.append(getattr(target_model, attr_name).desc())
|
1059
|
+
else:
|
1060
|
+
order_clauses.append(getattr(cls, field_name).desc())
|
1061
|
+
else:
|
1062
|
+
# Ascending order
|
1063
|
+
# Handle relationship field ordering with dot notation
|
1064
|
+
if "." in field_name:
|
1065
|
+
rel_name, attr_name = field_name.split(".", 1)
|
1066
|
+
if hasattr(cls, rel_name):
|
1067
|
+
rel_model = getattr(cls, rel_name)
|
1068
|
+
if hasattr(rel_model, "property"):
|
1069
|
+
target_model = rel_model.property.entity.class_
|
1070
|
+
if hasattr(target_model, attr_name):
|
1071
|
+
order_clauses.append(getattr(target_model, attr_name).asc())
|
1072
|
+
else:
|
1073
|
+
order_clauses.append(getattr(cls, field_name).asc())
|
1074
|
+
|
1075
|
+
if order_clauses:
|
1076
|
+
statement = statement.order_by(*order_clauses)
|
1040
1077
|
|
1041
1078
|
# Apply limit
|
1042
1079
|
if limit:
|
1043
1080
|
statement = statement.limit(limit)
|
1044
1081
|
|
1045
|
-
#
|
1082
|
+
# Load relationships if requested
|
1046
1083
|
if include_relationships:
|
1047
1084
|
for rel_name in cls._get_auto_relationship_fields():
|
1048
1085
|
statement = statement.options(selectinload(getattr(cls, rel_name)))
|
1049
1086
|
|
1050
|
-
# Execute the query
|
1051
1087
|
result = await session.execute(statement)
|
1052
1088
|
|
1053
1089
|
if all:
|
1054
|
-
|
1055
|
-
instances = result.scalars().all()
|
1090
|
+
objects = result.scalars().all()
|
1056
1091
|
|
1057
|
-
#
|
1058
|
-
if include_relationships:
|
1059
|
-
|
1060
|
-
|
1061
|
-
|
1062
|
-
|
1063
|
-
|
1064
|
-
|
1065
|
-
|
1066
|
-
# Skip if the relationship can't be loaded
|
1067
|
-
pass
|
1092
|
+
# Load nested relationships if requested
|
1093
|
+
if include_relationships and objects and max_depth > 1:
|
1094
|
+
loaded_objects = []
|
1095
|
+
for obj in objects:
|
1096
|
+
loaded_obj = await cls._load_relationships_recursively(
|
1097
|
+
session, obj, max_depth
|
1098
|
+
)
|
1099
|
+
loaded_objects.append(loaded_obj)
|
1100
|
+
return loaded_objects
|
1068
1101
|
|
1069
|
-
return
|
1102
|
+
return objects
|
1070
1103
|
else:
|
1071
|
-
|
1072
|
-
instance = result.scalars().first()
|
1104
|
+
obj = result.scalars().first()
|
1073
1105
|
|
1074
|
-
#
|
1075
|
-
if include_relationships and
|
1076
|
-
|
1077
|
-
|
1078
|
-
|
1079
|
-
|
1080
|
-
|
1081
|
-
except Exception:
|
1082
|
-
# Skip if the relationship can't be loaded
|
1083
|
-
pass
|
1084
|
-
|
1085
|
-
return instance
|
1106
|
+
# Load nested relationships if requested
|
1107
|
+
if include_relationships and obj and max_depth > 1:
|
1108
|
+
obj = await cls._load_relationships_recursively(
|
1109
|
+
session, obj, max_depth
|
1110
|
+
)
|
1111
|
+
|
1112
|
+
return obj
|
1086
1113
|
|
1087
1114
|
@classmethod
|
1088
1115
|
async def get_or_create(cls: Type[T], search_criteria: Dict[str, Any], defaults: Optional[Dict[str, Any]] = None) -> Tuple[T, bool]:
|
@@ -1110,6 +1137,226 @@ class EasyModel(SQLModel):
|
|
1110
1137
|
new_record = await cls.insert(data)
|
1111
1138
|
return new_record, True
|
1112
1139
|
|
1140
|
+
@classmethod
|
1141
|
+
async def insert_with_related(
|
1142
|
+
cls: Type[T],
|
1143
|
+
data: Dict[str, Any],
|
1144
|
+
related_data: Dict[str, List[Dict[str, Any]]] = None
|
1145
|
+
) -> T:
|
1146
|
+
"""
|
1147
|
+
Create a model instance with related objects in a single transaction.
|
1148
|
+
|
1149
|
+
Args:
|
1150
|
+
data: Dictionary of field values for the main model
|
1151
|
+
related_data: Dictionary mapping relationship names to lists of data dictionaries
|
1152
|
+
for creating related objects
|
1153
|
+
|
1154
|
+
Returns:
|
1155
|
+
The created model instance with relationships loaded
|
1156
|
+
"""
|
1157
|
+
if related_data is None:
|
1158
|
+
related_data = {}
|
1159
|
+
|
1160
|
+
# Create a copy of data for modification
|
1161
|
+
insert_data = data.copy()
|
1162
|
+
|
1163
|
+
# Add relationship fields to the data
|
1164
|
+
for rel_name, items_data in related_data.items():
|
1165
|
+
if items_data:
|
1166
|
+
insert_data[rel_name] = items_data
|
1167
|
+
|
1168
|
+
# Use the enhanced insert method to handle all relationships
|
1169
|
+
return await cls.insert(insert_data, include_relationships=True)
|
1170
|
+
|
1171
|
+
@classmethod
|
1172
|
+
def _get_many_to_many_relationships(cls) -> Dict[str, Tuple[Type['EasyModel'], Type['EasyModel']]]:
|
1173
|
+
"""
|
1174
|
+
Get all many-to-many relationships for this model.
|
1175
|
+
|
1176
|
+
Returns:
|
1177
|
+
Dictionary mapping relationship field names to tuples of (junction_model, target_model)
|
1178
|
+
"""
|
1179
|
+
from async_easy_model.auto_relationships import get_model_by_table_name
|
1180
|
+
|
1181
|
+
many_to_many_relationships = {}
|
1182
|
+
|
1183
|
+
# Check if this is a class attribute rather than an instance attribute
|
1184
|
+
relationship_fields = cls._get_auto_relationship_fields()
|
1185
|
+
|
1186
|
+
for rel_name in relationship_fields:
|
1187
|
+
if not hasattr(cls, rel_name):
|
1188
|
+
continue
|
1189
|
+
|
1190
|
+
rel_attr = getattr(cls, rel_name)
|
1191
|
+
|
1192
|
+
# Check if this is a many-to-many relationship by looking for secondary table
|
1193
|
+
if hasattr(rel_attr, 'prop') and hasattr(rel_attr.prop, 'secondary'):
|
1194
|
+
secondary = rel_attr.prop.secondary
|
1195
|
+
if isinstance(secondary, str): # For string table names (our implementation)
|
1196
|
+
junction_model = get_model_by_table_name(secondary)
|
1197
|
+
if junction_model:
|
1198
|
+
target_model = rel_attr.prop.mapper.class_
|
1199
|
+
many_to_many_relationships[rel_name] = (junction_model, target_model)
|
1200
|
+
|
1201
|
+
return many_to_many_relationships
|
1202
|
+
|
1203
|
+
@classmethod
|
1204
|
+
async def _process_many_to_many_relationship(
|
1205
|
+
cls,
|
1206
|
+
session: AsyncSession,
|
1207
|
+
parent_obj: 'EasyModel',
|
1208
|
+
rel_name: str,
|
1209
|
+
items: List[Dict[str, Any]]
|
1210
|
+
) -> None:
|
1211
|
+
"""
|
1212
|
+
Process a many-to-many relationship for an object.
|
1213
|
+
|
1214
|
+
Args:
|
1215
|
+
session: The database session
|
1216
|
+
parent_obj: The parent object (e.g., Book)
|
1217
|
+
rel_name: The name of the relationship (e.g., 'tags')
|
1218
|
+
items: List of data dictionaries for the related items
|
1219
|
+
|
1220
|
+
Returns:
|
1221
|
+
None
|
1222
|
+
"""
|
1223
|
+
# Get information about this many-to-many relationship
|
1224
|
+
many_to_many_rels = cls._get_many_to_many_relationships()
|
1225
|
+
if rel_name not in many_to_many_rels:
|
1226
|
+
logging.warning(f"Relationship {rel_name} is not a many-to-many relationship")
|
1227
|
+
return
|
1228
|
+
|
1229
|
+
junction_model, target_model = many_to_many_rels[rel_name]
|
1230
|
+
|
1231
|
+
# Get the foreign key fields from the junction model that reference this model and the target model
|
1232
|
+
from async_easy_model.auto_relationships import get_foreign_keys_from_model
|
1233
|
+
foreign_keys = get_foreign_keys_from_model(junction_model)
|
1234
|
+
|
1235
|
+
# Find the foreign key fields for this model and the target model
|
1236
|
+
this_model_fk = None
|
1237
|
+
target_model_fk = None
|
1238
|
+
|
1239
|
+
for fk_field, fk_target in foreign_keys.items():
|
1240
|
+
target_table = fk_target.split('.')[0]
|
1241
|
+
if target_table == cls.__tablename__:
|
1242
|
+
this_model_fk = fk_field
|
1243
|
+
elif target_table == target_model.__tablename__:
|
1244
|
+
target_model_fk = fk_field
|
1245
|
+
|
1246
|
+
if not this_model_fk or not target_model_fk:
|
1247
|
+
logging.warning(f"Could not find foreign key fields for {rel_name} relationship")
|
1248
|
+
return
|
1249
|
+
|
1250
|
+
# Process each related item
|
1251
|
+
for item_data in items:
|
1252
|
+
# First, create or find the target model instance
|
1253
|
+
target_obj = await cls._process_single_relationship_item(
|
1254
|
+
session, target_model, item_data
|
1255
|
+
)
|
1256
|
+
|
1257
|
+
if not target_obj:
|
1258
|
+
logging.warning(f"Failed to process {target_model.__name__} item for {rel_name}")
|
1259
|
+
continue
|
1260
|
+
|
1261
|
+
# Now create a junction record linking the parent to the target
|
1262
|
+
# Check if this link already exists
|
1263
|
+
junction_stmt = select(junction_model).where(
|
1264
|
+
getattr(junction_model, this_model_fk) == parent_obj.id,
|
1265
|
+
getattr(junction_model, target_model_fk) == target_obj.id
|
1266
|
+
)
|
1267
|
+
junction_result = await session.execute(junction_stmt)
|
1268
|
+
existing_junction = junction_result.scalars().first()
|
1269
|
+
|
1270
|
+
if not existing_junction:
|
1271
|
+
# Create new junction
|
1272
|
+
junction_data = {
|
1273
|
+
this_model_fk: parent_obj.id,
|
1274
|
+
target_model_fk: target_obj.id
|
1275
|
+
}
|
1276
|
+
junction_obj = junction_model(**junction_data)
|
1277
|
+
session.add(junction_obj)
|
1278
|
+
logging.info(f"Created junction between {cls.__name__} {parent_obj.id} and {target_model.__name__} {target_obj.id}")
|
1279
|
+
|
1280
|
+
@classmethod
|
1281
|
+
async def _load_relationships_recursively(cls, session, obj, max_depth=2, current_depth=0, visited_ids=None):
|
1282
|
+
"""
|
1283
|
+
Recursively load all relationships for an object and its related objects.
|
1284
|
+
|
1285
|
+
Args:
|
1286
|
+
session: SQLAlchemy session
|
1287
|
+
obj: The object to load relationships for
|
1288
|
+
max_depth: Maximum depth to recurse to prevent infinite loops
|
1289
|
+
current_depth: Current recursion depth (internal use)
|
1290
|
+
visited_ids: Set of already visited object IDs to prevent cycles
|
1291
|
+
|
1292
|
+
Returns:
|
1293
|
+
The object with all relationships loaded
|
1294
|
+
"""
|
1295
|
+
if visited_ids is None:
|
1296
|
+
visited_ids = set()
|
1297
|
+
|
1298
|
+
# Use object ID and class for tracking instead of the object itself (which isn't hashable)
|
1299
|
+
obj_key = (obj.__class__.__name__, obj.id)
|
1300
|
+
|
1301
|
+
# Stop if we've reached max depth or already visited this object
|
1302
|
+
if current_depth >= max_depth or obj_key in visited_ids:
|
1303
|
+
return obj
|
1304
|
+
|
1305
|
+
# Mark as visited to prevent cycles
|
1306
|
+
visited_ids.add(obj_key)
|
1307
|
+
|
1308
|
+
# Load all relationship fields for this object
|
1309
|
+
obj_class = obj.__class__
|
1310
|
+
relationship_fields = obj_class._get_auto_relationship_fields()
|
1311
|
+
|
1312
|
+
# For each relationship, load it and recurse
|
1313
|
+
for rel_name in relationship_fields:
|
1314
|
+
try:
|
1315
|
+
# Fetch the objects using selectinload
|
1316
|
+
stmt = select(obj_class).where(obj_class.id == obj.id)
|
1317
|
+
stmt = stmt.options(selectinload(getattr(obj_class, rel_name)))
|
1318
|
+
result = await session.execute(stmt)
|
1319
|
+
refreshed_obj = result.scalars().first()
|
1320
|
+
|
1321
|
+
# Get the loaded relationship
|
1322
|
+
related_objs = getattr(refreshed_obj, rel_name, None)
|
1323
|
+
|
1324
|
+
# Update the object's relationship
|
1325
|
+
setattr(obj, rel_name, related_objs)
|
1326
|
+
|
1327
|
+
# Skip if no related objects
|
1328
|
+
if related_objs is None:
|
1329
|
+
continue
|
1330
|
+
|
1331
|
+
# Recurse for related objects
|
1332
|
+
if isinstance(related_objs, list):
|
1333
|
+
for related_obj in related_objs:
|
1334
|
+
if hasattr(related_obj.__class__, '_get_auto_relationship_fields'):
|
1335
|
+
# Only recurse if the object has an ID (is persistent)
|
1336
|
+
if hasattr(related_obj, 'id') and related_obj.id is not None:
|
1337
|
+
await cls._load_relationships_recursively(
|
1338
|
+
session,
|
1339
|
+
related_obj,
|
1340
|
+
max_depth,
|
1341
|
+
current_depth + 1,
|
1342
|
+
visited_ids
|
1343
|
+
)
|
1344
|
+
else:
|
1345
|
+
if hasattr(related_objs.__class__, '_get_auto_relationship_fields'):
|
1346
|
+
# Only recurse if the object has an ID (is persistent)
|
1347
|
+
if hasattr(related_objs, 'id') and related_objs.id is not None:
|
1348
|
+
await cls._load_relationships_recursively(
|
1349
|
+
session,
|
1350
|
+
related_objs,
|
1351
|
+
max_depth,
|
1352
|
+
current_depth + 1,
|
1353
|
+
visited_ids
|
1354
|
+
)
|
1355
|
+
except Exception as e:
|
1356
|
+
logging.warning(f"Error loading relationship {rel_name}: {e}")
|
1357
|
+
|
1358
|
+
return obj
|
1359
|
+
|
1113
1360
|
# Register an event listener to update 'updated_at' on instance modifications.
|
1114
1361
|
@event.listens_for(Session, "before_flush")
|
1115
1362
|
def _update_updated_at(session, flush_context, instances):
|