agentic-kit-common 0.0.3__tar.gz → 0.0.17__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. {agentic_kit_common-0.0.3 → agentic_kit_common-0.0.17}/PKG-INFO +9 -2
  2. agentic_kit_common-0.0.17/agentic_kit_common/llm/__init__.py +2 -0
  3. agentic_kit_common-0.0.17/agentic_kit_common/llm/openai.py +33 -0
  4. agentic_kit_common-0.0.17/agentic_kit_common/llm/utils.py +61 -0
  5. agentic_kit_common-0.0.17/agentic_kit_common/log/logger.py +69 -0
  6. agentic_kit_common-0.0.17/agentic_kit_common/mongodb/__init__.py +2 -0
  7. agentic_kit_common-0.0.17/agentic_kit_common/mongodb/mongodb_manager.py +294 -0
  8. agentic_kit_common-0.0.17/agentic_kit_common/orm/execution.py +90 -0
  9. {agentic_kit_common-0.0.3 → agentic_kit_common-0.0.17}/agentic_kit_common/orm/manager.py +10 -6
  10. agentic_kit_common-0.0.17/agentic_kit_common/orm/multi_session.py +249 -0
  11. agentic_kit_common-0.0.17/agentic_kit_common/orm/schema.py +21 -0
  12. {agentic_kit_common-0.0.3 → agentic_kit_common-0.0.17}/agentic_kit_common/orm/session.py +1 -0
  13. agentic_kit_common-0.0.17/agentic_kit_common/sms/ali_sms_client.py +54 -0
  14. agentic_kit_common-0.0.17/agentic_kit_common/vector/embedding/__init__.py +1 -0
  15. agentic_kit_common-0.0.17/agentic_kit_common/vector/embedding/embedding.py +55 -0
  16. agentic_kit_common-0.0.17/agentic_kit_common/vector/manager/__init__.py +1 -0
  17. agentic_kit_common-0.0.17/agentic_kit_common/vector/manager/milvus_manager.py +237 -0
  18. agentic_kit_common-0.0.17/agentic_kit_common/vector/schema/__init__.py +3 -0
  19. agentic_kit_common-0.0.17/agentic_kit_common/vector/schema/base.py +7 -0
  20. agentic_kit_common-0.0.17/agentic_kit_common/vector/schema/milvus_schema.py +26 -0
  21. agentic_kit_common-0.0.17/agentic_kit_common/web/__init__.py +0 -0
  22. agentic_kit_common-0.0.17/agentic_kit_common/web/http/__init__.py +0 -0
  23. {agentic_kit_common-0.0.3 → agentic_kit_common-0.0.17}/agentic_kit_common.egg-info/PKG-INFO +9 -2
  24. agentic_kit_common-0.0.17/agentic_kit_common.egg-info/SOURCES.txt +44 -0
  25. {agentic_kit_common-0.0.3 → agentic_kit_common-0.0.17}/agentic_kit_common.egg-info/requires.txt +7 -0
  26. {agentic_kit_common-0.0.3 → agentic_kit_common-0.0.17}/setup.py +19 -2
  27. agentic_kit_common-0.0.17/test/__init__.py +0 -0
  28. agentic_kit_common-0.0.17/test/settings.py +18 -0
  29. agentic_kit_common-0.0.17/test/test_embedding.py +30 -0
  30. agentic_kit_common-0.0.17/test/test_orm.py +47 -0
  31. agentic_kit_common-0.0.17/test/test_vector.py +60 -0
  32. agentic_kit_common-0.0.3/agentic_kit_common.egg-info/SOURCES.txt +0 -20
  33. {agentic_kit_common-0.0.3 → agentic_kit_common-0.0.17}/README.md +0 -0
  34. {agentic_kit_common-0.0.3 → agentic_kit_common-0.0.17}/agentic_kit_common/__init__.py +0 -0
  35. {agentic_kit_common-0.0.3/agentic_kit_common/minio → agentic_kit_common-0.0.17/agentic_kit_common/log}/__init__.py +0 -0
  36. {agentic_kit_common-0.0.3/agentic_kit_common/web → agentic_kit_common-0.0.17/agentic_kit_common/minio}/__init__.py +0 -0
  37. {agentic_kit_common-0.0.3 → agentic_kit_common-0.0.17}/agentic_kit_common/minio/minio_manager.py +0 -0
  38. {agentic_kit_common-0.0.3 → agentic_kit_common-0.0.17}/agentic_kit_common/orm/__init__.py +0 -0
  39. {agentic_kit_common-0.0.3 → agentic_kit_common-0.0.17}/agentic_kit_common/orm/base.py +0 -0
  40. {agentic_kit_common-0.0.3/agentic_kit_common/web/http → agentic_kit_common-0.0.17/agentic_kit_common/sms}/__init__.py +0 -0
  41. {agentic_kit_common-0.0.3/test → agentic_kit_common-0.0.17/agentic_kit_common/vector}/__init__.py +0 -0
  42. {agentic_kit_common-0.0.3 → agentic_kit_common-0.0.17}/agentic_kit_common/web/http/response.py +0 -0
  43. {agentic_kit_common-0.0.3 → agentic_kit_common-0.0.17}/agentic_kit_common.egg-info/dependency_links.txt +0 -0
  44. {agentic_kit_common-0.0.3 → agentic_kit_common-0.0.17}/agentic_kit_common.egg-info/top_level.txt +0 -0
  45. {agentic_kit_common-0.0.3 → agentic_kit_common-0.0.17}/setup.cfg +0 -0
  46. {agentic_kit_common-0.0.3 → agentic_kit_common-0.0.17}/test/config.py +0 -0
  47. {agentic_kit_common-0.0.3 → agentic_kit_common-0.0.17}/test/test_minio.py +0 -0
@@ -1,7 +1,7 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: agentic-kit-common
3
- Version: 0.0.3
4
- Summary: Common utilities and tools for agentic kit ecosystem, including MinIO manager and other shared components
3
+ Version: 0.0.17
4
+ Summary: Common utilities and tools for agentic kit ecosystem
5
5
  Home-page:
6
6
  Author: manson
7
7
  Author-email: manson.li3307@gmail.com
@@ -24,8 +24,15 @@ Requires-Dist: langchain_core
24
24
  Requires-Dist: langgraph
25
25
  Requires-Dist: langchain_community
26
26
  Requires-Dist: langchain_experimental
27
+ Requires-Dist: langchain-openai
27
28
  Requires-Dist: mysql-connector-python
28
29
  Requires-Dist: sqlalchemy
30
+ Requires-Dist: sqlglot
31
+ Requires-Dist: pymilvus
32
+ Requires-Dist: xinference_client
33
+ Requires-Dist: pymongo
34
+ Requires-Dist: aliyun-python-sdk-core
35
+ Requires-Dist: aliyun-python-sdk-dysmsapi
29
36
  Dynamic: author
30
37
  Dynamic: author-email
31
38
  Dynamic: classifier
@@ -0,0 +1,2 @@
1
+ from .openai import create_openai_llm
2
+ from .utils import combine_simple_context
@@ -0,0 +1,33 @@
1
+ import os
2
+
3
+ from dotenv import load_dotenv
4
+ from langchain_openai import ChatOpenAI
5
+
6
+ load_dotenv()
7
+
8
+
9
+ model_name = os.getenv("MODEL_NAME", None)
10
+ openai_api_base = os.getenv("OPENAI_API_BASE", None)
11
+ openai_api_key = os.getenv("OPENAI_API_KEY", 'API_KEY')
12
+ temperature = float(os.getenv("TEMPERATURE", 0.2))
13
+
14
+
15
+ def create_openai_llm(**kwargs):
16
+ _model_name = kwargs.pop('model_name', model_name)
17
+ _openai_api_base = kwargs.pop('openai_api_base', openai_api_base)
18
+ _openai_api_key = kwargs.pop('openai_api_key', openai_api_key)
19
+ if not _openai_api_key:
20
+ _openai_api_key = 'API_KEY'
21
+ _temperature = kwargs.pop('temperature', temperature)
22
+
23
+ assert model_name is not None
24
+ assert openai_api_base is not None
25
+
26
+ llm = ChatOpenAI(
27
+ model_name=_model_name,
28
+ openai_api_base=_openai_api_base,
29
+ openai_api_key=_openai_api_key,
30
+ temperature=_temperature,
31
+ **kwargs
32
+ )
33
+ return llm
@@ -0,0 +1,61 @@
1
+ import json
2
+ import re
3
+ from typing import Any, Dict, List, Union
4
+
5
+ from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage
6
+
7
+
8
+ def combine_simple_context(system: str, user: str = None):
9
+ context: List[Union[BaseMessage, dict[str, Any]]] = [SystemMessage(content=system)]
10
+ if user:
11
+ context.append(HumanMessage(content=system))
12
+ return context
13
+
14
+
15
+ def fix_json_response(text: str) -> Union[Dict, List, str, int, float, bool, None]:
16
+ """
17
+ 去掉大模型返回的 ```json / ``` 等 markdown 标记,并安全反序列化 JSON。
18
+ 若解析失败,返回原字符串。
19
+
20
+ 参数
21
+ ----
22
+ text : str
23
+ 原始响应,可能包含 ```json ... ``` 或其他变体。
24
+
25
+ 返回
26
+ ----
27
+ Python 对象(dict / list / str / int / float / bool / None)
28
+ 解析失败时返回输入字符串本身。
29
+ """
30
+ if not isinstance(text, str):
31
+ return text
32
+
33
+ # 1. 去掉 ```json 或 ``` 包裹(支持开头、结尾、单行、多行)
34
+ cleaned = re.sub(r'^\s*```(?:json|JSON)?\s*\n?', '', text)
35
+ cleaned = re.sub(r'\n?\s*```\s*$', '', cleaned)
36
+
37
+ # 2. 去掉首尾空白
38
+ cleaned = cleaned.strip()
39
+
40
+ # 3. 尝试 JSON 反序列化
41
+ try:
42
+ json.loads(cleaned)
43
+ return cleaned
44
+ except json.JSONDecodeError:
45
+ # 4. 兜底:返回原字符串
46
+ return text
47
+
48
+
49
+ # ----------------- 使用示例 -----------------
50
+ if __name__ == "__main__":
51
+ demo_list = [
52
+ '```json\n{"a": 1, "b": "hello"}\n```',
53
+ "```JSON{'key': 'value'}```",
54
+ "```\n[1, 2, 3]```",
55
+ "plain text",
56
+ {"already_dict": 1},
57
+ ]
58
+ for d in demo_list:
59
+ print("原始:", repr(d))
60
+ print("修复:", repr(fix_json_response(d)))
61
+ print("-" * 30)
@@ -0,0 +1,69 @@
1
+ import logging
2
+ import os
3
+ import uuid
4
+ from logging.handlers import TimedRotatingFileHandler
5
+ from pathlib import Path
6
+ from typing import Union
7
+
8
+ from dotenv import load_dotenv
9
+
10
+ load_dotenv()
11
+
12
+
13
+ logger_root_path = os.getenv("LOGGER_ROOT_PATH", None)
14
+ log_dir = None
15
+
16
+ if logger_root_path:
17
+ if logger_root_path.endswith('/'):
18
+ log_dir = f"{logger_root_path}logs"
19
+ else:
20
+ log_dir = f"{logger_root_path}/logs"
21
+ Path(log_dir).mkdir(parents=True, exist_ok=True)
22
+
23
+
24
+ class LogUtils:
25
+ _handlers = {}
26
+
27
+ # 日志文件根据日期创建
28
+ # 根据参数name,将message写入对应的日志文件内
29
+ @classmethod
30
+ def log(cls, name: str="log", message: str="", log_uuid: str = str(uuid.uuid4()), level: Union[int, str] = logging.INFO) -> None:
31
+ if not log_dir:
32
+ return
33
+
34
+ try:
35
+ logger = logging.getLogger(name)
36
+ logger.setLevel(level)
37
+ logger.propagate = False
38
+ if name not in cls._handlers:
39
+ # 保证日志文件存在
40
+ # LOG_DIR.mkdir(parents=True, exist_ok=True)
41
+ # print(f"Log directory created at: {LOG_DIR}")
42
+ handler = TimedRotatingFileHandler(f'{log_dir}/{name}.log', # 基础文件名
43
+ when='midnight', # 每天午夜
44
+ interval=1,
45
+ backupCount=7, # 保留7天
46
+ encoding='utf-8'
47
+ )
48
+ formatter = logging.Formatter('%(uuid)s - %(asctime)s - %(name)s - %(levelname)s - %(message)s')
49
+ handler.setFormatter(formatter)
50
+ handler.setLevel(level)
51
+
52
+ # 控制台日志
53
+ stream_handler = logging.StreamHandler()
54
+ stream_handler.setLevel(level)
55
+ stream_formatter = logging.Formatter('%(uuid)s - %(asctime)s - %(name)s - %(levelname)s - %(message)s')
56
+ stream_handler.setFormatter(stream_formatter)
57
+ if not logger.hasHandlers():
58
+ logger.addHandler(handler)
59
+ logger.addHandler(stream_handler)
60
+ cls._handlers[name] = handler
61
+
62
+ req_id = log_uuid or str(uuid.uuid4())
63
+ logger.log(level, message, extra={'uuid': req_id}) # 使用UUID
64
+
65
+ for handler in logger.handlers:
66
+ handler.flush()
67
+ except Exception as e:
68
+ # note: ignore error
69
+ pass
@@ -0,0 +1,2 @@
1
+ from .mongodb_manager import MongodbConfig
2
+ from .mongodb_manager import MongodbManager
@@ -0,0 +1,294 @@
1
+ import logging
2
+ from typing import Any, List, Dict, Optional, Tuple
3
+
4
+ from pydantic import BaseModel, Field
5
+ from pymongo import MongoClient
6
+ from pymongo.collection import Collection
7
+ from pymongo.database import Database
8
+ from pymongo.errors import PyMongoError, AutoReconnect
9
+
10
+
11
+ class MongodbConfig(BaseModel):
12
+ """
13
+ Config for mongodb.
14
+ """
15
+
16
+ host: str = Field(..., description="MongoDB服务器地址")
17
+ port: Optional[int] = Field(27017, description="MongoDB服务器端口")
18
+ database: Optional[str] = Field('default', description="数据库名称")
19
+ collection_name: Optional[str] = Field('default', description="集合名称")
20
+
21
+ username: Optional[str] = Field(None, description="MongoDB用户名")
22
+ password: Optional[str] = Field(None, description="MongoDB密码")
23
+ ssl: Optional[bool] = Field(False, description="是否使用SSL连接")
24
+
25
+ max_pool_size: Optional[int] = Field(20, description="最大连接池大小")
26
+ min_pool_size: Optional[int] = Field(5, description="最小连接池大小")
27
+ wait_queue_timeout_ms: Optional[int] = Field(60000, description="等待队列超时时间,单位毫秒")
28
+ server_selection_timeout_ms: Optional[int] = Field(3000, description="服务器选择超时时间,单位毫秒")
29
+
30
+
31
+ class MongodbManager:
32
+ @classmethod
33
+ def create(cls, config: MongodbConfig):
34
+ return cls(config=config)
35
+
36
+ def __init__(self, config: MongodbConfig):
37
+ """
38
+ 初始化 MongoDB 客户端
39
+ """
40
+
41
+ self.host = config.host
42
+ self.port = config.port
43
+ self.database = config.database
44
+ self.username = config.username
45
+ self.password = config.password
46
+ self.ssl = config.ssl
47
+ self.max_pool_size = config.max_pool_size
48
+ self.min_pool_size = config.min_pool_size
49
+ self.wait_queue_timeout_ms = config.wait_queue_timeout_ms
50
+ self.server_selection_timeout_ms = config.server_selection_timeout_ms
51
+
52
+ if self.username and self.password:
53
+ uri = f'mongodb://{self.username}:{self.password}@{self.host}:{self.port}'
54
+ else:
55
+ uri = f'mongodb://{self.host}:{self.port}'
56
+
57
+ self.client = MongoClient(
58
+ uri,
59
+ maxPoolSize=self.max_pool_size,
60
+ minPoolSize=self.min_pool_size,
61
+ waitQueueTimeoutMS=self.wait_queue_timeout_ms,
62
+ serverSelectionTimeoutMS=self.server_selection_timeout_ms,
63
+ ssl=self.ssl,
64
+ )
65
+ self.db: Database = self.client[self.database]
66
+ self.logger = logging.getLogger('MongodbManager')
67
+
68
+ def get_collection(self, collection_name: str) -> Collection:
69
+ """
70
+ 获取集合
71
+ :param collection_name: 集合名称
72
+ :return: 集合对象
73
+ """
74
+ return self.db[collection_name]
75
+
76
+ def create_collection(self, collection_name: str) -> Collection:
77
+ """
78
+ 获取集合
79
+ :param collection_name: 集合名称
80
+ :return: 集合对象
81
+ """
82
+ collection_names = self.db.list_collection_names()
83
+ if collection_name not in collection_names:
84
+ return self.db.create_collection(name=collection_name)
85
+ else:
86
+ return self.get_collection(collection_name)
87
+
88
+ def insert_one(self, collection_name: str, document: Dict) -> Any:
89
+ """
90
+ 插入单个文档
91
+ :param collection_name: 集合名称
92
+ :param document: 要插入的文档
93
+ :return: 插入结果
94
+ """
95
+ collection = self.get_collection(collection_name)
96
+ return collection.insert_one(document)
97
+
98
+ def insert_many(self, collection_name: str, documents: List[Dict]) -> Any:
99
+ """
100
+ 插入多个文档
101
+ :param collection_name: 集合名称
102
+ :param documents: 要插入的文档列表
103
+ :return: 插入结果
104
+ """
105
+ collection = self.get_collection(collection_name)
106
+ return collection.insert_many(documents)
107
+
108
+ def find_one(self, collection_name: str, filter: Optional[Dict] = None, exlude: Optional[Dict] = None) -> Any:
109
+ """
110
+ 查询单个文档
111
+ :param collection_name: 集合名称
112
+ :param filter: 查询条件
113
+ :return: 查询结果
114
+ """
115
+ collection = self.get_collection(collection_name)
116
+ return collection.find_one(filter, exlude)
117
+
118
+ def find(self, collection_name: str, filter: Optional[Dict] = None, limit: int = 0, skip: int = 0, exlude: Optional[Dict] = None) -> List[Dict]:
119
+ """
120
+ 查询多个文档
121
+ :param collection_name: 集合名称
122
+ :param filter: 查询条件
123
+ :param limit: 查询结果数量限制
124
+ :param skip: 跳过的结果数量
125
+ :return: 查询结果列表
126
+ """
127
+ collection = self.get_collection(collection_name)
128
+ return list(collection.find(filter, exlude).limit(limit).skip(skip))
129
+
130
+ def update_one(self, collection_name: str, filter: Dict, update: Dict) -> Any:
131
+ """
132
+ 更新单个文档
133
+ :param collection_name: 集合名称
134
+ :param filter: 查询条件
135
+ :param update: 更新内容
136
+ :return: 更新结果
137
+ """
138
+ collection = self.get_collection(collection_name)
139
+ return collection.update_one(filter, update)
140
+
141
+ def update_many(self, collection_name: str, filter: Dict, update: Dict) -> Any:
142
+ """
143
+ 更新多个文档
144
+ :param collection_name: 集合名称
145
+ :param filter: 查询条件
146
+ :param update: 更新内容
147
+ :return: 更新结果
148
+ """
149
+ collection = self.get_collection(collection_name)
150
+ return collection.update_many(filter, update)
151
+
152
+ def delete_one(self, collection_name: str, filter: Dict) -> Any:
153
+ """
154
+ 删除单个文档
155
+ :param collection_name: 集合名称
156
+ :param filter: 查询条件
157
+ :return: 删除结果
158
+ """
159
+ collection = self.get_collection(collection_name)
160
+ return collection.delete_one(filter)
161
+
162
+ def delete_many(self, collection_name: str, filter: Dict) -> Any:
163
+ """
164
+ 删除多个文档
165
+ :param collection_name: 集合名称
166
+ :param filter: 查询条件
167
+ :return: 删除结果
168
+ """
169
+ collection = self.get_collection(collection_name)
170
+ return collection.delete_many(filter)
171
+
172
+ def count_documents(self, collection_name: str, filter: Optional[Dict] = None) -> int:
173
+ """
174
+ 统计文档数量
175
+ :param collection_name: 集合名称
176
+ :param filter: 查询条件
177
+ :return: 文档数量
178
+ """
179
+ collection = self.get_collection(collection_name)
180
+ return collection.count_documents(filter)
181
+
182
+ def create_index(self, collection_name: str, keys: List[Tuple[str, int]]) -> Any:
183
+ """
184
+ 创建索引
185
+ :param collection_name: 集合名称
186
+ :param keys: 索引键列表,例如 [("field1", 1), ("field2", -1)]
187
+ :return: 索引创建结果
188
+ """
189
+ collection = self.get_collection(collection_name)
190
+ return collection.create_index(keys)
191
+
192
+ def drop_collection(self, collection_name: str) -> None:
193
+ """
194
+ 删除集合
195
+ :param collection_name: 集合名称
196
+ """
197
+ self.db.drop_collection(collection_name)
198
+
199
+ def close(self) -> None:
200
+ """
201
+ 关闭 MongoDB 客户端连接
202
+ """
203
+ self.client.close()
204
+
205
+ def start_session(self):
206
+ """
207
+ 开始一个新的会话
208
+ :return: 会话对象
209
+ """
210
+ return self.client.start_session()
211
+
212
+ def start_transaction(self, session):
213
+ """
214
+ 开始一个新的事务
215
+ :param session: 会话对象
216
+ """
217
+ session.start_transaction()
218
+
219
+ def commit_transaction(self, session):
220
+ """
221
+ 提交事务
222
+ :param session: 会话对象
223
+ """
224
+ session.commit_transaction()
225
+
226
+ def abort_transaction(self, session):
227
+ """
228
+ 回滚事务
229
+ :param session: 会话对象
230
+ """
231
+ session.abort_transaction()
232
+
233
+ def end_session(self, session):
234
+ """
235
+ 结束会话
236
+ :param session: 会话对象
237
+ """
238
+ session.end_session()
239
+
240
+ def bulk_write(self, collection_name: str, requests: List) -> Any:
241
+ """
242
+ 批量写入操作
243
+ :param collection_name: 集合名称
244
+ :param requests: 批量写入请求列表
245
+ :return: 批量写入结果
246
+ """
247
+ collection = self.get_collection(collection_name)
248
+ return collection.bulk_write(requests)
249
+
250
+ def find_with_pagination(self, collection_name: str, filter: Optional[Dict] = None, page: int = 1, page_size: int = 10) -> List[Dict]:
251
+ """
252
+ 分页查询
253
+ :param collection_name: 集合名称
254
+ :param filter: 查询条件
255
+ :param page: 当前页码
256
+ :param page_size: 每页大小
257
+ :return: 查询结果列表
258
+ """
259
+ collection = self.get_collection(collection_name)
260
+ skip = (page - 1) * page_size
261
+ return list(collection.find(filter).skip(skip).limit(page_size))
262
+
263
+ def _execute_with_retry(self, func, *args, **kwargs):
264
+ """
265
+ 带重试机制的执行函数
266
+ :param func: 要执行的函数
267
+ :param args: 位置参数
268
+ :param kwargs: 关键字参数
269
+ :return: 函数执行结果
270
+ """
271
+ max_retries = 3
272
+ retries = 0
273
+ while retries < max_retries:
274
+ try:
275
+ return func(*args, **kwargs)
276
+ except AutoReconnect as e:
277
+ self.logger.warning(f"AutoReconnect occurred: {e}")
278
+ retries += 1
279
+ raise PyMongoError(f"Failed after {max_retries} retries")
280
+
281
+
282
+ if __name__ == '__main__':
283
+ config = MongodbConfig(
284
+ host='45.120.102.142',
285
+ port=8883,
286
+ )
287
+ manager = MongodbManager.create(config)
288
+ print(manager)
289
+ col = manager.get_collection('default')
290
+ print(col)
291
+ res = manager.insert_one('default', {'content': 'test content'})
292
+ print(res)
293
+ doc = manager.find_one('default')
294
+ print(doc)
@@ -0,0 +1,90 @@
1
+ from typing import List, Union, Optional, Set
2
+
3
+ import sqlglot
4
+ from sqlalchemy import text
5
+ from sqlalchemy.orm import Session
6
+ from sqlglot import expressions
7
+
8
+
9
+ def get_operation_type(statement) -> str:
10
+ """获取SQL操作类型"""
11
+ return type(statement).__name__.upper()
12
+
13
+
14
+ def get_sql_operation_info(sql: str) -> Optional[dict]:
15
+ """
16
+ 获取SQL操作信息(用于调试和日志记录)
17
+
18
+ Returns:
19
+ dict: 包含操作类型、涉及表等信息
20
+ """
21
+ try:
22
+ statement = sqlglot.parse_one(sql, read="mysql")
23
+ operation_type = get_operation_type(statement)
24
+
25
+ tables = []
26
+ for table in statement.find_all(expressions.Table):
27
+ table_name = table.name
28
+ if table_name.startswith('`') and table_name.endswith('`'):
29
+ table_name = table_name[1:-1]
30
+ tables.append(table_name)
31
+
32
+ return {
33
+ 'operation': operation_type,
34
+ 'tables': tables,
35
+ 'has_wildcard': bool(statement.find(expressions.Star)) if isinstance(statement, expressions.Select) else False
36
+ }
37
+ except Exception as e:
38
+ print(f"获取SQL操作信息失败: {e}")
39
+ return None
40
+
41
+
42
+ def is_readonly_expression(node: expressions.Expression, allowed_operations: Optional[Set[str]] = None, enable_wildcard_check: bool = True):
43
+ """递归检查表达式树中是否出现写操作节点"""
44
+ # 写操作黑名单节点类型
45
+ write_types = {
46
+ expressions.Insert,
47
+ expressions.Update,
48
+ expressions.Delete,
49
+ expressions.Create,
50
+ expressions.Alter,
51
+ expressions.Drop,
52
+ expressions.Replace,
53
+ expressions.Merge
54
+ }
55
+ for descendant in node.walk():
56
+ if type(descendant) in write_types:
57
+ return False, type(descendant)
58
+
59
+ if allowed_operations:
60
+ op_type = get_operation_type(node)
61
+ # 检查操作类型
62
+ if op_type.lower() not in allowed_operations and op_type.upper() not in allowed_operations:
63
+ return False, op_type
64
+
65
+ if enable_wildcard_check and isinstance(node, expressions.Select) and node.find(expressions.Star):
66
+ return False, expressions.Star
67
+
68
+ return True, None
69
+
70
+
71
+ def session_sql_execute(db_session: Session, sql_text: Union[str, List], format_result: bool = True):
72
+ def __do_execute(_sql: str):
73
+ _result = db_session.execute(text(f"{sql_text}"))
74
+ if format_result:
75
+ columns = list(_result.keys())
76
+ rows = [dict(zip(columns, r)) for r in _result.fetchall()]
77
+ return rows
78
+ else:
79
+ return _result
80
+
81
+ if isinstance(sql_text, str):
82
+ return __do_execute(sql_text)
83
+ elif isinstance(sql_text, list):
84
+ results = []
85
+ for sub_sql_text in sql_text:
86
+ result = __do_execute(sub_sql_text)
87
+ results.append(result)
88
+ return results
89
+ else:
90
+ return []
@@ -14,8 +14,7 @@ class BaseOrmManager(object):
14
14
  def get_by_id(cls, obj_id, close_session_after_curd=False):
15
15
  with get_db_session() as db:
16
16
  query = db.query(cls._model_cls) \
17
- .filter(cls._model_cls.id == obj_id) \
18
- .filter(cls._model_cls.active == True)
17
+ .filter(cls._model_cls.id == obj_id)
19
18
 
20
19
  if hasattr(cls._model_cls, 'active'):
21
20
  query = query.filter(cls._model_cls.active == 1)
@@ -51,7 +50,7 @@ class BaseOrmManager(object):
51
50
  return res
52
51
 
53
52
  @classmethod
54
- def get_list(cls, paginate=False, close_session_after_curd=False, **kwargs):
53
+ def get_list(cls, paginate=False, close_session_after_curd=False, render: bool = True, **kwargs):
55
54
  with get_db_session() as db:
56
55
  query = db.query(cls._model_cls)
57
56
  # """分页查询示例"""
@@ -92,6 +91,8 @@ class BaseOrmManager(object):
92
91
 
93
92
  # 查询记录
94
93
  result = query.all()
94
+ if render:
95
+ result = [cls.render(item) for item in result]
95
96
  pagination = {
96
97
  'total': total,
97
98
  'page': page,
@@ -102,6 +103,8 @@ class BaseOrmManager(object):
102
103
  # return pagination
103
104
  else:
104
105
  res = query.all()
106
+ if render:
107
+ res = [cls.render(item) for item in res]
105
108
  # return res
106
109
 
107
110
  if close_session_after_curd:
@@ -124,9 +127,10 @@ class BaseOrmManager(object):
124
127
  @classmethod
125
128
  def soft_delete_obj(cls, obj):
126
129
  """软删除obj信息"""
127
- cls.update_obj(obj, **{
128
- 'active': False
129
- })
130
+ if hasattr(cls._model_cls, 'active'):
131
+ cls.update_obj(obj, **{
132
+ 'active': False
133
+ })
130
134
  return obj
131
135
 
132
136
  @classmethod