pykitool 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pykitool/__init__.py +0 -0
- pykitool/base/__init__.py +0 -0
- pykitool/base/cache.py +352 -0
- pykitool/base/enums.py +102 -0
- pykitool/base/exception.py +6 -0
- pykitool/base/response.py +82 -0
- pykitool/base/tlog.py +230 -0
- pykitool/sqliter/__init__.py +0 -0
- pykitool/sqliter/exception.py +30 -0
- pykitool/sqliter/middleware.py +84 -0
- pykitool/sqliter/plus.py +125 -0
- pykitool/sqliter/repo.py +188 -0
- pykitool/utils/__init__.py +0 -0
- pykitool/utils/cbfile.py +697 -0
- pykitool/utils/cbrequest.py +473 -0
- pykitool/utils/cbruntime.py +870 -0
- pykitool/utils/cbutils.py +518 -0
- pykitool-0.0.1.dist-info/METADATA +36 -0
- pykitool-0.0.1.dist-info/RECORD +21 -0
- pykitool-0.0.1.dist-info/WHEEL +5 -0
- pykitool-0.0.1.dist-info/top_level.txt +1 -0
pykitool/base/tlog.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
import builtins
|
|
2
|
+
import inspect
|
|
3
|
+
import os
|
|
4
|
+
import sys
|
|
5
|
+
import warnings
|
|
6
|
+
|
|
7
|
+
import urllib3
|
|
8
|
+
|
|
9
|
+
# 忽略
|
|
10
|
+
warnings.filterwarnings("ignore")
|
|
11
|
+
|
|
12
|
+
# 禁用在使用不安全的 HTTPS 请求时产生的警告
|
|
13
|
+
urllib3.disable_warnings()
|
|
14
|
+
|
|
15
|
+
# 关闭
|
|
16
|
+
import logging
|
|
17
|
+
import threading
|
|
18
|
+
import time
|
|
19
|
+
|
|
20
|
+
from loguru import logger
|
|
21
|
+
|
|
22
|
+
# 格式化
|
|
23
|
+
format = "<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | <level>{level:<5}</level> {thread:<5} <cyan>{file}:{line}</cyan> - {message}"
|
|
24
|
+
# format = "<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | <level>{level:<7}</level> | <yellow>{thread}</yellow> | <cyan>{file}:{function}:{line}</cyan> - <level>{message}</level>"
|
|
25
|
+
|
|
26
|
+
# 白名单列表
|
|
27
|
+
WHITELIST = []
|
|
28
|
+
|
|
29
|
+
# ================================ 目录 ================================
|
|
30
|
+
|
|
31
|
+
# 可替换的日志目录获取函数(默认返回 logs,支持外部注入)
|
|
32
|
+
_log_dir_getter: callable = lambda: "logs"
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
# 默认目录
|
|
36
|
+
def get_log_dir() -> str:
|
|
37
|
+
return _log_dir_getter()
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
# 供外部注入自定义日志目录获取函数
|
|
41
|
+
def set_log_dir_getter(fn: callable):
|
|
42
|
+
"""注入自定义日志目录获取函数,注入后调用 update_log_level() 立即生效。
|
|
43
|
+
|
|
44
|
+
示例::
|
|
45
|
+
|
|
46
|
+
# 从配置动态读取
|
|
47
|
+
set_log_dir_getter(lambda: const.DIR_LOGS)
|
|
48
|
+
update_log_level()
|
|
49
|
+
|
|
50
|
+
# 直接指定固定目录
|
|
51
|
+
set_log_dir_getter(lambda: "/var/log/myapp")
|
|
52
|
+
update_log_level()
|
|
53
|
+
"""
|
|
54
|
+
global _log_dir_getter
|
|
55
|
+
_log_dir_getter = fn
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
# ================================ 级别 ================================
|
|
59
|
+
|
|
60
|
+
# 可替换的日志级别获取函数(默认返回 INFO,支持外部注入)
|
|
61
|
+
_log_level_getter: callable = lambda: "INFO"
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
# 默认级别
|
|
65
|
+
def get_log_level() -> str:
|
|
66
|
+
return _log_level_getter()
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
# 供外部注入自定义日志级别获取函数
|
|
70
|
+
def set_log_level_getter(fn: callable):
|
|
71
|
+
"""注入自定义日志级别获取函数,注入后调用 update_log_level() 立即生效。
|
|
72
|
+
|
|
73
|
+
示例::
|
|
74
|
+
|
|
75
|
+
# 从配置动态读取
|
|
76
|
+
set_log_level_getter(lambda: config.LOG_LEVEL)
|
|
77
|
+
update_log_level()
|
|
78
|
+
|
|
79
|
+
# 直接指定固定级别
|
|
80
|
+
set_log_level_getter(lambda: "DEBUG")
|
|
81
|
+
update_log_level()
|
|
82
|
+
"""
|
|
83
|
+
global _log_level_getter
|
|
84
|
+
_log_level_getter = fn
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
# 拦截器
|
|
88
|
+
class InterceptHandler(logging.Handler):
|
|
89
|
+
def emit(self, record):
|
|
90
|
+
# 只处理属于白名单的模块
|
|
91
|
+
if not any(record.name.startswith(whitelisted) for whitelisted in WHITELIST):
|
|
92
|
+
if get_log_level() == "DEBUG":
|
|
93
|
+
print(f"{record.name} -> {record.getMessage()}")
|
|
94
|
+
return
|
|
95
|
+
# 尝试获取 loguru 的级别名,失败则使用原始 levelno
|
|
96
|
+
try:
|
|
97
|
+
level = logger.level(record.levelname).name
|
|
98
|
+
except ValueError:
|
|
99
|
+
level = record.levelno
|
|
100
|
+
logger.opt(depth=6, exception=record.exc_info).log(level, record.getMessage())
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
# 启动器
|
|
104
|
+
class StartupTimer:
|
|
105
|
+
def __init__(self):
|
|
106
|
+
self.start_time = time.perf_counter()
|
|
107
|
+
self.last_mark = self.start_time
|
|
108
|
+
self.records = []
|
|
109
|
+
self.lock = threading.Lock()
|
|
110
|
+
|
|
111
|
+
def _get_caller_info(self):
|
|
112
|
+
frame = inspect.currentframe().f_back.f_back
|
|
113
|
+
filename = frame.f_code.co_filename
|
|
114
|
+
lineno = frame.f_lineno
|
|
115
|
+
return f"{filename}:{lineno}"
|
|
116
|
+
|
|
117
|
+
# 跟踪
|
|
118
|
+
def track(self, tag=None, print: bool = True):
|
|
119
|
+
with self.lock:
|
|
120
|
+
now = time.perf_counter()
|
|
121
|
+
delta = now - self.last_mark
|
|
122
|
+
total = now - self.start_time
|
|
123
|
+
|
|
124
|
+
if tag is None:
|
|
125
|
+
tag = self._get_caller_info()
|
|
126
|
+
|
|
127
|
+
self.records.append((tag, delta, total))
|
|
128
|
+
|
|
129
|
+
self.last_mark = now
|
|
130
|
+
if print:
|
|
131
|
+
logger.info(f"[StartupTimer] {tag}: +{delta*1000:07.2f} ms (total {total*1000:07.2f} ms)")
|
|
132
|
+
|
|
133
|
+
# 计算总时长
|
|
134
|
+
def total_time(self):
|
|
135
|
+
return time.perf_counter() - self.start_time
|
|
136
|
+
|
|
137
|
+
# 摘要输出
|
|
138
|
+
def dump(self):
|
|
139
|
+
for tag, delta, total in self.records:
|
|
140
|
+
print(f"{tag:12} | +{delta*1000:07.2f} ms | total {total*1000:07.2f} ms")
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
# 初始化
|
|
144
|
+
class LoguruLogger:
|
|
145
|
+
|
|
146
|
+
_instance = None
|
|
147
|
+
|
|
148
|
+
def __new__(cls):
|
|
149
|
+
if cls._instance is None:
|
|
150
|
+
# 单利模式
|
|
151
|
+
cls._instance = super(LoguruLogger, cls).__new__(cls)
|
|
152
|
+
cls._instance._configure_logger()
|
|
153
|
+
# 清空默认自带的日志处理器
|
|
154
|
+
loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict]
|
|
155
|
+
for log in loggers:
|
|
156
|
+
log.handlers = []
|
|
157
|
+
# loguru 接管 logging
|
|
158
|
+
logging.basicConfig(handlers=[InterceptHandler()], level=0)
|
|
159
|
+
# loguru 接管 print()
|
|
160
|
+
# builtins.print = logger.info
|
|
161
|
+
# 设置 comtypes 包的日志级别为 WARNING,屏蔽其 INFO 级别的日志
|
|
162
|
+
return cls._instance
|
|
163
|
+
|
|
164
|
+
def _configure_logger(self):
|
|
165
|
+
# 获取日志级别
|
|
166
|
+
current_level = get_log_level()
|
|
167
|
+
# 获取日志
|
|
168
|
+
self.logger = logger
|
|
169
|
+
# 移除默认日志输出
|
|
170
|
+
self.logger.remove()
|
|
171
|
+
# 配置控制台输出格式
|
|
172
|
+
self.logger.add(sys.stdout, level=current_level, format=format, colorize=True)
|
|
173
|
+
# 配置文件输出并启用轮转 - 按日期轮转(每天午夜)和文件大小轮转
|
|
174
|
+
log_dir = get_log_dir()
|
|
175
|
+
os.makedirs(log_dir, exist_ok=True)
|
|
176
|
+
self.logger.add(log_dir + "/creator_{time:YYYY-MM-DD}.log", rotation="00:00", level=current_level, format=format, colorize=False)
|
|
177
|
+
|
|
178
|
+
def update_log_level(self):
|
|
179
|
+
# 更新日志级别配置
|
|
180
|
+
self._configure_logger()
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
# 重试
|
|
184
|
+
def retry_print(retry_state):
|
|
185
|
+
retry = retry_state.retry_object
|
|
186
|
+
stop = retry.stop
|
|
187
|
+
max_attempts = getattr(stop, "max_attempt_number", "?")
|
|
188
|
+
logger.warning(f"Retrying {retry_state.attempt_number}/{max_attempts} time(s), {retry_state.fn.__name__} -> {retry_state.outcome.exception()}, waiting {retry_state.next_action.sleep} seconds before retrying.")
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
# 禁用
|
|
192
|
+
def disable_print():
|
|
193
|
+
|
|
194
|
+
# 空方法,以便于禁止输出
|
|
195
|
+
def no_print(*args, **kwargs):
|
|
196
|
+
pass
|
|
197
|
+
|
|
198
|
+
original_print = builtins.print
|
|
199
|
+
builtins.print = no_print
|
|
200
|
+
return original_print
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
# 恢复
|
|
204
|
+
def restore_print(original_print):
|
|
205
|
+
if original_print:
|
|
206
|
+
builtins.print = original_print
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
# 更新配置
|
|
210
|
+
def update_cfg():
|
|
211
|
+
if LoguruLogger._instance:
|
|
212
|
+
LoguruLogger._instance.update_log_level()
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
if not LoguruLogger._instance:
|
|
216
|
+
LoguruLogger()
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
if __name__ == "__main__":
|
|
220
|
+
print("message.{}", 11)
|
|
221
|
+
logging.info("message %s", 22)
|
|
222
|
+
# 通过日志输出一些消息
|
|
223
|
+
logger.trace("message.{}", 0)
|
|
224
|
+
logger.debug("message.{}", 1)
|
|
225
|
+
logger.info("message.{}", 2)
|
|
226
|
+
logger.success("message.{}", 3)
|
|
227
|
+
logger.warning("message.{}", 4)
|
|
228
|
+
logger.error("message.{}", 5)
|
|
229
|
+
logger.critical("message.{}", 6)
|
|
230
|
+
logger.exception("message.{}", 7)
|
|
File without changes
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
class MissingSessionError(Exception):
|
|
2
|
+
"""
|
|
3
|
+
Excetion raised for when the user tries to access a database session before
|
|
4
|
+
it is created.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
def __init__(self):
|
|
8
|
+
msg = """
|
|
9
|
+
No session found! Either you are not currently in a request context,
|
|
10
|
+
or you need to manually create a session context by using a `db`
|
|
11
|
+
instance as a context manager e.g.:
|
|
12
|
+
|
|
13
|
+
with db():
|
|
14
|
+
db.session.query(User).all()
|
|
15
|
+
"""
|
|
16
|
+
super().__init__(msg)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class SessionNotInitialisedError(Exception):
|
|
20
|
+
"""
|
|
21
|
+
Exception raised when the user creates a new DB session without first
|
|
22
|
+
initialising it.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(self):
|
|
26
|
+
msg = """
|
|
27
|
+
Session not initialised! Ensure that DBSessionMiddleware has been
|
|
28
|
+
initialised before attempting database access.
|
|
29
|
+
"""
|
|
30
|
+
super().__init__(msg)
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
from contextvars import ContextVar
|
|
2
|
+
from typing import Dict, Optional, Union
|
|
3
|
+
|
|
4
|
+
from sqlalchemy.engine import Engine
|
|
5
|
+
from sqlalchemy.engine.url import URL
|
|
6
|
+
from sqlalchemy.orm import sessionmaker as sessionmaker_
|
|
7
|
+
from sqlmodel import Session, create_engine
|
|
8
|
+
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
|
9
|
+
from starlette.requests import Request
|
|
10
|
+
from starlette.types import ASGIApp
|
|
11
|
+
|
|
12
|
+
from pykitool.sqliter.exception import MissingSessionError, SessionNotInitialisedError
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class sessionmaker(sessionmaker_):
|
|
16
|
+
def __init__(self, *args, **kwargs):
|
|
17
|
+
if "class_" not in kwargs:
|
|
18
|
+
kwargs["class_"] = Session
|
|
19
|
+
super().__init__(*args, **kwargs)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
_Session: sessionmaker = None
|
|
23
|
+
_session: ContextVar[Optional[Session]] = ContextVar("_session", default=None)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class DBSessionMiddleware(BaseHTTPMiddleware):
|
|
27
|
+
def __init__(self, app: ASGIApp, db_url: Optional[Union[str, URL]] = None, custom_engine: Optional[Engine] = None, engine_args: Dict = None, session_args: Dict = None, commit_on_exit: bool = False):
|
|
28
|
+
super().__init__(app)
|
|
29
|
+
global _Session
|
|
30
|
+
engine_args = engine_args or {}
|
|
31
|
+
self.commit_on_exit = commit_on_exit
|
|
32
|
+
session_args = session_args or {}
|
|
33
|
+
if not custom_engine and not db_url:
|
|
34
|
+
raise ValueError("You need to pass a db_url or a custom_engine parameter.")
|
|
35
|
+
if not custom_engine:
|
|
36
|
+
engine = create_engine(db_url, **engine_args)
|
|
37
|
+
else:
|
|
38
|
+
engine = custom_engine
|
|
39
|
+
_Session = sessionmaker(bind=engine, **session_args)
|
|
40
|
+
|
|
41
|
+
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
|
|
42
|
+
with db(commit_on_exit=self.commit_on_exit):
|
|
43
|
+
response = await call_next(request)
|
|
44
|
+
return response
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class DBSessionMeta(type):
|
|
48
|
+
# using this metaclass means that we can access db.session as a property
|
|
49
|
+
# at a class level,
|
|
50
|
+
# rather than db().session
|
|
51
|
+
|
|
52
|
+
@property
|
|
53
|
+
def session(self) -> Session:
|
|
54
|
+
if _Session is None:
|
|
55
|
+
raise SessionNotInitialisedError
|
|
56
|
+
session = _session.get()
|
|
57
|
+
if session is None:
|
|
58
|
+
raise MissingSessionError
|
|
59
|
+
return session
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class DBSession(metaclass=DBSessionMeta):
|
|
63
|
+
def __init__(self, session_args: Dict = None, commit_on_exit: bool = False):
|
|
64
|
+
self.token = None
|
|
65
|
+
self.session_args = session_args or {}
|
|
66
|
+
self.commit_on_exit = commit_on_exit
|
|
67
|
+
|
|
68
|
+
def __enter__(self):
|
|
69
|
+
if not isinstance(_Session, sessionmaker):
|
|
70
|
+
raise SessionNotInitialisedError
|
|
71
|
+
self.token = _session.set(_Session(**self.session_args))
|
|
72
|
+
return type(self)
|
|
73
|
+
|
|
74
|
+
def __exit__(self, exc_type, *_):
|
|
75
|
+
sess = _session.get()
|
|
76
|
+
if exc_type is not None:
|
|
77
|
+
sess.rollback()
|
|
78
|
+
if self.commit_on_exit:
|
|
79
|
+
sess.commit()
|
|
80
|
+
sess.close()
|
|
81
|
+
_session.reset(self.token)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
db: DBSessionMeta = DBSession
|
pykitool/sqliter/plus.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
from typing import Any, Dict, Tuple, Union
|
|
2
|
+
|
|
3
|
+
import sqlmodel
|
|
4
|
+
from pydantic import ConfigDict
|
|
5
|
+
from sqlalchemy.engine import Engine
|
|
6
|
+
from sqlalchemy.sql import Select
|
|
7
|
+
from sqlmodel import SQLModel, select, text
|
|
8
|
+
from sqlmodel.sql.expression import SelectOfScalar
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class EngineException(Exception):
|
|
12
|
+
pass
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
# 类属性描述符:兼容 Python 3.13 移除 @classmethod + @property 链式用法
|
|
16
|
+
class _ClassPropertyDescriptor:
|
|
17
|
+
def __init__(self, func):
|
|
18
|
+
self.func = func
|
|
19
|
+
|
|
20
|
+
def __get__(self, obj, cls=None):
|
|
21
|
+
if cls is None:
|
|
22
|
+
cls = type(obj)
|
|
23
|
+
return self.func(cls)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def classproperty(func):
|
|
27
|
+
return _ClassPropertyDescriptor(func)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class SQLModelPlus(SQLModel):
|
|
31
|
+
model_config = ConfigDict(ignored_types=(_ClassPropertyDescriptor,))
|
|
32
|
+
|
|
33
|
+
__engines__: Dict[str, Engine] = {}
|
|
34
|
+
|
|
35
|
+
@classproperty
|
|
36
|
+
def __get_scope(cls) -> str:
|
|
37
|
+
return str(cls.__scope__) if hasattr(cls, "__scope__") else "default"
|
|
38
|
+
|
|
39
|
+
@classmethod
|
|
40
|
+
def create_tables(cls, *args, **kwargs):
|
|
41
|
+
cls.metadata.create_all(cls.__engines__.get(cls.__get_scope), *args, **kwargs)
|
|
42
|
+
|
|
43
|
+
@classmethod
|
|
44
|
+
def find_by_id(cls, ident: Union[Dict[str, Any], Tuple[Any], Any]):
|
|
45
|
+
with cls.Session() as session:
|
|
46
|
+
return session.get(cls, ident)
|
|
47
|
+
|
|
48
|
+
def save(self):
|
|
49
|
+
try:
|
|
50
|
+
return self.create()
|
|
51
|
+
except:
|
|
52
|
+
return self.update()
|
|
53
|
+
|
|
54
|
+
def create(self):
|
|
55
|
+
with self.__class__.Session() as session:
|
|
56
|
+
session.add(self)
|
|
57
|
+
session.commit()
|
|
58
|
+
session.refresh(self)
|
|
59
|
+
return self
|
|
60
|
+
|
|
61
|
+
def update(self):
|
|
62
|
+
with self.__class__.Session() as session:
|
|
63
|
+
updated_instance = session.merge(self)
|
|
64
|
+
session.commit()
|
|
65
|
+
session.refresh(updated_instance)
|
|
66
|
+
return updated_instance
|
|
67
|
+
|
|
68
|
+
def delete(self):
|
|
69
|
+
with self.__class__.Session() as session:
|
|
70
|
+
session.delete(self)
|
|
71
|
+
session.commit()
|
|
72
|
+
return self
|
|
73
|
+
|
|
74
|
+
@classmethod
|
|
75
|
+
def set_engine(cls, engine: Engine) -> None:
|
|
76
|
+
cls.__engines__[cls.__get_scope] = engine
|
|
77
|
+
|
|
78
|
+
@classmethod
|
|
79
|
+
def get_engine(cls) -> Engine:
|
|
80
|
+
return cls.__engines__.get(cls.__get_scope)
|
|
81
|
+
|
|
82
|
+
@classmethod
|
|
83
|
+
def query(
|
|
84
|
+
cls,
|
|
85
|
+
statement: Union[SelectOfScalar, str],
|
|
86
|
+
params: Union[Dict[str, Any], Tuple[Any]] = {},
|
|
87
|
+
):
|
|
88
|
+
return Query(model_cls=cls, statement=statement, params=params)
|
|
89
|
+
|
|
90
|
+
# 获取数据库 Session 实例
|
|
91
|
+
@classmethod
|
|
92
|
+
def Session(cls) -> sqlmodel.Session:
|
|
93
|
+
engine: Engine | None = cls.__engines__.get(cls.__get_scope)
|
|
94
|
+
if engine is None:
|
|
95
|
+
raise EngineException("Engine is not initialized. Use `.set_engine` method to set engine.")
|
|
96
|
+
return sqlmodel.Session(bind=engine)
|
|
97
|
+
|
|
98
|
+
@classproperty
|
|
99
|
+
def select(cls) -> SelectOfScalar:
|
|
100
|
+
return select(cls)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class Query:
|
|
104
|
+
def __init__(
|
|
105
|
+
self,
|
|
106
|
+
model_cls: SQLModelPlus,
|
|
107
|
+
statement: Union[SelectOfScalar, str],
|
|
108
|
+
params: Union[Dict[str, Any], Tuple[Any], None] = None,
|
|
109
|
+
):
|
|
110
|
+
self.model_cls = model_cls
|
|
111
|
+
self.statement = statement if isinstance(statement, (SelectOfScalar, Select)) else text(statement)
|
|
112
|
+
self.params = params
|
|
113
|
+
|
|
114
|
+
@property
|
|
115
|
+
def all(self):
|
|
116
|
+
with self.model_cls.Session() as session:
|
|
117
|
+
return session.exec(self.statement, params=self.params).all()
|
|
118
|
+
|
|
119
|
+
@property
|
|
120
|
+
def first(self):
|
|
121
|
+
with self.model_cls.Session() as session:
|
|
122
|
+
return session.exec(self.statement, params=self.params).first()
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
__all__ = ["SQLModelPlus", "EngineException"]
|
pykitool/sqliter/repo.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
from contextlib import contextmanager
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
from sqlmodel import Session, SQLModel, delete, func, select, update
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@contextmanager
|
|
8
|
+
def reuse_session_or_new(db_engine=None, session: Optional[Session] = None):
|
|
9
|
+
"""
|
|
10
|
+
Context manager to wrap session reuse or creation logic.
|
|
11
|
+
|
|
12
|
+
:param session: An existing session to reuse. If None,
|
|
13
|
+
a new session is created.
|
|
14
|
+
:param db_engine: The database engine to use if creating a new session.
|
|
15
|
+
"""
|
|
16
|
+
should_close = False
|
|
17
|
+
try:
|
|
18
|
+
# If session is None, create a new session using the provided db_engine
|
|
19
|
+
if session is None:
|
|
20
|
+
if db_engine is None:
|
|
21
|
+
raise ValueError("No session and no db_engine provided to create a session.")
|
|
22
|
+
session = Session(db_engine)
|
|
23
|
+
should_close = True
|
|
24
|
+
|
|
25
|
+
# Yield the session for use in the context block
|
|
26
|
+
yield session
|
|
27
|
+
finally:
|
|
28
|
+
# Close the session if it was created inside this context manager
|
|
29
|
+
if should_close:
|
|
30
|
+
session.close()
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class SQLModelRepo:
|
|
34
|
+
def __init__(self, model: SQLModel, db_engine, init_stmt=None, session=None):
|
|
35
|
+
"""Generic repository for SQLModel.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
model (SQLModel): The SQLModel class (table) for which
|
|
39
|
+
the repo is instantiated.
|
|
40
|
+
db_engine: The SQLAlchemy engine linked to the database.
|
|
41
|
+
|
|
42
|
+
Usage:
|
|
43
|
+
users_repo = SQLModelRepo(model=User, db_engine=engine)
|
|
44
|
+
users_repo.get_by_id(1)
|
|
45
|
+
"""
|
|
46
|
+
self.model = model
|
|
47
|
+
self._init_stmt = init_stmt
|
|
48
|
+
self.db_engine = db_engine
|
|
49
|
+
self.session = session
|
|
50
|
+
|
|
51
|
+
def __call__(self, session):
|
|
52
|
+
new_repo = SQLModelRepo(model=self.model, db_engine=self.db_engine)
|
|
53
|
+
new_repo.session = session
|
|
54
|
+
return new_repo
|
|
55
|
+
|
|
56
|
+
def create(self, **kwargs):
|
|
57
|
+
"""Create a new record and save to the database."""
|
|
58
|
+
instance = self.model(**kwargs)
|
|
59
|
+
with reuse_session_or_new(self.db_engine, self.session) as session:
|
|
60
|
+
session.add(instance)
|
|
61
|
+
session.commit()
|
|
62
|
+
session.refresh(instance)
|
|
63
|
+
return instance
|
|
64
|
+
|
|
65
|
+
def get_by_id(self, id, *fields):
|
|
66
|
+
"""Fetch an object by its primary key."""
|
|
67
|
+
stmt = self.init_stmt(*fields)
|
|
68
|
+
with reuse_session_or_new(self.db_engine, self.session) as session:
|
|
69
|
+
return session.exec(stmt.where(getattr(self.model, "id") == id)).first()
|
|
70
|
+
|
|
71
|
+
def save(self, instance):
|
|
72
|
+
"""Save the current object (instance) to the database."""
|
|
73
|
+
with reuse_session_or_new(self.db_engine, self.session) as session:
|
|
74
|
+
session.add(instance)
|
|
75
|
+
session.commit()
|
|
76
|
+
session.refresh(instance)
|
|
77
|
+
|
|
78
|
+
def save_or_update(self, instance):
|
|
79
|
+
"""Save the current object (instance) to the database."""
|
|
80
|
+
with reuse_session_or_new(self.db_engine, self.session) as session:
|
|
81
|
+
existing_obj = session.exec(select(self.model).where(self.model.id == instance.id)).first()
|
|
82
|
+
if existing_obj:
|
|
83
|
+
for k, v in instance.model_dump().items():
|
|
84
|
+
setattr(existing_obj, k, v)
|
|
85
|
+
session.add(existing_obj)
|
|
86
|
+
instance = existing_obj
|
|
87
|
+
else:
|
|
88
|
+
session.add(instance)
|
|
89
|
+
session.commit()
|
|
90
|
+
session.refresh(instance)
|
|
91
|
+
|
|
92
|
+
def update(self, id, **kwargs):
|
|
93
|
+
"""Record partial update."""
|
|
94
|
+
with reuse_session_or_new(self.db_engine, self.session) as session:
|
|
95
|
+
update_stmt = update(self.model).where(self.model.id == id).values(**kwargs)
|
|
96
|
+
session.execute(update_stmt)
|
|
97
|
+
session.commit()
|
|
98
|
+
|
|
99
|
+
def update_all(self, **kwargs):
|
|
100
|
+
"""Partial update for all selected records."""
|
|
101
|
+
with reuse_session_or_new(self.db_engine, self.session) as session:
|
|
102
|
+
if self._init_stmt:
|
|
103
|
+
update_stmt = update(self.model).where(self.init_stmt().whereclause).values(**kwargs)
|
|
104
|
+
else:
|
|
105
|
+
update_stmt = update(self.model).values(**kwargs)
|
|
106
|
+
session.execute(update_stmt)
|
|
107
|
+
session.commit()
|
|
108
|
+
|
|
109
|
+
def delete(self, instance):
|
|
110
|
+
"""Delete an object from the database."""
|
|
111
|
+
with reuse_session_or_new(self.db_engine, self.session) as session:
|
|
112
|
+
session.delete(instance)
|
|
113
|
+
session.commit()
|
|
114
|
+
|
|
115
|
+
def delete_all(self):
|
|
116
|
+
"""Delete all records in query."""
|
|
117
|
+
with reuse_session_or_new(self.db_engine, self.session) as session:
|
|
118
|
+
if self._init_stmt:
|
|
119
|
+
delete_stmt = delete(self.model).where(self.init_stmt().whereclause)
|
|
120
|
+
else:
|
|
121
|
+
delete_stmt = delete(self.model)
|
|
122
|
+
session.execute(delete_stmt)
|
|
123
|
+
session.commit()
|
|
124
|
+
|
|
125
|
+
def filter(self, *filters, _fields=(), **kwargs) -> "SQLModelRepo":
|
|
126
|
+
"""Filter records based on provided conditions."""
|
|
127
|
+
stmt = self.init_stmt(*_fields).where(*filters, *[getattr(self.model, k) == v if isinstance(k, str) else k == v for k, v in kwargs.items()])
|
|
128
|
+
return SQLModelRepo(init_stmt=stmt, model=self.model, db_engine=self.db_engine, session=self.session)
|
|
129
|
+
|
|
130
|
+
def paginate(self, offset: int, limit: int, order_by: str, desc: bool = False) -> list:
|
|
131
|
+
"""Paginate results"""
|
|
132
|
+
with reuse_session_or_new(self.db_engine, self.session) as session:
|
|
133
|
+
return self._paginate(session, offset, limit, order_by, desc)
|
|
134
|
+
|
|
135
|
+
def paginate_with_total(self, offset: int, limit: int, order_by: str, desc: bool = False) -> (list, int): # type: ignore
|
|
136
|
+
"""Paginate results and fetch total count
|
|
137
|
+
|
|
138
|
+
Returns:
|
|
139
|
+
tuple(list, int) - Items and total count.
|
|
140
|
+
"""
|
|
141
|
+
with reuse_session_or_new(self.db_engine, self.session) as session:
|
|
142
|
+
count_stmt = select(func.count()).select_from(self.init_stmt().subquery())
|
|
143
|
+
count = session.execute(count_stmt).scalar()
|
|
144
|
+
results = self._paginate(session, offset, limit, order_by, desc)
|
|
145
|
+
return results, count
|
|
146
|
+
|
|
147
|
+
def _paginate(self, session, offset: int, limit: int, order_by: str, desc: bool = False) -> list:
|
|
148
|
+
order_by = getattr(self.model, order_by)
|
|
149
|
+
if desc:
|
|
150
|
+
order_by = getattr(order_by, "desc")()
|
|
151
|
+
return session.exec(self.init_stmt().order_by(order_by).offset(offset).limit(limit)).all()
|
|
152
|
+
|
|
153
|
+
def all(self) -> list:
|
|
154
|
+
"""Get all results"""
|
|
155
|
+
with reuse_session_or_new(self.db_engine, self.session) as session:
|
|
156
|
+
return session.exec(self.init_stmt()).all()
|
|
157
|
+
|
|
158
|
+
def count(self):
|
|
159
|
+
"""Get total results count"""
|
|
160
|
+
with reuse_session_or_new(self.db_engine, self.session) as session:
|
|
161
|
+
count_stmt = select(func.count()).select_from(self.init_stmt().subquery())
|
|
162
|
+
return session.execute(count_stmt).scalar()
|
|
163
|
+
|
|
164
|
+
def first(self):
|
|
165
|
+
with reuse_session_or_new(self.db_engine, self.session) as session:
|
|
166
|
+
return session.exec(self.init_stmt()).first()
|
|
167
|
+
|
|
168
|
+
def get_or_404(self, id):
|
|
169
|
+
if not (obj := self.get_by_id(id)):
|
|
170
|
+
raise Exception(status_code=404, detail=f"{self.model.__name__.title()} with id {id} not found")
|
|
171
|
+
return obj
|
|
172
|
+
|
|
173
|
+
def delete_or_404(self, id):
|
|
174
|
+
obj = self.get_or_404(id)
|
|
175
|
+
self.delete(obj)
|
|
176
|
+
|
|
177
|
+
def update_or_404(self, id, **kwargs):
|
|
178
|
+
if self.get_or_404(id):
|
|
179
|
+
self.update(id, **kwargs)
|
|
180
|
+
|
|
181
|
+
def _get_select_obj(self, fields=None):
|
|
182
|
+
return [self.model] if not fields else [getattr(self.model, f) for f in fields]
|
|
183
|
+
|
|
184
|
+
def init_stmt(self, *fields):
|
|
185
|
+
if self._init_stmt is not None:
|
|
186
|
+
return self._init_stmt
|
|
187
|
+
else:
|
|
188
|
+
return select(*self._get_select_obj(fields))
|
|
File without changes
|