agentic-kit-common 0.0.9__py3-none-any.whl → 0.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of agentic-kit-common might be problematic. Click here for more details.

File without changes
@@ -0,0 +1,16 @@
1
+ import logging
2
+
3
+
4
+ def setup_logger(tag, level):
5
+ logger = logging.getLogger(tag)
6
+ logger.setLevel(level)
7
+
8
+ if not logger.handlers:
9
+ handler = logging.StreamHandler()
10
+ formatter = logging.Formatter(
11
+ '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
12
+ )
13
+ handler.setFormatter(formatter)
14
+ logger.addHandler(handler)
15
+
16
+ return logger
@@ -1,53 +1,90 @@
1
- from typing import List, Union
1
+ from typing import List, Union, Optional, Set
2
2
 
3
- import sqlparse
3
+ import sqlglot
4
4
  from sqlalchemy import text
5
5
  from sqlalchemy.orm import Session
6
+ from sqlglot import expressions
6
7
 
7
8
 
8
- def is_readonly_sql(sql: str) -> bool:
9
+ def get_operation_type(statement) -> str:
10
+ """获取SQL操作类型"""
11
+ return type(statement).__name__.upper()
12
+
13
+
14
+ def get_sql_operation_info(sql: str) -> Optional[dict]:
9
15
  """
10
- 只允许:
11
- SELECT / WITH / VALUES / EXPLAIN / DESCRIBE / SHOW
12
- 拒绝:
13
- INSERT/UPDATE/DELETE/CREATE/ALTER/DROP/TRUNCATE/LOAD/REPLACE/LOCK/UNLOCK/GRANT/REVOKE/EXECUTE/CALL
16
+ 获取SQL操作信息(用于调试和日志记录)
17
+
18
+ Returns:
19
+ dict: 包含操作类型、涉及表等信息
14
20
  """
15
- sql_clean = sqlparse.format(sql.strip(), strip_comments=True)
16
- tokens = [t.normalized for t in sqlparse.parse(sql_clean)[0].flatten()
17
- if t.ttype in (sqlparse.tokens.Keyword, sqlparse.tokens.DML, sqlparse.tokens.DDL)]
21
+ try:
22
+ statement = sqlglot.parse_one(sql, read="mysql")
23
+ operation_type = get_operation_type(statement)
24
+
25
+ tables = []
26
+ for table in statement.find_all(expressions.Table):
27
+ table_name = table.name
28
+ if table_name.startswith('`') and table_name.endswith('`'):
29
+ table_name = table_name[1:-1]
30
+ tables.append(table_name)
31
+
32
+ return {
33
+ 'operation': operation_type,
34
+ 'tables': tables,
35
+ 'has_wildcard': bool(statement.find(expressions.Star)) if isinstance(statement, expressions.Select) else False
36
+ }
37
+ except Exception as e:
38
+ print(f"获取SQL操作信息失败: {e}")
39
+ return None
40
+
41
+
42
+ def is_readonly_expression(node: expressions.Expression, allowed_operations: Optional[Set[str]] = None, enable_wildcard_check: bool = True):
43
+ """递归检查表达式树中是否出现写操作节点"""
44
+ # 写操作黑名单节点类型
45
+ write_types = {
46
+ expressions.Insert,
47
+ expressions.Update,
48
+ expressions.Delete,
49
+ expressions.Create,
50
+ expressions.Alter,
51
+ expressions.Drop,
52
+ expressions.Replace,
53
+ expressions.Merge
54
+ }
55
+ for descendant in node.walk():
56
+ if type(descendant) in write_types:
57
+ return False, type(descendant)
58
+
59
+ if allowed_operations:
60
+ op_type = get_operation_type(node)
61
+ # 检查操作类型
62
+ if op_type not in allowed_operations:
63
+ return False, op_type
64
+
65
+ if enable_wildcard_check and isinstance(node, expressions.Select) and node.find(expressions.Star):
66
+ return False, expressions.Star
18
67
 
19
- forbidden = {"INSERT", "UPDATE", "DELETE", "CREATE", "ALTER", "DROP", "TRUNCATE",
20
- "LOAD", "REPLACE", "LOCK", "UNLOCK", "GRANT", "REVOKE", "EXEC", "CALL"}
21
- allowed_root = {"SELECT", "WITH", "EXPLAIN", "DESCRIBE", "SHOW", "VALUES"}
68
+ return True, None
22
69
 
23
- root = tokens[0].upper() if tokens else ""
24
- if root not in allowed_root:
25
- return False
26
- if any(k in forbidden for k in tokens):
27
- return False
28
- return True
29
70
 
71
+ def session_sql_execute(db_session: Session, sql_text: Union[str, List], format_result: bool = True):
72
+ def __do_execute(_sql: str):
73
+ _result = db_session.execute(text(f"{sql_text}"))
74
+ if format_result:
75
+ columns = list(_result.keys())
76
+ rows = [dict(zip(columns, r)) for r in _result.fetchall()]
77
+ return rows
78
+ else:
79
+ return _result
30
80
 
31
- def session_sql_execute(db_session: Session, sql_text: Union[str, List], query_only: bool = True):
32
81
  if isinstance(sql_text, str):
33
- # 处理单条SQL语句
34
- if query_only:
35
- if not is_readonly_sql(sql_text):
36
- return []
37
- result = db_session.execute(text(f"{sql_text}"))
38
- columns = list(result.keys())
39
- rows = [dict(zip(columns, r)) for r in result.fetchall()]
40
- return rows
82
+ return __do_execute(sql_text)
41
83
  elif isinstance(sql_text, list):
42
84
  results = []
43
85
  for sub_sql_text in sql_text:
44
- if query_only:
45
- if not is_readonly_sql(sub_sql_text):
46
- continue
47
- result = db_session.execute(text(f"{sub_sql_text}"))
48
- columns = list(result.keys())
49
- rows = [dict(zip(columns, r)) for r in result.fetchall()]
50
- results.append(rows)
86
+ result = __do_execute(sub_sql_text)
87
+ results.append(result)
51
88
  return results
52
89
  else:
53
90
  return []
@@ -1,6 +1,7 @@
1
1
  import abc
2
2
 
3
3
  from langchain_community.embeddings import XinferenceEmbeddings
4
+ from langchain_openai import OpenAIEmbeddings
4
5
 
5
6
 
6
7
  class EmbeddingFactoryBase:
@@ -26,14 +27,28 @@ class XinferenceEmbeddingFactory(EmbeddingFactoryBase):
26
27
  return embedding
27
28
 
28
29
 
30
+ class VllmEmbeddingFactory(EmbeddingFactoryBase):
31
+ provider: str = 'vllm'
32
+
33
+ @classmethod
34
+ def create_embedding(cls, base_url: str, model_uid: str, api_key: str = 'api_key', dims: int = 1024):
35
+ embedding = OpenAIEmbeddings(
36
+ openai_api_key=api_key, # vLLM 不校验 key
37
+ openai_api_base=base_url,
38
+ model=model_uid,
39
+ )
40
+ return embedding
41
+
42
+
29
43
  _FACTORY_LIST = [
30
- XinferenceEmbeddingFactory
44
+ XinferenceEmbeddingFactory,
45
+ VllmEmbeddingFactory,
31
46
  ]
32
47
 
33
48
 
34
49
  class EmbeddingFactory:
35
50
  @classmethod
36
- def create_embedding(cls, base_url: str, model_uid: str, api_key: str = '', dims: int = 1024, provider: str = 'xinference'):
51
+ def create_embedding(cls, base_url: str, model_uid: str, api_key: str = 'api_key', dims: int = 1024, provider: str = 'xinference'):
37
52
  for factory in _FACTORY_LIST:
38
53
  if factory.provider == provider:
39
54
  return factory.create_embedding(base_url=base_url, model_uid=model_uid, api_key=api_key, dims=dims)
@@ -4,7 +4,7 @@ from typing import List, Dict
4
4
  from langchain_core.embeddings import Embeddings
5
5
  from pydantic import BaseModel
6
6
  from pymilvus import connections, Collection, CollectionSchema, FieldSchema
7
- from pymilvus.orm import utility
7
+ from pymilvus import db, utility
8
8
 
9
9
  from ..schema import default_search_params, default_index_params_auto, default_index_params_vector, default_fields, \
10
10
  default_output_fields, default_search_field, default_query_fields
@@ -18,14 +18,6 @@ class MilvusManager:
18
18
 
19
19
  collection_name: str
20
20
 
21
- # search_field: str
22
- #
23
- # query_fields: List[str]
24
- #
25
- # output_fields: List[str]
26
- #
27
- # fields: List[FieldSchema]
28
- #
29
21
  model_cls: BaseModel
30
22
 
31
23
  @classmethod
@@ -33,9 +25,9 @@ class MilvusManager:
33
25
  cls,
34
26
  embed_model: Embeddings,
35
27
  vector_store_uri: str,
36
- # database_name: str,
37
- # collection_name: str,
38
28
  # model_cls: BaseModel,
29
+ database_name: str = None,
30
+ collection_name: str = None,
39
31
  search_field: str = None,
40
32
  query_fields: List[str] = None,
41
33
  output_fields: List[str] = None,
@@ -62,8 +54,8 @@ class MilvusManager:
62
54
  return cls(
63
55
  embed_model=embed_model,
64
56
  vector_store_uri=vector_store_uri,
65
- # database_name=database_name,
66
- # collection_name=collection_name,
57
+ database_name=database_name,
58
+ collection_name=collection_name,
67
59
  search_field=search_field,
68
60
  query_fields=query_fields,
69
61
  output_fields=output_fields,
@@ -78,8 +70,8 @@ class MilvusManager:
78
70
  self,
79
71
  embed_model: Embeddings,
80
72
  vector_store_uri: str,
81
- # database_name: str,
82
- # collection_name: str,
73
+ database_name: str,
74
+ collection_name: str,
83
75
  search_field: str,
84
76
  query_fields: List[str],
85
77
  output_fields: List[str],
@@ -89,8 +81,15 @@ class MilvusManager:
89
81
  index_params_auto: Dict,
90
82
  search_params: Dict,
91
83
  ):
92
- # self.database_name = database_name
93
- # self.collection_name = collection_name
84
+ if database_name is None:
85
+ self.database_name = self.__class__.database_name
86
+ else:
87
+ self.database_name = database_name
88
+ if collection_name is None:
89
+ self.collection_name = self.__class__.collection_name
90
+ else:
91
+ self.collection_name = collection_name
92
+
94
93
  self.embed_model = embed_model
95
94
  self.vector_store_uri = vector_store_uri
96
95
  self.search_field = search_field
@@ -106,6 +105,16 @@ class MilvusManager:
106
105
  self.do_init()
107
106
 
108
107
  def do_init(self):
108
+ # 1. 先连到系统默认库,才能操作 database 级别
109
+ res = connections.connect(alias="default", uri=self.vector_store_uri) # 不指定 db_name 即连 default
110
+ # 2. 验证连接成功
111
+ if not connections.has_connection("default"):
112
+ raise RuntimeError("Milvus 连接失败")
113
+ # 3. 判断并创建目标库
114
+ if self.database_name not in db.list_database():
115
+ db.create_database(self.database_name) # pymilvus ≥2.2.9
116
+ print(f"database {self.database_name} created")
117
+
109
118
  res = connections.connect(
110
119
  alias="default",
111
120
  db_name=self.database_name,
@@ -201,13 +210,15 @@ class MilvusManager:
201
210
  def insert(self, items: list[dict]):
202
211
  uids = []
203
212
  vectors = []
213
+ owner_uids = []
204
214
  for item in items:
205
215
  if 'uid' not in item or 'content' not in item:
206
216
  raise Exception(f'insert items must include [uid, content], illegal for {item}')
207
217
  uids.append(item['uid'])
218
+ owner_uids.append(item['owner_uid'])
208
219
  vectors.append(self.embed_model.embed_query(item['content']))
209
220
 
210
- insert_fields = [uids, vectors]
221
+ insert_fields = [uids, owner_uids, vectors]
211
222
  insert_result = self.collection.insert(insert_fields)
212
223
  self.collection.flush()
213
224
  return insert_result
@@ -3,11 +3,12 @@ from pymilvus import FieldSchema, DataType
3
3
  default_fields = [
4
4
  FieldSchema(name="pk", dtype=DataType.INT64, is_primary=True, auto_id=True),
5
5
  FieldSchema(name="uid", dtype=DataType.VARCHAR, max_length=36),
6
+ FieldSchema(name="owner_uid", dtype=DataType.VARCHAR, max_length=64),
6
7
  FieldSchema(name="vectors", dtype=DataType.FLOAT_VECTOR, dim=1024),
7
8
  ]
8
- default_output_fields = ['pk', 'uid']
9
+ default_output_fields = ['pk', 'uid', 'owner_uid']
9
10
  default_search_field = "vectors"
10
- default_query_fields = ['uid']
11
+ default_query_fields = ['uid', 'owner_uid']
11
12
 
12
13
  default_index_params_vector = {
13
14
  'metric_type': 'COSINE', # IP 或 COSINE
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: agentic-kit-common
3
- Version: 0.0.9
3
+ Version: 0.0.11
4
4
  Summary: Common utilities and tools for agentic kit ecosystem
5
5
  Home-page:
6
6
  Author: manson
@@ -26,6 +26,7 @@ Requires-Dist: langchain_community
26
26
  Requires-Dist: langchain_experimental
27
27
  Requires-Dist: mysql-connector-python
28
28
  Requires-Dist: sqlalchemy
29
+ Requires-Dist: sqlglot
29
30
  Requires-Dist: pymilvus
30
31
  Requires-Dist: xinference_client
31
32
  Dynamic: author
@@ -1,32 +1,34 @@
1
1
  agentic_kit_common/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ agentic_kit_common/log/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
+ agentic_kit_common/log/logger.py,sha256=EH8pOW6KVpb-c4RdK2sohpHgfs3_hBcaAletRs1O23k,391
2
4
  agentic_kit_common/minio/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
5
  agentic_kit_common/minio/minio_manager.py,sha256=PzoUWn0YqXTHx1UClbkwLkIUmv5O4aSDc7eVf08qOzs,21494
4
6
  agentic_kit_common/orm/__init__.py,sha256=wqY81g3P7FftFvLK5SaxFJzNNxrQwtmcb4RCXOSAZa8,71
5
7
  agentic_kit_common/orm/base.py,sha256=QIura_i2nIY2XeA3-KkO2loLNbEAoJK2qx0hu_8nhYU,2277
6
- agentic_kit_common/orm/execution.py,sha256=sCEcInFGFZ9ZiIs0eHKbQPtN4Z02v7vGdFMyVEWygcU,1951
8
+ agentic_kit_common/orm/execution.py,sha256=beyRJqVGY5nOqMDpvLcfwimUzP6eSz8ayNHLwh1hGOo,2858
7
9
  agentic_kit_common/orm/manager.py,sha256=lWgFk5fUu_9m6yN_fskWHSYGaW30ty--KZKj8AuvIh0,4852
8
10
  agentic_kit_common/orm/multi_session.py,sha256=CAEnOLl0I-r77JAknKQ2sERF40OkHD7BOrYVYoHuQV4,8814
9
11
  agentic_kit_common/orm/schema.py,sha256=ukdVP71NE_JO5_HOe_FApJjBAKUsDluoElVn5vNSGJQ,1023
10
12
  agentic_kit_common/orm/session.py,sha256=LX4ZUKJNXdcQ0KqRt5L0pJX-QG-tDyKXEDijcPUkGD0,2716
11
13
  agentic_kit_common/vector/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
14
  agentic_kit_common/vector/embedding/__init__.py,sha256=mp--cfCDzTn5hNxzcIAU4m0g4F0d0rrZErFVWsGDvx4,40
13
- agentic_kit_common/vector/embedding/embedding.py,sha256=k1hHNkBcucu6lSe2hYSrCal-ML26Eo6Bfe8cVkXSww8,1380
15
+ agentic_kit_common/vector/embedding/embedding.py,sha256=ZEvl2w2-PrAph60u2wJItj4lb68XxK805tRzOp49MdI,1861
14
16
  agentic_kit_common/vector/manager/__init__.py,sha256=w2uAmKGRx9Nv3QySXIAgxQGlYR549AQvbBuBFsQttKI,42
15
- agentic_kit_common/vector/manager/milvus_manager.py,sha256=HaoKKsyMOfmA3jX45XEVhmK6SprjTmwULOyhCqqAIGg,7840
17
+ agentic_kit_common/vector/manager/milvus_manager.py,sha256=gDwWlkrnIUFqMoDNsvqnPAKTyAUnyxD-fh7jy93LcQY,8568
16
18
  agentic_kit_common/vector/schema/__init__.py,sha256=F8WCi4ybnujFW-qHcDI9KOqzw7i8r07dmBUO2R2_Iu0,228
17
19
  agentic_kit_common/vector/schema/base.py,sha256=oo5OSFCX9UeQuWskiebqctF1dyXISD7_czzoKYoHnKg,102
18
- agentic_kit_common/vector/schema/milvus_schema.py,sha256=V8igxe6_FmZ46jicNdKxBbf0vl3bwTC4AyRf9gqB_VA,684
20
+ agentic_kit_common/vector/schema/milvus_schema.py,sha256=4kIUtIIE2vcKP1s1Ghn_eikuzDb1lsSeAaTwrsNzUlQ,784
19
21
  agentic_kit_common/web/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
20
22
  agentic_kit_common/web/http/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
21
23
  agentic_kit_common/web/http/response.py,sha256=uxd_MRHsFfQ0pUeHITDQD8tuOY6fo6IJ0MO6uL62AlM,1347
22
24
  test/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
23
25
  test/config.py,sha256=VMcvfWnuWZKr_p0rC57i90jTq-XVsRuVzfEIe1C7Drg,3804
24
26
  test/settings.py,sha256=vU1dqUvGzshi6MG7JbGfW4jKfUsfAObgUfky-LMjUNs,344
25
- test/test_embedding.py,sha256=78M6t9tt6-oCfdQZfOartCIkUFZMvogcu5UKJh1lJb8,609
27
+ test/test_embedding.py,sha256=xcfHDHGL2_tpXp_VaLnNDAWvKdkc1K1BjiYoV5WxtFY,900
26
28
  test/test_minio.py,sha256=TOkX8A2pPkjrwAIH88xBZmFpDc1ZgTk1QS6mtubOZ-Y,2308
27
29
  test/test_orm.py,sha256=8fGCU7BWaD5sDbg0fgYN0Saf_hi7t-q8svHCHLQDceo,1620
28
- test/test_vector.py,sha256=4sJ7bXt5iRhc6j1Ig0bZ1j0huPU-clnsnT7O0-uFbkw,1311
29
- agentic_kit_common-0.0.9.dist-info/METADATA,sha256=2QyqNG2MQrLYyQ6TDuPXkyBnPdf5La0g4-rloQ4jAZY,7496
30
- agentic_kit_common-0.0.9.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
31
- agentic_kit_common-0.0.9.dist-info/top_level.txt,sha256=nEKDlp84vqKSVWssGcxyuIsTqWLhMo45xqMs2GK4Dgg,24
32
- agentic_kit_common-0.0.9.dist-info/RECORD,,
30
+ test/test_vector.py,sha256=Z3Bwvrw0XGBWXHYXk9pkF1cjnUZSZk0gmLsgQIIaZuY,1717
31
+ agentic_kit_common-0.0.11.dist-info/METADATA,sha256=i2EbkVNKrTVCcfwE9ky0Sfr64Lg8A97GsPOQzDCgKwM,7520
32
+ agentic_kit_common-0.0.11.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
33
+ agentic_kit_common-0.0.11.dist-info/top_level.txt,sha256=nEKDlp84vqKSVWssGcxyuIsTqWLhMo45xqMs2GK4Dgg,24
34
+ agentic_kit_common-0.0.11.dist-info/RECORD,,
test/test_embedding.py CHANGED
@@ -6,13 +6,22 @@ project_root = Path(__file__).parent.parent
6
6
  sys.path.insert(0, str(project_root))
7
7
 
8
8
  from agentic_kit_common.vector.embedding import EmbeddingFactory
9
+ from langchain_openai import OpenAIEmbeddings
9
10
  from .settings import global_settings
10
11
 
11
12
 
12
13
  class MyTestCase(unittest.TestCase):
13
14
  def test_embedding(self):
14
- embedding = EmbeddingFactory.create_embedding(base_url=global_settings.embedding_base_url, model_uid=global_settings.embedding_model_uid)
15
+ # embedding = EmbeddingFactory.create_embedding(base_url=global_settings.embedding_base_url, model_uid=global_settings.embedding_model_uid)
16
+
17
+ embedding = EmbeddingFactory.create_embedding(
18
+ base_url=global_settings.embedding_base_url,
19
+ model_uid=global_settings.embedding_model_uid,
20
+ provider='vllm'
21
+ )
22
+
15
23
  text = '你好'
24
+ # text = 'hello, world'
16
25
  text_vector = embedding.embed_query(text=text)
17
26
  print(text_vector)
18
27
 
test/test_vector.py CHANGED
@@ -17,11 +17,11 @@ from .settings import global_settings
17
17
 
18
18
  logging.basicConfig(level=logging.DEBUG)
19
19
 
20
- embedding = EmbeddingFactory.create_embedding(base_url=global_settings.embedding_base_url, model_uid=global_settings.embedding_model_uid)
20
+ embedding = EmbeddingFactory.create_embedding(base_url=global_settings.embedding_base_url, model_uid=global_settings.embedding_model_uid, provider='vllm')
21
21
 
22
22
 
23
23
  class McpServerManager(MilvusManager):
24
- database_name: str = 'skill_center'
24
+ database_name: str = 'nl2sql'
25
25
 
26
26
  collection_name: str = 'mcp_server'
27
27
 
@@ -31,8 +31,16 @@ class McpServerManager(MilvusManager):
31
31
  class MyTestCase(unittest.TestCase):
32
32
  def test_vector(self):
33
33
  manager = McpServerManager.create(embed_model=embedding, vector_store_uri=global_settings.milvus_url)
34
+ # manager = MilvusManager.create(embed_model=embedding, vector_store_uri=global_settings.milvus_url, database_name='nl2sql', collection_name='mcp_server2')
35
+ # manager.insert([{
36
+ # 'uid': 'xxxx',
37
+ # 'content': 'hello',
38
+ # 'owner_uid': 'xxxx2'
39
+ # }])
34
40
 
35
- manager.query()
41
+ owner_uid = 'xxxx2'
42
+ res = manager.query(expr=f'owner_uid == "{owner_uid}"')
43
+ print(res)
36
44
 
37
45
  # manager.search('你好')
38
46
 
@@ -43,9 +51,9 @@ class MyTestCase(unittest.TestCase):
43
51
  # 'content': 'hello',
44
52
  # }])
45
53
 
46
- manager.delete_by_uid(uid='2050d9a9-b125-42b9-80ce-bd9eee89d9eb')
54
+ # manager.delete_by_uid(uid='2050d9a9-b125-42b9-80ce-bd9eee89d9eb')
47
55
 
48
- manager.query()
56
+ # manager.query()
49
57
 
50
58
 
51
59
  if __name__ == '__main__':