fastapi-rtk 0.2.60__py3-none-any.whl → 1.0.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fastapi_rtk/__init__.py +0 -1
- fastapi_rtk/_version.py +1 -0
- fastapi_rtk/api/model_rest_api.py +182 -87
- fastapi_rtk/auth/auth.py +0 -9
- fastapi_rtk/backends/sqla/db.py +32 -7
- fastapi_rtk/backends/sqla/filters.py +16 -0
- fastapi_rtk/backends/sqla/interface.py +11 -62
- fastapi_rtk/backends/sqla/model.py +16 -1
- fastapi_rtk/bases/db.py +20 -2
- fastapi_rtk/bases/file_manager.py +12 -0
- fastapi_rtk/bases/filter.py +1 -1
- fastapi_rtk/cli/cli.py +61 -0
- fastapi_rtk/cli/commands/security.py +6 -6
- fastapi_rtk/const.py +1 -1
- fastapi_rtk/db.py +3 -0
- fastapi_rtk/dependencies.py +110 -64
- fastapi_rtk/fastapi_react_toolkit.py +123 -172
- fastapi_rtk/file_managers/s3_file_manager.py +63 -32
- fastapi_rtk/lang/messages.pot +12 -12
- fastapi_rtk/lang/translations/de/LC_MESSAGES/messages.mo +0 -0
- fastapi_rtk/lang/translations/de/LC_MESSAGES/messages.po +12 -12
- fastapi_rtk/lang/translations/en/LC_MESSAGES/messages.mo +0 -0
- fastapi_rtk/lang/translations/en/LC_MESSAGES/messages.po +12 -12
- fastapi_rtk/manager.py +10 -14
- fastapi_rtk/schemas.py +6 -4
- fastapi_rtk/security/sqla/apis.py +20 -5
- fastapi_rtk/security/sqla/models.py +8 -23
- fastapi_rtk/security/sqla/security_manager.py +367 -10
- fastapi_rtk/utils/async_task_runner.py +119 -30
- fastapi_rtk/utils/csv_json_converter.py +242 -39
- fastapi_rtk/utils/hooks.py +7 -4
- fastapi_rtk/utils/self_dependencies.py +1 -1
- fastapi_rtk/version.py +6 -1
- fastapi_rtk-1.0.18.dist-info/METADATA +28 -0
- {fastapi_rtk-0.2.60.dist-info → fastapi_rtk-1.0.18.dist-info}/RECORD +38 -38
- {fastapi_rtk-0.2.60.dist-info → fastapi_rtk-1.0.18.dist-info}/WHEEL +1 -2
- fastapi_rtk-0.2.60.dist-info/METADATA +0 -25
- fastapi_rtk-0.2.60.dist-info/top_level.txt +0 -1
- {fastapi_rtk-0.2.60.dist-info → fastapi_rtk-1.0.18.dist-info}/entry_points.txt +0 -0
- {fastapi_rtk-0.2.60.dist-info → fastapi_rtk-1.0.18.dist-info}/licenses/LICENSE +0 -0
|
@@ -3,6 +3,8 @@ import typing
|
|
|
3
3
|
from datetime import datetime
|
|
4
4
|
|
|
5
5
|
import fastapi_users.exceptions
|
|
6
|
+
import sqlalchemy
|
|
7
|
+
import sqlalchemy.orm
|
|
6
8
|
from pydantic import BaseModel, Field
|
|
7
9
|
from sqlalchemy import select
|
|
8
10
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
@@ -11,7 +13,7 @@ from sqlalchemy.orm import Session
|
|
|
11
13
|
from ...const import logger
|
|
12
14
|
from ...globals import g
|
|
13
15
|
from ...setting import Setting
|
|
14
|
-
from ...utils import merge_schema, safe_call
|
|
16
|
+
from ...utils import T, lazy, merge_schema, safe_call
|
|
15
17
|
from .models import Api, Permission, PermissionApi, Role, User
|
|
16
18
|
|
|
17
19
|
__all__ = ["SecurityManager"]
|
|
@@ -28,6 +30,21 @@ class SecurityManager:
|
|
|
28
30
|
"""
|
|
29
31
|
|
|
30
32
|
toolkit: typing.Optional["FastAPIReactToolkit"]
|
|
33
|
+
builtin_roles = lazy(lambda: Setting.ROLES)
|
|
34
|
+
"""
|
|
35
|
+
The built-in roles defined in the settings.
|
|
36
|
+
|
|
37
|
+
Format:
|
|
38
|
+
```python
|
|
39
|
+
{
|
|
40
|
+
"role_name": [
|
|
41
|
+
(api_name1|api_name2|..., permission_name1|permission_name2|...),
|
|
42
|
+
...
|
|
43
|
+
],
|
|
44
|
+
...
|
|
45
|
+
}
|
|
46
|
+
```
|
|
47
|
+
"""
|
|
31
48
|
|
|
32
49
|
def __init__(self, toolkit: typing.Optional["FastAPIReactToolkit"] = None) -> None:
|
|
33
50
|
self.toolkit = toolkit
|
|
@@ -385,6 +402,273 @@ class SecurityManager:
|
|
|
385
402
|
-----------------------------------------
|
|
386
403
|
"""
|
|
387
404
|
|
|
405
|
+
def has_access_in_builtin_roles(
|
|
406
|
+
self, role_name: str, api_name: str, permission_name: str
|
|
407
|
+
):
|
|
408
|
+
"""
|
|
409
|
+
Checks if the given role has access to the specified API and permission in the built-in roles.
|
|
410
|
+
|
|
411
|
+
Args:
|
|
412
|
+
role_name (str): The name of the role to check.
|
|
413
|
+
api_name (str): The name of the API to check.
|
|
414
|
+
permission_name (str): The name of the permission to check.
|
|
415
|
+
|
|
416
|
+
Returns:
|
|
417
|
+
bool: True if the role has access, False otherwise.
|
|
418
|
+
"""
|
|
419
|
+
if role_name not in self.builtin_roles:
|
|
420
|
+
return False
|
|
421
|
+
|
|
422
|
+
for api, perm in self.builtin_roles[role_name]:
|
|
423
|
+
api_names = api.split("|")
|
|
424
|
+
perm_names = perm.split("|")
|
|
425
|
+
if api_name in api_names and permission_name in perm_names:
|
|
426
|
+
return True
|
|
427
|
+
return False
|
|
428
|
+
|
|
429
|
+
def get_roles_from_builtin_roles(self):
|
|
430
|
+
"""
|
|
431
|
+
Retrieves the names of the built-in roles.
|
|
432
|
+
|
|
433
|
+
Returns:
|
|
434
|
+
list[str]: The list of built-in role names.
|
|
435
|
+
"""
|
|
436
|
+
return list(self.builtin_roles.keys())
|
|
437
|
+
|
|
438
|
+
def get_api_permission_tuples_from_builtin_roles(self, role: str | None = None):
|
|
439
|
+
"""
|
|
440
|
+
Retrieves the API-permission tuples from the built-in roles.
|
|
441
|
+
|
|
442
|
+
Args:
|
|
443
|
+
role (str | None, optional): The name of the role to filter by. If None, retrieves from all roles. Defaults to None.
|
|
444
|
+
|
|
445
|
+
Returns:
|
|
446
|
+
list[tuple[str, str]]: The list of API-permission tuples.
|
|
447
|
+
"""
|
|
448
|
+
api_permission_tuples = list[tuple[str, str]]()
|
|
449
|
+
for role_name, role_api_permission_list in self.builtin_roles.items():
|
|
450
|
+
if role is not None and role != role_name:
|
|
451
|
+
continue
|
|
452
|
+
|
|
453
|
+
for api, perm in role_api_permission_list:
|
|
454
|
+
api_permission_tuples.append((api, perm))
|
|
455
|
+
return api_permission_tuples
|
|
456
|
+
|
|
457
|
+
def get_role_and_api_permission_tuples_from_builtin_roles(self):
|
|
458
|
+
"""
|
|
459
|
+
Retrieves the role and API-permission tuples from the built-in roles.
|
|
460
|
+
|
|
461
|
+
Returns:
|
|
462
|
+
list[tuple[str, list[tuple[str, str]]]]: The list of role and API-permission tuples.
|
|
463
|
+
"""
|
|
464
|
+
role_api_permission_tuples = list[tuple[str, list[tuple[str, str]]]]()
|
|
465
|
+
for role_name, role_api_permission_list in self.builtin_roles.items():
|
|
466
|
+
api_permission_list = list[tuple[str, str]]()
|
|
467
|
+
for api, perm in role_api_permission_list:
|
|
468
|
+
api_permission_list.append((api, perm))
|
|
469
|
+
role_api_permission_tuples.append((role_name, api_permission_list))
|
|
470
|
+
return role_api_permission_tuples
|
|
471
|
+
|
|
472
|
+
async def create_roles(
|
|
473
|
+
self,
|
|
474
|
+
roles: list[str],
|
|
475
|
+
*,
|
|
476
|
+
session: AsyncSession | Session = None,
|
|
477
|
+
raise_exception: bool = True,
|
|
478
|
+
):
|
|
479
|
+
"""
|
|
480
|
+
Creates new roles with the given names. Existing roles are not duplicated.
|
|
481
|
+
|
|
482
|
+
Args:
|
|
483
|
+
roles (list[str]): The names of the roles to create.
|
|
484
|
+
session (AsyncSession | Session, optional): The database session to use. If not given, a new session will be created. Defaults to None.
|
|
485
|
+
raise_exception (bool, optional): When set to True, raises an exception if an error occurs, otherwise returns None. Defaults to True.
|
|
486
|
+
|
|
487
|
+
Returns:
|
|
488
|
+
list[Role] | None: The created role objects if successful, else None.
|
|
489
|
+
|
|
490
|
+
Raises:
|
|
491
|
+
SomeException: Description of the exception raised, if any.
|
|
492
|
+
"""
|
|
493
|
+
return await self._create_entities(
|
|
494
|
+
Role,
|
|
495
|
+
roles,
|
|
496
|
+
session=session,
|
|
497
|
+
raise_exception=raise_exception,
|
|
498
|
+
on_after_create=lambda role: logger.info(f"ADDING ROLE {role}"),
|
|
499
|
+
)
|
|
500
|
+
|
|
501
|
+
async def create_permissions(
|
|
502
|
+
self,
|
|
503
|
+
permissions: list[str],
|
|
504
|
+
*,
|
|
505
|
+
session: AsyncSession | Session = None,
|
|
506
|
+
raise_exception: bool = True,
|
|
507
|
+
):
|
|
508
|
+
"""
|
|
509
|
+
Creates new permissions with the given names. Existing permissions are not duplicated.
|
|
510
|
+
|
|
511
|
+
Args:
|
|
512
|
+
permissions (list[str]): The names of the permissions to create.
|
|
513
|
+
session (AsyncSession | Session, optional): The database session to use. If not given, a new session will be created. Defaults to None.
|
|
514
|
+
raise_exception (bool, optional): When set to True, raises an exception if an error occurs, otherwise returns None. Defaults to True.
|
|
515
|
+
|
|
516
|
+
Returns:
|
|
517
|
+
list[Permission] | None: The created permission objects if successful, else None.
|
|
518
|
+
|
|
519
|
+
Raises:
|
|
520
|
+
SomeException: Description of the exception raised, if any.
|
|
521
|
+
"""
|
|
522
|
+
return await self._create_entities(
|
|
523
|
+
Permission,
|
|
524
|
+
permissions,
|
|
525
|
+
session=session,
|
|
526
|
+
raise_exception=raise_exception,
|
|
527
|
+
on_after_create=lambda permission: logger.info(
|
|
528
|
+
f"ADDING PERMISSION {permission}"
|
|
529
|
+
),
|
|
530
|
+
)
|
|
531
|
+
|
|
532
|
+
async def create_apis(
|
|
533
|
+
self,
|
|
534
|
+
apis: list[str],
|
|
535
|
+
*,
|
|
536
|
+
session: AsyncSession | Session = None,
|
|
537
|
+
raise_exception: bool = True,
|
|
538
|
+
):
|
|
539
|
+
"""
|
|
540
|
+
Creates new APIs with the given names. Existing APIs are not duplicated.
|
|
541
|
+
|
|
542
|
+
Args:
|
|
543
|
+
apis (list[str]): The names of the APIs to create.
|
|
544
|
+
session (AsyncSession | Session, optional): The database session to use. If not given, a new session will be created. Defaults to None.
|
|
545
|
+
raise_exception (bool, optional): When set to True, raises an exception if an error occurs, otherwise returns None. Defaults to True.
|
|
546
|
+
|
|
547
|
+
Returns:
|
|
548
|
+
list[Api] | None: The created API objects if successful, else None.
|
|
549
|
+
|
|
550
|
+
Raises:
|
|
551
|
+
SomeException: Description of the exception raised, if any.
|
|
552
|
+
"""
|
|
553
|
+
return await self._create_entities(
|
|
554
|
+
Api,
|
|
555
|
+
apis,
|
|
556
|
+
session=session,
|
|
557
|
+
raise_exception=raise_exception,
|
|
558
|
+
on_after_create=lambda api: logger.info(f"ADDING API {api}"),
|
|
559
|
+
)
|
|
560
|
+
|
|
561
|
+
async def associate_list_of_permission_with_api(
|
|
562
|
+
self,
|
|
563
|
+
permission_api_tuples: list[tuple[Permission, Api]],
|
|
564
|
+
*,
|
|
565
|
+
session: AsyncSession | Session = None,
|
|
566
|
+
raise_exception: bool = True,
|
|
567
|
+
):
|
|
568
|
+
"""
|
|
569
|
+
Associates a list of permissions with APIs. Existing associations are not duplicated.
|
|
570
|
+
|
|
571
|
+
Args:
|
|
572
|
+
permission_api_tuples (list[tuple[Permission, Api]]): A list of tuples containing Permission and Api objects to associate.
|
|
573
|
+
session (AsyncSession | Session, optional): The database session to use. If not given, a new session will be created. Defaults to None.
|
|
574
|
+
raise_exception (bool, optional): When set to True, raises an exception if an error occurs, otherwise returns None. Defaults to True.
|
|
575
|
+
|
|
576
|
+
Raises:
|
|
577
|
+
SomeException: Description of the exception raised, if any.
|
|
578
|
+
"""
|
|
579
|
+
try:
|
|
580
|
+
if not session:
|
|
581
|
+
async with db.session() as session:
|
|
582
|
+
await self.associate_list_of_permission_with_api(
|
|
583
|
+
permission_api_tuples,
|
|
584
|
+
session=session,
|
|
585
|
+
raise_exception=raise_exception,
|
|
586
|
+
)
|
|
587
|
+
return
|
|
588
|
+
|
|
589
|
+
conditions = []
|
|
590
|
+
query = select(PermissionApi).options(
|
|
591
|
+
sqlalchemy.orm.joinedload(PermissionApi.api),
|
|
592
|
+
sqlalchemy.orm.joinedload(PermissionApi.permission),
|
|
593
|
+
sqlalchemy.orm.selectinload(PermissionApi.roles),
|
|
594
|
+
)
|
|
595
|
+
for permission, api in permission_api_tuples:
|
|
596
|
+
conditions.append(
|
|
597
|
+
sqlalchemy.and_(
|
|
598
|
+
PermissionApi.permission_id == permission.id,
|
|
599
|
+
PermissionApi.api_id == api.id,
|
|
600
|
+
)
|
|
601
|
+
)
|
|
602
|
+
query = query.where(sqlalchemy.or_(*conditions))
|
|
603
|
+
result = await safe_call(session.scalars(query))
|
|
604
|
+
existing_permission_apis = list(result.all())
|
|
605
|
+
|
|
606
|
+
new_tuples = list[tuple[Permission, Api]]()
|
|
607
|
+
for permission, api in permission_api_tuples:
|
|
608
|
+
exists = False
|
|
609
|
+
for permission_api in existing_permission_apis:
|
|
610
|
+
if (
|
|
611
|
+
permission_api.permission.id == permission.id
|
|
612
|
+
and permission_api.api.id == api.id
|
|
613
|
+
):
|
|
614
|
+
exists = True
|
|
615
|
+
break
|
|
616
|
+
if not exists:
|
|
617
|
+
new_tuples.append((permission, api))
|
|
618
|
+
|
|
619
|
+
new_permission_apis = list[PermissionApi]()
|
|
620
|
+
for permission, api in new_tuples:
|
|
621
|
+
permission_api = PermissionApi(permission=permission, api=api, roles=[])
|
|
622
|
+
session.add(permission_api)
|
|
623
|
+
new_permission_apis.append(permission_api)
|
|
624
|
+
logger.info(f"ASSOCIATING PERMISSION {permission} WITH API {api}")
|
|
625
|
+
await safe_call(session.commit())
|
|
626
|
+
return existing_permission_apis + new_permission_apis
|
|
627
|
+
except Exception as e:
|
|
628
|
+
if not raise_exception:
|
|
629
|
+
return
|
|
630
|
+
raise e
|
|
631
|
+
|
|
632
|
+
async def associate_list_of_role_with_permission_api(
|
|
633
|
+
self,
|
|
634
|
+
role_permission_api_tuples: list[tuple[Role, PermissionApi]],
|
|
635
|
+
*,
|
|
636
|
+
session: AsyncSession | Session = None,
|
|
637
|
+
raise_exception: bool = True,
|
|
638
|
+
):
|
|
639
|
+
"""
|
|
640
|
+
Associates a list of roles with permission APIs. Existing associations are not duplicated.
|
|
641
|
+
|
|
642
|
+
Args:
|
|
643
|
+
role_permission_api_tuples (list[tuple[Role, PermissionApi]]): A list of tuples containing Role and PermissionApi objects to associate.
|
|
644
|
+
session (AsyncSession | Session, optional): The database session to use. If not given, a new session will be created. Defaults to None.
|
|
645
|
+
raise_exception (bool, optional): When set to True, raises an exception if an error occurs, otherwise returns None. Defaults to True.
|
|
646
|
+
|
|
647
|
+
Raises:
|
|
648
|
+
SomeException: Description of the exception raised, if any.
|
|
649
|
+
"""
|
|
650
|
+
try:
|
|
651
|
+
if not session:
|
|
652
|
+
async with db.session() as session:
|
|
653
|
+
await self.associate_list_of_role_with_permission_api(
|
|
654
|
+
role_permission_api_tuples,
|
|
655
|
+
session=session,
|
|
656
|
+
raise_exception=raise_exception,
|
|
657
|
+
)
|
|
658
|
+
return
|
|
659
|
+
|
|
660
|
+
for role, permission_api in role_permission_api_tuples:
|
|
661
|
+
if role not in permission_api.roles:
|
|
662
|
+
permission_api.roles.append(role)
|
|
663
|
+
logger.info(
|
|
664
|
+
f"ASSOCIATING ROLE {role} WITH PERMISSION API {permission_api}"
|
|
665
|
+
)
|
|
666
|
+
await safe_call(session.commit())
|
|
667
|
+
except Exception as e:
|
|
668
|
+
if not raise_exception:
|
|
669
|
+
return
|
|
670
|
+
raise e
|
|
671
|
+
|
|
388
672
|
async def cleanup(self, *, session: AsyncSession | Session = None):
|
|
389
673
|
"""
|
|
390
674
|
Cleanup unused permissions from apis and roles.
|
|
@@ -404,13 +688,12 @@ class SecurityManager:
|
|
|
404
688
|
"FastAPIReactToolkit instance not provided, you must provide it to use this function."
|
|
405
689
|
)
|
|
406
690
|
|
|
407
|
-
api_permission_tuples =
|
|
691
|
+
api_permission_tuples = self.get_api_permission_tuples_from_builtin_roles()
|
|
408
692
|
apis = [api.__class__.__name__ for api in self.toolkit.apis]
|
|
409
693
|
permissions = self.toolkit.total_permissions()
|
|
410
|
-
for
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
permissions.append(permission)
|
|
694
|
+
for api, permission in api_permission_tuples:
|
|
695
|
+
apis.append(api)
|
|
696
|
+
permissions.append(permission)
|
|
414
697
|
|
|
415
698
|
# Clean up unused permissions
|
|
416
699
|
unused_permissions = await safe_call(
|
|
@@ -428,14 +711,20 @@ class SecurityManager:
|
|
|
428
711
|
logger.info(f"DELETING API {api} AND ITS ASSOCIATIONS")
|
|
429
712
|
await safe_call(session.delete(api))
|
|
430
713
|
|
|
431
|
-
roles =
|
|
714
|
+
roles = self.get_roles_from_builtin_roles()
|
|
432
715
|
if g.admin_role is not None:
|
|
433
716
|
roles.append(g.admin_role)
|
|
434
717
|
|
|
435
718
|
# Clean up existing permission-apis, that are no longer connected to any roles
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
719
|
+
permission_apis_in_db = await safe_call(
|
|
720
|
+
session.scalars(
|
|
721
|
+
select(PermissionApi).options(
|
|
722
|
+
sqlalchemy.orm.selectinload(PermissionApi.roles)
|
|
723
|
+
)
|
|
724
|
+
)
|
|
725
|
+
)
|
|
726
|
+
for permission_api in permission_apis_in_db:
|
|
727
|
+
for role in permission_api.roles:
|
|
439
728
|
if role.name not in roles:
|
|
440
729
|
permission_api.roles.remove(role)
|
|
441
730
|
logger.info(
|
|
@@ -443,3 +732,71 @@ class SecurityManager:
|
|
|
443
732
|
)
|
|
444
733
|
|
|
445
734
|
await safe_call(session.commit())
|
|
735
|
+
|
|
736
|
+
"""
|
|
737
|
+
-----------------------------------------
|
|
738
|
+
HELPER FUNCTIONS
|
|
739
|
+
-----------------------------------------
|
|
740
|
+
"""
|
|
741
|
+
|
|
742
|
+
async def _create_entities(
|
|
743
|
+
self,
|
|
744
|
+
entity_class: typing.Type[T],
|
|
745
|
+
names: list[str],
|
|
746
|
+
*,
|
|
747
|
+
session: AsyncSession | Session = None,
|
|
748
|
+
raise_exception: bool = True,
|
|
749
|
+
on_after_create: typing.Optional[
|
|
750
|
+
typing.Callable[[T], typing.Awaitable[None] | None]
|
|
751
|
+
] = None,
|
|
752
|
+
):
|
|
753
|
+
"""
|
|
754
|
+
Helper function to create new entities with the given names.
|
|
755
|
+
Existing entities are not duplicated.
|
|
756
|
+
|
|
757
|
+
Args:
|
|
758
|
+
entity_class (typing.Type[T]): The entity class to create.
|
|
759
|
+
names (list[str]): The names of the entities to create.
|
|
760
|
+
session (AsyncSession | Session, optional): The database session to use. If not given, a new session will be created. Defaults to None.
|
|
761
|
+
raise_exception (bool, optional): When set to True, raises an exception if an error occurs, otherwise returns None. Defaults to True.
|
|
762
|
+
on_after_create (Optional[Callable[[T], Awaitable[None] | None]], optional): An optional callback function to be called after each entity is created. Defaults to None.
|
|
763
|
+
|
|
764
|
+
Returns:
|
|
765
|
+
list[T] | None: The created entity objects if successful, else None.
|
|
766
|
+
|
|
767
|
+
Raises:
|
|
768
|
+
Exception: If an error occurs and raise_exception is True.
|
|
769
|
+
"""
|
|
770
|
+
try:
|
|
771
|
+
if not session:
|
|
772
|
+
async with db.session() as session:
|
|
773
|
+
return await self._create_entities(
|
|
774
|
+
entity_class,
|
|
775
|
+
names,
|
|
776
|
+
session=session,
|
|
777
|
+
raise_exception=raise_exception,
|
|
778
|
+
)
|
|
779
|
+
|
|
780
|
+
# Check for existing entities
|
|
781
|
+
result = await safe_call(
|
|
782
|
+
session.scalars(
|
|
783
|
+
select(entity_class).where(entity_class.name.in_(names))
|
|
784
|
+
)
|
|
785
|
+
)
|
|
786
|
+
existing_entities = list(result.all())
|
|
787
|
+
existing_names = [entity.name for entity in existing_entities]
|
|
788
|
+
|
|
789
|
+
new_names = [name for name in names if name not in existing_names]
|
|
790
|
+
new_entities = list[T]()
|
|
791
|
+
for name in new_names:
|
|
792
|
+
entity = entity_class(name=name)
|
|
793
|
+
session.add(entity)
|
|
794
|
+
new_entities.append(entity)
|
|
795
|
+
if on_after_create:
|
|
796
|
+
await safe_call(on_after_create(entity))
|
|
797
|
+
await safe_call(session.commit())
|
|
798
|
+
return existing_entities + new_entities
|
|
799
|
+
except Exception as e:
|
|
800
|
+
if not raise_exception:
|
|
801
|
+
return
|
|
802
|
+
raise e
|
|
@@ -2,12 +2,15 @@ import asyncio
|
|
|
2
2
|
import contextvars
|
|
3
3
|
import functools
|
|
4
4
|
import inspect
|
|
5
|
+
import traceback as tb
|
|
5
6
|
import typing
|
|
6
7
|
|
|
7
8
|
from .prettify_dict import prettify_dict
|
|
8
9
|
|
|
9
10
|
__all__ = ["AsyncTaskRunner"]
|
|
10
11
|
|
|
12
|
+
T = typing.TypeVar("T")
|
|
13
|
+
|
|
11
14
|
|
|
12
15
|
class CallerInfo(typing.TypedDict):
|
|
13
16
|
"""
|
|
@@ -35,26 +38,37 @@ class AsyncTaskException(AsyncTaskRunnerException):
|
|
|
35
38
|
self.caller = caller
|
|
36
39
|
|
|
37
40
|
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
except Exception as e:
|
|
52
|
-
raise AsyncTaskException(str(e), original_exception=e, caller=caller) from e
|
|
41
|
+
class ResultAccessor(typing.Generic[T]):
|
|
42
|
+
"""Helper class that can be awaited OR accessed directly."""
|
|
43
|
+
|
|
44
|
+
def __init__(self, task: "AsyncTask[T]"):
|
|
45
|
+
self._task = task
|
|
46
|
+
|
|
47
|
+
def __await__(self):
|
|
48
|
+
"""Allow awaiting: await task.result"""
|
|
49
|
+
return self._task.__await__()
|
|
50
|
+
|
|
51
|
+
def __repr__(self) -> str:
|
|
52
|
+
"""Direct access without await - returns value or raises error."""
|
|
53
|
+
return repr(self._ensure_result())
|
|
53
54
|
|
|
54
|
-
|
|
55
|
+
def __str__(self) -> str:
|
|
56
|
+
"""Direct access without await - returns value or raises error."""
|
|
57
|
+
return str(self._ensure_result())
|
|
55
58
|
|
|
59
|
+
# Make it behave like the actual value when accessed
|
|
60
|
+
def __getattr__(self, name):
|
|
61
|
+
return getattr(self._ensure_result(), name)
|
|
56
62
|
|
|
57
|
-
|
|
63
|
+
def _ensure_result(self):
|
|
64
|
+
if not hasattr(self._task, "_result"):
|
|
65
|
+
raise RuntimeError(
|
|
66
|
+
"Task has not been executed yet. Await the task to get the result."
|
|
67
|
+
)
|
|
68
|
+
return self._task._result
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class AsyncTask(typing.Generic[T]):
|
|
58
72
|
"""
|
|
59
73
|
Represents a task to be run asynchronously.
|
|
60
74
|
|
|
@@ -63,19 +77,51 @@ class AsyncTask:
|
|
|
63
77
|
|
|
64
78
|
def __init__(
|
|
65
79
|
self,
|
|
66
|
-
task: typing.Callable[
|
|
80
|
+
task: typing.Callable[..., typing.Awaitable[T]] | typing.Awaitable[T],
|
|
67
81
|
/,
|
|
68
82
|
caller: CallerInfo | None = None,
|
|
69
83
|
tags: list[str] | None = None,
|
|
70
84
|
):
|
|
71
|
-
self.task =
|
|
85
|
+
self.task = task
|
|
72
86
|
self.caller = caller
|
|
73
87
|
self.tags = tags or []
|
|
74
88
|
|
|
89
|
+
@property
|
|
90
|
+
def result(self):
|
|
91
|
+
"""
|
|
92
|
+
Provides access to the result of the task.
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
ResultAccessor[T]: An accessor that can be awaited or accessed directly.
|
|
96
|
+
"""
|
|
97
|
+
return ResultAccessor(self)
|
|
98
|
+
|
|
75
99
|
def __call__(self):
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
100
|
+
return self._run()
|
|
101
|
+
|
|
102
|
+
def __await__(self):
|
|
103
|
+
return self._run().__await__()
|
|
104
|
+
|
|
105
|
+
async def _run(self):
|
|
106
|
+
"""
|
|
107
|
+
Runs the task and caches the result.
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
T: The result of the task.
|
|
111
|
+
"""
|
|
112
|
+
if hasattr(self, "_result"):
|
|
113
|
+
return self._result
|
|
114
|
+
|
|
115
|
+
try:
|
|
116
|
+
coro = self.task
|
|
117
|
+
if callable(coro):
|
|
118
|
+
coro = coro()
|
|
119
|
+
self._result = await coro
|
|
120
|
+
return self._result
|
|
121
|
+
except Exception as e:
|
|
122
|
+
raise AsyncTaskException(
|
|
123
|
+
str(e), original_exception=e, caller=self.caller
|
|
124
|
+
) from e
|
|
79
125
|
|
|
80
126
|
|
|
81
127
|
class AsyncTaskRunner:
|
|
@@ -143,19 +189,59 @@ class AsyncTaskRunner:
|
|
|
143
189
|
"""
|
|
144
190
|
self.run_tasks_even_if_exception = run_tasks_even_if_exception
|
|
145
191
|
|
|
192
|
+
@typing.overload
|
|
146
193
|
@classmethod
|
|
147
194
|
def add_task(
|
|
148
195
|
cls,
|
|
149
|
-
|
|
196
|
+
task: typing.Callable[..., typing.Awaitable[T]]
|
|
197
|
+
| typing.Awaitable[T]
|
|
198
|
+
| AsyncTask[T],
|
|
199
|
+
*,
|
|
200
|
+
tags: list[str] | None = None,
|
|
201
|
+
instance: "AsyncTaskRunner | None" = None,
|
|
202
|
+
) -> AsyncTask[T]: ...
|
|
203
|
+
|
|
204
|
+
@typing.overload
|
|
205
|
+
@classmethod
|
|
206
|
+
def add_task(
|
|
207
|
+
cls,
|
|
208
|
+
task1: typing.Callable[..., typing.Awaitable[T]]
|
|
209
|
+
| typing.Awaitable[T]
|
|
210
|
+
| AsyncTask[T],
|
|
211
|
+
task2: typing.Callable[..., typing.Awaitable[T]]
|
|
212
|
+
| typing.Awaitable[T]
|
|
213
|
+
| AsyncTask[T],
|
|
214
|
+
*tasks: typing.Callable[..., typing.Awaitable[T]]
|
|
215
|
+
| typing.Awaitable[T]
|
|
216
|
+
| AsyncTask[T],
|
|
217
|
+
tags: list[str] | None = None,
|
|
218
|
+
instance: "AsyncTaskRunner | None" = None,
|
|
219
|
+
) -> list[AsyncTask[T]]: ...
|
|
220
|
+
@classmethod
|
|
221
|
+
def add_task(
|
|
222
|
+
cls,
|
|
223
|
+
*tasks: typing.Callable[..., typing.Awaitable[T]]
|
|
224
|
+
| typing.Awaitable[T]
|
|
225
|
+
| AsyncTask[T],
|
|
150
226
|
tags: list[str] | None = None,
|
|
151
227
|
instance: "AsyncTaskRunner | None" = None,
|
|
152
228
|
):
|
|
229
|
+
"""
|
|
230
|
+
Adds one or more tasks to the current context's task list. The tasks will be executed when exiting the context.
|
|
231
|
+
|
|
232
|
+
Args:
|
|
233
|
+
tags (list[str] | None, optional): Tags to associate with the tasks. Defaults to None.
|
|
234
|
+
instance (AsyncTaskRunner | None, optional): The AsyncTaskRunner instance to add tasks to. Defaults to None.
|
|
235
|
+
|
|
236
|
+
Returns:
|
|
237
|
+
AsyncTask[T] | list[AsyncTask[T]]: The added task(s).
|
|
238
|
+
"""
|
|
153
239
|
task_list = cls._get_current_task_list("add_task", instance=instance)
|
|
154
240
|
# Get caller info
|
|
155
241
|
frame = inspect.currentframe()
|
|
156
242
|
caller = inspect.getouterframes(frame, 2)[1] if frame else None
|
|
157
243
|
|
|
158
|
-
async_tasks = [
|
|
244
|
+
async_tasks: list[AsyncTask[T]] = [
|
|
159
245
|
AsyncTask(
|
|
160
246
|
task,
|
|
161
247
|
caller=CallerInfo(
|
|
@@ -170,6 +256,7 @@ class AsyncTaskRunner:
|
|
|
170
256
|
for task in tasks
|
|
171
257
|
]
|
|
172
258
|
task_list.extend(async_tasks)
|
|
259
|
+
return async_tasks if len(async_tasks) > 1 else async_tasks[0]
|
|
173
260
|
|
|
174
261
|
@classmethod
|
|
175
262
|
def remove_tasks_by_tag(
|
|
@@ -214,16 +301,16 @@ class AsyncTaskRunner:
|
|
|
214
301
|
return
|
|
215
302
|
|
|
216
303
|
if tasks:
|
|
217
|
-
|
|
304
|
+
exceptions_with_index = list[tuple[AsyncTaskException, int]]()
|
|
218
305
|
futures = await asyncio.gather(
|
|
219
306
|
*(task() for task in tasks), return_exceptions=True
|
|
220
307
|
)
|
|
221
|
-
for future in futures:
|
|
308
|
+
for index, future in enumerate(futures):
|
|
222
309
|
if isinstance(future, AsyncTaskException):
|
|
223
310
|
# Handle exceptions from tasks
|
|
224
|
-
|
|
311
|
+
exceptions_with_index.append((future, index))
|
|
225
312
|
|
|
226
|
-
if
|
|
313
|
+
if exceptions_with_index:
|
|
227
314
|
raise AsyncTaskRunnerException(
|
|
228
315
|
f"\n{
|
|
229
316
|
prettify_dict(
|
|
@@ -233,9 +320,11 @@ class AsyncTaskRunner:
|
|
|
233
320
|
f'Task {index + 1}': {
|
|
234
321
|
'message': str(exc),
|
|
235
322
|
'caller': exc.caller,
|
|
236
|
-
'traceback':
|
|
323
|
+
'traceback': ''.join(
|
|
324
|
+
tb.format_exception(exc.original_exception)
|
|
325
|
+
),
|
|
237
326
|
}
|
|
238
|
-
for
|
|
327
|
+
for exc, index in exceptions_with_index
|
|
239
328
|
},
|
|
240
329
|
}
|
|
241
330
|
)
|