agentic-kit-common 0.0.8__py3-none-any.whl → 0.0.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of agentic-kit-common might be problematic. Click here for more details.

@@ -0,0 +1,53 @@
1
+ from typing import List, Union
2
+
3
+ import sqlparse
4
+ from sqlalchemy import text
5
+ from sqlalchemy.orm import Session
6
+
7
+
8
+ def is_readonly_sql(sql: str) -> bool:
9
+ """
10
+ 只允许:
11
+ SELECT / WITH / VALUES / EXPLAIN / DESCRIBE / SHOW
12
+ 拒绝:
13
+ INSERT/UPDATE/DELETE/CREATE/ALTER/DROP/TRUNCATE/LOAD/REPLACE/LOCK/UNLOCK/GRANT/REVOKE/EXECUTE/CALL
14
+ """
15
+ sql_clean = sqlparse.format(sql.strip(), strip_comments=True)
16
+ tokens = [t.normalized for t in sqlparse.parse(sql_clean)[0].flatten()
17
+ if t.ttype in (sqlparse.tokens.Keyword, sqlparse.tokens.DML, sqlparse.tokens.DDL)]
18
+
19
+ forbidden = {"INSERT", "UPDATE", "DELETE", "CREATE", "ALTER", "DROP", "TRUNCATE",
20
+ "LOAD", "REPLACE", "LOCK", "UNLOCK", "GRANT", "REVOKE", "EXEC", "CALL"}
21
+ allowed_root = {"SELECT", "WITH", "EXPLAIN", "DESCRIBE", "SHOW", "VALUES"}
22
+
23
+ root = tokens[0].upper() if tokens else ""
24
+ if root not in allowed_root:
25
+ return False
26
+ if any(k in forbidden for k in tokens):
27
+ return False
28
+ return True
29
+
30
+
31
+ def session_sql_execute(db_session: Session, sql_text: Union[str, List], query_only: bool = True):
32
+ if isinstance(sql_text, str):
33
+ # 处理单条SQL语句
34
+ if query_only:
35
+ if not is_readonly_sql(sql_text):
36
+ return []
37
+ result = db_session.execute(text(f"{sql_text}"))
38
+ columns = list(result.keys())
39
+ rows = [dict(zip(columns, r)) for r in result.fetchall()]
40
+ return rows
41
+ elif isinstance(sql_text, list):
42
+ results = []
43
+ for sub_sql_text in sql_text:
44
+ if query_only:
45
+ if not is_readonly_sql(sub_sql_text):
46
+ continue
47
+ result = db_session.execute(text(f"{sub_sql_text}"))
48
+ columns = list(result.keys())
49
+ rows = [dict(zip(columns, r)) for r in result.fetchall()]
50
+ results.append(rows)
51
+ return results
52
+ else:
53
+ return []
@@ -0,0 +1,245 @@
1
+ import logging
2
+ from contextlib import contextmanager
3
+ from threading import Lock
4
+ from typing import Dict, Optional, Any, List
5
+
6
+ from sqlalchemy import create_engine, text
7
+ from sqlalchemy.exc import OperationalError
8
+ from sqlalchemy.orm import sessionmaker, scoped_session
9
+
10
+ from .schema import DatabaseEngineModel, _DEFAULT_CONFIG
11
+
12
+ # 配置日志
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class DatabaseEngineManager:
18
+ """数据库引擎管理器"""
19
+
20
+ def __init__(self, init_engines: Optional[List[DatabaseEngineModel]] = None):
21
+ self._engines: Dict[str, Any] = {}
22
+ self._sessions: Dict[str, Any] = {}
23
+ self._engine_configs: Dict[str, Dict] = {}
24
+ self._default_engine_name = ""
25
+ self._lock = Lock()
26
+
27
+ # 初始化默认配置
28
+ self._default_config = _DEFAULT_CONFIG.model_dump()
29
+
30
+ # 初始化引擎
31
+ if init_engines:
32
+ self._init_engines(init_engines=init_engines)
33
+
34
+ def _init_engines(self, init_engines: List[DatabaseEngineModel]):
35
+ for engine in init_engines:
36
+ self.add_engine(engine)
37
+
38
+ def _create_engine(self, database_uri: str, database_name: str, config: Optional[Dict] = None) -> Any:
39
+ """创建数据库引擎"""
40
+ if config is None:
41
+ config = self._default_config.copy()
42
+
43
+ # 构建完整的数据库URL
44
+ if database_uri.endswith('/'):
45
+ full_url = f"{database_uri}{database_name}"
46
+ else:
47
+ full_url = f"{database_uri}/{database_name}"
48
+
49
+ try:
50
+ engine = create_engine(
51
+ url=full_url,
52
+ echo=config.get("echo", False),
53
+ pool_size=config.get("pool_size", 10),
54
+ max_overflow=config.get("max_overflow", 5),
55
+ pool_pre_ping=config.get("pool_pre_ping", True),
56
+ pool_recycle=config.get("pool_recycle", 1800),
57
+ )
58
+ logger.info(f"成功创建引擎: {database_name} -> {database_uri}")
59
+ return engine
60
+ except Exception as e:
61
+ logger.error(f"创建引擎失败: {database_name}, 错误: {e}")
62
+ raise
63
+
64
+ def add_engine(self, engine_info: DatabaseEngineModel) -> bool:
65
+ """添加数据库引擎"""
66
+ with self._lock:
67
+ engine_name = engine_info.engine_name
68
+ database_uri = engine_info.database_uri
69
+ database_name = engine_info.database_name
70
+ config = engine_info.config.model_dump().copy()
71
+
72
+ if engine_name in self._engines:
73
+ logger.warning(f"引擎 '{engine_name}' 已存在")
74
+ return False
75
+
76
+ try:
77
+ # 创建引擎
78
+ engine = self._create_engine(database_uri, database_name, config)
79
+
80
+ # 创建会话工厂
81
+ session_factory = scoped_session(
82
+ sessionmaker(
83
+ bind=engine,
84
+ expire_on_commit=False,
85
+ autocommit=False,
86
+ autoflush=False
87
+ )
88
+ )
89
+
90
+ # 存储配置和实例
91
+ self._engines[engine_name] = engine
92
+ self._sessions[engine_name] = session_factory
93
+ self._engine_configs[engine_name] = {
94
+ "database_uri": database_uri,
95
+ "database_name": database_name,
96
+ "config": config or self._default_config.copy()
97
+ }
98
+
99
+ logger.info(f"成功添加引擎: {engine_name}")
100
+ return True
101
+
102
+ except Exception as e:
103
+ logger.error(f"添加引擎失败: {engine_name}, 错误: {e}")
104
+ return False
105
+
106
+ def remove_engine(self, engine_name: str) -> bool:
107
+ """移除数据库引擎"""
108
+ with self._lock:
109
+ if engine_name not in self._engines:
110
+ logger.warning(f"引擎 '{engine_name}' 不存在")
111
+ return False
112
+
113
+ try:
114
+ # 关闭所有会话
115
+ if engine_name in self._sessions:
116
+ self._sessions[engine_name].close_all()
117
+ del self._sessions[engine_name]
118
+
119
+ # 处置引擎
120
+ if engine_name in self._engines:
121
+ self._engines[engine_name].dispose()
122
+ del self._engines[engine_name]
123
+
124
+ # 移除配置
125
+ if engine_name in self._engine_configs:
126
+ del self._engine_configs[engine_name]
127
+
128
+ # 如果移除的是默认引擎
129
+ if engine_name == self._default_engine_name and self._engines:
130
+ self._default_engine_name = ''
131
+
132
+ logger.info(f"成功移除引擎: {engine_name}")
133
+ return True
134
+
135
+ except Exception as e:
136
+ logger.error(f"移除引擎失败: {engine_name}, 错误: {e}")
137
+ return False
138
+
139
+ def get_engine(self, engine_name: Optional[str] = None) -> Any:
140
+ """获取数据库引擎"""
141
+ if engine_name is None:
142
+ engine_name = self._default_engine_name
143
+
144
+ if engine_name not in self._engines:
145
+ raise ValueError(f"数据库引擎 '{engine_name}' 不存在")
146
+
147
+ return self._engines[engine_name]
148
+
149
+ def get_session_factory(self, engine_name: Optional[str] = None) -> Any:
150
+ """获取会话工厂"""
151
+ if engine_name is None:
152
+ engine_name = self._default_engine_name
153
+
154
+ if engine_name not in self._sessions:
155
+ raise ValueError(f"数据库引擎 '{engine_name}' 不存在")
156
+
157
+ return self._sessions[engine_name]
158
+
159
+ @contextmanager
160
+ def get_db_session(self,
161
+ engine_name: Optional[str] = None,
162
+ auto_commit_by_exit: bool = False,
163
+ auto_close: bool = True):
164
+ """获取数据库会话(上下文管理器)"""
165
+ if engine_name is None:
166
+ engine_name = self._default_engine_name
167
+
168
+ if engine_name not in self._sessions:
169
+ raise ValueError(f"数据库引擎 '{engine_name}' 不存在")
170
+
171
+ session_factory = self._sessions[engine_name]
172
+ session = session_factory()
173
+
174
+ try:
175
+ yield session
176
+ if auto_commit_by_exit:
177
+ session.commit()
178
+ except OperationalError as e:
179
+ logger.warning(f"数据库连接异常,尝试重新连接: {e}")
180
+ session.rollback()
181
+ # 重新创建会话
182
+ session.close()
183
+ session = session_factory()
184
+ yield session
185
+ if auto_commit_by_exit:
186
+ session.commit()
187
+ except Exception as e:
188
+ session.rollback()
189
+ logger.error(f"数据库操作异常: {e}")
190
+ raise e
191
+ finally:
192
+ if auto_close:
193
+ session.close()
194
+
195
+ def set_default_engine(self, engine_name: str) -> bool:
196
+ """设置默认引擎"""
197
+ with self._lock:
198
+ if engine_name not in self._engines:
199
+ logger.warning(f"引擎 '{engine_name}' 不存在")
200
+ return False
201
+
202
+ self._default_engine_name = engine_name
203
+ logger.info(f"设置默认引擎为: {engine_name}")
204
+ return True
205
+
206
+ def list_engines(self) -> List[str]:
207
+ """列出所有引擎名称"""
208
+ return list(self._engines.keys())
209
+
210
+ def get_engine_info(self, engine_name: str) -> Optional[Dict]:
211
+ """获取引擎信息"""
212
+ if engine_name not in self._engine_configs:
213
+ return None
214
+
215
+ info = self._engine_configs[engine_name].copy()
216
+ info["is_default"] = (engine_name == self._default_engine_name)
217
+ return info
218
+
219
+ def health_check(self, engine_name: Optional[str] = None) -> Dict[str, bool]:
220
+ """健康检查"""
221
+ results = {}
222
+
223
+ if engine_name:
224
+ engines_to_check = [engine_name] if engine_name in self._engines else []
225
+ else:
226
+ engines_to_check = list(self._engines.keys())
227
+
228
+ for name in engines_to_check:
229
+ try:
230
+ with self.get_db_session(name, auto_close=True) as session:
231
+ session.execute(text("SELECT 1"))
232
+ results[name] = True
233
+ logger.debug(f"健康检查通过: {name}")
234
+ except Exception as e:
235
+ results[name] = False
236
+ logger.error(f"健康检查失败: {name}, 错误: {e}")
237
+
238
+ return results
239
+
240
+
241
+ # 初始化数据库引擎管理器
242
+ def initialize_database_engine_manager(init_engines: Optional[List[DatabaseEngineModel]] = None):
243
+ """初始化默认数据库"""
244
+ manager = DatabaseEngineManager(init_engines=init_engines)
245
+ return manager
@@ -0,0 +1,19 @@
1
+ from pydantic import BaseModel, Field
2
+
3
+
4
+ class DatabaseEngineConfigModel(BaseModel):
5
+ echo: bool = Field(False, description="数据库的 SQL 语句 + 参数 打印到标准输出")
6
+ pool_pre_ping: bool = Field(True, description="每次从池里拿连接前,先发一句做“心跳”")
7
+ pool_size: int = Field(10, description="连接池里 长期保持的“永久”连接 数量")
8
+ max_overflow: int = Field(5, description="当 pool_size 用光后,最多还能再新建多少条“临时”连接")
9
+ pool_recycle: int = Field(1800, description="一条连接被 复用多久之后强制回收(关闭并新建)")
10
+
11
+
12
+ _DEFAULT_CONFIG = DatabaseEngineConfigModel()
13
+
14
+
15
+ class DatabaseEngineModel(BaseModel):
16
+ engine_name: str = Field(..., description="数据库引擎名字")
17
+ database_uri: str = Field(..., description="数据库uir地址")
18
+ database_name: str = Field(..., description="数据库名称")
19
+ config: DatabaseEngineConfigModel = Field(_DEFAULT_CONFIG, description="配置信息")
@@ -1,6 +1,7 @@
1
1
  import abc
2
2
 
3
3
  from langchain_community.embeddings import XinferenceEmbeddings
4
+ from langchain_openai import OpenAIEmbeddings
4
5
 
5
6
 
6
7
  class EmbeddingFactoryBase:
@@ -26,14 +27,28 @@ class XinferenceEmbeddingFactory(EmbeddingFactoryBase):
26
27
  return embedding
27
28
 
28
29
 
30
+ class VllmEmbeddingFactory(EmbeddingFactoryBase):
31
+ provider: str = 'vllm'
32
+
33
+ @classmethod
34
+ def create_embedding(cls, base_url: str, model_uid: str, api_key: str = 'api_key', dims: int = 1024):
35
+ embedding = OpenAIEmbeddings(
36
+ openai_api_key=api_key, # vLLM 不校验 key
37
+ openai_api_base=base_url,
38
+ model=model_uid,
39
+ )
40
+ return embedding
41
+
42
+
29
43
  _FACTORY_LIST = [
30
- XinferenceEmbeddingFactory
44
+ XinferenceEmbeddingFactory,
45
+ VllmEmbeddingFactory,
31
46
  ]
32
47
 
33
48
 
34
49
  class EmbeddingFactory:
35
50
  @classmethod
36
- def create_embedding(cls, base_url: str, model_uid: str, api_key: str = '', dims: int = 1024, provider: str = 'xinference'):
51
+ def create_embedding(cls, base_url: str, model_uid: str, api_key: str = 'api_key', dims: int = 1024, provider: str = 'xinference'):
37
52
  for factory in _FACTORY_LIST:
38
53
  if factory.provider == provider:
39
54
  return factory.create_embedding(base_url=base_url, model_uid=model_uid, api_key=api_key, dims=dims)
@@ -4,7 +4,7 @@ from typing import List, Dict
4
4
  from langchain_core.embeddings import Embeddings
5
5
  from pydantic import BaseModel
6
6
  from pymilvus import connections, Collection, CollectionSchema, FieldSchema
7
- from pymilvus.orm import utility
7
+ from pymilvus import db, utility
8
8
 
9
9
  from ..schema import default_search_params, default_index_params_auto, default_index_params_vector, default_fields, \
10
10
  default_output_fields, default_search_field, default_query_fields
@@ -18,14 +18,6 @@ class MilvusManager:
18
18
 
19
19
  collection_name: str
20
20
 
21
- # search_field: str
22
- #
23
- # query_fields: List[str]
24
- #
25
- # output_fields: List[str]
26
- #
27
- # fields: List[FieldSchema]
28
- #
29
21
  model_cls: BaseModel
30
22
 
31
23
  @classmethod
@@ -33,9 +25,9 @@ class MilvusManager:
33
25
  cls,
34
26
  embed_model: Embeddings,
35
27
  vector_store_uri: str,
36
- # database_name: str,
37
- # collection_name: str,
38
28
  # model_cls: BaseModel,
29
+ database_name: str = None,
30
+ collection_name: str = None,
39
31
  search_field: str = None,
40
32
  query_fields: List[str] = None,
41
33
  output_fields: List[str] = None,
@@ -62,8 +54,8 @@ class MilvusManager:
62
54
  return cls(
63
55
  embed_model=embed_model,
64
56
  vector_store_uri=vector_store_uri,
65
- # database_name=database_name,
66
- # collection_name=collection_name,
57
+ database_name=database_name,
58
+ collection_name=collection_name,
67
59
  search_field=search_field,
68
60
  query_fields=query_fields,
69
61
  output_fields=output_fields,
@@ -78,8 +70,8 @@ class MilvusManager:
78
70
  self,
79
71
  embed_model: Embeddings,
80
72
  vector_store_uri: str,
81
- # database_name: str,
82
- # collection_name: str,
73
+ database_name: str,
74
+ collection_name: str,
83
75
  search_field: str,
84
76
  query_fields: List[str],
85
77
  output_fields: List[str],
@@ -89,8 +81,15 @@ class MilvusManager:
89
81
  index_params_auto: Dict,
90
82
  search_params: Dict,
91
83
  ):
92
- # self.database_name = database_name
93
- # self.collection_name = collection_name
84
+ if database_name is None:
85
+ self.database_name = self.__class__.database_name
86
+ else:
87
+ self.database_name = database_name
88
+ if collection_name is None:
89
+ self.collection_name = self.__class__.collection_name
90
+ else:
91
+ self.collection_name = collection_name
92
+
94
93
  self.embed_model = embed_model
95
94
  self.vector_store_uri = vector_store_uri
96
95
  self.search_field = search_field
@@ -106,6 +105,16 @@ class MilvusManager:
106
105
  self.do_init()
107
106
 
108
107
  def do_init(self):
108
+ # 1. 先连到系统默认库,才能操作 database 级别
109
+ res = connections.connect(alias="default", uri=self.vector_store_uri) # 不指定 db_name 即连 default
110
+ # 2. 验证连接成功
111
+ if not connections.has_connection("default"):
112
+ raise RuntimeError("Milvus 连接失败")
113
+ # 3. 判断并创建目标库
114
+ if self.database_name not in db.list_database():
115
+ db.create_database(self.database_name) # pymilvus ≥2.2.9
116
+ print(f"database {self.database_name} created")
117
+
109
118
  res = connections.connect(
110
119
  alias="default",
111
120
  db_name=self.database_name,
@@ -201,13 +210,15 @@ class MilvusManager:
201
210
  def insert(self, items: list[dict]):
202
211
  uids = []
203
212
  vectors = []
213
+ owner_uids = []
204
214
  for item in items:
205
215
  if 'uid' not in item or 'content' not in item:
206
216
  raise Exception(f'insert items must include [uid, content], illegal for {item}')
207
217
  uids.append(item['uid'])
218
+ owner_uids.append(item['owner_uid'])
208
219
  vectors.append(self.embed_model.embed_query(item['content']))
209
220
 
210
- insert_fields = [uids, vectors]
221
+ insert_fields = [uids, owner_uids, vectors]
211
222
  insert_result = self.collection.insert(insert_fields)
212
223
  self.collection.flush()
213
224
  return insert_result
@@ -3,11 +3,12 @@ from pymilvus import FieldSchema, DataType
3
3
  default_fields = [
4
4
  FieldSchema(name="pk", dtype=DataType.INT64, is_primary=True, auto_id=True),
5
5
  FieldSchema(name="uid", dtype=DataType.VARCHAR, max_length=36),
6
+ FieldSchema(name="owner_uid", dtype=DataType.VARCHAR, max_length=64),
6
7
  FieldSchema(name="vectors", dtype=DataType.FLOAT_VECTOR, dim=1024),
7
8
  ]
8
- default_output_fields = ['pk', 'uid']
9
+ default_output_fields = ['pk', 'uid', 'owner_uid']
9
10
  default_search_field = "vectors"
10
- default_query_fields = ['uid']
11
+ default_query_fields = ['uid', 'owner_uid']
11
12
 
12
13
  default_index_params_vector = {
13
14
  'metric_type': 'COSINE', # IP 或 COSINE
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: agentic-kit-common
3
- Version: 0.0.8
3
+ Version: 0.0.10
4
4
  Summary: Common utilities and tools for agentic kit ecosystem
5
5
  Home-page:
6
6
  Author: manson
@@ -3,26 +3,30 @@ agentic_kit_common/minio/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZ
3
3
  agentic_kit_common/minio/minio_manager.py,sha256=PzoUWn0YqXTHx1UClbkwLkIUmv5O4aSDc7eVf08qOzs,21494
4
4
  agentic_kit_common/orm/__init__.py,sha256=wqY81g3P7FftFvLK5SaxFJzNNxrQwtmcb4RCXOSAZa8,71
5
5
  agentic_kit_common/orm/base.py,sha256=QIura_i2nIY2XeA3-KkO2loLNbEAoJK2qx0hu_8nhYU,2277
6
+ agentic_kit_common/orm/execution.py,sha256=sCEcInFGFZ9ZiIs0eHKbQPtN4Z02v7vGdFMyVEWygcU,1951
6
7
  agentic_kit_common/orm/manager.py,sha256=lWgFk5fUu_9m6yN_fskWHSYGaW30ty--KZKj8AuvIh0,4852
8
+ agentic_kit_common/orm/multi_session.py,sha256=CAEnOLl0I-r77JAknKQ2sERF40OkHD7BOrYVYoHuQV4,8814
9
+ agentic_kit_common/orm/schema.py,sha256=ukdVP71NE_JO5_HOe_FApJjBAKUsDluoElVn5vNSGJQ,1023
7
10
  agentic_kit_common/orm/session.py,sha256=LX4ZUKJNXdcQ0KqRt5L0pJX-QG-tDyKXEDijcPUkGD0,2716
8
11
  agentic_kit_common/vector/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
12
  agentic_kit_common/vector/embedding/__init__.py,sha256=mp--cfCDzTn5hNxzcIAU4m0g4F0d0rrZErFVWsGDvx4,40
10
- agentic_kit_common/vector/embedding/embedding.py,sha256=k1hHNkBcucu6lSe2hYSrCal-ML26Eo6Bfe8cVkXSww8,1380
13
+ agentic_kit_common/vector/embedding/embedding.py,sha256=ZEvl2w2-PrAph60u2wJItj4lb68XxK805tRzOp49MdI,1861
11
14
  agentic_kit_common/vector/manager/__init__.py,sha256=w2uAmKGRx9Nv3QySXIAgxQGlYR549AQvbBuBFsQttKI,42
12
- agentic_kit_common/vector/manager/milvus_manager.py,sha256=HaoKKsyMOfmA3jX45XEVhmK6SprjTmwULOyhCqqAIGg,7840
15
+ agentic_kit_common/vector/manager/milvus_manager.py,sha256=gDwWlkrnIUFqMoDNsvqnPAKTyAUnyxD-fh7jy93LcQY,8568
13
16
  agentic_kit_common/vector/schema/__init__.py,sha256=F8WCi4ybnujFW-qHcDI9KOqzw7i8r07dmBUO2R2_Iu0,228
14
17
  agentic_kit_common/vector/schema/base.py,sha256=oo5OSFCX9UeQuWskiebqctF1dyXISD7_czzoKYoHnKg,102
15
- agentic_kit_common/vector/schema/milvus_schema.py,sha256=V8igxe6_FmZ46jicNdKxBbf0vl3bwTC4AyRf9gqB_VA,684
18
+ agentic_kit_common/vector/schema/milvus_schema.py,sha256=4kIUtIIE2vcKP1s1Ghn_eikuzDb1lsSeAaTwrsNzUlQ,784
16
19
  agentic_kit_common/web/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
20
  agentic_kit_common/web/http/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
18
21
  agentic_kit_common/web/http/response.py,sha256=uxd_MRHsFfQ0pUeHITDQD8tuOY6fo6IJ0MO6uL62AlM,1347
19
22
  test/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
20
23
  test/config.py,sha256=VMcvfWnuWZKr_p0rC57i90jTq-XVsRuVzfEIe1C7Drg,3804
21
24
  test/settings.py,sha256=vU1dqUvGzshi6MG7JbGfW4jKfUsfAObgUfky-LMjUNs,344
22
- test/test_embedding.py,sha256=78M6t9tt6-oCfdQZfOartCIkUFZMvogcu5UKJh1lJb8,609
25
+ test/test_embedding.py,sha256=xcfHDHGL2_tpXp_VaLnNDAWvKdkc1K1BjiYoV5WxtFY,900
23
26
  test/test_minio.py,sha256=TOkX8A2pPkjrwAIH88xBZmFpDc1ZgTk1QS6mtubOZ-Y,2308
24
- test/test_vector.py,sha256=4sJ7bXt5iRhc6j1Ig0bZ1j0huPU-clnsnT7O0-uFbkw,1311
25
- agentic_kit_common-0.0.8.dist-info/METADATA,sha256=mLi0procHOdh1DiCFCstlsD1RM-wAMWEKFpHTRMbLxo,7496
26
- agentic_kit_common-0.0.8.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
27
- agentic_kit_common-0.0.8.dist-info/top_level.txt,sha256=nEKDlp84vqKSVWssGcxyuIsTqWLhMo45xqMs2GK4Dgg,24
28
- agentic_kit_common-0.0.8.dist-info/RECORD,,
27
+ test/test_orm.py,sha256=8fGCU7BWaD5sDbg0fgYN0Saf_hi7t-q8svHCHLQDceo,1620
28
+ test/test_vector.py,sha256=Z3Bwvrw0XGBWXHYXk9pkF1cjnUZSZk0gmLsgQIIaZuY,1717
29
+ agentic_kit_common-0.0.10.dist-info/METADATA,sha256=KIp7HPRyzygUBrXu8H9B4xu8NOFiD6_B5PTaYVN20Fc,7497
30
+ agentic_kit_common-0.0.10.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
31
+ agentic_kit_common-0.0.10.dist-info/top_level.txt,sha256=nEKDlp84vqKSVWssGcxyuIsTqWLhMo45xqMs2GK4Dgg,24
32
+ agentic_kit_common-0.0.10.dist-info/RECORD,,
test/test_embedding.py CHANGED
@@ -6,13 +6,22 @@ project_root = Path(__file__).parent.parent
6
6
  sys.path.insert(0, str(project_root))
7
7
 
8
8
  from agentic_kit_common.vector.embedding import EmbeddingFactory
9
+ from langchain_openai import OpenAIEmbeddings
9
10
  from .settings import global_settings
10
11
 
11
12
 
12
13
  class MyTestCase(unittest.TestCase):
13
14
  def test_embedding(self):
14
- embedding = EmbeddingFactory.create_embedding(base_url=global_settings.embedding_base_url, model_uid=global_settings.embedding_model_uid)
15
+ # embedding = EmbeddingFactory.create_embedding(base_url=global_settings.embedding_base_url, model_uid=global_settings.embedding_model_uid)
16
+
17
+ embedding = EmbeddingFactory.create_embedding(
18
+ base_url=global_settings.embedding_base_url,
19
+ model_uid=global_settings.embedding_model_uid,
20
+ provider='vllm'
21
+ )
22
+
15
23
  text = '你好'
24
+ # text = 'hello, world'
16
25
  text_vector = embedding.embed_query(text=text)
17
26
  print(text_vector)
18
27
 
test/test_orm.py ADDED
@@ -0,0 +1,47 @@
1
+ import sys
2
+ import unittest
3
+ from pathlib import Path
4
+
5
+ from agentic_kit_common.orm.execution import session_sql_execute
6
+ from agentic_kit_common.orm.multi_session import initialize_database_engine_manager
7
+ from agentic_kit_common.orm.schema import DatabaseEngineModel
8
+
9
+ project_root = Path(__file__).parent.parent
10
+ sys.path.insert(0, str(project_root))
11
+
12
+
13
+ class MyTestCase(unittest.TestCase):
14
+ def test_multi_session(self):
15
+ engine = DatabaseEngineModel(
16
+ engine_name='default',
17
+ database_name='czailab_llm',
18
+ database_uri='mysql+mysqlconnector://ailab_dev:Qwert!%40%234@45.120.102.236'
19
+ )
20
+ # manager = initialize_database_engine_manager(init_engines=[engine])
21
+
22
+ manager = initialize_database_engine_manager()
23
+ manager.add_engine(engine_info=engine)
24
+
25
+ print(manager.get_engine('default'))
26
+ print(manager.get_engine_info('default'))
27
+ print(manager.list_engines())
28
+
29
+ # session = manager.get_db_session(engine_name='default')
30
+ with manager.get_db_session(engine_name='default') as session:
31
+ print(session)
32
+ print(type(session))
33
+
34
+ result = session_sql_execute(db_session=session, sql_text='select * from tenant; select * from rerank_model')
35
+ print(result)
36
+ print(type(result))
37
+ for item in result:
38
+ print(item)
39
+ print(item.keys())
40
+ print(item.values())
41
+ print(item.items())
42
+ for _item in item.items():
43
+ print(_item)
44
+
45
+
46
+ if __name__ == '__main__':
47
+ unittest.main()
test/test_vector.py CHANGED
@@ -17,11 +17,11 @@ from .settings import global_settings
17
17
 
18
18
  logging.basicConfig(level=logging.DEBUG)
19
19
 
20
- embedding = EmbeddingFactory.create_embedding(base_url=global_settings.embedding_base_url, model_uid=global_settings.embedding_model_uid)
20
+ embedding = EmbeddingFactory.create_embedding(base_url=global_settings.embedding_base_url, model_uid=global_settings.embedding_model_uid, provider='vllm')
21
21
 
22
22
 
23
23
  class McpServerManager(MilvusManager):
24
- database_name: str = 'skill_center'
24
+ database_name: str = 'nl2sql'
25
25
 
26
26
  collection_name: str = 'mcp_server'
27
27
 
@@ -31,8 +31,16 @@ class McpServerManager(MilvusManager):
31
31
  class MyTestCase(unittest.TestCase):
32
32
  def test_vector(self):
33
33
  manager = McpServerManager.create(embed_model=embedding, vector_store_uri=global_settings.milvus_url)
34
+ # manager = MilvusManager.create(embed_model=embedding, vector_store_uri=global_settings.milvus_url, database_name='nl2sql', collection_name='mcp_server2')
35
+ # manager.insert([{
36
+ # 'uid': 'xxxx',
37
+ # 'content': 'hello',
38
+ # 'owner_uid': 'xxxx2'
39
+ # }])
34
40
 
35
- manager.query()
41
+ owner_uid = 'xxxx2'
42
+ res = manager.query(expr=f'owner_uid == "{owner_uid}"')
43
+ print(res)
36
44
 
37
45
  # manager.search('你好')
38
46
 
@@ -43,9 +51,9 @@ class MyTestCase(unittest.TestCase):
43
51
  # 'content': 'hello',
44
52
  # }])
45
53
 
46
- manager.delete_by_uid(uid='2050d9a9-b125-42b9-80ce-bd9eee89d9eb')
54
+ # manager.delete_by_uid(uid='2050d9a9-b125-42b9-80ce-bd9eee89d9eb')
47
55
 
48
- manager.query()
56
+ # manager.query()
49
57
 
50
58
 
51
59
  if __name__ == '__main__':