fastapi-rtk 0.2.60__py3-none-any.whl → 1.0.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. fastapi_rtk/__init__.py +0 -1
  2. fastapi_rtk/_version.py +1 -0
  3. fastapi_rtk/api/model_rest_api.py +182 -87
  4. fastapi_rtk/auth/auth.py +0 -9
  5. fastapi_rtk/backends/sqla/db.py +32 -7
  6. fastapi_rtk/backends/sqla/filters.py +16 -0
  7. fastapi_rtk/backends/sqla/interface.py +11 -62
  8. fastapi_rtk/backends/sqla/model.py +16 -1
  9. fastapi_rtk/bases/db.py +20 -2
  10. fastapi_rtk/bases/file_manager.py +12 -0
  11. fastapi_rtk/bases/filter.py +1 -1
  12. fastapi_rtk/cli/cli.py +61 -0
  13. fastapi_rtk/cli/commands/security.py +6 -6
  14. fastapi_rtk/const.py +1 -1
  15. fastapi_rtk/db.py +3 -0
  16. fastapi_rtk/dependencies.py +110 -64
  17. fastapi_rtk/fastapi_react_toolkit.py +123 -172
  18. fastapi_rtk/file_managers/s3_file_manager.py +63 -32
  19. fastapi_rtk/lang/messages.pot +12 -12
  20. fastapi_rtk/lang/translations/de/LC_MESSAGES/messages.mo +0 -0
  21. fastapi_rtk/lang/translations/de/LC_MESSAGES/messages.po +12 -12
  22. fastapi_rtk/lang/translations/en/LC_MESSAGES/messages.mo +0 -0
  23. fastapi_rtk/lang/translations/en/LC_MESSAGES/messages.po +12 -12
  24. fastapi_rtk/manager.py +10 -14
  25. fastapi_rtk/schemas.py +6 -4
  26. fastapi_rtk/security/sqla/apis.py +20 -5
  27. fastapi_rtk/security/sqla/models.py +8 -23
  28. fastapi_rtk/security/sqla/security_manager.py +367 -10
  29. fastapi_rtk/utils/async_task_runner.py +119 -30
  30. fastapi_rtk/utils/csv_json_converter.py +242 -39
  31. fastapi_rtk/utils/hooks.py +7 -4
  32. fastapi_rtk/utils/self_dependencies.py +1 -1
  33. fastapi_rtk/version.py +6 -1
  34. fastapi_rtk-1.0.18.dist-info/METADATA +28 -0
  35. {fastapi_rtk-0.2.60.dist-info → fastapi_rtk-1.0.18.dist-info}/RECORD +38 -38
  36. {fastapi_rtk-0.2.60.dist-info → fastapi_rtk-1.0.18.dist-info}/WHEEL +1 -2
  37. fastapi_rtk-0.2.60.dist-info/METADATA +0 -25
  38. fastapi_rtk-0.2.60.dist-info/top_level.txt +0 -1
  39. {fastapi_rtk-0.2.60.dist-info → fastapi_rtk-1.0.18.dist-info}/entry_points.txt +0 -0
  40. {fastapi_rtk-0.2.60.dist-info → fastapi_rtk-1.0.18.dist-info}/licenses/LICENSE +0 -0
@@ -3,6 +3,8 @@ import typing
3
3
  from datetime import datetime
4
4
 
5
5
  import fastapi_users.exceptions
6
+ import sqlalchemy
7
+ import sqlalchemy.orm
6
8
  from pydantic import BaseModel, Field
7
9
  from sqlalchemy import select
8
10
  from sqlalchemy.ext.asyncio import AsyncSession
@@ -11,7 +13,7 @@ from sqlalchemy.orm import Session
11
13
  from ...const import logger
12
14
  from ...globals import g
13
15
  from ...setting import Setting
14
- from ...utils import merge_schema, safe_call
16
+ from ...utils import T, lazy, merge_schema, safe_call
15
17
  from .models import Api, Permission, PermissionApi, Role, User
16
18
 
17
19
  __all__ = ["SecurityManager"]
@@ -28,6 +30,21 @@ class SecurityManager:
28
30
  """
29
31
 
30
32
  toolkit: typing.Optional["FastAPIReactToolkit"]
33
+ builtin_roles = lazy(lambda: Setting.ROLES)
34
+ """
35
+ The built-in roles defined in the settings.
36
+
37
+ Format:
38
+ ```python
39
+ {
40
+ "role_name": [
41
+ (api_name1|api_name2|..., permission_name1|permission_name2|...),
42
+ ...
43
+ ],
44
+ ...
45
+ }
46
+ ```
47
+ """
31
48
 
32
49
  def __init__(self, toolkit: typing.Optional["FastAPIReactToolkit"] = None) -> None:
33
50
  self.toolkit = toolkit
@@ -385,6 +402,273 @@ class SecurityManager:
385
402
  -----------------------------------------
386
403
  """
387
404
 
405
+ def has_access_in_builtin_roles(
406
+ self, role_name: str, api_name: str, permission_name: str
407
+ ):
408
+ """
409
+ Checks if the given role has access to the specified API and permission in the built-in roles.
410
+
411
+ Args:
412
+ role_name (str): The name of the role to check.
413
+ api_name (str): The name of the API to check.
414
+ permission_name (str): The name of the permission to check.
415
+
416
+ Returns:
417
+ bool: True if the role has access, False otherwise.
418
+ """
419
+ if role_name not in self.builtin_roles:
420
+ return False
421
+
422
+ for api, perm in self.builtin_roles[role_name]:
423
+ api_names = api.split("|")
424
+ perm_names = perm.split("|")
425
+ if api_name in api_names and permission_name in perm_names:
426
+ return True
427
+ return False
428
+
429
+ def get_roles_from_builtin_roles(self):
430
+ """
431
+ Retrieves the names of the built-in roles.
432
+
433
+ Returns:
434
+ list[str]: The list of built-in role names.
435
+ """
436
+ return list(self.builtin_roles.keys())
437
+
438
+ def get_api_permission_tuples_from_builtin_roles(self, role: str | None = None):
439
+ """
440
+ Retrieves the API-permission tuples from the built-in roles.
441
+
442
+ Args:
443
+ role (str | None, optional): The name of the role to filter by. If None, retrieves from all roles. Defaults to None.
444
+
445
+ Returns:
446
+ list[tuple[str, str]]: The list of API-permission tuples.
447
+ """
448
+ api_permission_tuples = list[tuple[str, str]]()
449
+ for role_name, role_api_permission_list in self.builtin_roles.items():
450
+ if role is not None and role != role_name:
451
+ continue
452
+
453
+ for api, perm in role_api_permission_list:
454
+ api_permission_tuples.append((api, perm))
455
+ return api_permission_tuples
456
+
457
+ def get_role_and_api_permission_tuples_from_builtin_roles(self):
458
+ """
459
+ Retrieves the role and API-permission tuples from the built-in roles.
460
+
461
+ Returns:
462
+ list[tuple[str, list[tuple[str, str]]]]: The list of role and API-permission tuples.
463
+ """
464
+ role_api_permission_tuples = list[tuple[str, list[tuple[str, str]]]]()
465
+ for role_name, role_api_permission_list in self.builtin_roles.items():
466
+ api_permission_list = list[tuple[str, str]]()
467
+ for api, perm in role_api_permission_list:
468
+ api_permission_list.append((api, perm))
469
+ role_api_permission_tuples.append((role_name, api_permission_list))
470
+ return role_api_permission_tuples
471
+
472
+ async def create_roles(
473
+ self,
474
+ roles: list[str],
475
+ *,
476
+ session: AsyncSession | Session = None,
477
+ raise_exception: bool = True,
478
+ ):
479
+ """
480
+ Creates new roles with the given names. Existing roles are not duplicated.
481
+
482
+ Args:
483
+ roles (list[str]): The names of the roles to create.
484
+ session (AsyncSession | Session, optional): The database session to use. If not given, a new session will be created. Defaults to None.
485
+ raise_exception (bool, optional): When set to True, raises an exception if an error occurs, otherwise returns None. Defaults to True.
486
+
487
+ Returns:
488
+ list[Role] | None: The created role objects if successful, else None.
489
+
490
+ Raises:
491
+ SomeException: Description of the exception raised, if any.
492
+ """
493
+ return await self._create_entities(
494
+ Role,
495
+ roles,
496
+ session=session,
497
+ raise_exception=raise_exception,
498
+ on_after_create=lambda role: logger.info(f"ADDING ROLE {role}"),
499
+ )
500
+
501
+ async def create_permissions(
502
+ self,
503
+ permissions: list[str],
504
+ *,
505
+ session: AsyncSession | Session = None,
506
+ raise_exception: bool = True,
507
+ ):
508
+ """
509
+ Creates new permissions with the given names. Existing permissions are not duplicated.
510
+
511
+ Args:
512
+ permissions (list[str]): The names of the permissions to create.
513
+ session (AsyncSession | Session, optional): The database session to use. If not given, a new session will be created. Defaults to None.
514
+ raise_exception (bool, optional): When set to True, raises an exception if an error occurs, otherwise returns None. Defaults to True.
515
+
516
+ Returns:
517
+ list[Permission] | None: The created permission objects if successful, else None.
518
+
519
+ Raises:
520
+ SomeException: Description of the exception raised, if any.
521
+ """
522
+ return await self._create_entities(
523
+ Permission,
524
+ permissions,
525
+ session=session,
526
+ raise_exception=raise_exception,
527
+ on_after_create=lambda permission: logger.info(
528
+ f"ADDING PERMISSION {permission}"
529
+ ),
530
+ )
531
+
532
+ async def create_apis(
533
+ self,
534
+ apis: list[str],
535
+ *,
536
+ session: AsyncSession | Session = None,
537
+ raise_exception: bool = True,
538
+ ):
539
+ """
540
+ Creates new APIs with the given names. Existing APIs are not duplicated.
541
+
542
+ Args:
543
+ apis (list[str]): The names of the APIs to create.
544
+ session (AsyncSession | Session, optional): The database session to use. If not given, a new session will be created. Defaults to None.
545
+ raise_exception (bool, optional): When set to True, raises an exception if an error occurs, otherwise returns None. Defaults to True.
546
+
547
+ Returns:
548
+ list[Api] | None: The created API objects if successful, else None.
549
+
550
+ Raises:
551
+ SomeException: Description of the exception raised, if any.
552
+ """
553
+ return await self._create_entities(
554
+ Api,
555
+ apis,
556
+ session=session,
557
+ raise_exception=raise_exception,
558
+ on_after_create=lambda api: logger.info(f"ADDING API {api}"),
559
+ )
560
+
561
+ async def associate_list_of_permission_with_api(
562
+ self,
563
+ permission_api_tuples: list[tuple[Permission, Api]],
564
+ *,
565
+ session: AsyncSession | Session = None,
566
+ raise_exception: bool = True,
567
+ ):
568
+ """
569
+ Associates a list of permissions with APIs. Existing associations are not duplicated.
570
+
571
+ Args:
572
+ permission_api_tuples (list[tuple[Permission, Api]]): A list of tuples containing Permission and Api objects to associate.
573
+ session (AsyncSession | Session, optional): The database session to use. If not given, a new session will be created. Defaults to None.
574
+ raise_exception (bool, optional): When set to True, raises an exception if an error occurs, otherwise returns None. Defaults to True.
575
+
576
+ Raises:
577
+ SomeException: Description of the exception raised, if any.
578
+ """
579
+ try:
580
+ if not session:
581
+ async with db.session() as session:
582
+ await self.associate_list_of_permission_with_api(
583
+ permission_api_tuples,
584
+ session=session,
585
+ raise_exception=raise_exception,
586
+ )
587
+ return
588
+
589
+ conditions = []
590
+ query = select(PermissionApi).options(
591
+ sqlalchemy.orm.joinedload(PermissionApi.api),
592
+ sqlalchemy.orm.joinedload(PermissionApi.permission),
593
+ sqlalchemy.orm.selectinload(PermissionApi.roles),
594
+ )
595
+ for permission, api in permission_api_tuples:
596
+ conditions.append(
597
+ sqlalchemy.and_(
598
+ PermissionApi.permission_id == permission.id,
599
+ PermissionApi.api_id == api.id,
600
+ )
601
+ )
602
+ query = query.where(sqlalchemy.or_(*conditions))
603
+ result = await safe_call(session.scalars(query))
604
+ existing_permission_apis = list(result.all())
605
+
606
+ new_tuples = list[tuple[Permission, Api]]()
607
+ for permission, api in permission_api_tuples:
608
+ exists = False
609
+ for permission_api in existing_permission_apis:
610
+ if (
611
+ permission_api.permission.id == permission.id
612
+ and permission_api.api.id == api.id
613
+ ):
614
+ exists = True
615
+ break
616
+ if not exists:
617
+ new_tuples.append((permission, api))
618
+
619
+ new_permission_apis = list[PermissionApi]()
620
+ for permission, api in new_tuples:
621
+ permission_api = PermissionApi(permission=permission, api=api, roles=[])
622
+ session.add(permission_api)
623
+ new_permission_apis.append(permission_api)
624
+ logger.info(f"ASSOCIATING PERMISSION {permission} WITH API {api}")
625
+ await safe_call(session.commit())
626
+ return existing_permission_apis + new_permission_apis
627
+ except Exception as e:
628
+ if not raise_exception:
629
+ return
630
+ raise e
631
+
632
+ async def associate_list_of_role_with_permission_api(
633
+ self,
634
+ role_permission_api_tuples: list[tuple[Role, PermissionApi]],
635
+ *,
636
+ session: AsyncSession | Session = None,
637
+ raise_exception: bool = True,
638
+ ):
639
+ """
640
+ Associates a list of roles with permission APIs. Existing associations are not duplicated.
641
+
642
+ Args:
643
+ role_permission_api_tuples (list[tuple[Role, PermissionApi]]): A list of tuples containing Role and PermissionApi objects to associate.
644
+ session (AsyncSession | Session, optional): The database session to use. If not given, a new session will be created. Defaults to None.
645
+ raise_exception (bool, optional): When set to True, raises an exception if an error occurs, otherwise returns None. Defaults to True.
646
+
647
+ Raises:
648
+ SomeException: Description of the exception raised, if any.
649
+ """
650
+ try:
651
+ if not session:
652
+ async with db.session() as session:
653
+ await self.associate_list_of_role_with_permission_api(
654
+ role_permission_api_tuples,
655
+ session=session,
656
+ raise_exception=raise_exception,
657
+ )
658
+ return
659
+
660
+ for role, permission_api in role_permission_api_tuples:
661
+ if role not in permission_api.roles:
662
+ permission_api.roles.append(role)
663
+ logger.info(
664
+ f"ASSOCIATING ROLE {role} WITH PERMISSION API {permission_api}"
665
+ )
666
+ await safe_call(session.commit())
667
+ except Exception as e:
668
+ if not raise_exception:
669
+ return
670
+ raise e
671
+
388
672
  async def cleanup(self, *, session: AsyncSession | Session = None):
389
673
  """
390
674
  Cleanup unused permissions from apis and roles.
@@ -404,13 +688,12 @@ class SecurityManager:
404
688
  "FastAPIReactToolkit instance not provided, you must provide it to use this function."
405
689
  )
406
690
 
407
- api_permission_tuples = (Setting.ROLES).values()
691
+ api_permission_tuples = self.get_api_permission_tuples_from_builtin_roles()
408
692
  apis = [api.__class__.__name__ for api in self.toolkit.apis]
409
693
  permissions = self.toolkit.total_permissions()
410
- for api_permission_tuple in api_permission_tuples:
411
- for api, permission in api_permission_tuple:
412
- apis.append(api)
413
- permissions.append(permission)
694
+ for api, permission in api_permission_tuples:
695
+ apis.append(api)
696
+ permissions.append(permission)
414
697
 
415
698
  # Clean up unused permissions
416
699
  unused_permissions = await safe_call(
@@ -428,14 +711,20 @@ class SecurityManager:
428
711
  logger.info(f"DELETING API {api} AND ITS ASSOCIATIONS")
429
712
  await safe_call(session.delete(api))
430
713
 
431
- roles = Setting.ROLES.keys()
714
+ roles = self.get_roles_from_builtin_roles()
432
715
  if g.admin_role is not None:
433
716
  roles.append(g.admin_role)
434
717
 
435
718
  # Clean up existing permission-apis, that are no longer connected to any roles
436
- unused_permission_apis = await safe_call(session.scalars(select(PermissionApi)))
437
- for permission_api in unused_permission_apis:
438
- for role in list(permission_api.roles) or []:
719
+ permission_apis_in_db = await safe_call(
720
+ session.scalars(
721
+ select(PermissionApi).options(
722
+ sqlalchemy.orm.selectinload(PermissionApi.roles)
723
+ )
724
+ )
725
+ )
726
+ for permission_api in permission_apis_in_db:
727
+ for role in permission_api.roles:
439
728
  if role.name not in roles:
440
729
  permission_api.roles.remove(role)
441
730
  logger.info(
@@ -443,3 +732,71 @@ class SecurityManager:
443
732
  )
444
733
 
445
734
  await safe_call(session.commit())
735
+
736
+ """
737
+ -----------------------------------------
738
+ HELPER FUNCTIONS
739
+ -----------------------------------------
740
+ """
741
+
742
+ async def _create_entities(
743
+ self,
744
+ entity_class: typing.Type[T],
745
+ names: list[str],
746
+ *,
747
+ session: AsyncSession | Session = None,
748
+ raise_exception: bool = True,
749
+ on_after_create: typing.Optional[
750
+ typing.Callable[[T], typing.Awaitable[None] | None]
751
+ ] = None,
752
+ ):
753
+ """
754
+ Helper function to create new entities with the given names.
755
+ Existing entities are not duplicated.
756
+
757
+ Args:
758
+ entity_class (typing.Type[T]): The entity class to create.
759
+ names (list[str]): The names of the entities to create.
760
+ session (AsyncSession | Session, optional): The database session to use. If not given, a new session will be created. Defaults to None.
761
+ raise_exception (bool, optional): When set to True, raises an exception if an error occurs, otherwise returns None. Defaults to True.
762
+ on_after_create (Optional[Callable[[T], Awaitable[None] | None]], optional): An optional callback function to be called after each entity is created. Defaults to None.
763
+
764
+ Returns:
765
+ list[T] | None: The created entity objects if successful, else None.
766
+
767
+ Raises:
768
+ Exception: If an error occurs and raise_exception is True.
769
+ """
770
+ try:
771
+ if not session:
772
+ async with db.session() as session:
773
+ return await self._create_entities(
774
+ entity_class,
775
+ names,
776
+ session=session,
777
+ raise_exception=raise_exception,
778
+ )
779
+
780
+ # Check for existing entities
781
+ result = await safe_call(
782
+ session.scalars(
783
+ select(entity_class).where(entity_class.name.in_(names))
784
+ )
785
+ )
786
+ existing_entities = list(result.all())
787
+ existing_names = [entity.name for entity in existing_entities]
788
+
789
+ new_names = [name for name in names if name not in existing_names]
790
+ new_entities = list[T]()
791
+ for name in new_names:
792
+ entity = entity_class(name=name)
793
+ session.add(entity)
794
+ new_entities.append(entity)
795
+ if on_after_create:
796
+ await safe_call(on_after_create(entity))
797
+ await safe_call(session.commit())
798
+ return existing_entities + new_entities
799
+ except Exception as e:
800
+ if not raise_exception:
801
+ return
802
+ raise e
@@ -2,12 +2,15 @@ import asyncio
2
2
  import contextvars
3
3
  import functools
4
4
  import inspect
5
+ import traceback as tb
5
6
  import typing
6
7
 
7
8
  from .prettify_dict import prettify_dict
8
9
 
9
10
  __all__ = ["AsyncTaskRunner"]
10
11
 
12
+ T = typing.TypeVar("T")
13
+
11
14
 
12
15
  class CallerInfo(typing.TypedDict):
13
16
  """
@@ -35,26 +38,37 @@ class AsyncTaskException(AsyncTaskRunnerException):
35
38
  self.caller = caller
36
39
 
37
40
 
38
- def wrap_in_async_task_exception(
39
- task: typing.Callable[[], typing.Coroutine | None] | typing.Coroutine,
40
- /,
41
- caller: CallerInfo | None = None,
42
- ):
43
- async def wrapper():
44
- try:
45
- tsk = task
46
- if callable(tsk):
47
- tsk = tsk()
48
- if tsk is None:
49
- return None
50
- return await tsk
51
- except Exception as e:
52
- raise AsyncTaskException(str(e), original_exception=e, caller=caller) from e
41
+ class ResultAccessor(typing.Generic[T]):
42
+ """Helper class that can be awaited OR accessed directly."""
43
+
44
+ def __init__(self, task: "AsyncTask[T]"):
45
+ self._task = task
46
+
47
+ def __await__(self):
48
+ """Allow awaiting: await task.result"""
49
+ return self._task.__await__()
50
+
51
+ def __repr__(self) -> str:
52
+ """Direct access without await - returns value or raises error."""
53
+ return repr(self._ensure_result())
53
54
 
54
- return wrapper
55
+ def __str__(self) -> str:
56
+ """Direct access without await - returns value or raises error."""
57
+ return str(self._ensure_result())
55
58
 
59
+ # Make it behave like the actual value when accessed
60
+ def __getattr__(self, name):
61
+ return getattr(self._ensure_result(), name)
56
62
 
57
- class AsyncTask:
63
+ def _ensure_result(self):
64
+ if not hasattr(self._task, "_result"):
65
+ raise RuntimeError(
66
+ "Task has not been executed yet. Await the task to get the result."
67
+ )
68
+ return self._task._result
69
+
70
+
71
+ class AsyncTask(typing.Generic[T]):
58
72
  """
59
73
  Represents a task to be run asynchronously.
60
74
 
@@ -63,19 +77,51 @@ class AsyncTask:
63
77
 
64
78
  def __init__(
65
79
  self,
66
- task: typing.Callable[[], typing.Coroutine | None] | typing.Coroutine,
80
+ task: typing.Callable[..., typing.Awaitable[T]] | typing.Awaitable[T],
67
81
  /,
68
82
  caller: CallerInfo | None = None,
69
83
  tags: list[str] | None = None,
70
84
  ):
71
- self.task = wrap_in_async_task_exception(task, caller=caller)
85
+ self.task = task
72
86
  self.caller = caller
73
87
  self.tags = tags or []
74
88
 
89
+ @property
90
+ def result(self):
91
+ """
92
+ Provides access to the result of the task.
93
+
94
+ Returns:
95
+ ResultAccessor[T]: An accessor that can be awaited or accessed directly.
96
+ """
97
+ return ResultAccessor(self)
98
+
75
99
  def __call__(self):
76
- if callable(self.task):
77
- return self.task()
78
- return self.task
100
+ return self._run()
101
+
102
+ def __await__(self):
103
+ return self._run().__await__()
104
+
105
+ async def _run(self):
106
+ """
107
+ Runs the task and caches the result.
108
+
109
+ Returns:
110
+ T: The result of the task.
111
+ """
112
+ if hasattr(self, "_result"):
113
+ return self._result
114
+
115
+ try:
116
+ coro = self.task
117
+ if callable(coro):
118
+ coro = coro()
119
+ self._result = await coro
120
+ return self._result
121
+ except Exception as e:
122
+ raise AsyncTaskException(
123
+ str(e), original_exception=e, caller=self.caller
124
+ ) from e
79
125
 
80
126
 
81
127
  class AsyncTaskRunner:
@@ -143,19 +189,59 @@ class AsyncTaskRunner:
143
189
  """
144
190
  self.run_tasks_even_if_exception = run_tasks_even_if_exception
145
191
 
192
+ @typing.overload
146
193
  @classmethod
147
194
  def add_task(
148
195
  cls,
149
- *tasks: typing.Callable[[], typing.Coroutine | None] | typing.Coroutine,
196
+ task: typing.Callable[..., typing.Awaitable[T]]
197
+ | typing.Awaitable[T]
198
+ | AsyncTask[T],
199
+ *,
200
+ tags: list[str] | None = None,
201
+ instance: "AsyncTaskRunner | None" = None,
202
+ ) -> AsyncTask[T]: ...
203
+
204
+ @typing.overload
205
+ @classmethod
206
+ def add_task(
207
+ cls,
208
+ task1: typing.Callable[..., typing.Awaitable[T]]
209
+ | typing.Awaitable[T]
210
+ | AsyncTask[T],
211
+ task2: typing.Callable[..., typing.Awaitable[T]]
212
+ | typing.Awaitable[T]
213
+ | AsyncTask[T],
214
+ *tasks: typing.Callable[..., typing.Awaitable[T]]
215
+ | typing.Awaitable[T]
216
+ | AsyncTask[T],
217
+ tags: list[str] | None = None,
218
+ instance: "AsyncTaskRunner | None" = None,
219
+ ) -> list[AsyncTask[T]]: ...
220
+ @classmethod
221
+ def add_task(
222
+ cls,
223
+ *tasks: typing.Callable[..., typing.Awaitable[T]]
224
+ | typing.Awaitable[T]
225
+ | AsyncTask[T],
150
226
  tags: list[str] | None = None,
151
227
  instance: "AsyncTaskRunner | None" = None,
152
228
  ):
229
+ """
230
+ Adds one or more tasks to the current context's task list. The tasks will be executed when exiting the context.
231
+
232
+ Args:
233
+ tags (list[str] | None, optional): Tags to associate with the tasks. Defaults to None.
234
+ instance (AsyncTaskRunner | None, optional): The AsyncTaskRunner instance to add tasks to. Defaults to None.
235
+
236
+ Returns:
237
+ AsyncTask[T] | list[AsyncTask[T]]: The added task(s).
238
+ """
153
239
  task_list = cls._get_current_task_list("add_task", instance=instance)
154
240
  # Get caller info
155
241
  frame = inspect.currentframe()
156
242
  caller = inspect.getouterframes(frame, 2)[1] if frame else None
157
243
 
158
- async_tasks = [
244
+ async_tasks: list[AsyncTask[T]] = [
159
245
  AsyncTask(
160
246
  task,
161
247
  caller=CallerInfo(
@@ -170,6 +256,7 @@ class AsyncTaskRunner:
170
256
  for task in tasks
171
257
  ]
172
258
  task_list.extend(async_tasks)
259
+ return async_tasks if len(async_tasks) > 1 else async_tasks[0]
173
260
 
174
261
  @classmethod
175
262
  def remove_tasks_by_tag(
@@ -214,16 +301,16 @@ class AsyncTaskRunner:
214
301
  return
215
302
 
216
303
  if tasks:
217
- exceptions: list[AsyncTaskException] = []
304
+ exceptions_with_index = list[tuple[AsyncTaskException, int]]()
218
305
  futures = await asyncio.gather(
219
306
  *(task() for task in tasks), return_exceptions=True
220
307
  )
221
- for future in futures:
308
+ for index, future in enumerate(futures):
222
309
  if isinstance(future, AsyncTaskException):
223
310
  # Handle exceptions from tasks
224
- exceptions.append(future)
311
+ exceptions_with_index.append((future, index))
225
312
 
226
- if exceptions:
313
+ if exceptions_with_index:
227
314
  raise AsyncTaskRunnerException(
228
315
  f"\n{
229
316
  prettify_dict(
@@ -233,9 +320,11 @@ class AsyncTaskRunner:
233
320
  f'Task {index + 1}': {
234
321
  'message': str(exc),
235
322
  'caller': exc.caller,
236
- 'traceback': exc.original_exception.__traceback__,
323
+ 'traceback': ''.join(
324
+ tb.format_exception(exc.original_exception)
325
+ ),
237
326
  }
238
- for index, exc in enumerate(exceptions)
327
+ for exc, index in exceptions_with_index
239
328
  },
240
329
  }
241
330
  )