agentic-kit-common 0.0.3__tar.gz → 0.0.17__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {agentic_kit_common-0.0.3 → agentic_kit_common-0.0.17}/PKG-INFO +9 -2
- agentic_kit_common-0.0.17/agentic_kit_common/llm/__init__.py +2 -0
- agentic_kit_common-0.0.17/agentic_kit_common/llm/openai.py +33 -0
- agentic_kit_common-0.0.17/agentic_kit_common/llm/utils.py +61 -0
- agentic_kit_common-0.0.17/agentic_kit_common/log/logger.py +69 -0
- agentic_kit_common-0.0.17/agentic_kit_common/mongodb/__init__.py +2 -0
- agentic_kit_common-0.0.17/agentic_kit_common/mongodb/mongodb_manager.py +294 -0
- agentic_kit_common-0.0.17/agentic_kit_common/orm/execution.py +90 -0
- {agentic_kit_common-0.0.3 → agentic_kit_common-0.0.17}/agentic_kit_common/orm/manager.py +10 -6
- agentic_kit_common-0.0.17/agentic_kit_common/orm/multi_session.py +249 -0
- agentic_kit_common-0.0.17/agentic_kit_common/orm/schema.py +21 -0
- {agentic_kit_common-0.0.3 → agentic_kit_common-0.0.17}/agentic_kit_common/orm/session.py +1 -0
- agentic_kit_common-0.0.17/agentic_kit_common/sms/ali_sms_client.py +54 -0
- agentic_kit_common-0.0.17/agentic_kit_common/vector/embedding/__init__.py +1 -0
- agentic_kit_common-0.0.17/agentic_kit_common/vector/embedding/embedding.py +55 -0
- agentic_kit_common-0.0.17/agentic_kit_common/vector/manager/__init__.py +1 -0
- agentic_kit_common-0.0.17/agentic_kit_common/vector/manager/milvus_manager.py +237 -0
- agentic_kit_common-0.0.17/agentic_kit_common/vector/schema/__init__.py +3 -0
- agentic_kit_common-0.0.17/agentic_kit_common/vector/schema/base.py +7 -0
- agentic_kit_common-0.0.17/agentic_kit_common/vector/schema/milvus_schema.py +26 -0
- agentic_kit_common-0.0.17/agentic_kit_common/web/__init__.py +0 -0
- agentic_kit_common-0.0.17/agentic_kit_common/web/http/__init__.py +0 -0
- {agentic_kit_common-0.0.3 → agentic_kit_common-0.0.17}/agentic_kit_common.egg-info/PKG-INFO +9 -2
- agentic_kit_common-0.0.17/agentic_kit_common.egg-info/SOURCES.txt +44 -0
- {agentic_kit_common-0.0.3 → agentic_kit_common-0.0.17}/agentic_kit_common.egg-info/requires.txt +7 -0
- {agentic_kit_common-0.0.3 → agentic_kit_common-0.0.17}/setup.py +19 -2
- agentic_kit_common-0.0.17/test/__init__.py +0 -0
- agentic_kit_common-0.0.17/test/settings.py +18 -0
- agentic_kit_common-0.0.17/test/test_embedding.py +30 -0
- agentic_kit_common-0.0.17/test/test_orm.py +47 -0
- agentic_kit_common-0.0.17/test/test_vector.py +60 -0
- agentic_kit_common-0.0.3/agentic_kit_common.egg-info/SOURCES.txt +0 -20
- {agentic_kit_common-0.0.3 → agentic_kit_common-0.0.17}/README.md +0 -0
- {agentic_kit_common-0.0.3 → agentic_kit_common-0.0.17}/agentic_kit_common/__init__.py +0 -0
- {agentic_kit_common-0.0.3/agentic_kit_common/minio → agentic_kit_common-0.0.17/agentic_kit_common/log}/__init__.py +0 -0
- {agentic_kit_common-0.0.3/agentic_kit_common/web → agentic_kit_common-0.0.17/agentic_kit_common/minio}/__init__.py +0 -0
- {agentic_kit_common-0.0.3 → agentic_kit_common-0.0.17}/agentic_kit_common/minio/minio_manager.py +0 -0
- {agentic_kit_common-0.0.3 → agentic_kit_common-0.0.17}/agentic_kit_common/orm/__init__.py +0 -0
- {agentic_kit_common-0.0.3 → agentic_kit_common-0.0.17}/agentic_kit_common/orm/base.py +0 -0
- {agentic_kit_common-0.0.3/agentic_kit_common/web/http → agentic_kit_common-0.0.17/agentic_kit_common/sms}/__init__.py +0 -0
- {agentic_kit_common-0.0.3/test → agentic_kit_common-0.0.17/agentic_kit_common/vector}/__init__.py +0 -0
- {agentic_kit_common-0.0.3 → agentic_kit_common-0.0.17}/agentic_kit_common/web/http/response.py +0 -0
- {agentic_kit_common-0.0.3 → agentic_kit_common-0.0.17}/agentic_kit_common.egg-info/dependency_links.txt +0 -0
- {agentic_kit_common-0.0.3 → agentic_kit_common-0.0.17}/agentic_kit_common.egg-info/top_level.txt +0 -0
- {agentic_kit_common-0.0.3 → agentic_kit_common-0.0.17}/setup.cfg +0 -0
- {agentic_kit_common-0.0.3 → agentic_kit_common-0.0.17}/test/config.py +0 -0
- {agentic_kit_common-0.0.3 → agentic_kit_common-0.0.17}/test/test_minio.py +0 -0
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: agentic-kit-common
|
|
3
|
-
Version: 0.0.
|
|
4
|
-
Summary: Common utilities and tools for agentic kit ecosystem
|
|
3
|
+
Version: 0.0.17
|
|
4
|
+
Summary: Common utilities and tools for agentic kit ecosystem
|
|
5
5
|
Home-page:
|
|
6
6
|
Author: manson
|
|
7
7
|
Author-email: manson.li3307@gmail.com
|
|
@@ -24,8 +24,15 @@ Requires-Dist: langchain_core
|
|
|
24
24
|
Requires-Dist: langgraph
|
|
25
25
|
Requires-Dist: langchain_community
|
|
26
26
|
Requires-Dist: langchain_experimental
|
|
27
|
+
Requires-Dist: langchain-openai
|
|
27
28
|
Requires-Dist: mysql-connector-python
|
|
28
29
|
Requires-Dist: sqlalchemy
|
|
30
|
+
Requires-Dist: sqlglot
|
|
31
|
+
Requires-Dist: pymilvus
|
|
32
|
+
Requires-Dist: xinference_client
|
|
33
|
+
Requires-Dist: pymongo
|
|
34
|
+
Requires-Dist: aliyun-python-sdk-core
|
|
35
|
+
Requires-Dist: aliyun-python-sdk-dysmsapi
|
|
29
36
|
Dynamic: author
|
|
30
37
|
Dynamic: author-email
|
|
31
38
|
Dynamic: classifier
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
from dotenv import load_dotenv
|
|
4
|
+
from langchain_openai import ChatOpenAI
|
|
5
|
+
|
|
6
|
+
load_dotenv()
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
model_name = os.getenv("MODEL_NAME", None)
|
|
10
|
+
openai_api_base = os.getenv("OPENAI_API_BASE", None)
|
|
11
|
+
openai_api_key = os.getenv("OPENAI_API_KEY", 'API_KEY')
|
|
12
|
+
temperature = float(os.getenv("TEMPERATURE", 0.2))
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def create_openai_llm(**kwargs):
|
|
16
|
+
_model_name = kwargs.pop('model_name', model_name)
|
|
17
|
+
_openai_api_base = kwargs.pop('openai_api_base', openai_api_base)
|
|
18
|
+
_openai_api_key = kwargs.pop('openai_api_key', openai_api_key)
|
|
19
|
+
if not _openai_api_key:
|
|
20
|
+
_openai_api_key = 'API_KEY'
|
|
21
|
+
_temperature = kwargs.pop('temperature', temperature)
|
|
22
|
+
|
|
23
|
+
assert model_name is not None
|
|
24
|
+
assert openai_api_base is not None
|
|
25
|
+
|
|
26
|
+
llm = ChatOpenAI(
|
|
27
|
+
model_name=_model_name,
|
|
28
|
+
openai_api_base=_openai_api_base,
|
|
29
|
+
openai_api_key=_openai_api_key,
|
|
30
|
+
temperature=_temperature,
|
|
31
|
+
**kwargs
|
|
32
|
+
)
|
|
33
|
+
return llm
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import re
|
|
3
|
+
from typing import Any, Dict, List, Union
|
|
4
|
+
|
|
5
|
+
from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def combine_simple_context(system: str, user: str = None):
|
|
9
|
+
context: List[Union[BaseMessage, dict[str, Any]]] = [SystemMessage(content=system)]
|
|
10
|
+
if user:
|
|
11
|
+
context.append(HumanMessage(content=system))
|
|
12
|
+
return context
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def fix_json_response(text: str) -> Union[Dict, List, str, int, float, bool, None]:
|
|
16
|
+
"""
|
|
17
|
+
去掉大模型返回的 ```json / ``` 等 markdown 标记,并安全反序列化 JSON。
|
|
18
|
+
若解析失败,返回原字符串。
|
|
19
|
+
|
|
20
|
+
参数
|
|
21
|
+
----
|
|
22
|
+
text : str
|
|
23
|
+
原始响应,可能包含 ```json ... ``` 或其他变体。
|
|
24
|
+
|
|
25
|
+
返回
|
|
26
|
+
----
|
|
27
|
+
Python 对象(dict / list / str / int / float / bool / None)
|
|
28
|
+
解析失败时返回输入字符串本身。
|
|
29
|
+
"""
|
|
30
|
+
if not isinstance(text, str):
|
|
31
|
+
return text
|
|
32
|
+
|
|
33
|
+
# 1. 去掉 ```json 或 ``` 包裹(支持开头、结尾、单行、多行)
|
|
34
|
+
cleaned = re.sub(r'^\s*```(?:json|JSON)?\s*\n?', '', text)
|
|
35
|
+
cleaned = re.sub(r'\n?\s*```\s*$', '', cleaned)
|
|
36
|
+
|
|
37
|
+
# 2. 去掉首尾空白
|
|
38
|
+
cleaned = cleaned.strip()
|
|
39
|
+
|
|
40
|
+
# 3. 尝试 JSON 反序列化
|
|
41
|
+
try:
|
|
42
|
+
json.loads(cleaned)
|
|
43
|
+
return cleaned
|
|
44
|
+
except json.JSONDecodeError:
|
|
45
|
+
# 4. 兜底:返回原字符串
|
|
46
|
+
return text
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
# ----------------- 使用示例 -----------------
|
|
50
|
+
if __name__ == "__main__":
|
|
51
|
+
demo_list = [
|
|
52
|
+
'```json\n{"a": 1, "b": "hello"}\n```',
|
|
53
|
+
"```JSON{'key': 'value'}```",
|
|
54
|
+
"```\n[1, 2, 3]```",
|
|
55
|
+
"plain text",
|
|
56
|
+
{"already_dict": 1},
|
|
57
|
+
]
|
|
58
|
+
for d in demo_list:
|
|
59
|
+
print("原始:", repr(d))
|
|
60
|
+
print("修复:", repr(fix_json_response(d)))
|
|
61
|
+
print("-" * 30)
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
import uuid
|
|
4
|
+
from logging.handlers import TimedRotatingFileHandler
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Union
|
|
7
|
+
|
|
8
|
+
from dotenv import load_dotenv
|
|
9
|
+
|
|
10
|
+
load_dotenv()
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
logger_root_path = os.getenv("LOGGER_ROOT_PATH", None)
|
|
14
|
+
log_dir = None
|
|
15
|
+
|
|
16
|
+
if logger_root_path:
|
|
17
|
+
if logger_root_path.endswith('/'):
|
|
18
|
+
log_dir = f"{logger_root_path}logs"
|
|
19
|
+
else:
|
|
20
|
+
log_dir = f"{logger_root_path}/logs"
|
|
21
|
+
Path(log_dir).mkdir(parents=True, exist_ok=True)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class LogUtils:
|
|
25
|
+
_handlers = {}
|
|
26
|
+
|
|
27
|
+
# 日志文件根据日期创建
|
|
28
|
+
# 根据参数name,将message写入对应的日志文件内
|
|
29
|
+
@classmethod
|
|
30
|
+
def log(cls, name: str="log", message: str="", log_uuid: str = str(uuid.uuid4()), level: Union[int, str] = logging.INFO) -> None:
|
|
31
|
+
if not log_dir:
|
|
32
|
+
return
|
|
33
|
+
|
|
34
|
+
try:
|
|
35
|
+
logger = logging.getLogger(name)
|
|
36
|
+
logger.setLevel(level)
|
|
37
|
+
logger.propagate = False
|
|
38
|
+
if name not in cls._handlers:
|
|
39
|
+
# 保证日志文件存在
|
|
40
|
+
# LOG_DIR.mkdir(parents=True, exist_ok=True)
|
|
41
|
+
# print(f"Log directory created at: {LOG_DIR}")
|
|
42
|
+
handler = TimedRotatingFileHandler(f'{log_dir}/{name}.log', # 基础文件名
|
|
43
|
+
when='midnight', # 每天午夜
|
|
44
|
+
interval=1,
|
|
45
|
+
backupCount=7, # 保留7天
|
|
46
|
+
encoding='utf-8'
|
|
47
|
+
)
|
|
48
|
+
formatter = logging.Formatter('%(uuid)s - %(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
49
|
+
handler.setFormatter(formatter)
|
|
50
|
+
handler.setLevel(level)
|
|
51
|
+
|
|
52
|
+
# 控制台日志
|
|
53
|
+
stream_handler = logging.StreamHandler()
|
|
54
|
+
stream_handler.setLevel(level)
|
|
55
|
+
stream_formatter = logging.Formatter('%(uuid)s - %(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
56
|
+
stream_handler.setFormatter(stream_formatter)
|
|
57
|
+
if not logger.hasHandlers():
|
|
58
|
+
logger.addHandler(handler)
|
|
59
|
+
logger.addHandler(stream_handler)
|
|
60
|
+
cls._handlers[name] = handler
|
|
61
|
+
|
|
62
|
+
req_id = log_uuid or str(uuid.uuid4())
|
|
63
|
+
logger.log(level, message, extra={'uuid': req_id}) # 使用UUID
|
|
64
|
+
|
|
65
|
+
for handler in logger.handlers:
|
|
66
|
+
handler.flush()
|
|
67
|
+
except Exception as e:
|
|
68
|
+
# note: ignore error
|
|
69
|
+
pass
|
|
@@ -0,0 +1,294 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Any, List, Dict, Optional, Tuple
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel, Field
|
|
5
|
+
from pymongo import MongoClient
|
|
6
|
+
from pymongo.collection import Collection
|
|
7
|
+
from pymongo.database import Database
|
|
8
|
+
from pymongo.errors import PyMongoError, AutoReconnect
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class MongodbConfig(BaseModel):
|
|
12
|
+
"""
|
|
13
|
+
Config for mongodb.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
host: str = Field(..., description="MongoDB服务器地址")
|
|
17
|
+
port: Optional[int] = Field(27017, description="MongoDB服务器端口")
|
|
18
|
+
database: Optional[str] = Field('default', description="数据库名称")
|
|
19
|
+
collection_name: Optional[str] = Field('default', description="集合名称")
|
|
20
|
+
|
|
21
|
+
username: Optional[str] = Field(None, description="MongoDB用户名")
|
|
22
|
+
password: Optional[str] = Field(None, description="MongoDB密码")
|
|
23
|
+
ssl: Optional[bool] = Field(False, description="是否使用SSL连接")
|
|
24
|
+
|
|
25
|
+
max_pool_size: Optional[int] = Field(20, description="最大连接池大小")
|
|
26
|
+
min_pool_size: Optional[int] = Field(5, description="最小连接池大小")
|
|
27
|
+
wait_queue_timeout_ms: Optional[int] = Field(60000, description="等待队列超时时间,单位毫秒")
|
|
28
|
+
server_selection_timeout_ms: Optional[int] = Field(3000, description="服务器选择超时时间,单位毫秒")
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class MongodbManager:
|
|
32
|
+
@classmethod
|
|
33
|
+
def create(cls, config: MongodbConfig):
|
|
34
|
+
return cls(config=config)
|
|
35
|
+
|
|
36
|
+
def __init__(self, config: MongodbConfig):
|
|
37
|
+
"""
|
|
38
|
+
初始化 MongoDB 客户端
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
self.host = config.host
|
|
42
|
+
self.port = config.port
|
|
43
|
+
self.database = config.database
|
|
44
|
+
self.username = config.username
|
|
45
|
+
self.password = config.password
|
|
46
|
+
self.ssl = config.ssl
|
|
47
|
+
self.max_pool_size = config.max_pool_size
|
|
48
|
+
self.min_pool_size = config.min_pool_size
|
|
49
|
+
self.wait_queue_timeout_ms = config.wait_queue_timeout_ms
|
|
50
|
+
self.server_selection_timeout_ms = config.server_selection_timeout_ms
|
|
51
|
+
|
|
52
|
+
if self.username and self.password:
|
|
53
|
+
uri = f'mongodb://{self.username}:{self.password}@{self.host}:{self.port}'
|
|
54
|
+
else:
|
|
55
|
+
uri = f'mongodb://{self.host}:{self.port}'
|
|
56
|
+
|
|
57
|
+
self.client = MongoClient(
|
|
58
|
+
uri,
|
|
59
|
+
maxPoolSize=self.max_pool_size,
|
|
60
|
+
minPoolSize=self.min_pool_size,
|
|
61
|
+
waitQueueTimeoutMS=self.wait_queue_timeout_ms,
|
|
62
|
+
serverSelectionTimeoutMS=self.server_selection_timeout_ms,
|
|
63
|
+
ssl=self.ssl,
|
|
64
|
+
)
|
|
65
|
+
self.db: Database = self.client[self.database]
|
|
66
|
+
self.logger = logging.getLogger('MongodbManager')
|
|
67
|
+
|
|
68
|
+
def get_collection(self, collection_name: str) -> Collection:
|
|
69
|
+
"""
|
|
70
|
+
获取集合
|
|
71
|
+
:param collection_name: 集合名称
|
|
72
|
+
:return: 集合对象
|
|
73
|
+
"""
|
|
74
|
+
return self.db[collection_name]
|
|
75
|
+
|
|
76
|
+
def create_collection(self, collection_name: str) -> Collection:
|
|
77
|
+
"""
|
|
78
|
+
获取集合
|
|
79
|
+
:param collection_name: 集合名称
|
|
80
|
+
:return: 集合对象
|
|
81
|
+
"""
|
|
82
|
+
collection_names = self.db.list_collection_names()
|
|
83
|
+
if collection_name not in collection_names:
|
|
84
|
+
return self.db.create_collection(name=collection_name)
|
|
85
|
+
else:
|
|
86
|
+
return self.get_collection(collection_name)
|
|
87
|
+
|
|
88
|
+
def insert_one(self, collection_name: str, document: Dict) -> Any:
|
|
89
|
+
"""
|
|
90
|
+
插入单个文档
|
|
91
|
+
:param collection_name: 集合名称
|
|
92
|
+
:param document: 要插入的文档
|
|
93
|
+
:return: 插入结果
|
|
94
|
+
"""
|
|
95
|
+
collection = self.get_collection(collection_name)
|
|
96
|
+
return collection.insert_one(document)
|
|
97
|
+
|
|
98
|
+
def insert_many(self, collection_name: str, documents: List[Dict]) -> Any:
|
|
99
|
+
"""
|
|
100
|
+
插入多个文档
|
|
101
|
+
:param collection_name: 集合名称
|
|
102
|
+
:param documents: 要插入的文档列表
|
|
103
|
+
:return: 插入结果
|
|
104
|
+
"""
|
|
105
|
+
collection = self.get_collection(collection_name)
|
|
106
|
+
return collection.insert_many(documents)
|
|
107
|
+
|
|
108
|
+
def find_one(self, collection_name: str, filter: Optional[Dict] = None, exlude: Optional[Dict] = None) -> Any:
|
|
109
|
+
"""
|
|
110
|
+
查询单个文档
|
|
111
|
+
:param collection_name: 集合名称
|
|
112
|
+
:param filter: 查询条件
|
|
113
|
+
:return: 查询结果
|
|
114
|
+
"""
|
|
115
|
+
collection = self.get_collection(collection_name)
|
|
116
|
+
return collection.find_one(filter, exlude)
|
|
117
|
+
|
|
118
|
+
def find(self, collection_name: str, filter: Optional[Dict] = None, limit: int = 0, skip: int = 0, exlude: Optional[Dict] = None) -> List[Dict]:
|
|
119
|
+
"""
|
|
120
|
+
查询多个文档
|
|
121
|
+
:param collection_name: 集合名称
|
|
122
|
+
:param filter: 查询条件
|
|
123
|
+
:param limit: 查询结果数量限制
|
|
124
|
+
:param skip: 跳过的结果数量
|
|
125
|
+
:return: 查询结果列表
|
|
126
|
+
"""
|
|
127
|
+
collection = self.get_collection(collection_name)
|
|
128
|
+
return list(collection.find(filter, exlude).limit(limit).skip(skip))
|
|
129
|
+
|
|
130
|
+
def update_one(self, collection_name: str, filter: Dict, update: Dict) -> Any:
|
|
131
|
+
"""
|
|
132
|
+
更新单个文档
|
|
133
|
+
:param collection_name: 集合名称
|
|
134
|
+
:param filter: 查询条件
|
|
135
|
+
:param update: 更新内容
|
|
136
|
+
:return: 更新结果
|
|
137
|
+
"""
|
|
138
|
+
collection = self.get_collection(collection_name)
|
|
139
|
+
return collection.update_one(filter, update)
|
|
140
|
+
|
|
141
|
+
def update_many(self, collection_name: str, filter: Dict, update: Dict) -> Any:
|
|
142
|
+
"""
|
|
143
|
+
更新多个文档
|
|
144
|
+
:param collection_name: 集合名称
|
|
145
|
+
:param filter: 查询条件
|
|
146
|
+
:param update: 更新内容
|
|
147
|
+
:return: 更新结果
|
|
148
|
+
"""
|
|
149
|
+
collection = self.get_collection(collection_name)
|
|
150
|
+
return collection.update_many(filter, update)
|
|
151
|
+
|
|
152
|
+
def delete_one(self, collection_name: str, filter: Dict) -> Any:
|
|
153
|
+
"""
|
|
154
|
+
删除单个文档
|
|
155
|
+
:param collection_name: 集合名称
|
|
156
|
+
:param filter: 查询条件
|
|
157
|
+
:return: 删除结果
|
|
158
|
+
"""
|
|
159
|
+
collection = self.get_collection(collection_name)
|
|
160
|
+
return collection.delete_one(filter)
|
|
161
|
+
|
|
162
|
+
def delete_many(self, collection_name: str, filter: Dict) -> Any:
|
|
163
|
+
"""
|
|
164
|
+
删除多个文档
|
|
165
|
+
:param collection_name: 集合名称
|
|
166
|
+
:param filter: 查询条件
|
|
167
|
+
:return: 删除结果
|
|
168
|
+
"""
|
|
169
|
+
collection = self.get_collection(collection_name)
|
|
170
|
+
return collection.delete_many(filter)
|
|
171
|
+
|
|
172
|
+
def count_documents(self, collection_name: str, filter: Optional[Dict] = None) -> int:
|
|
173
|
+
"""
|
|
174
|
+
统计文档数量
|
|
175
|
+
:param collection_name: 集合名称
|
|
176
|
+
:param filter: 查询条件
|
|
177
|
+
:return: 文档数量
|
|
178
|
+
"""
|
|
179
|
+
collection = self.get_collection(collection_name)
|
|
180
|
+
return collection.count_documents(filter)
|
|
181
|
+
|
|
182
|
+
def create_index(self, collection_name: str, keys: List[Tuple[str, int]]) -> Any:
|
|
183
|
+
"""
|
|
184
|
+
创建索引
|
|
185
|
+
:param collection_name: 集合名称
|
|
186
|
+
:param keys: 索引键列表,例如 [("field1", 1), ("field2", -1)]
|
|
187
|
+
:return: 索引创建结果
|
|
188
|
+
"""
|
|
189
|
+
collection = self.get_collection(collection_name)
|
|
190
|
+
return collection.create_index(keys)
|
|
191
|
+
|
|
192
|
+
def drop_collection(self, collection_name: str) -> None:
|
|
193
|
+
"""
|
|
194
|
+
删除集合
|
|
195
|
+
:param collection_name: 集合名称
|
|
196
|
+
"""
|
|
197
|
+
self.db.drop_collection(collection_name)
|
|
198
|
+
|
|
199
|
+
def close(self) -> None:
|
|
200
|
+
"""
|
|
201
|
+
关闭 MongoDB 客户端连接
|
|
202
|
+
"""
|
|
203
|
+
self.client.close()
|
|
204
|
+
|
|
205
|
+
def start_session(self):
|
|
206
|
+
"""
|
|
207
|
+
开始一个新的会话
|
|
208
|
+
:return: 会话对象
|
|
209
|
+
"""
|
|
210
|
+
return self.client.start_session()
|
|
211
|
+
|
|
212
|
+
def start_transaction(self, session):
|
|
213
|
+
"""
|
|
214
|
+
开始一个新的事务
|
|
215
|
+
:param session: 会话对象
|
|
216
|
+
"""
|
|
217
|
+
session.start_transaction()
|
|
218
|
+
|
|
219
|
+
def commit_transaction(self, session):
|
|
220
|
+
"""
|
|
221
|
+
提交事务
|
|
222
|
+
:param session: 会话对象
|
|
223
|
+
"""
|
|
224
|
+
session.commit_transaction()
|
|
225
|
+
|
|
226
|
+
def abort_transaction(self, session):
|
|
227
|
+
"""
|
|
228
|
+
回滚事务
|
|
229
|
+
:param session: 会话对象
|
|
230
|
+
"""
|
|
231
|
+
session.abort_transaction()
|
|
232
|
+
|
|
233
|
+
def end_session(self, session):
|
|
234
|
+
"""
|
|
235
|
+
结束会话
|
|
236
|
+
:param session: 会话对象
|
|
237
|
+
"""
|
|
238
|
+
session.end_session()
|
|
239
|
+
|
|
240
|
+
def bulk_write(self, collection_name: str, requests: List) -> Any:
|
|
241
|
+
"""
|
|
242
|
+
批量写入操作
|
|
243
|
+
:param collection_name: 集合名称
|
|
244
|
+
:param requests: 批量写入请求列表
|
|
245
|
+
:return: 批量写入结果
|
|
246
|
+
"""
|
|
247
|
+
collection = self.get_collection(collection_name)
|
|
248
|
+
return collection.bulk_write(requests)
|
|
249
|
+
|
|
250
|
+
def find_with_pagination(self, collection_name: str, filter: Optional[Dict] = None, page: int = 1, page_size: int = 10) -> List[Dict]:
|
|
251
|
+
"""
|
|
252
|
+
分页查询
|
|
253
|
+
:param collection_name: 集合名称
|
|
254
|
+
:param filter: 查询条件
|
|
255
|
+
:param page: 当前页码
|
|
256
|
+
:param page_size: 每页大小
|
|
257
|
+
:return: 查询结果列表
|
|
258
|
+
"""
|
|
259
|
+
collection = self.get_collection(collection_name)
|
|
260
|
+
skip = (page - 1) * page_size
|
|
261
|
+
return list(collection.find(filter).skip(skip).limit(page_size))
|
|
262
|
+
|
|
263
|
+
def _execute_with_retry(self, func, *args, **kwargs):
|
|
264
|
+
"""
|
|
265
|
+
带重试机制的执行函数
|
|
266
|
+
:param func: 要执行的函数
|
|
267
|
+
:param args: 位置参数
|
|
268
|
+
:param kwargs: 关键字参数
|
|
269
|
+
:return: 函数执行结果
|
|
270
|
+
"""
|
|
271
|
+
max_retries = 3
|
|
272
|
+
retries = 0
|
|
273
|
+
while retries < max_retries:
|
|
274
|
+
try:
|
|
275
|
+
return func(*args, **kwargs)
|
|
276
|
+
except AutoReconnect as e:
|
|
277
|
+
self.logger.warning(f"AutoReconnect occurred: {e}")
|
|
278
|
+
retries += 1
|
|
279
|
+
raise PyMongoError(f"Failed after {max_retries} retries")
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
if __name__ == '__main__':
|
|
283
|
+
config = MongodbConfig(
|
|
284
|
+
host='45.120.102.142',
|
|
285
|
+
port=8883,
|
|
286
|
+
)
|
|
287
|
+
manager = MongodbManager.create(config)
|
|
288
|
+
print(manager)
|
|
289
|
+
col = manager.get_collection('default')
|
|
290
|
+
print(col)
|
|
291
|
+
res = manager.insert_one('default', {'content': 'test content'})
|
|
292
|
+
print(res)
|
|
293
|
+
doc = manager.find_one('default')
|
|
294
|
+
print(doc)
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
from typing import List, Union, Optional, Set
|
|
2
|
+
|
|
3
|
+
import sqlglot
|
|
4
|
+
from sqlalchemy import text
|
|
5
|
+
from sqlalchemy.orm import Session
|
|
6
|
+
from sqlglot import expressions
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def get_operation_type(statement) -> str:
|
|
10
|
+
"""获取SQL操作类型"""
|
|
11
|
+
return type(statement).__name__.upper()
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def get_sql_operation_info(sql: str) -> Optional[dict]:
|
|
15
|
+
"""
|
|
16
|
+
获取SQL操作信息(用于调试和日志记录)
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
dict: 包含操作类型、涉及表等信息
|
|
20
|
+
"""
|
|
21
|
+
try:
|
|
22
|
+
statement = sqlglot.parse_one(sql, read="mysql")
|
|
23
|
+
operation_type = get_operation_type(statement)
|
|
24
|
+
|
|
25
|
+
tables = []
|
|
26
|
+
for table in statement.find_all(expressions.Table):
|
|
27
|
+
table_name = table.name
|
|
28
|
+
if table_name.startswith('`') and table_name.endswith('`'):
|
|
29
|
+
table_name = table_name[1:-1]
|
|
30
|
+
tables.append(table_name)
|
|
31
|
+
|
|
32
|
+
return {
|
|
33
|
+
'operation': operation_type,
|
|
34
|
+
'tables': tables,
|
|
35
|
+
'has_wildcard': bool(statement.find(expressions.Star)) if isinstance(statement, expressions.Select) else False
|
|
36
|
+
}
|
|
37
|
+
except Exception as e:
|
|
38
|
+
print(f"获取SQL操作信息失败: {e}")
|
|
39
|
+
return None
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def is_readonly_expression(node: expressions.Expression, allowed_operations: Optional[Set[str]] = None, enable_wildcard_check: bool = True):
|
|
43
|
+
"""递归检查表达式树中是否出现写操作节点"""
|
|
44
|
+
# 写操作黑名单节点类型
|
|
45
|
+
write_types = {
|
|
46
|
+
expressions.Insert,
|
|
47
|
+
expressions.Update,
|
|
48
|
+
expressions.Delete,
|
|
49
|
+
expressions.Create,
|
|
50
|
+
expressions.Alter,
|
|
51
|
+
expressions.Drop,
|
|
52
|
+
expressions.Replace,
|
|
53
|
+
expressions.Merge
|
|
54
|
+
}
|
|
55
|
+
for descendant in node.walk():
|
|
56
|
+
if type(descendant) in write_types:
|
|
57
|
+
return False, type(descendant)
|
|
58
|
+
|
|
59
|
+
if allowed_operations:
|
|
60
|
+
op_type = get_operation_type(node)
|
|
61
|
+
# 检查操作类型
|
|
62
|
+
if op_type.lower() not in allowed_operations and op_type.upper() not in allowed_operations:
|
|
63
|
+
return False, op_type
|
|
64
|
+
|
|
65
|
+
if enable_wildcard_check and isinstance(node, expressions.Select) and node.find(expressions.Star):
|
|
66
|
+
return False, expressions.Star
|
|
67
|
+
|
|
68
|
+
return True, None
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def session_sql_execute(db_session: Session, sql_text: Union[str, List], format_result: bool = True):
|
|
72
|
+
def __do_execute(_sql: str):
|
|
73
|
+
_result = db_session.execute(text(f"{sql_text}"))
|
|
74
|
+
if format_result:
|
|
75
|
+
columns = list(_result.keys())
|
|
76
|
+
rows = [dict(zip(columns, r)) for r in _result.fetchall()]
|
|
77
|
+
return rows
|
|
78
|
+
else:
|
|
79
|
+
return _result
|
|
80
|
+
|
|
81
|
+
if isinstance(sql_text, str):
|
|
82
|
+
return __do_execute(sql_text)
|
|
83
|
+
elif isinstance(sql_text, list):
|
|
84
|
+
results = []
|
|
85
|
+
for sub_sql_text in sql_text:
|
|
86
|
+
result = __do_execute(sub_sql_text)
|
|
87
|
+
results.append(result)
|
|
88
|
+
return results
|
|
89
|
+
else:
|
|
90
|
+
return []
|
|
@@ -14,8 +14,7 @@ class BaseOrmManager(object):
|
|
|
14
14
|
def get_by_id(cls, obj_id, close_session_after_curd=False):
|
|
15
15
|
with get_db_session() as db:
|
|
16
16
|
query = db.query(cls._model_cls) \
|
|
17
|
-
.filter(cls._model_cls.id == obj_id)
|
|
18
|
-
.filter(cls._model_cls.active == True)
|
|
17
|
+
.filter(cls._model_cls.id == obj_id)
|
|
19
18
|
|
|
20
19
|
if hasattr(cls._model_cls, 'active'):
|
|
21
20
|
query = query.filter(cls._model_cls.active == 1)
|
|
@@ -51,7 +50,7 @@ class BaseOrmManager(object):
|
|
|
51
50
|
return res
|
|
52
51
|
|
|
53
52
|
@classmethod
|
|
54
|
-
def get_list(cls, paginate=False, close_session_after_curd=False, **kwargs):
|
|
53
|
+
def get_list(cls, paginate=False, close_session_after_curd=False, render: bool = True, **kwargs):
|
|
55
54
|
with get_db_session() as db:
|
|
56
55
|
query = db.query(cls._model_cls)
|
|
57
56
|
# """分页查询示例"""
|
|
@@ -92,6 +91,8 @@ class BaseOrmManager(object):
|
|
|
92
91
|
|
|
93
92
|
# 查询记录
|
|
94
93
|
result = query.all()
|
|
94
|
+
if render:
|
|
95
|
+
result = [cls.render(item) for item in result]
|
|
95
96
|
pagination = {
|
|
96
97
|
'total': total,
|
|
97
98
|
'page': page,
|
|
@@ -102,6 +103,8 @@ class BaseOrmManager(object):
|
|
|
102
103
|
# return pagination
|
|
103
104
|
else:
|
|
104
105
|
res = query.all()
|
|
106
|
+
if render:
|
|
107
|
+
res = [cls.render(item) for item in res]
|
|
105
108
|
# return res
|
|
106
109
|
|
|
107
110
|
if close_session_after_curd:
|
|
@@ -124,9 +127,10 @@ class BaseOrmManager(object):
|
|
|
124
127
|
@classmethod
|
|
125
128
|
def soft_delete_obj(cls, obj):
|
|
126
129
|
"""软删除obj信息"""
|
|
127
|
-
cls.
|
|
128
|
-
|
|
129
|
-
|
|
130
|
+
if hasattr(cls._model_cls, 'active'):
|
|
131
|
+
cls.update_obj(obj, **{
|
|
132
|
+
'active': False
|
|
133
|
+
})
|
|
130
134
|
return obj
|
|
131
135
|
|
|
132
136
|
@classmethod
|