local-coze 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- local_coze/__init__.py +110 -0
- local_coze/cli/__init__.py +3 -0
- local_coze/cli/chat.py +126 -0
- local_coze/cli/cli.py +34 -0
- local_coze/cli/constants.py +7 -0
- local_coze/cli/db.py +81 -0
- local_coze/cli/embedding.py +193 -0
- local_coze/cli/image.py +162 -0
- local_coze/cli/knowledge.py +195 -0
- local_coze/cli/search.py +198 -0
- local_coze/cli/utils.py +41 -0
- local_coze/cli/video.py +191 -0
- local_coze/cli/video_edit.py +888 -0
- local_coze/cli/voice.py +351 -0
- local_coze/core/__init__.py +25 -0
- local_coze/core/client.py +253 -0
- local_coze/core/config.py +58 -0
- local_coze/core/exceptions.py +67 -0
- local_coze/database/__init__.py +29 -0
- local_coze/database/client.py +170 -0
- local_coze/database/migration.py +342 -0
- local_coze/embedding/__init__.py +31 -0
- local_coze/embedding/client.py +350 -0
- local_coze/embedding/models.py +130 -0
- local_coze/image/__init__.py +19 -0
- local_coze/image/client.py +110 -0
- local_coze/image/models.py +163 -0
- local_coze/knowledge/__init__.py +19 -0
- local_coze/knowledge/client.py +148 -0
- local_coze/knowledge/models.py +45 -0
- local_coze/llm/__init__.py +25 -0
- local_coze/llm/client.py +317 -0
- local_coze/llm/models.py +48 -0
- local_coze/memory/__init__.py +14 -0
- local_coze/memory/client.py +176 -0
- local_coze/s3/__init__.py +12 -0
- local_coze/s3/client.py +580 -0
- local_coze/s3/models.py +18 -0
- local_coze/search/__init__.py +19 -0
- local_coze/search/client.py +183 -0
- local_coze/search/models.py +57 -0
- local_coze/video/__init__.py +17 -0
- local_coze/video/client.py +347 -0
- local_coze/video/models.py +39 -0
- local_coze/video_edit/__init__.py +23 -0
- local_coze/video_edit/examples.py +340 -0
- local_coze/video_edit/frame_extractor.py +176 -0
- local_coze/video_edit/models.py +362 -0
- local_coze/video_edit/video_edit.py +631 -0
- local_coze/voice/__init__.py +17 -0
- local_coze/voice/asr.py +82 -0
- local_coze/voice/models.py +86 -0
- local_coze/voice/tts.py +94 -0
- local_coze-0.0.1.dist-info/METADATA +636 -0
- local_coze-0.0.1.dist-info/RECORD +58 -0
- local_coze-0.0.1.dist-info/WHEEL +4 -0
- local_coze-0.0.1.dist-info/entry_points.txt +3 -0
- local_coze-0.0.1.dist-info/licenses/LICENSE +21 -0
local_coze/llm/client.py
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
1
|
+
from typing import Any, Dict, Iterator, List, Optional, Union
|
|
2
|
+
|
|
3
|
+
from coze_coding_utils.runtime_ctx.context import Context, default_headers
|
|
4
|
+
from cozeloop.decorator import observe
|
|
5
|
+
from langchain_core.messages import (
|
|
6
|
+
AIMessage,
|
|
7
|
+
BaseMessage,
|
|
8
|
+
BaseMessageChunk,
|
|
9
|
+
HumanMessage,
|
|
10
|
+
SystemMessage,
|
|
11
|
+
)
|
|
12
|
+
from langchain_openai import ChatOpenAI
|
|
13
|
+
|
|
14
|
+
from ..core.config import Config
|
|
15
|
+
from .models import LLMConfig
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class LLMClient:
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
config: Optional[Config] = None,
|
|
22
|
+
ctx: Optional[Context] = None,
|
|
23
|
+
custom_headers: Optional[Dict[str, str]] = None,
|
|
24
|
+
verbose: bool = False,
|
|
25
|
+
):
|
|
26
|
+
if config is None:
|
|
27
|
+
config = Config()
|
|
28
|
+
self.config = config
|
|
29
|
+
self.ctx = ctx
|
|
30
|
+
self.custom_headers = custom_headers or {}
|
|
31
|
+
self.verbose = verbose
|
|
32
|
+
self.base_url = self.config.base_model_url
|
|
33
|
+
self.api_key = self.config.api_key
|
|
34
|
+
|
|
35
|
+
def _create_llm(
|
|
36
|
+
self,
|
|
37
|
+
llm_config: LLMConfig,
|
|
38
|
+
use_caching: bool = False,
|
|
39
|
+
previous_response_id: Optional[str] = None,
|
|
40
|
+
extra_headers: Optional[Dict[str, str]] = None,
|
|
41
|
+
) -> ChatOpenAI:
|
|
42
|
+
extra_body = {}
|
|
43
|
+
|
|
44
|
+
if llm_config.thinking:
|
|
45
|
+
extra_body["thinking"] = {"type": llm_config.thinking}
|
|
46
|
+
|
|
47
|
+
if llm_config.caching:
|
|
48
|
+
extra_body["caching"] = {"type": llm_config.caching}
|
|
49
|
+
|
|
50
|
+
headers = {}
|
|
51
|
+
|
|
52
|
+
if self.ctx is not None:
|
|
53
|
+
ctx_headers = default_headers(self.ctx)
|
|
54
|
+
headers.update(ctx_headers)
|
|
55
|
+
|
|
56
|
+
if self.custom_headers:
|
|
57
|
+
headers.update(self.custom_headers)
|
|
58
|
+
|
|
59
|
+
if llm_config.max_tokens == 0:
|
|
60
|
+
llm_config.max_tokens = 32768
|
|
61
|
+
if llm_config.max_completion_tokens == 0:
|
|
62
|
+
llm_config.max_completion_tokens = 32768
|
|
63
|
+
if llm_config.max_tokens and llm_config.max_completion_tokens:
|
|
64
|
+
llm_config.max_tokens = None
|
|
65
|
+
|
|
66
|
+
config_headers = self.config.get_headers(extra_headers)
|
|
67
|
+
headers.update(config_headers)
|
|
68
|
+
|
|
69
|
+
llm = ChatOpenAI(
|
|
70
|
+
model=llm_config.model,
|
|
71
|
+
api_key=self.api_key,
|
|
72
|
+
base_url=self.base_url,
|
|
73
|
+
streaming=llm_config.streaming,
|
|
74
|
+
extra_body=extra_body if extra_body else None,
|
|
75
|
+
temperature=llm_config.temperature,
|
|
76
|
+
frequency_penalty=llm_config.frequency_penalty,
|
|
77
|
+
top_p=llm_config.top_p,
|
|
78
|
+
max_tokens=llm_config.max_tokens,
|
|
79
|
+
max_completion_tokens=llm_config.max_completion_tokens,
|
|
80
|
+
default_headers=headers,
|
|
81
|
+
use_responses_api=use_caching,
|
|
82
|
+
use_previous_response_id=previous_response_id is not None,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
return llm
|
|
86
|
+
|
|
87
|
+
@observe(name="llm_stream")
|
|
88
|
+
def stream(
|
|
89
|
+
self,
|
|
90
|
+
messages: List[BaseMessage],
|
|
91
|
+
model: str = "doubao-seed-1-8-251228",
|
|
92
|
+
thinking: Optional[str] = "disabled",
|
|
93
|
+
caching: Optional[str] = "disabled",
|
|
94
|
+
temperature: Optional[float] = 1.0,
|
|
95
|
+
frequency_penalty: Optional[float] = 0,
|
|
96
|
+
top_p: Optional[float] = 0,
|
|
97
|
+
max_tokens: Optional[int] = None,
|
|
98
|
+
max_completion_tokens: Optional[int] = 32768,
|
|
99
|
+
previous_response_id: Optional[str] = None,
|
|
100
|
+
extra_headers: Optional[Dict[str, str]] = None,
|
|
101
|
+
) -> Iterator[BaseMessageChunk]:
|
|
102
|
+
"""
|
|
103
|
+
流式调用大语言模型,逐块返回生成的内容
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
messages: 消息列表,使用 LangChain 消息格式(必需)
|
|
107
|
+
- SystemMessage: 系统提示词,定义 AI 角色和行为
|
|
108
|
+
- HumanMessage: 用户消息
|
|
109
|
+
- AIMessage: AI 回复,用于多轮对话
|
|
110
|
+
|
|
111
|
+
model: 模型ID,默认 "doubao-seed-1-8-251228"
|
|
112
|
+
可选模型:
|
|
113
|
+
- "doubao-seed-1-8-251228": 最新模型,更高性能
|
|
114
|
+
- "doubao-seed-1-6-251015": 平衡性能
|
|
115
|
+
- "doubao-seed-1-6-flash-250615": 快速模型
|
|
116
|
+
- "doubao-seed-1-6-thinking-250715": 思考模型
|
|
117
|
+
|
|
118
|
+
thinking: 思考模式,默认 "disabled"
|
|
119
|
+
- "enabled": 启用深度思考,适合复杂推理任务
|
|
120
|
+
- "disabled": 禁用,适合快速响应
|
|
121
|
+
|
|
122
|
+
caching: 缓存模式,默认 "disabled"
|
|
123
|
+
- "enabled": 启用缓存,加速重复上下文的响应
|
|
124
|
+
- "disabled": 禁用
|
|
125
|
+
|
|
126
|
+
temperature: 温度参数,控制输出随机性,范围 0-2,默认 1.0
|
|
127
|
+
- 0.0-0.3: 确定性输出,适合代码生成、数据分析
|
|
128
|
+
- 0.7-0.9: 平衡创造性,适合通用对话
|
|
129
|
+
- 1.0-2.0: 高创造性,适合创意写作、头脑风暴
|
|
130
|
+
|
|
131
|
+
frequency_penalty: 频率惩罚,减少重复内容,范围 -2 到 2,默认 0
|
|
132
|
+
正值减少重复,负值增加重复
|
|
133
|
+
|
|
134
|
+
top_p: 核采样参数,控制输出多样性,范围 0-1,默认 0
|
|
135
|
+
值越小输出越确定,值越大输出越多样
|
|
136
|
+
|
|
137
|
+
max_tokens: 最大输出 token 数,默认 None(不限制)
|
|
138
|
+
用于限制输出长度或控制成本
|
|
139
|
+
|
|
140
|
+
max_completion_tokens: 最大完成 token 数,默认 None
|
|
141
|
+
更精确的长度控制
|
|
142
|
+
|
|
143
|
+
previous_response_id: 上一次响应ID,用于缓存场景,默认 None
|
|
144
|
+
|
|
145
|
+
extra_headers: 额外的 HTTP 请求头,默认 None
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
Iterator[BaseMessageChunk]: 流式返回的消息块,每个块包含:
|
|
149
|
+
- content: 文本内容
|
|
150
|
+
- response_metadata: 响应元数据
|
|
151
|
+
|
|
152
|
+
Example:
|
|
153
|
+
>>> from langchain_core.messages import HumanMessage
|
|
154
|
+
>>> client = LLMClient()
|
|
155
|
+
>>> messages = [HumanMessage(content="你好")]
|
|
156
|
+
>>>
|
|
157
|
+
>>> # 最简单用法
|
|
158
|
+
>>> for chunk in client.stream(messages):
|
|
159
|
+
... if chunk.content:
|
|
160
|
+
... print(chunk.content, end="")
|
|
161
|
+
>>>
|
|
162
|
+
>>> # 调整温度
|
|
163
|
+
>>> for chunk in client.stream(messages, temperature=0.7):
|
|
164
|
+
... if chunk.content:
|
|
165
|
+
... print(chunk.content, end="")
|
|
166
|
+
>>>
|
|
167
|
+
>>> # 启用思考模式
|
|
168
|
+
>>> for chunk in client.stream(messages, thinking="enabled"):
|
|
169
|
+
... if chunk.content:
|
|
170
|
+
... print(chunk.content, end="")
|
|
171
|
+
"""
|
|
172
|
+
llm_config = LLMConfig(
|
|
173
|
+
model=model,
|
|
174
|
+
thinking=thinking,
|
|
175
|
+
caching=caching,
|
|
176
|
+
temperature=temperature,
|
|
177
|
+
frequency_penalty=frequency_penalty,
|
|
178
|
+
top_p=top_p,
|
|
179
|
+
max_tokens=max_tokens,
|
|
180
|
+
max_completion_tokens=max_completion_tokens,
|
|
181
|
+
streaming=True,
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
use_caching = caching == "enabled" or previous_response_id is not None
|
|
185
|
+
|
|
186
|
+
if previous_response_id:
|
|
187
|
+
for i in range(len(messages) - 1, -1, -1):
|
|
188
|
+
msg = messages[i]
|
|
189
|
+
if isinstance(msg, AIMessage):
|
|
190
|
+
msg.response_metadata["id"] = previous_response_id
|
|
191
|
+
break
|
|
192
|
+
|
|
193
|
+
llm = self._create_llm(
|
|
194
|
+
llm_config,
|
|
195
|
+
use_caching=use_caching,
|
|
196
|
+
previous_response_id=previous_response_id,
|
|
197
|
+
extra_headers=extra_headers,
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
for chunk in llm.stream(messages):
|
|
201
|
+
yield chunk
|
|
202
|
+
|
|
203
|
+
@observe(name="llm_invoke")
|
|
204
|
+
def invoke(
|
|
205
|
+
self,
|
|
206
|
+
messages: List[BaseMessage],
|
|
207
|
+
model: str = "doubao-seed-1-8-251228",
|
|
208
|
+
thinking: Optional[str] = "disabled",
|
|
209
|
+
caching: Optional[str] = "disabled",
|
|
210
|
+
temperature: Optional[float] = 1.0,
|
|
211
|
+
frequency_penalty: Optional[float] = 0,
|
|
212
|
+
top_p: Optional[float] = 0,
|
|
213
|
+
max_tokens: Optional[int] = None,
|
|
214
|
+
max_completion_tokens: Optional[int] = 32768,
|
|
215
|
+
previous_response_id: Optional[str] = None,
|
|
216
|
+
extra_headers: Optional[Dict[str, str]] = None,
|
|
217
|
+
) -> AIMessage:
|
|
218
|
+
"""
|
|
219
|
+
非流式调用大语言模型,返回完整的响应
|
|
220
|
+
|
|
221
|
+
内部通过流式调用实现,自动组装完整响应后返回。
|
|
222
|
+
|
|
223
|
+
Args:
|
|
224
|
+
messages: 消息列表,使用 LangChain 消息格式(必需)
|
|
225
|
+
- SystemMessage: 系统提示词,定义 AI 角色和行为
|
|
226
|
+
- HumanMessage: 用户消息
|
|
227
|
+
- AIMessage: AI 回复,用于多轮对话
|
|
228
|
+
|
|
229
|
+
model: 模型ID,默认 "doubao-seed-1-8-251228"
|
|
230
|
+
可选模型:
|
|
231
|
+
- "doubao-seed-1-8-251228": 最新模型,更高性能
|
|
232
|
+
- "doubao-seed-1-6-251015": 平衡性能
|
|
233
|
+
- "doubao-seed-1-6-flash-250615": 快速模型
|
|
234
|
+
- "doubao-seed-1-6-thinking-250715": 思考模型
|
|
235
|
+
|
|
236
|
+
thinking: 思考模式,默认 "disabled"
|
|
237
|
+
- "enabled": 启用深度思考,适合复杂推理任务
|
|
238
|
+
- "disabled": 禁用,适合快速响应
|
|
239
|
+
|
|
240
|
+
caching: 缓存模式,默认 "disabled"
|
|
241
|
+
- "enabled": 启用缓存,加速重复上下文的响应
|
|
242
|
+
- "disabled": 禁用
|
|
243
|
+
|
|
244
|
+
temperature: 温度参数,控制输出随机性,范围 0-2,默认 1.0
|
|
245
|
+
- 0.0-0.3: 确定性输出,适合代码生成、数据分析
|
|
246
|
+
- 0.7-0.9: 平衡创造性,适合通用对话
|
|
247
|
+
- 1.0-2.0: 高创造性,适合创意写作、头脑风暴
|
|
248
|
+
|
|
249
|
+
frequency_penalty: 频率惩罚,减少重复内容,范围 -2 到 2,默认 0
|
|
250
|
+
正值减少重复,负值增加重复
|
|
251
|
+
|
|
252
|
+
top_p: 核采样参数,控制输出多样性,范围 0-1,默认 0
|
|
253
|
+
值越小输出越确定,值越大输出越多样
|
|
254
|
+
|
|
255
|
+
max_tokens: 最大输出 token 数,默认 None(不限制)
|
|
256
|
+
用于限制输出长度或控制成本
|
|
257
|
+
|
|
258
|
+
max_completion_tokens: 最大完成 token 数,默认 None
|
|
259
|
+
更精确的长度控制
|
|
260
|
+
|
|
261
|
+
previous_response_id: 上一次响应ID,用于缓存场景,默认 None
|
|
262
|
+
|
|
263
|
+
extra_headers: 额外的 HTTP 请求头,默认 None
|
|
264
|
+
|
|
265
|
+
Returns:
|
|
266
|
+
AIMessage: 完整的响应消息,包含:
|
|
267
|
+
- content: 完整的文本内容
|
|
268
|
+
- response_metadata: 响应元数据(模型信息、token 使用量等)
|
|
269
|
+
|
|
270
|
+
Example:
|
|
271
|
+
>>> from langchain_core.messages import SystemMessage, HumanMessage
|
|
272
|
+
>>> client = LLMClient()
|
|
273
|
+
>>>
|
|
274
|
+
>>> # 最简单用法
|
|
275
|
+
>>> messages = [HumanMessage(content="你好")]
|
|
276
|
+
>>> response = client.invoke(messages)
|
|
277
|
+
>>> print(response.content)
|
|
278
|
+
>>>
|
|
279
|
+
>>> # 带系统提示词
|
|
280
|
+
>>> messages = [
|
|
281
|
+
... SystemMessage(content="你是一个 Python 专家"),
|
|
282
|
+
... HumanMessage(content="什么是装饰器?")
|
|
283
|
+
... ]
|
|
284
|
+
>>> response = client.invoke(messages)
|
|
285
|
+
>>> print(response.content)
|
|
286
|
+
>>>
|
|
287
|
+
>>> # 调整参数
|
|
288
|
+
>>> response = client.invoke(
|
|
289
|
+
... messages=messages,
|
|
290
|
+
... temperature=0.7,
|
|
291
|
+
... max_tokens=500
|
|
292
|
+
... )
|
|
293
|
+
>>> print(response.content)
|
|
294
|
+
>>> print(f"Token 使用: {response.response_metadata}")
|
|
295
|
+
"""
|
|
296
|
+
full_content = ""
|
|
297
|
+
response_metadata = {}
|
|
298
|
+
|
|
299
|
+
for chunk in self.stream(
|
|
300
|
+
messages=messages,
|
|
301
|
+
model=model,
|
|
302
|
+
thinking=thinking,
|
|
303
|
+
caching=caching,
|
|
304
|
+
temperature=temperature,
|
|
305
|
+
frequency_penalty=frequency_penalty,
|
|
306
|
+
top_p=top_p,
|
|
307
|
+
max_tokens=max_tokens,
|
|
308
|
+
max_completion_tokens=max_completion_tokens,
|
|
309
|
+
previous_response_id=previous_response_id,
|
|
310
|
+
extra_headers=extra_headers,
|
|
311
|
+
):
|
|
312
|
+
if chunk.content:
|
|
313
|
+
full_content += str(chunk.content)
|
|
314
|
+
if chunk.response_metadata:
|
|
315
|
+
response_metadata.update(chunk.response_metadata)
|
|
316
|
+
|
|
317
|
+
return AIMessage(content=full_content, response_metadata=response_metadata)
|
local_coze/llm/models.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
from typing import Optional, Literal, Dict, Any, List, Union
|
|
2
|
+
from pydantic import BaseModel, Field
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class ThinkingConfig(BaseModel):
|
|
6
|
+
type: Literal["enabled", "disabled"] = Field("disabled", description="是否开启深度思考能力")
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class CachingConfig(BaseModel):
|
|
10
|
+
type: Literal["enabled", "disabled"] = Field("disabled", description="是否开启模型缓存")
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class LLMConfig(BaseModel):
|
|
14
|
+
model: str = Field("doubao-seed-1-8-251228", description="模型ID")
|
|
15
|
+
thinking: Optional[Literal["enabled", "disabled"]] = Field("disabled", description="是否开启深度思考能力")
|
|
16
|
+
caching: Optional[Literal["enabled", "disabled"]] = Field("disabled", description="是否开启模型缓存")
|
|
17
|
+
temperature: Optional[float] = Field(1.0, ge=0, le=2, description="控制模型输出的随机性")
|
|
18
|
+
frequency_penalty: Optional[float] = Field(0, ge=-2, le=2, description="重复语句惩罚")
|
|
19
|
+
top_p: Optional[float] = Field(0, ge=0, le=1, description="控制模型输出的多样性")
|
|
20
|
+
max_tokens: Optional[int] = Field(None, description="控制模型输出的最大 tokens 数")
|
|
21
|
+
max_completion_tokens: Optional[int] = Field(None, description="控制模型输出的最大 completion tokens 数")
|
|
22
|
+
streaming: bool = Field(True, description="是否使用流式输出")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class TextContent(BaseModel):
|
|
26
|
+
type: Literal["text"] = "text"
|
|
27
|
+
text: str
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class ImageURLDetail(BaseModel):
|
|
31
|
+
url: str
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class ImageURLContent(BaseModel):
|
|
35
|
+
type: Literal["image_url"] = "image_url"
|
|
36
|
+
image_url: ImageURLDetail
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class VideoURLDetail(BaseModel):
|
|
40
|
+
url: str
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class VideoURLContent(BaseModel):
|
|
44
|
+
type: Literal["video_url"] = "video_url"
|
|
45
|
+
video_url: VideoURLDetail
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
MessageContent = Union[str, List[Union[TextContent, ImageURLContent, VideoURLContent]]]
|
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Memory 模块
|
|
3
|
+
提供 LangGraph 检查点管理,支持 PostgreSQL 持久化和内存兜底
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import os
|
|
7
|
+
import time
|
|
8
|
+
import logging
|
|
9
|
+
from typing import Optional, Union
|
|
10
|
+
|
|
11
|
+
import psycopg
|
|
12
|
+
from psycopg_pool import AsyncConnectionPool
|
|
13
|
+
from langgraph.checkpoint.postgres import PostgresSaver
|
|
14
|
+
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
|
15
|
+
from langgraph.checkpoint.memory import MemorySaver
|
|
16
|
+
from langgraph.checkpoint.base import BaseCheckpointSaver
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
# 数据库连接超时时间(秒),每次尝试 15 秒,共尝试 2 次
|
|
21
|
+
DB_CONNECTION_TIMEOUT = 15
|
|
22
|
+
DB_MAX_RETRIES = 2
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _load_env() -> None:
|
|
26
|
+
"""加载环境变量"""
|
|
27
|
+
try:
|
|
28
|
+
from dotenv import load_dotenv
|
|
29
|
+
load_dotenv()
|
|
30
|
+
except ImportError:
|
|
31
|
+
pass
|
|
32
|
+
|
|
33
|
+
try:
|
|
34
|
+
from coze_workload_identity import Client
|
|
35
|
+
client = Client()
|
|
36
|
+
env_vars = client.get_project_env_vars()
|
|
37
|
+
client.close()
|
|
38
|
+
for env_var in env_vars:
|
|
39
|
+
if env_var.key not in os.environ:
|
|
40
|
+
os.environ[env_var.key] = env_var.value
|
|
41
|
+
except Exception as e:
|
|
42
|
+
logger.debug(f"coze_workload_identity not available: {e}")
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _get_db_url() -> Optional[str]:
|
|
46
|
+
"""安全获取 db_url,失败时返回 None"""
|
|
47
|
+
_load_env()
|
|
48
|
+
|
|
49
|
+
url = os.getenv("PGDATABASE_URL")
|
|
50
|
+
if url and url.strip():
|
|
51
|
+
return url
|
|
52
|
+
|
|
53
|
+
logger.warning("PGDATABASE_URL is not set, will fallback to MemorySaver")
|
|
54
|
+
return None
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class MemoryManager:
|
|
58
|
+
"""Memory Manager 单例类"""
|
|
59
|
+
|
|
60
|
+
_instance: Optional['MemoryManager'] = None
|
|
61
|
+
_checkpointer: Optional[Union[AsyncPostgresSaver, MemorySaver]] = None
|
|
62
|
+
_pool: Optional[AsyncConnectionPool] = None
|
|
63
|
+
_setup_done: bool = False
|
|
64
|
+
|
|
65
|
+
def __new__(cls):
|
|
66
|
+
if cls._instance is None:
|
|
67
|
+
cls._instance = super().__new__(cls)
|
|
68
|
+
return cls._instance
|
|
69
|
+
|
|
70
|
+
def _connect_with_retry(self, db_url: str) -> Optional[psycopg.Connection]:
|
|
71
|
+
"""带重试的数据库连接,每次 15 秒超时,共尝试 2 次"""
|
|
72
|
+
last_error = None
|
|
73
|
+
for attempt in range(1, DB_MAX_RETRIES + 1):
|
|
74
|
+
try:
|
|
75
|
+
logger.info(f"Attempting database connection (attempt {attempt}/{DB_MAX_RETRIES})")
|
|
76
|
+
conn = psycopg.connect(db_url, autocommit=True, connect_timeout=DB_CONNECTION_TIMEOUT)
|
|
77
|
+
logger.info(f"Database connection established on attempt {attempt}")
|
|
78
|
+
return conn
|
|
79
|
+
except Exception as e:
|
|
80
|
+
last_error = e
|
|
81
|
+
logger.warning(f"Database connection attempt {attempt} failed: {e}")
|
|
82
|
+
if attempt < DB_MAX_RETRIES:
|
|
83
|
+
time.sleep(1) # 重试前短暂等待
|
|
84
|
+
logger.error(f"All {DB_MAX_RETRIES} database connection attempts failed, last error: {last_error}")
|
|
85
|
+
return None
|
|
86
|
+
|
|
87
|
+
def _setup_schema_and_tables(self, db_url: str) -> bool:
|
|
88
|
+
"""同步创建 schema 和表(只执行一次),返回是否成功"""
|
|
89
|
+
if self._setup_done:
|
|
90
|
+
return True
|
|
91
|
+
|
|
92
|
+
conn = self._connect_with_retry(db_url)
|
|
93
|
+
if conn is None:
|
|
94
|
+
return False
|
|
95
|
+
|
|
96
|
+
try:
|
|
97
|
+
with conn.cursor() as cur:
|
|
98
|
+
cur.execute("CREATE SCHEMA IF NOT EXISTS memory")
|
|
99
|
+
conn.execute("SET search_path TO memory")
|
|
100
|
+
PostgresSaver(conn).setup()
|
|
101
|
+
self._setup_done = True
|
|
102
|
+
logger.info("Memory schema and tables created")
|
|
103
|
+
return True
|
|
104
|
+
except Exception as e:
|
|
105
|
+
logger.warning(f"Failed to setup schema/tables: {e}")
|
|
106
|
+
return False
|
|
107
|
+
finally:
|
|
108
|
+
conn.close()
|
|
109
|
+
|
|
110
|
+
def _create_fallback_checkpointer(self) -> MemorySaver:
|
|
111
|
+
"""创建内存兜底 checkpointer"""
|
|
112
|
+
self._checkpointer = MemorySaver()
|
|
113
|
+
logger.warning("Using MemorySaver as fallback checkpointer (data will not persist across restarts)")
|
|
114
|
+
return self._checkpointer
|
|
115
|
+
|
|
116
|
+
def get_checkpointer(self) -> BaseCheckpointSaver:
|
|
117
|
+
"""获取 checkpointer,优先使用 PostgresSaver,失败时退化为 MemorySaver"""
|
|
118
|
+
if self._checkpointer is not None:
|
|
119
|
+
return self._checkpointer
|
|
120
|
+
|
|
121
|
+
# 1. 尝试获取 db_url
|
|
122
|
+
db_url = _get_db_url()
|
|
123
|
+
if not db_url:
|
|
124
|
+
return self._create_fallback_checkpointer()
|
|
125
|
+
|
|
126
|
+
# 2. 尝试连接数据库并创建 schema/表(带重试)
|
|
127
|
+
if not self._setup_schema_and_tables(db_url):
|
|
128
|
+
return self._create_fallback_checkpointer()
|
|
129
|
+
|
|
130
|
+
# 3. 连接字符串加上 search_path
|
|
131
|
+
if "?" in db_url:
|
|
132
|
+
db_url = f"{db_url}&options=-csearch_path%3Dmemory"
|
|
133
|
+
else:
|
|
134
|
+
db_url = f"{db_url}?options=-csearch_path%3Dmemory"
|
|
135
|
+
|
|
136
|
+
# 4. 尝试创建连接池和 checkpointer
|
|
137
|
+
try:
|
|
138
|
+
self._pool = AsyncConnectionPool(
|
|
139
|
+
conninfo=db_url,
|
|
140
|
+
timeout=DB_CONNECTION_TIMEOUT,
|
|
141
|
+
min_size=1,
|
|
142
|
+
max_idle=300,
|
|
143
|
+
)
|
|
144
|
+
self._checkpointer = AsyncPostgresSaver(self._pool)
|
|
145
|
+
logger.info("AsyncPostgresSaver initialized successfully")
|
|
146
|
+
except Exception as e:
|
|
147
|
+
logger.warning(f"Failed to create AsyncPostgresSaver: {e}, will fallback to MemorySaver")
|
|
148
|
+
return self._create_fallback_checkpointer()
|
|
149
|
+
|
|
150
|
+
return self._checkpointer
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
_memory_manager: Optional[MemoryManager] = None
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def get_memory_saver() -> BaseCheckpointSaver:
|
|
157
|
+
"""
|
|
158
|
+
获取 checkpointer,优先使用 PostgresSaver,db_url 不可用或连接失败时退化为 MemorySaver
|
|
159
|
+
|
|
160
|
+
Returns:
|
|
161
|
+
BaseCheckpointSaver: LangGraph 检查点保存器
|
|
162
|
+
|
|
163
|
+
Example:
|
|
164
|
+
>>> from local_coze.memory import get_memory_saver
|
|
165
|
+
>>> checkpointer = get_memory_saver()
|
|
166
|
+
"""
|
|
167
|
+
global _memory_manager
|
|
168
|
+
if _memory_manager is None:
|
|
169
|
+
_memory_manager = MemoryManager()
|
|
170
|
+
return _memory_manager.get_checkpointer()
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
__all__ = [
|
|
174
|
+
"MemoryManager",
|
|
175
|
+
"get_memory_saver",
|
|
176
|
+
]
|