sycommon-python-lib 0.1.46__py3-none-any.whl → 0.1.56b5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. sycommon/config/Config.py +6 -2
  2. sycommon/config/RerankerConfig.py +1 -0
  3. sycommon/database/async_base_db_service.py +36 -0
  4. sycommon/database/async_database_service.py +96 -0
  5. sycommon/llm/__init__.py +0 -0
  6. sycommon/llm/embedding.py +149 -0
  7. sycommon/llm/get_llm.py +246 -0
  8. sycommon/llm/llm_logger.py +126 -0
  9. sycommon/llm/llm_tokens.py +119 -0
  10. sycommon/logging/async_sql_logger.py +65 -0
  11. sycommon/logging/kafka_log.py +21 -9
  12. sycommon/logging/logger_levels.py +23 -0
  13. sycommon/middleware/context.py +2 -0
  14. sycommon/middleware/traceid.py +155 -32
  15. sycommon/notice/__init__.py +0 -0
  16. sycommon/notice/uvicorn_monitor.py +195 -0
  17. sycommon/rabbitmq/rabbitmq_client.py +144 -152
  18. sycommon/rabbitmq/rabbitmq_pool.py +213 -479
  19. sycommon/rabbitmq/rabbitmq_service.py +77 -127
  20. sycommon/services.py +78 -75
  21. sycommon/synacos/feign.py +18 -7
  22. sycommon/synacos/feign_client.py +26 -8
  23. sycommon/synacos/nacos_service.py +18 -2
  24. sycommon/tools/merge_headers.py +97 -0
  25. sycommon/tools/snowflake.py +290 -23
  26. {sycommon_python_lib-0.1.46.dist-info → sycommon_python_lib-0.1.56b5.dist-info}/METADATA +15 -10
  27. {sycommon_python_lib-0.1.46.dist-info → sycommon_python_lib-0.1.56b5.dist-info}/RECORD +30 -18
  28. {sycommon_python_lib-0.1.46.dist-info → sycommon_python_lib-0.1.56b5.dist-info}/WHEEL +0 -0
  29. {sycommon_python_lib-0.1.46.dist-info → sycommon_python_lib-0.1.56b5.dist-info}/entry_points.txt +0 -0
  30. {sycommon_python_lib-0.1.46.dist-info → sycommon_python_lib-0.1.56b5.dist-info}/top_level.txt +0 -0
sycommon/config/Config.py CHANGED
@@ -16,10 +16,9 @@ class Config(metaclass=SingletonMeta):
16
16
  self.config = yaml.safe_load(f)
17
17
  self.MaxBytes = self.config.get('MaxBytes', 209715200)
18
18
  self.Timeout = self.config.get('Timeout', 300000)
19
+ self.MaxRetries = self.config.get('MaxRetries', 3)
19
20
  self.OCR = self.config.get('OCR', None)
20
21
  self.INVOICE_OCR = self.config.get('INVOICE_OCR', None)
21
- self.UnstructuredAPI = self.config.get('UnstructuredAPI', None)
22
- self.MaxRetries = self.config.get('MaxRetries', 3)
23
22
  self.llm_configs = []
24
23
  self.embedding_configs = []
25
24
  self.reranker_configs = []
@@ -71,3 +70,8 @@ class Config(metaclass=SingletonMeta):
71
70
  self.reranker_configs.append(validated_config.model_dump())
72
71
  except ValueError as e:
73
72
  print(f"Invalid LLM configuration: {e}")
73
+
74
+ def set_attr(self, share_configs: dict):
75
+ self.config = {**self.config, **
76
+ share_configs.get('llm', {}), **share_configs}
77
+ self._process_config()
@@ -5,6 +5,7 @@ class RerankerConfig(BaseModel):
5
5
  model: str
6
6
  provider: str
7
7
  baseUrl: str
8
+ maxTokens: int
8
9
 
9
10
  @classmethod
10
11
  def from_config(cls, model_name: str):
@@ -0,0 +1,36 @@
1
+ from contextlib import asynccontextmanager
2
+ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
3
+ from sycommon.config.Config import SingletonMeta
4
+ from sycommon.database.async_database_service import AsyncDatabaseService
5
+ from sycommon.logging.kafka_log import SYLogger
6
+
7
+
8
+ class AsyncBaseDBService(metaclass=SingletonMeta):
9
+ """数据库操作基础服务类,封装异步会话管理功能"""
10
+
11
+ def __init__(self):
12
+ # 获取异步引擎 (假设 DatabaseService.engine() 返回的是 AsyncEngine)
13
+ self.engine = AsyncDatabaseService.engine()
14
+
15
+ # 创建异步 Session 工厂
16
+ # class_=AsyncSession 是必须的,用于指定生成的是异步会话
17
+ self.Session = async_sessionmaker(
18
+ bind=self.engine,
19
+ class_=AsyncSession,
20
+ expire_on_commit=False
21
+ )
22
+
23
+ @asynccontextmanager
24
+ async def session(self):
25
+ """
26
+ 异步数据库会话上下文管理器
27
+ 自动处理会话的创建、提交、回滚和关闭
28
+ """
29
+ async with self.Session() as session:
30
+ try:
31
+ yield session
32
+ await session.commit()
33
+ except Exception as e:
34
+ await session.rollback()
35
+ SYLogger.error(f"Database operation failed: {str(e)}")
36
+ raise
@@ -0,0 +1,96 @@
1
+ from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
2
+ from sqlalchemy import text
3
+
4
+ from sycommon.config.Config import SingletonMeta
5
+ from sycommon.config.DatabaseConfig import DatabaseConfig, convert_dict_keys
6
+ from sycommon.logging.kafka_log import SYLogger
7
+ from sycommon.logging.async_sql_logger import AsyncSQLTraceLogger
8
+ from sycommon.synacos.nacos_service import NacosService
9
+
10
+
11
+ class AsyncDatabaseService(metaclass=SingletonMeta):
12
+ _engine = None
13
+
14
+ @staticmethod
15
+ async def setup_database(config: dict, shareConfigKey: str):
16
+ common = NacosService(config).share_configs.get(shareConfigKey, {})
17
+ if common and common.get('spring', {}).get('datasource', None):
18
+ databaseConfig = common.get('spring', {}).get('datasource', None)
19
+ converted_dict = convert_dict_keys(databaseConfig)
20
+ db_config = DatabaseConfig.model_validate(converted_dict)
21
+
22
+ # 初始化 DatabaseConnector (传入配置)
23
+ connector = AsyncDatabaseConnector(db_config)
24
+
25
+ # 赋值 engine
26
+ AsyncDatabaseService._engine = connector.engine
27
+
28
+ # 执行异步测试连接
29
+ if not await connector.test_connection():
30
+ raise Exception("Database connection test failed")
31
+
32
+ @staticmethod
33
+ def engine():
34
+ return AsyncDatabaseService._engine
35
+
36
+
37
+ class AsyncDatabaseConnector(metaclass=SingletonMeta):
38
+ def __init__(self, db_config: DatabaseConfig):
39
+ # 从 DatabaseConfig 中提取数据库连接信息
40
+ self.db_user = db_config.username
41
+ self.db_password = db_config.password
42
+
43
+ # 提取 URL 中的主机、端口和数据库名
44
+ url_parts = db_config.url.split('//')[1].split('/')
45
+ host_port = url_parts[0].split(':')
46
+ self.db_host = host_port[0]
47
+ self.db_port = host_port[1]
48
+ self.db_name = url_parts[1].split('?')[0]
49
+
50
+ # 提取 URL 中的参数
51
+ params_str = url_parts[1].split('?')[1] if len(
52
+ url_parts[1].split('?')) > 1 else ''
53
+ params = {}
54
+ for param in params_str.split('&'):
55
+ if param:
56
+ key, value = param.split('=')
57
+ params[key] = value
58
+
59
+ # 在params中去掉指定的参数
60
+ for key in ['useUnicode', 'characterEncoding', 'serverTimezone', 'zeroDateTimeBehavior']:
61
+ if key in params:
62
+ del params[key]
63
+
64
+ # 构建数据库连接 URL
65
+ # 注意:这里将 mysqlconnector 替换为 aiomysql 以支持异步
66
+ self.db_url = f'mysql+aiomysql://{self.db_user}:{self.db_password}@{self.db_host}:{self.db_port}/{self.db_name}'
67
+
68
+ SYLogger.info(f"Database URL: {self.db_url}")
69
+
70
+ # 优化连接池配置
71
+ # 使用 create_async_engine 替代 create_engine
72
+ self.engine = create_async_engine(
73
+ self.db_url,
74
+ connect_args=params,
75
+ pool_size=10, # 连接池大小
76
+ max_overflow=20, # 最大溢出连接数
77
+ pool_timeout=30, # 连接超时时间(秒)
78
+ pool_recycle=3600, # 连接回收时间(秒)
79
+ pool_pre_ping=True, # 每次获取连接前检查连接是否有效
80
+ echo=False, # 打印 SQL 语句
81
+ )
82
+
83
+ # 注册 SQL 日志拦截器 (注意:SQLTraceLogger 需要支持异步引擎,或者您可能需要调整日志逻辑)
84
+ # 假设 SQLTraceLogger.setup_sql_logging 能够处理 AsyncEngine
85
+ AsyncSQLTraceLogger.setup_sql_logging(self.engine)
86
+
87
+ async def test_connection(self):
88
+ try:
89
+ # 异步上下文管理器
90
+ async with self.engine.connect() as connection:
91
+ # 执行简单查询
92
+ await connection.execute(text("SELECT 1"))
93
+ return True
94
+ except Exception as e:
95
+ SYLogger.error(f"Database connection test failed: {e}")
96
+ return False
File without changes
@@ -0,0 +1,149 @@
1
+ import asyncio
2
+ import json
3
+ import aiohttp
4
+ from typing import Union, List, Optional
5
+
6
+ from sycommon.config.Config import SingletonMeta
7
+ from sycommon.config.EmbeddingConfig import EmbeddingConfig
8
+ from sycommon.config.RerankerConfig import RerankerConfig
9
+ from sycommon.logging.kafka_log import SYLogger
10
+
11
+
12
+ class Embedding(metaclass=SingletonMeta):
13
+ def __init__(self):
14
+ # 1. 并发限制
15
+ self.max_concurrency = 20
16
+ # 保留默认模型名称
17
+ self.default_embedding_model = "bge-large-zh-v1.5"
18
+ self.default_reranker_model = "bge-reranker-large"
19
+
20
+ # 初始化默认模型的基础URL
21
+ self.embeddings_base_url = EmbeddingConfig.from_config(
22
+ self.default_embedding_model).baseUrl
23
+ self.reranker_base_url = RerankerConfig.from_config(
24
+ self.default_reranker_model).baseUrl
25
+
26
+ # 并发信号量
27
+ self.semaphore = asyncio.Semaphore(self.max_concurrency)
28
+
29
+ async def _get_embeddings_http_async(
30
+ self,
31
+ input: Union[str, List[str]],
32
+ encoding_format: str = None,
33
+ model: str = None,
34
+ **kwargs
35
+ ):
36
+ async with self.semaphore:
37
+ # 优先使用传入的模型名,无则用默认值
38
+ target_model = model or self.default_embedding_model
39
+ target_base_url = EmbeddingConfig.from_config(target_model).baseUrl
40
+ url = f"{target_base_url}/v1/embeddings"
41
+
42
+ request_body = {
43
+ "model": target_model,
44
+ "input": input,
45
+ "encoding_format": encoding_format or "float"
46
+ }
47
+ request_body.update(kwargs)
48
+
49
+ async with aiohttp.ClientSession() as session:
50
+ async with session.post(url, json=request_body) as response:
51
+ if response.status != 200:
52
+ error_detail = await response.text()
53
+ SYLogger.error(
54
+ f"Embedding request failed (model: {target_model}): {error_detail}")
55
+ return None
56
+ return await response.json()
57
+
58
+ async def _get_reranker_http_async(
59
+ self,
60
+ documents: List[str],
61
+ query: str,
62
+ top_n: Optional[int] = None,
63
+ model: str = None,
64
+ max_chunks_per_doc: Optional[int] = None,
65
+ return_documents: Optional[bool] = True,
66
+ return_len: Optional[bool] = True,
67
+ **kwargs
68
+ ):
69
+ async with self.semaphore:
70
+ # 优先使用传入的模型名,无则用默认值
71
+ target_model = model or self.default_reranker_model
72
+ target_base_url = RerankerConfig.from_config(target_model).baseUrl
73
+ url = f"{target_base_url}/v1/rerank"
74
+
75
+ request_body = {
76
+ "model": target_model,
77
+ "documents": documents,
78
+ "query": query,
79
+ "top_n": top_n or len(documents),
80
+ "max_chunks_per_doc": max_chunks_per_doc,
81
+ "return_documents": return_documents,
82
+ "return_len": return_len,
83
+ "kwargs": json.dumps(kwargs),
84
+ }
85
+ request_body.update(kwargs)
86
+
87
+ async with aiohttp.ClientSession() as session:
88
+ async with session.post(url, json=request_body) as response:
89
+ if response.status != 200:
90
+ error_detail = await response.text()
91
+ SYLogger.error(
92
+ f"Rerank request failed (model: {target_model}): {error_detail}")
93
+ return None
94
+ return await response.json()
95
+
96
+ async def get_embeddings(
97
+ self,
98
+ corpus: List[str],
99
+ model: str = None
100
+ ):
101
+ """
102
+ 获取语料库的嵌入向量,结果顺序与输入语料库顺序一致
103
+
104
+ Args:
105
+ corpus: 待生成嵌入向量的文本列表
106
+ model: 可选,指定使用的embedding模型名称,默认使用bge-large-zh-v1.5
107
+ """
108
+ SYLogger.info(
109
+ f"Requesting embeddings for corpus: {corpus} (model: {model or self.default_embedding_model}, max_concurrency: {self.max_concurrency})")
110
+ # 给每个异步任务传入模型名称
111
+ tasks = [self._get_embeddings_http_async(
112
+ text, model=model) for text in corpus]
113
+ results = await asyncio.gather(*tasks)
114
+
115
+ vectors = []
116
+ for result in results:
117
+ if result is None:
118
+ zero_vector = [0.0] * 1024
119
+ vectors.append(zero_vector)
120
+ SYLogger.warning(
121
+ f"Embedding request failed, append zero vector (1024D)")
122
+ continue
123
+ for item in result["data"]:
124
+ vectors.append(item["embedding"])
125
+
126
+ SYLogger.info(
127
+ f"Embeddings for corpus: {corpus} created (model: {model or self.default_embedding_model})")
128
+ return vectors
129
+
130
+ async def get_reranker(
131
+ self,
132
+ top_results: List[str],
133
+ query: str,
134
+ model: str = None
135
+ ):
136
+ """
137
+ 对搜索结果进行重排序
138
+
139
+ Args:
140
+ top_results: 待重排序的文本列表
141
+ query: 排序参考的查询语句
142
+ model: 可选,指定使用的reranker模型名称,默认使用bge-reranker-large
143
+ """
144
+ SYLogger.info(
145
+ f"Requesting reranker for top_results: {top_results} (model: {model or self.default_reranker_model}, max_concurrency: {self.max_concurrency})")
146
+ data = await self._get_reranker_http_async(top_results, query, model=model)
147
+ SYLogger.info(
148
+ f"Reranker for top_results: {top_results} completed (model: {model or self.default_reranker_model})")
149
+ return data
@@ -0,0 +1,246 @@
1
+ from typing import Dict, Type, List, Optional, Callable, Any
2
+ from sycommon.llm.llm_logger import LLMLogger
3
+ from langchain_core.language_models import BaseChatModel
4
+ from langchain_core.runnables import Runnable, RunnableLambda, RunnableConfig
5
+ from langchain_core.output_parsers import PydanticOutputParser
6
+ from langchain_core.messages import BaseMessage, HumanMessage
7
+ from langchain.chat_models import init_chat_model
8
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
9
+ from pydantic import BaseModel, ValidationError, Field
10
+ from sycommon.config.LLMConfig import LLMConfig
11
+ from sycommon.llm.llm_tokens import TokensCallbackHandler
12
+ from sycommon.logging.kafka_log import SYLogger
13
+
14
+
15
+ class StructuredRunnableWithToken(Runnable):
16
+ """带Token统计的Runnable类"""
17
+
18
+ def __init__(self, retry_chain: Runnable):
19
+ super().__init__()
20
+ self.retry_chain = retry_chain
21
+
22
+ def _adapt_input(self, input: Any) -> List[BaseMessage]:
23
+ """适配输入格式"""
24
+ if isinstance(input, list) and all(isinstance(x, BaseMessage) for x in input):
25
+ return input
26
+ elif isinstance(input, BaseMessage):
27
+ return [input]
28
+ elif isinstance(input, str):
29
+ return [HumanMessage(content=input)]
30
+ elif isinstance(input, dict) and "input" in input:
31
+ return [HumanMessage(content=str(input["input"]))]
32
+ else:
33
+ raise ValueError(f"不支持的输入格式:{type(input)}")
34
+
35
+ def _get_callback_config(self, config: Optional[RunnableConfig] = None) -> tuple[RunnableConfig, TokensCallbackHandler]:
36
+ """构建包含Token统计的回调配置"""
37
+ # 每次调用创建新的Token处理器实例
38
+ token_handler = TokensCallbackHandler()
39
+
40
+ # 初始化配置
41
+ if config is None:
42
+ processed_config = {"callbacks": []}
43
+ else:
44
+ processed_config = config.copy()
45
+ if "callbacks" not in processed_config:
46
+ processed_config["callbacks"] = []
47
+
48
+ # 添加回调(去重)
49
+ callbacks = processed_config["callbacks"]
50
+ # 添加LLMLogger(如果不存在)
51
+ if not any(isinstance(cb, LLMLogger) for cb in callbacks):
52
+ callbacks.append(LLMLogger())
53
+ # 添加Token处理器
54
+ callbacks.append(token_handler)
55
+
56
+ # 按类型去重
57
+ callback_types = {}
58
+ unique_callbacks = []
59
+ for cb in callbacks:
60
+ cb_type = type(cb)
61
+ if cb_type not in callback_types:
62
+ callback_types[cb_type] = cb
63
+ unique_callbacks.append(cb)
64
+
65
+ processed_config["callbacks"] = unique_callbacks
66
+
67
+ return processed_config, token_handler
68
+
69
+ # 同步调用
70
+ def invoke(self, input: Any, config: Optional[RunnableConfig] = None, ** kwargs) -> Dict[str, Any]:
71
+ try:
72
+ processed_config, token_handler = self._get_callback_config(
73
+ config)
74
+ adapted_input = self._adapt_input(input)
75
+
76
+ structured_result = self.retry_chain.invoke(
77
+ {"messages": adapted_input},
78
+ config=processed_config,
79
+ **kwargs
80
+ )
81
+
82
+ # 获取Token统计结果
83
+ token_usage = token_handler.usage_metadata
84
+ structured_result._token_usage_ = token_usage
85
+
86
+ return structured_result
87
+
88
+ except Exception as e:
89
+ SYLogger.error(f"同步LLM调用失败: {str(e)}", exc_info=True)
90
+ return None
91
+
92
+ # 异步调用
93
+ async def ainvoke(self, input: Any, config: Optional[RunnableConfig] = None, ** kwargs) -> Dict[str, Any]:
94
+ try:
95
+ processed_config, token_handler = self._get_callback_config(
96
+ config)
97
+ adapted_input = self._adapt_input(input)
98
+
99
+ structured_result = await self.retry_chain.ainvoke(
100
+ {"messages": adapted_input},
101
+ config=processed_config,
102
+ **kwargs
103
+ )
104
+
105
+ token_usage = token_handler.usage_metadata
106
+ structured_result._token_usage_ = token_usage
107
+
108
+ return structured_result
109
+
110
+ except Exception as e:
111
+ SYLogger.error(f"异步LLM调用失败: {str(e)}", exc_info=True)
112
+ return None
113
+
114
+
115
+ class LLMWithAutoTokenUsage(BaseChatModel):
116
+ """自动为结构化调用返回token_usage的LLM包装类"""
117
+ llm: BaseChatModel = Field(default=None)
118
+
119
+ def __init__(self, llm: BaseChatModel, **kwargs):
120
+ super().__init__(llm=llm, ** kwargs)
121
+
122
+ def with_structured_output(
123
+ self,
124
+ output_model: Type[BaseModel],
125
+ max_retries: int = 3,
126
+ is_extract: bool = False,
127
+ override_prompt: ChatPromptTemplate = None,
128
+ custom_processors: Optional[List[Callable[[str], str]]] = None,
129
+ custom_parser: Optional[Callable[[str], BaseModel]] = None
130
+ ) -> Runnable:
131
+ """返回支持自动统计Token的结构化Runnable"""
132
+ parser = PydanticOutputParser(pydantic_object=output_model)
133
+
134
+ # 提示词模板
135
+ accuracy_instructions = """
136
+ 字段值的抽取准确率(0~1之间),评分规则:
137
+ 1.0(完全准确):直接从原文提取,无需任何加工,且格式与原文完全一致
138
+ 0.9(轻微处理):数据来源明确,但需进行格式标准化或冗余信息剔除(不改变原始数值)
139
+ 0.8(有限推断):数据需通过上下文关联或简单计算得出,仍有明确依据
140
+ 0.8以下(不可靠):数据需大量推测、存在歧义或来源不明,处理方式:直接忽略该数据,设置为None
141
+ """
142
+
143
+ if is_extract:
144
+ prompt = ChatPromptTemplate.from_messages([
145
+ MessagesPlaceholder(variable_name="messages"),
146
+ HumanMessage(content=f"""
147
+ 请提取信息并遵循以下规则:
148
+ 1. 准确率要求:{accuracy_instructions.strip()}
149
+ 2. 输出格式:{parser.get_format_instructions()}
150
+ """)
151
+ ])
152
+ else:
153
+ prompt = override_prompt or ChatPromptTemplate.from_messages([
154
+ MessagesPlaceholder(variable_name="messages"),
155
+ HumanMessage(content=f"""
156
+ 输出格式:{parser.get_format_instructions()}
157
+ """)
158
+ ])
159
+
160
+ # 文本处理函数
161
+ def extract_response_content(response: BaseMessage) -> str:
162
+ try:
163
+ return response.content
164
+ except Exception as e:
165
+ raise ValueError(f"提取响应内容失败:{str(e)}") from e
166
+
167
+ def strip_code_block_markers(content: str) -> str:
168
+ try:
169
+ return content.strip("```json").strip("```").strip()
170
+ except Exception as e:
171
+ raise ValueError(f"移除代码块标记失败:{str(e)}") from e
172
+
173
+ def normalize_in_json(content: str) -> str:
174
+ try:
175
+ return content.replace("None", "null").replace("none", "null").replace("NONE", "null").replace("''", '""')
176
+ except Exception as e:
177
+ raise ValueError(f"JSON格式化失败:{str(e)}") from e
178
+
179
+ def default_parse_to_pydantic(content: str) -> BaseModel:
180
+ try:
181
+ return parser.parse(content)
182
+ except (ValidationError, ValueError) as e:
183
+ raise ValueError(f"解析结构化结果失败:{str(e)}") from e
184
+
185
+ # ========== 构建处理链 ==========
186
+ base_chain = prompt | self.llm | RunnableLambda(
187
+ extract_response_content)
188
+
189
+ # 文本处理链
190
+ process_runnables = custom_processors or [
191
+ RunnableLambda(strip_code_block_markers),
192
+ RunnableLambda(normalize_in_json)
193
+ ]
194
+ process_chain = base_chain
195
+ for runnable in process_runnables:
196
+ process_chain = process_chain | runnable
197
+
198
+ # 解析链
199
+ parse_chain = process_chain | RunnableLambda(
200
+ custom_parser or default_parse_to_pydantic)
201
+
202
+ # 重试链
203
+ retry_chain = parse_chain.with_retry(
204
+ retry_if_exception_type=(ValidationError, ValueError),
205
+ stop_after_attempt=max_retries,
206
+ wait_exponential_jitter=True,
207
+ exponential_jitter_params={
208
+ "initial": 0.1, "max": 3.0, "exp_base": 2.0, "jitter": 1.0}
209
+ )
210
+
211
+ return StructuredRunnableWithToken(retry_chain)
212
+
213
+ # ========== 实现BaseChatModel抽象方法 ==========
214
+ def _generate(self, messages, stop=None, run_manager=None, ** kwargs):
215
+ return self.llm._generate(messages, stop=stop, run_manager=run_manager, ** kwargs)
216
+
217
+ @property
218
+ def _llm_type(self) -> str:
219
+ return self.llm._llm_type
220
+
221
+
222
+ def get_llm(
223
+ model: str = None,
224
+ streaming: bool = False
225
+ ) -> LLMWithAutoTokenUsage:
226
+ if not model:
227
+ model = "Qwen2.5-72B"
228
+
229
+ llmConfig = LLMConfig.from_config(model)
230
+ if not llmConfig:
231
+ raise Exception(f"无效的模型配置:{model}")
232
+
233
+ llm = init_chat_model(
234
+ model_provider=llmConfig.provider,
235
+ model=llmConfig.model,
236
+ base_url=llmConfig.baseUrl,
237
+ api_key="-",
238
+ temperature=0.1,
239
+ streaming=streaming,
240
+ callbacks=[LLMLogger()]
241
+ )
242
+
243
+ if llm is None:
244
+ raise Exception(f"初始化原始LLM实例失败:{model}")
245
+
246
+ return LLMWithAutoTokenUsage(llm)
@@ -0,0 +1,126 @@
1
+ from langchain_core.callbacks import AsyncCallbackHandler
2
+ from typing import Any, Dict, List
3
+ from langchain_core.outputs import GenerationChunk, ChatGeneration
4
+ from langchain_core.messages import BaseMessage
5
+
6
+ from sycommon.logging.kafka_log import SYLogger
7
+
8
+
9
+ class LLMLogger(AsyncCallbackHandler):
10
+ """
11
+ 通用LLM日志回调处理器,同时支持:
12
+ - 同步调用(如 chain.invoke())
13
+ - 异步调用(如 chain.astream())
14
+ - 聊天模型调用
15
+ """
16
+
17
+ # ------------------------------
18
+ # 同步回调方法(处理 invoke 等同步调用)
19
+ # ------------------------------
20
+ def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
21
+ model_name = serialized.get('name', 'unknown')
22
+ SYLogger.info(
23
+ f"[同步] LLM调用开始 | 模型: {model_name} | 提示词数: {len(prompts)}")
24
+ self._log_prompts(prompts)
25
+
26
+ def on_chat_model_start(
27
+ self,
28
+ serialized: Dict[str, Any],
29
+ messages: List[List[BaseMessage]],
30
+ **kwargs: Any
31
+ ) -> None:
32
+ model_name = serialized.get('name', 'unknown')
33
+ SYLogger.info(
34
+ f"[同步] 聊天模型调用开始 | 模型: {model_name} | 消息组数: {len(messages)}")
35
+ self._log_chat_messages(messages)
36
+
37
+ def on_llm_end(self, response: Any, **kwargs: Any) -> None:
38
+ # 处理普通LLM结果
39
+ if hasattr(response, 'generations') and all(
40
+ isinstance(gen[0], GenerationChunk) for gen in response.generations
41
+ ):
42
+ for i, generation in enumerate(response.generations):
43
+ result = generation[0].text
44
+ SYLogger.info(
45
+ f"[同步] LLM调用结束 | 结果 #{i+1} 长度: {len(result)}")
46
+ self._log_result(result, i+1)
47
+ # 处理聊天模型结果
48
+ elif hasattr(response, 'generations') and all(
49
+ isinstance(gen[0], ChatGeneration) for gen in response.generations
50
+ ):
51
+ for i, generation in enumerate(response.generations):
52
+ result = generation[0].message.content
53
+ SYLogger.info(
54
+ f"[同步] 聊天模型调用结束 | 结果 #{i+1} 长度: {len(result)}")
55
+ self._log_result(result, i+1)
56
+
57
+ def on_llm_error(self, error: Exception, **kwargs: Any) -> None:
58
+ if isinstance(error, GeneratorExit):
59
+ SYLogger.info("[同步] LLM生成器正常关闭")
60
+ return
61
+ SYLogger.error(f"[同步] LLM调用出错: {str(error)}")
62
+
63
+ # ------------------------------
64
+ # 异步回调方法(处理 astream 等异步调用)
65
+ # ------------------------------
66
+ async def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
67
+ model_name = serialized.get('name', 'unknown')
68
+ SYLogger.info(
69
+ f"[异步] LLM调用开始 | 模型: {model_name} | 提示词数: {len(prompts)}")
70
+ self._log_prompts(prompts)
71
+
72
+ async def on_chat_model_start(
73
+ self,
74
+ serialized: Dict[str, Any],
75
+ messages: List[List[BaseMessage]],
76
+ **kwargs: Any
77
+ ) -> None:
78
+ model_name = serialized.get('name', 'unknown')
79
+ SYLogger.info(
80
+ f"[异步] 聊天模型调用开始 | 模型: {model_name} | 消息组数: {len(messages)}")
81
+ self._log_chat_messages(messages)
82
+
83
+ async def on_llm_end(self, response: Any, **kwargs: Any) -> None:
84
+ # 处理普通LLM结果
85
+ if hasattr(response, 'generations') and all(
86
+ isinstance(gen[0], GenerationChunk) for gen in response.generations
87
+ ):
88
+ for i, generation in enumerate(response.generations):
89
+ result = generation[0].text
90
+ SYLogger.info(
91
+ f"[异步] LLM调用结束 | 结果 #{i+1} 长度: {len(result)}")
92
+ self._log_result(result, i+1)
93
+ # 处理聊天模型结果
94
+ elif hasattr(response, 'generations') and all(
95
+ isinstance(gen[0], ChatGeneration) for gen in response.generations
96
+ ):
97
+ for i, generation in enumerate(response.generations):
98
+ result = generation[0].message.content
99
+ SYLogger.info(
100
+ f"[异步] 聊天模型调用结束 | 结果 #{i+1} 长度: {len(result)}")
101
+ self._log_result(result, i+1)
102
+
103
+ async def on_llm_error(self, error: Exception, **kwargs: Any) -> None:
104
+ if isinstance(error, GeneratorExit):
105
+ SYLogger.info("[异步] LLM生成器正常关闭")
106
+ return
107
+ SYLogger.error(f"[异步] LLM调用出错: {str(error)}")
108
+
109
+ # ------------------------------
110
+ # 共享工具方法(避免代码重复)
111
+ # ------------------------------
112
+ def _log_prompts(self, prompts: List[str]) -> None:
113
+ """记录提示词"""
114
+ for i, prompt in enumerate(prompts):
115
+ SYLogger.info(f"提示词 #{i+1}:\n{prompt}")
116
+
117
+ def _log_chat_messages(self, messages: List[List[BaseMessage]]) -> None:
118
+ """记录聊天模型的消息"""
119
+ for i, message_group in enumerate(messages):
120
+ SYLogger.info(f"消息组 #{i+1}:")
121
+ for msg in message_group:
122
+ SYLogger.info(f" {msg.type}: {msg.content}")
123
+
124
+ def _log_result(self, result: str, index: int) -> None:
125
+ """记录结果"""
126
+ SYLogger.info(f"结果 #{index}:\n{result}")