zrb 1.0.0b1__py3-none-any.whl → 1.0.0b2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zrb/__main__.py +0 -3
- zrb/builtin/__init__.py +3 -0
- zrb/builtin/group.py +1 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/config.py +1 -1
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/add_entity_task.py +66 -21
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/add_entity_util.py +67 -41
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/template/app_template/module/my_module/service/my_entity/my_entity_service.py +69 -15
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/template/app_template/module/my_module/service/my_entity/my_entity_service_factory.py +2 -1
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/template/app_template/module/my_module/service/my_entity/repository/my_entity_db_repository.py +0 -10
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/template/app_template/module/my_module/service/my_entity/repository/my_entity_repository.py +37 -16
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/template/app_template/module/my_module/service/my_entity/repository/my_entity_repository_factory.py +2 -2
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/template/app_template/schema/my_entity.py +16 -6
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/template/client_method.py +57 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/template/gateway_subroute.py +63 -28
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/module/add_module_task.py +1 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/module/template/app_template/module/my_module/client/my_module_api_client.py +6 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/module/template/app_template/module/my_module/client/{any_client.py → my_module_client.py} +1 -1
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/module/template/app_template/module/my_module/client/my_module_client_factory.py +11 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/module/template/app_template/module/my_module/client/my_module_direct_client.py +5 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/module/template/app_template/module/my_module/route.py +1 -1
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/module/template/module_task_definition.py +2 -2
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/task.py +4 -4
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/util.py +47 -20
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/common/app_factory.py +29 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/common/base_db_repository.py +185 -101
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/common/base_service.py +99 -108
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/common/{db_engine.py → db_engine_factory.py} +1 -1
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/common/error.py +12 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/common/logger_factory.py +10 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/common/parser_factory.py +7 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/common/util/app.py +47 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/common/util/parser.py +105 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/common/util/user_agent.py +58 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/config.py +1 -1
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/main.py +1 -1
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/client/auth_api_client.py +16 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/client/auth_client.py +163 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/client/auth_client_factory.py +9 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/client/auth_direct_client.py +15 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/migration/versions/3093c7336477_add_auth_tables.py +160 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/migration_metadata.py +18 -1
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/route.py +5 -1
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/permission/__init__.py +0 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/permission/permission_service.py +117 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/permission/permission_service_factory.py +11 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/permission/repository/permission_db_repository.py +26 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/permission/repository/permission_repository.py +61 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/permission/repository/permission_repository_factory.py +13 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/role/__init__.py +0 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/role/repository/role_db_repository.py +75 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/role/repository/role_repository.py +59 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/role/repository/role_repository_factory.py +13 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/role/role_service.py +105 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/role/role_service_factory.py +7 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/user/repository/user_db_repository.py +42 -13
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/user/repository/user_repository.py +38 -17
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/user/repository/user_repository_factory.py +2 -2
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/user/user_service.py +69 -17
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/service/user/user_service_factory.py +2 -1
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/gateway/route.py +1 -1
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/gateway/subroute/auth.py +198 -28
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/gateway/util/view.py +1 -1
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/requirements.txt +1 -1
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/schema/permission.py +17 -5
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/schema/role.py +50 -4
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/schema/session.py +52 -0
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/schema/user.py +30 -5
- zrb/builtin/random.py +61 -0
- zrb/cmd/cmd_val.py +6 -5
- zrb/runner/cli.py +10 -1
- zrb/runner/web_util/token.py +7 -3
- zrb/task/base_task.py +24 -2
- zrb/task/cmd_task.py +7 -5
- zrb/util/cmd/command.py +1 -0
- zrb/util/file.py +7 -1
- {zrb-1.0.0b1.dist-info → zrb-1.0.0b2.dist-info}/METADATA +1 -1
- {zrb-1.0.0b1.dist-info → zrb-1.0.0b2.dist-info}/RECORD +80 -61
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/entity/template/any_client_method.py +0 -27
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/module/template/app_template/module/my_module/client/api_client.py +0 -6
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/module/template/app_template/module/my_module/client/direct_client.py +0 -5
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/_zrb/module/template/app_template/module/my_module/client/factory.py +0 -9
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/common/app.py +0 -57
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/client/any_client.py +0 -33
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/client/api_client.py +0 -7
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/client/direct_client.py +0 -6
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/client/factory.py +0 -9
- zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/module/auth/migration/versions/3093c7336477_add_user_table.py +0 -37
- /zrb/builtin/project/add/fastapp/fastapp_template/my_app_name/common/{view.py → util/view.py} +0 -0
- {zrb-1.0.0b1.dist-info → zrb-1.0.0b2.dist-info}/WHEEL +0 -0
- {zrb-1.0.0b1.dist-info → zrb-1.0.0b2.dist-info}/entry_points.txt +0 -0
@@ -1,12 +1,17 @@
|
|
1
|
+
import datetime
|
2
|
+
from contextlib import asynccontextmanager
|
1
3
|
from typing import Any, Callable, Generic, Type, TypeVar
|
2
4
|
|
3
|
-
|
4
|
-
from
|
5
|
+
import ulid
|
6
|
+
from my_app_name.common.error import InvalidValueError, NotFoundError
|
7
|
+
from my_app_name.common.parser_factory import parse_filter_param, parse_sort_param
|
8
|
+
from sqlalchemy import Engine, delete, func, insert, select, update
|
5
9
|
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
|
6
|
-
from
|
10
|
+
from sqlalchemy.sql import Select
|
11
|
+
from sqlmodel import Session, SQLModel
|
7
12
|
|
8
13
|
DBModel = TypeVar("DBModel", bound=SQLModel)
|
9
|
-
ResponseModel = TypeVar("
|
14
|
+
ResponseModel = TypeVar("ResponseModel", bound=SQLModel)
|
10
15
|
CreateModel = TypeVar("CreateModel", bound=SQLModel)
|
11
16
|
UpdateModel = TypeVar("UpdateModel", bound=SQLModel)
|
12
17
|
|
@@ -23,112 +28,191 @@ class BaseDBRepository(Generic[DBModel, ResponseModel, CreateModel, UpdateModel]
|
|
23
28
|
self.engine = engine
|
24
29
|
self.is_async = isinstance(engine, AsyncEngine)
|
25
30
|
|
26
|
-
def
|
27
|
-
return self.
|
31
|
+
def _select(self) -> Select:
|
32
|
+
return select(self.db_model)
|
28
33
|
|
29
|
-
|
30
|
-
|
31
|
-
for key, preprocessor in self.column_preprocessors.items():
|
32
|
-
if key in data_dict:
|
33
|
-
data_dict[key] = preprocessor(data_dict[key])
|
34
|
-
db_instance = self.db_model(**data_dict)
|
35
|
-
if self.is_async:
|
36
|
-
async with AsyncSession(self.engine) as session:
|
37
|
-
session.add(db_instance)
|
38
|
-
await session.commit()
|
39
|
-
await session.refresh(db_instance)
|
40
|
-
else:
|
41
|
-
with Session(self.engine) as session:
|
42
|
-
session.add(db_instance)
|
43
|
-
session.commit()
|
44
|
-
session.refresh(db_instance)
|
45
|
-
return self._to_response(db_instance)
|
34
|
+
def _rows_to_responses(self, rows: list[tuple[Any]]) -> list[ResponseModel]:
|
35
|
+
return [self.response_model.model_validate(row[0]) for row in rows]
|
46
36
|
|
47
|
-
|
48
|
-
if
|
49
|
-
async with AsyncSession(self.engine) as session:
|
50
|
-
db_instance = await session.get(self.db_model, item_id)
|
51
|
-
else:
|
52
|
-
with Session(self.engine) as session:
|
53
|
-
db_instance = session.get(self.db_model, item_id)
|
54
|
-
if not db_instance:
|
37
|
+
def _ensure_one(self, responses: list[ResponseModel]) -> ResponseModel:
|
38
|
+
if not responses:
|
55
39
|
raise NotFoundError(f"{self.entity_name} not found")
|
56
|
-
|
40
|
+
if len(responses) > 1:
|
41
|
+
raise InvalidValueError(f"Duplicate {self.entity_name}")
|
42
|
+
return responses[0]
|
57
43
|
|
58
|
-
|
59
|
-
|
60
|
-
statement = select(self.db_model).offset(offset).limit(page_size)
|
44
|
+
@asynccontextmanager
|
45
|
+
async def _session_scope(self):
|
61
46
|
if self.is_async:
|
62
47
|
async with AsyncSession(self.engine) as session:
|
63
|
-
|
64
|
-
|
48
|
+
async with session.begin():
|
49
|
+
yield session
|
65
50
|
else:
|
66
51
|
with Session(self.engine) as session:
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
async def
|
71
|
-
update_data = data.model_dump(exclude_unset=True)
|
72
|
-
for key, value in update_data.items():
|
73
|
-
if key in self.column_preprocessors:
|
74
|
-
update_data[key] = self.column_preprocessors[key](value)
|
75
|
-
if self.is_async:
|
76
|
-
async with AsyncSession(self.engine) as session:
|
77
|
-
db_instance = await session.get(self.db_model, item_id)
|
78
|
-
if not db_instance:
|
79
|
-
raise NotFoundError(f"{self.entity_name} not found")
|
80
|
-
for key, value in update_data.items():
|
81
|
-
setattr(db_instance, key, value)
|
82
|
-
session.add(db_instance)
|
83
|
-
await session.commit()
|
84
|
-
await session.refresh(db_instance)
|
85
|
-
else:
|
86
|
-
with Session(self.engine) as session:
|
87
|
-
db_instance = session.get(self.db_model, item_id)
|
88
|
-
if not db_instance:
|
89
|
-
raise NotFoundError(f"{self.entity_name} not found")
|
90
|
-
for key, value in update_data.items():
|
91
|
-
setattr(db_instance, key, value)
|
92
|
-
session.add(db_instance)
|
93
|
-
session.commit()
|
94
|
-
session.refresh(db_instance)
|
95
|
-
return self._to_response(db_instance)
|
96
|
-
|
97
|
-
async def delete(self, item_id: str) -> ResponseModel:
|
52
|
+
with session.begin():
|
53
|
+
yield session
|
54
|
+
|
55
|
+
async def _commit(self, session: Session | AsyncSession):
|
98
56
|
if self.is_async:
|
99
|
-
|
100
|
-
db_instance = await session.get(self.db_model, item_id)
|
101
|
-
if not db_instance:
|
102
|
-
raise NotFoundError(f"{self.entity_name} not found")
|
103
|
-
await session.delete(db_instance)
|
104
|
-
await session.commit()
|
57
|
+
await session.commit()
|
105
58
|
else:
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
raise NotFoundError(f"{self.entity_name} not found")
|
110
|
-
session.delete(db_instance)
|
111
|
-
session.commit()
|
112
|
-
return self._to_response(db_instance)
|
113
|
-
|
114
|
-
async def create_bulk(self, data_list: list[CreateModel]) -> list[ResponseModel]:
|
115
|
-
db_instances = []
|
116
|
-
for data in data_list:
|
117
|
-
data_dict = data.model_dump(exclude_unset=True)
|
118
|
-
for key, preprocessor in self.column_preprocessors.items():
|
119
|
-
if key in data_dict:
|
120
|
-
data_dict[key] = preprocessor(data_dict[key])
|
121
|
-
db_instances.append(self.db_model(**data_dict))
|
59
|
+
session.commit()
|
60
|
+
|
61
|
+
async def _execute_statement(self, session, statement: Any) -> Any:
|
122
62
|
if self.is_async:
|
123
|
-
|
124
|
-
session.add_all(db_instances)
|
125
|
-
await session.commit()
|
126
|
-
for instance in db_instances:
|
127
|
-
await session.refresh(instance)
|
63
|
+
return await session.execute(statement)
|
128
64
|
else:
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
65
|
+
return session.execute(statement)
|
66
|
+
|
67
|
+
async def get_by_id(self, id: str) -> ResponseModel:
|
68
|
+
statement = self._select().where(self.db_model.id == id)
|
69
|
+
async with self._session_scope() as session:
|
70
|
+
result = await self._execute_statement(session, statement)
|
71
|
+
responses = self._rows_to_responses(result.all())
|
72
|
+
return self._ensure_one(responses)
|
73
|
+
|
74
|
+
async def get_by_ids(self, id_list: list[str]) -> list[ResponseModel]:
|
75
|
+
statement = self._select().where(self.db_model.id.in_(id_list))
|
76
|
+
async with self._session_scope() as session:
|
77
|
+
result = await self._execute_statement(session, statement)
|
78
|
+
return [
|
79
|
+
self.db_model(**entity.model_dump())
|
80
|
+
for entity in result.scalars().all()
|
81
|
+
]
|
82
|
+
|
83
|
+
async def count(self, filter: str | None = None) -> int:
|
84
|
+
count_statement = select(func.count(1)).select_from(self.db_model)
|
85
|
+
if filter:
|
86
|
+
filter_param = parse_filter_param(self.db_model, filter)
|
87
|
+
count_statement = count_statement.where(*filter_param)
|
88
|
+
async with self._session_scope() as session:
|
89
|
+
result = await self._execute_statement(session, count_statement)
|
90
|
+
return result.scalar_one()
|
91
|
+
|
92
|
+
async def get(
|
93
|
+
self,
|
94
|
+
page: int = 1,
|
95
|
+
page_size: int = 10,
|
96
|
+
filter: str | None = None,
|
97
|
+
sort: str | None = None,
|
98
|
+
) -> list[ResponseModel]:
|
99
|
+
offset = (page - 1) * page_size
|
100
|
+
statement = self._select().offset(offset).limit(page_size)
|
101
|
+
if filter:
|
102
|
+
filter_param = parse_filter_param(self.db_model, filter)
|
103
|
+
statement = statement.where(*filter_param)
|
104
|
+
if sort:
|
105
|
+
sort_param = parse_sort_param(self.db_model, sort)
|
106
|
+
statement = statement.order_by(*sort_param)
|
107
|
+
async with self._session_scope() as session:
|
108
|
+
result = await self._execute_statement(session, statement)
|
109
|
+
return [
|
110
|
+
self.db_model(**entity.model_dump())
|
111
|
+
for entity in result.scalars().all()
|
112
|
+
]
|
113
|
+
|
114
|
+
def _model_to_data_dict(
|
115
|
+
self, data: SQLModel, **additional_data: Any
|
116
|
+
) -> dict[str, Any]:
|
117
|
+
data_dict = data.model_dump(exclude_unset=True)
|
118
|
+
data_dict.update(additional_data)
|
119
|
+
for key, preprocessor in self.column_preprocessors.items():
|
120
|
+
if key not in data_dict:
|
121
|
+
continue
|
122
|
+
if not hasattr(self.db_model, key):
|
123
|
+
raise InvalidValueError(f"Invalid {self.entity_name} property: {key}")
|
124
|
+
data_dict[key] = preprocessor(data_dict[key])
|
125
|
+
return data_dict
|
126
|
+
|
127
|
+
async def create(self, data: CreateModel) -> DBModel:
|
128
|
+
now = datetime.datetime.now(datetime.timezone.utc)
|
129
|
+
data_dict = self._model_to_data_dict(data, created_at=now, id=ulid.new().str)
|
130
|
+
async with self._session_scope() as session:
|
131
|
+
await self._execute_statement(
|
132
|
+
session, insert(self.db_model).values(**data_dict)
|
133
|
+
)
|
134
|
+
statement = select(self.db_model).where(self.db_model.id == data_dict["id"])
|
135
|
+
result = await self._execute_statement(session, statement)
|
136
|
+
created_entity = result.scalar_one_or_none()
|
137
|
+
if created_entity is None:
|
138
|
+
raise NotFoundError(f"{self.entity_name} not found after creation")
|
139
|
+
return self.db_model(**created_entity.model_dump())
|
140
|
+
|
141
|
+
async def create_bulk(self, data_list: list[CreateModel]) -> list[DBModel]:
|
142
|
+
now = datetime.datetime.now(datetime.timezone.utc)
|
143
|
+
data_dicts = [
|
144
|
+
self._model_to_data_dict(data, created_at=now, id=ulid.new().str)
|
145
|
+
for data in data_list
|
146
|
+
]
|
147
|
+
async with self._session_scope() as session:
|
148
|
+
await self._execute_statement(
|
149
|
+
session, insert(self.db_model).values(data_dicts)
|
150
|
+
)
|
151
|
+
id_list = [d["id"] for d in data_dicts]
|
152
|
+
statement = select(self.db_model).where(self.db_model.id.in_(id_list))
|
153
|
+
result = await self._execute_statement(session, statement)
|
154
|
+
return [
|
155
|
+
self.db_model(**entity.model_dump())
|
156
|
+
for entity in result.scalars().all()
|
157
|
+
]
|
158
|
+
|
159
|
+
async def delete(self, id: str) -> DBModel:
|
160
|
+
async with self._session_scope() as session:
|
161
|
+
statement = select(self.db_model).where(self.db_model.id == id)
|
162
|
+
result = await self._execute_statement(session, statement)
|
163
|
+
entity = result.scalar_one_or_none()
|
164
|
+
if not entity:
|
165
|
+
raise NotFoundError(f"{self.entity_name} not found")
|
166
|
+
await self._execute_statement(
|
167
|
+
session, delete(self.db_model).where(self.db_model.id == id)
|
168
|
+
)
|
169
|
+
return self.db_model(**entity.model_dump())
|
170
|
+
|
171
|
+
async def delete_bulk(self, id_list: list[str]) -> list[DBModel]:
|
172
|
+
async with self._session_scope() as session:
|
173
|
+
statement = select(self.db_model).where(self.db_model.id.in_(id_list))
|
174
|
+
result = await self._execute_statement(session, statement)
|
175
|
+
entities = result.scalars().all()
|
176
|
+
await self._execute_statement(
|
177
|
+
session, delete(self.db_model).where(self.db_model.id.in_(id_list))
|
178
|
+
)
|
179
|
+
return [self.db_model(**entity.model_dump()) for entity in entities]
|
180
|
+
|
181
|
+
async def update(self, id: str, data: UpdateModel) -> DBModel:
|
182
|
+
now = datetime.datetime.now(datetime.timezone.utc)
|
183
|
+
update_data = self._model_to_data_dict(data, updated_at=now)
|
184
|
+
async with self._session_scope() as session:
|
185
|
+
statement = (
|
186
|
+
update(self.db_model)
|
187
|
+
.where(self.db_model.id == id)
|
188
|
+
.values(**update_data)
|
189
|
+
)
|
190
|
+
await self._execute_statement(session, statement)
|
191
|
+
result = await self._execute_statement(
|
192
|
+
session, select(self.db_model).where(self.db_model.id == id)
|
193
|
+
)
|
194
|
+
updated_instance = result.scalar_one_or_none()
|
195
|
+
if not updated_instance:
|
196
|
+
raise NotFoundError(f"{self.entity_name} not found")
|
197
|
+
return self.db_model(**updated_instance.model_dump())
|
198
|
+
|
199
|
+
async def update_bulk(self, id_list: list[str], data: UpdateModel) -> list[DBModel]:
|
200
|
+
now = datetime.datetime.now(datetime.timezone.utc)
|
201
|
+
update_data = self._model_to_data_dict(data, updated_at=now)
|
202
|
+
update_data = {k: v for k, v in update_data.items() if v is not None}
|
203
|
+
if not update_data:
|
204
|
+
raise InvalidValueError("No valid update data provided")
|
205
|
+
async with self._session_scope() as session:
|
206
|
+
statement = (
|
207
|
+
update(self.db_model)
|
208
|
+
.where(self.db_model.id.in_(id_list))
|
209
|
+
.values(**update_data)
|
210
|
+
)
|
211
|
+
await self._execute_statement(session, statement)
|
212
|
+
result = await self._execute_statement(
|
213
|
+
session, select(self.db_model).where(self.db_model.id.in_(id_list))
|
214
|
+
)
|
215
|
+
return [
|
216
|
+
self.db_model(**entity.model_dump())
|
217
|
+
for entity in result.scalars().all()
|
218
|
+
]
|
@@ -1,13 +1,13 @@
|
|
1
|
+
import inspect
|
1
2
|
from enum import Enum
|
2
|
-
from functools import partial
|
3
|
+
from functools import partial
|
4
|
+
from logging import Logger
|
3
5
|
from typing import Any, Callable, Sequence
|
4
6
|
|
5
7
|
import httpx
|
6
|
-
from fastapi import APIRouter, params
|
7
|
-
from
|
8
|
-
from
|
9
|
-
from fastapi.utils import generate_unique_id
|
10
|
-
from starlette.responses import JSONResponse, Response
|
8
|
+
from fastapi import APIRouter, Depends, params
|
9
|
+
from my_app_name.common.error import ClientAPIError
|
10
|
+
from pydantic import BaseModel
|
11
11
|
|
12
12
|
|
13
13
|
class RouteParam:
|
@@ -17,62 +17,37 @@ class RouteParam:
|
|
17
17
|
response_model: Any,
|
18
18
|
status_code: int | None = None,
|
19
19
|
tags: list[str | Enum] | None = None,
|
20
|
-
dependencies: Sequence[params.Depends] | None = None,
|
21
20
|
summary: str | None = None,
|
22
21
|
description: str = "",
|
23
|
-
response_description: str = "",
|
24
|
-
responses: dict[int | str, dict[str, Any]] | None = None,
|
25
22
|
deprecated: bool | None = None,
|
26
23
|
methods: set[str] | list[str] | None = None,
|
27
|
-
operation_id: str | None = None,
|
28
|
-
response_model_include: IncEx | None = None,
|
29
|
-
response_model_exclude: IncEx | None = None,
|
30
|
-
response_model_by_alias: bool = True,
|
31
|
-
response_model_exclude_unset: bool = False,
|
32
|
-
response_model_exclude_defaults: bool = False,
|
33
|
-
response_model_exclude_none: bool = False,
|
34
|
-
include_in_schema: bool = True,
|
35
|
-
response_class: type[Response] = Response,
|
36
|
-
name: str | None = None,
|
37
|
-
openapi_extra: dict[str, Any] | None = None,
|
38
|
-
generate_unique_id_function: Callable[[APIRoute], str] | None = None,
|
39
24
|
func: Callable | None = None,
|
40
25
|
):
|
41
26
|
self.path = path
|
42
27
|
self.response_model = response_model
|
43
28
|
self.status_code = status_code
|
44
29
|
self.tags = tags
|
45
|
-
self.dependencies = dependencies
|
46
30
|
self.summary = summary
|
47
31
|
self.description = description
|
48
|
-
self.response_description = response_description
|
49
|
-
self.responses = responses
|
50
32
|
self.deprecated = deprecated
|
51
33
|
self.methods = methods
|
52
|
-
self.operation_id = operation_id
|
53
|
-
self.response_model_include = response_model_include
|
54
|
-
self.response_model_exclude = response_model_exclude
|
55
|
-
self.response_model_by_alias = response_model_by_alias
|
56
|
-
self.response_model_exclude_unset = response_model_exclude_unset
|
57
|
-
self.response_model_exclude_defaults = response_model_exclude_defaults
|
58
|
-
self.response_model_exclude_none = response_model_exclude_none
|
59
|
-
self.include_in_schema = include_in_schema
|
60
|
-
self.response_class = response_class
|
61
|
-
self.name = name
|
62
|
-
self.openapi_extra = openapi_extra
|
63
|
-
self.generate_unique_id_function = generate_unique_id_function
|
64
34
|
self.func = func
|
65
35
|
|
66
36
|
|
67
37
|
class BaseService:
|
68
38
|
_route_params: dict[str, RouteParam] = {}
|
69
39
|
|
70
|
-
def __init__(self):
|
40
|
+
def __init__(self, logger: Logger):
|
41
|
+
self._logger = logger
|
71
42
|
self._route_params: dict[str, RouteParam] = {}
|
72
43
|
for name, method in self.__class__.__dict__.items():
|
73
44
|
if hasattr(method, "__route_param__"):
|
74
45
|
self._route_params[name] = getattr(method, "__route_param__")
|
75
46
|
|
47
|
+
@property
|
48
|
+
def logger(self) -> Logger:
|
49
|
+
return self._logger
|
50
|
+
|
76
51
|
@classmethod
|
77
52
|
def route(
|
78
53
|
cls,
|
@@ -84,63 +59,40 @@ class BaseService:
|
|
84
59
|
dependencies: Sequence[params.Depends] | None = None,
|
85
60
|
summary: str | None = None,
|
86
61
|
description: str = None,
|
87
|
-
response_description: str = "Successful Response",
|
88
|
-
responses: dict[int | str, dict[str, Any]] | None = None,
|
89
62
|
deprecated: bool | None = None,
|
90
63
|
methods: set[str] | list[str] | None = None,
|
91
|
-
operation_id: str | None = None,
|
92
|
-
response_model_include: IncEx | None = None,
|
93
|
-
response_model_exclude: IncEx | None = None,
|
94
|
-
response_model_by_alias: bool = True,
|
95
|
-
response_model_exclude_unset: bool = False,
|
96
|
-
response_model_exclude_defaults: bool = False,
|
97
|
-
response_model_exclude_none: bool = False,
|
98
|
-
include_in_schema: bool = True,
|
99
|
-
response_class: type[Response] = JSONResponse,
|
100
|
-
name: str | None = None,
|
101
|
-
openapi_extra: dict[str, Any] | None = None,
|
102
|
-
generate_unique_id_function: Callable[[APIRoute], str] = generate_unique_id,
|
103
64
|
):
|
104
65
|
"""
|
105
66
|
Decorator to register a method with its HTTP details.
|
106
67
|
"""
|
107
68
|
|
108
69
|
def decorator(func: Callable):
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
70
|
+
# Check for Depends in function parameters
|
71
|
+
sig = inspect.signature(func)
|
72
|
+
for param in sig.parameters.values():
|
73
|
+
if param.annotation is Depends or (
|
74
|
+
hasattr(param.annotation, "__origin__")
|
75
|
+
and param.annotation.__origin__ is Depends
|
76
|
+
):
|
77
|
+
raise ValueError(
|
78
|
+
f"Depends is not allowed in function parameters. Found in {func.__name__}" # noqa
|
79
|
+
)
|
113
80
|
# Inject __route_param__ property to the method
|
114
81
|
# Method with __route_param__ property will automatically
|
115
82
|
# registered to self._route_param and will be automatically exposed
|
116
83
|
# into DirectClient and APIClient
|
117
|
-
|
84
|
+
func.__route_param__ = RouteParam(
|
118
85
|
path=path,
|
119
86
|
response_model=response_model,
|
120
87
|
status_code=status_code,
|
121
88
|
tags=tags,
|
122
|
-
dependencies=dependencies,
|
123
89
|
summary=summary,
|
124
90
|
description=description,
|
125
|
-
response_description=response_description,
|
126
|
-
responses=responses,
|
127
91
|
deprecated=deprecated,
|
128
92
|
methods=methods,
|
129
|
-
operation_id=operation_id,
|
130
|
-
response_model_include=response_model_include,
|
131
|
-
response_model_exclude=response_model_exclude,
|
132
|
-
response_model_by_alias=response_model_by_alias,
|
133
|
-
response_model_exclude_unset=response_model_exclude_unset,
|
134
|
-
response_model_exclude_defaults=response_model_exclude_defaults,
|
135
|
-
response_model_exclude_none=response_model_exclude_none,
|
136
|
-
include_in_schema=include_in_schema,
|
137
|
-
response_class=response_class,
|
138
|
-
name=name,
|
139
|
-
openapi_extra=openapi_extra,
|
140
|
-
generate_unique_id_function=generate_unique_id_function,
|
141
93
|
func=func,
|
142
94
|
)
|
143
|
-
return
|
95
|
+
return func
|
144
96
|
|
145
97
|
return decorator
|
146
98
|
|
@@ -149,10 +101,10 @@ class BaseService:
|
|
149
101
|
Dynamically create a direct client class.
|
150
102
|
"""
|
151
103
|
_methods = self._route_params
|
152
|
-
DirectClient =
|
104
|
+
DirectClient = _create_client_class("DirectClient")
|
153
105
|
for name, details in _methods.items():
|
154
106
|
func = details.func
|
155
|
-
client_method =
|
107
|
+
client_method = _create_direct_client_method(self._logger, func, self)
|
156
108
|
# Use __get__ to make a bounded method,
|
157
109
|
# ensuring that client_method use DirectClient as `self`
|
158
110
|
setattr(DirectClient, name, client_method.__get__(DirectClient))
|
@@ -163,10 +115,10 @@ class BaseService:
|
|
163
115
|
Dynamically create an API client class.
|
164
116
|
"""
|
165
117
|
_methods = self._route_params
|
166
|
-
APIClient =
|
118
|
+
APIClient = _create_client_class("APIClient")
|
167
119
|
# Dynamically generate methods
|
168
120
|
for name, param in _methods.items():
|
169
|
-
client_method =
|
121
|
+
client_method = _create_api_client_method(self._logger, param, base_url)
|
170
122
|
# Use __get__ to make a bounded method,
|
171
123
|
# ensuring that client_method use APIClient as `self`
|
172
124
|
setattr(APIClient, name, client_method.__get__(APIClient))
|
@@ -186,29 +138,14 @@ class BaseService:
|
|
186
138
|
response_model=route_param.response_model,
|
187
139
|
status_code=route_param.status_code,
|
188
140
|
tags=route_param.tags,
|
189
|
-
dependencies=route_param.dependencies,
|
190
141
|
summary=route_param.summary,
|
191
142
|
description=route_param.description,
|
192
|
-
response_description=route_param.response_description,
|
193
|
-
responses=route_param.responses,
|
194
143
|
deprecated=route_param.deprecated,
|
195
144
|
methods=route_param.methods,
|
196
|
-
operation_id=route_param.operation_id,
|
197
|
-
response_model_include=route_param.response_model_include,
|
198
|
-
response_model_exclude=route_param.response_model_exclude,
|
199
|
-
response_model_by_alias=route_param.response_model_by_alias,
|
200
|
-
response_model_exclude_unset=route_param.response_model_exclude_unset,
|
201
|
-
response_model_exclude_defaults=route_param.response_model_exclude_defaults,
|
202
|
-
response_model_exclude_none=route_param.response_model_exclude_none,
|
203
|
-
include_in_schema=route_param.include_in_schema,
|
204
|
-
response_class=route_param.response_class,
|
205
|
-
name=route_param.name,
|
206
|
-
openapi_extra=route_param.openapi_extra,
|
207
|
-
generate_unique_id_function=route_param.generate_unique_id_function,
|
208
145
|
)
|
209
146
|
|
210
147
|
|
211
|
-
def
|
148
|
+
def _create_client_class(name):
|
212
149
|
class Client:
|
213
150
|
pass
|
214
151
|
|
@@ -216,30 +153,84 @@ def create_client_class(name):
|
|
216
153
|
return Client
|
217
154
|
|
218
155
|
|
219
|
-
def
|
156
|
+
def _create_direct_client_method(logger: Logger, func: Callable, service: BaseService):
|
220
157
|
async def client_method(self, *args, **kwargs):
|
221
158
|
return await func(service, *args, **kwargs)
|
222
159
|
|
223
160
|
return client_method
|
224
161
|
|
225
162
|
|
226
|
-
def
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
163
|
+
def _create_api_client_method(logger: Logger, param: RouteParam, base_url: str):
|
164
|
+
async def client_method(*args, **kwargs):
|
165
|
+
url = base_url + param.path
|
166
|
+
method = (
|
167
|
+
param.methods[0].lower()
|
168
|
+
if isinstance(param.methods, list)
|
169
|
+
else param.methods.lower()
|
170
|
+
)
|
171
|
+
# Get the signature of the original function
|
172
|
+
sig = inspect.signature(param.func)
|
173
|
+
# Bind the arguments to the signature
|
174
|
+
bound_args = sig.bind(*args, **kwargs)
|
175
|
+
bound_args.apply_defaults()
|
176
|
+
# Analyze parameters
|
177
|
+
params = list(sig.parameters.values())
|
178
|
+
body_params = [
|
179
|
+
p
|
180
|
+
for p in params
|
181
|
+
if p.name != "self" and p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
|
182
|
+
]
|
183
|
+
# Prepare the request
|
184
|
+
path_params = {}
|
185
|
+
query_params = {}
|
186
|
+
body = {}
|
187
|
+
for name, value in bound_args.arguments.items():
|
188
|
+
if name == "self":
|
189
|
+
continue
|
190
|
+
if f"{{{name}}}" in param.path:
|
191
|
+
path_params[name] = value
|
192
|
+
elif isinstance(value, BaseModel):
|
193
|
+
body = _parse_api_param(value)
|
194
|
+
elif method in ["get", "delete"]:
|
195
|
+
query_params[name] = _parse_api_param(value)
|
196
|
+
elif len(body_params) == 1 and name == body_params[0].name:
|
197
|
+
# If there's only one body parameter, use its value directly
|
198
|
+
body = _parse_api_param(value)
|
199
|
+
else:
|
200
|
+
body[name] = _parse_api_param(value)
|
201
|
+
# Format the URL with path parameters
|
202
|
+
url = url.format(**path_params)
|
203
|
+
logger.info(
|
204
|
+
f"Sending request to {url} with method {method}, json={body}, params={query_params}" # noqa
|
205
|
+
)
|
231
206
|
async with httpx.AsyncClient() as client:
|
232
|
-
|
233
|
-
|
234
|
-
response = await client.post(url, json=kwargs)
|
235
|
-
elif "put" in _methods:
|
236
|
-
response = await client.put(url, json=kwargs)
|
237
|
-
elif "delete" in _methods:
|
238
|
-
response = await client.delete(url, json=kwargs)
|
207
|
+
if method in ["get", "delete"]:
|
208
|
+
response = await getattr(client, method)(url, params=query_params)
|
239
209
|
else:
|
240
|
-
response = await client
|
241
|
-
|
242
|
-
|
210
|
+
response = await getattr(client, method)(
|
211
|
+
url, json=body, params=query_params
|
212
|
+
)
|
213
|
+
logger.info(
|
214
|
+
f"Received response: status={response.status_code}, content={response.content}"
|
215
|
+
)
|
216
|
+
if response.status_code >= 400:
|
217
|
+
error_detail = (
|
218
|
+
response.json()
|
219
|
+
if response.headers.get("content-type") == "application/json"
|
220
|
+
else response.text
|
221
|
+
)
|
222
|
+
raise ClientAPIError(response.status_code, error_detail)
|
243
223
|
return response.json()
|
244
224
|
|
245
225
|
return client_method
|
226
|
+
|
227
|
+
|
228
|
+
def _parse_api_param(data: Any) -> Any:
|
229
|
+
if isinstance(data, BaseModel):
|
230
|
+
return data.model_dump()
|
231
|
+
elif isinstance(data, list):
|
232
|
+
return [_parse_api_param(item) for item in data]
|
233
|
+
elif isinstance(data, dict):
|
234
|
+
return {key: _parse_api_param(value) for key, value in data.items()}
|
235
|
+
else:
|
236
|
+
return data
|
@@ -6,3 +6,15 @@ from fastapi import HTTPException
|
|
6
6
|
class NotFoundError(HTTPException):
|
7
7
|
def __init__(self, message: str, headers: Dict[str, str] | None = None) -> None:
|
8
8
|
super().__init__(404, {"message": message}, headers)
|
9
|
+
|
10
|
+
|
11
|
+
class InvalidValueError(HTTPException):
|
12
|
+
def __init__(self, message: str, headers: Dict[str, str] | None = None) -> None:
|
13
|
+
super().__init__(422, {"message": message}, headers)
|
14
|
+
|
15
|
+
|
16
|
+
class ClientAPIError(HTTPException):
|
17
|
+
def __init__(
|
18
|
+
self, status_code: int, message: str, headers: Dict[str, str] | None = None
|
19
|
+
) -> None:
|
20
|
+
super().__init__(status_code, {"message": message}, headers)
|