coze-coding-utils 0.2.2a1__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- coze_coding_utils/error/codes.py +2 -0
- coze_coding_utils/error/patterns.py +13 -1
- coze_coding_utils/file/file.py +0 -8
- coze_coding_utils/helper/stream_runner.py +563 -0
- {coze_coding_utils-0.2.2a1.dist-info → coze_coding_utils-0.2.3.dist-info}/METADATA +1 -1
- {coze_coding_utils-0.2.2a1.dist-info → coze_coding_utils-0.2.3.dist-info}/RECORD +8 -7
- {coze_coding_utils-0.2.2a1.dist-info → coze_coding_utils-0.2.3.dist-info}/WHEEL +0 -0
- {coze_coding_utils-0.2.2a1.dist-info → coze_coding_utils-0.2.3.dist-info}/licenses/LICENSE +0 -0
coze_coding_utils/error/codes.py
CHANGED
|
@@ -84,6 +84,7 @@ class ErrorCode(IntEnum):
|
|
|
84
84
|
API_LLM_CONTENT_FILTER = 301007 # 内容过滤
|
|
85
85
|
API_LLM_IMAGE_FORMAT = 301008 # 图片格式不支持
|
|
86
86
|
API_LLM_VIDEO_FORMAT = 301009 # 视频格式不支持
|
|
87
|
+
API_PROJECT_NOT_FOUND = 301010 # 项目不存在
|
|
87
88
|
|
|
88
89
|
# 302xxx - 图片生成API错误
|
|
89
90
|
API_IMAGE_GEN_FAILED = 302001 # 图片生成失败
|
|
@@ -136,6 +137,7 @@ class ErrorCode(IntEnum):
|
|
|
136
137
|
INTEGRATION_FEISHU_AUTH_FAILED = 501001 # 飞书认证失败
|
|
137
138
|
INTEGRATION_FEISHU_API_FAILED = 501002 # 飞书API调用失败
|
|
138
139
|
INTEGRATION_FEISHU_DOC_FAILED = 501003 # 飞书文档操作失败
|
|
140
|
+
INTEGRATION_FEISHU_TABLE_FAILED = 501004 # 飞书表格调用失败
|
|
139
141
|
|
|
140
142
|
# 502xxx - 微信集成错误
|
|
141
143
|
INTEGRATION_WECHAT_AUTH_FAILED = 502001 # 微信认证失败
|
|
@@ -112,6 +112,14 @@ ERROR_PATTERNS: List[ErrorPattern] = [
|
|
|
112
112
|
(['headobject operation', 'not found'],
|
|
113
113
|
ErrorCode.RESOURCE_S3_DOWNLOAD_FAILED, "S3对象不存在"),
|
|
114
114
|
|
|
115
|
+
# ==================== 权益类错误 ====================
|
|
116
|
+
(['因触发限流调用内置集成失败',"限流"],
|
|
117
|
+
ErrorCode.API_LLM_RATE_LIMIT, "限流"),
|
|
118
|
+
|
|
119
|
+
(['project not found'],
|
|
120
|
+
ErrorCode.API_PROJECT_NOT_FOUND, "项目不存在"),
|
|
121
|
+
|
|
122
|
+
|
|
115
123
|
# ==================== OCR/文档处理错误 ====================
|
|
116
124
|
(['ocr识别失败', '无法从响应中提取有效的json'],
|
|
117
125
|
ErrorCode.RESOURCE_FILE_FORMAT_ERROR, "OCR识别失败"),
|
|
@@ -261,7 +269,7 @@ ERROR_PATTERNS: List[ErrorPattern] = [
|
|
|
261
269
|
ErrorCode.CONFIG_API_KEY_MISSING, "AWS凭证缺失"),
|
|
262
270
|
(['生成pdf报告失败', 'stylesheet'],
|
|
263
271
|
ErrorCode.RESOURCE_FILE_FORMAT_ERROR, "PDF样式错误"),
|
|
264
|
-
(['从数据库查询'
|
|
272
|
+
(['从数据库查询'],
|
|
265
273
|
ErrorCode.INTEGRATION_DB_QUERY, "数据库查询失败"),
|
|
266
274
|
(['excel文件解析', '表格结构检测失败'],
|
|
267
275
|
ErrorCode.RESOURCE_FILE_FORMAT_ERROR, "Excel解析失败"),
|
|
@@ -293,6 +301,8 @@ ERROR_PATTERNS: List[ErrorPattern] = [
|
|
|
293
301
|
ErrorCode.INTEGRATION_DB_CONNECTION, "数据库连接已关闭"),
|
|
294
302
|
(['psycopg2', 'postgresql'],
|
|
295
303
|
ErrorCode.INTEGRATION_DB_QUERY, "数据库错误"),
|
|
304
|
+
(['数据库读取失败'],
|
|
305
|
+
ErrorCode.INTEGRATION_DB_CONNECTION, "数据库连接失败"),
|
|
296
306
|
|
|
297
307
|
# ==================== 网络相关错误 ====================
|
|
298
308
|
(['broken pipe', 'errno 32'],
|
|
@@ -365,6 +375,8 @@ ERROR_PATTERNS: List[ErrorPattern] = [
|
|
|
365
375
|
ErrorCode.API_AUDIO_GEN_FAILED, "腾讯云TTS生成失败"),
|
|
366
376
|
|
|
367
377
|
# ==================== 飞书相关错误 ====================
|
|
378
|
+
(['FeishuBitable API error'],
|
|
379
|
+
ErrorCode.INTEGRATION_FEISHU_TABLE_FAILED, "飞书Bitable API错误"),
|
|
368
380
|
(['获取草稿列表失败'],
|
|
369
381
|
ErrorCode.INTEGRATION_FEISHU_API_FAILED, "飞书获取草稿列表失败"),
|
|
370
382
|
(['飞书api错误'],
|
coze_coding_utils/file/file.py
CHANGED
|
@@ -92,14 +92,6 @@ def infer_file_category(path_or_url: str) -> tuple[str, str]:
|
|
|
92
92
|
class FileOps:
|
|
93
93
|
DOWNLOAD_DIR = "/tmp"
|
|
94
94
|
|
|
95
|
-
@staticmethod
|
|
96
|
-
def read_content(file_obj:File, max_length=10000) -> str:
|
|
97
|
-
return ""
|
|
98
|
-
|
|
99
|
-
@staticmethod
|
|
100
|
-
def get_local_path(file_obj:File) -> str:
|
|
101
|
-
return file_obj.url
|
|
102
|
-
|
|
103
95
|
@staticmethod
|
|
104
96
|
def _get_bytes_stream(file_obj:File) -> tuple[bytes, str]:
|
|
105
97
|
"""
|
|
@@ -0,0 +1,563 @@
|
|
|
1
|
+
import time
|
|
2
|
+
import asyncio
|
|
3
|
+
import threading
|
|
4
|
+
import contextvars
|
|
5
|
+
import logging
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
from typing import Any, Dict, Iterator, AsyncIterable
|
|
8
|
+
from langchain_core.runnables import RunnableConfig
|
|
9
|
+
from langgraph.graph.state import CompiledStateGraph
|
|
10
|
+
from coze_coding_utils.helper.agent_helper import (
|
|
11
|
+
to_stream_input,
|
|
12
|
+
agent_iter_server_messages,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
from coze_coding_utils.error import classify_error
|
|
16
|
+
import asyncio
|
|
17
|
+
import time
|
|
18
|
+
import traceback
|
|
19
|
+
from typing import Any, Dict, AsyncGenerator, Callable
|
|
20
|
+
from coze_coding_utils.runtime_ctx.context import Context
|
|
21
|
+
from coze_coding_utils.messages.server import (
|
|
22
|
+
create_message_end_dict,
|
|
23
|
+
create_message_error_dict,
|
|
24
|
+
MESSAGE_END_CODE_CANCELED,
|
|
25
|
+
)
|
|
26
|
+
from coze_coding_utils.helper.agent_helper import to_client_message
|
|
27
|
+
from coze_coding_utils.error.classifier import ErrorClassifier
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
TIMEOUT_SECONDS = 900
|
|
32
|
+
PING_INTERVAL_SECONDS = 30
|
|
33
|
+
|
|
34
|
+
class WorkflowEventType:
|
|
35
|
+
WORKFLOW_START = "workflow_start"
|
|
36
|
+
WORKFLOW_END = "workflow_end"
|
|
37
|
+
NODE_START = "node_start" # 节点开始事件,只有debug模式发送
|
|
38
|
+
NODE_END = "node_end" # 节点结束事件,只有debug模式发送
|
|
39
|
+
ERROR = "error" # 错误事件
|
|
40
|
+
PING = "ping" # 心跳事件
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class WorkflowErrorCode:
|
|
44
|
+
CANCELED = "CANCELED" # 取消事件
|
|
45
|
+
TIMEOUT = "TIMEOUT" # 超时事件
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class BaseStreamRunner(ABC):
|
|
49
|
+
@abstractmethod
|
|
50
|
+
def stream(self, payload: Dict[str, Any], graph: CompiledStateGraph, run_config: RunnableConfig, ctx: Context) -> Iterator[Any]:
|
|
51
|
+
pass
|
|
52
|
+
|
|
53
|
+
@abstractmethod
|
|
54
|
+
async def astream(self, payload: Dict[str, Any], graph: CompiledStateGraph, run_config: RunnableConfig, ctx: Context) -> AsyncIterable[Any]:
|
|
55
|
+
pass
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class AgentStreamRunner(BaseStreamRunner):
|
|
59
|
+
def stream(self, payload: Dict[str, Any], graph: CompiledStateGraph, run_config: RunnableConfig, ctx: Context) -> Iterator[Any]:
|
|
60
|
+
client_msg, session_id = to_client_message(payload)
|
|
61
|
+
run_config["recursion_limit"] = 100
|
|
62
|
+
run_config["configurable"] = {"thread_id": session_id}
|
|
63
|
+
stream_input = to_stream_input(client_msg)
|
|
64
|
+
t0 = time.time()
|
|
65
|
+
try:
|
|
66
|
+
items = graph.stream(stream_input, stream_mode="messages", config=run_config, context=ctx)
|
|
67
|
+
server_msgs_iter = agent_iter_server_messages(
|
|
68
|
+
items,
|
|
69
|
+
session_id=client_msg.session_id,
|
|
70
|
+
query_msg_id=client_msg.local_msg_id,
|
|
71
|
+
local_msg_id=client_msg.local_msg_id,
|
|
72
|
+
run_id=ctx.run_id,
|
|
73
|
+
log_id=ctx.logid,
|
|
74
|
+
)
|
|
75
|
+
for sm in server_msgs_iter:
|
|
76
|
+
yield sm.dict()
|
|
77
|
+
except asyncio.CancelledError:
|
|
78
|
+
logger.info(f"Stream cancelled for run_id: {ctx.run_id}")
|
|
79
|
+
end_msg = create_message_end_dict(
|
|
80
|
+
code=MESSAGE_END_CODE_CANCELED,
|
|
81
|
+
message="Stream execution cancelled",
|
|
82
|
+
session_id=client_msg.session_id,
|
|
83
|
+
query_msg_id=client_msg.local_msg_id,
|
|
84
|
+
log_id=ctx.logid,
|
|
85
|
+
time_cost_ms=int((time.time() - t0) * 1000),
|
|
86
|
+
reply_id="",
|
|
87
|
+
sequence_id=1,
|
|
88
|
+
)
|
|
89
|
+
yield end_msg
|
|
90
|
+
raise
|
|
91
|
+
except Exception as ex:
|
|
92
|
+
err = classify_error(ex, {"node_name": "stream"})
|
|
93
|
+
end_msg = create_message_error_dict(
|
|
94
|
+
code=str(err.code),
|
|
95
|
+
message=err.message,
|
|
96
|
+
session_id=client_msg.session_id,
|
|
97
|
+
query_msg_id=client_msg.local_msg_id,
|
|
98
|
+
log_id=ctx.logid,
|
|
99
|
+
reply_id="",
|
|
100
|
+
sequence_id=1,
|
|
101
|
+
)
|
|
102
|
+
yield end_msg
|
|
103
|
+
|
|
104
|
+
async def astream(self, payload: Dict[str, Any], graph: CompiledStateGraph, run_config: RunnableConfig, ctx: Context) -> AsyncIterable[Any]:
|
|
105
|
+
client_msg, session_id = to_client_message(payload)
|
|
106
|
+
run_config["recursion_limit"] = 100
|
|
107
|
+
run_config["configurable"] = {"thread_id": session_id}
|
|
108
|
+
stream_input = to_stream_input(client_msg)
|
|
109
|
+
|
|
110
|
+
loop = asyncio.get_running_loop()
|
|
111
|
+
q: asyncio.Queue = asyncio.Queue()
|
|
112
|
+
context = contextvars.copy_context()
|
|
113
|
+
start_time = time.time()
|
|
114
|
+
cancelled = threading.Event()
|
|
115
|
+
|
|
116
|
+
def producer():
|
|
117
|
+
last_seq = 0
|
|
118
|
+
try:
|
|
119
|
+
if cancelled.is_set():
|
|
120
|
+
logger.info(f"Producer cancelled before start for run_id: {ctx.run_id}")
|
|
121
|
+
return
|
|
122
|
+
|
|
123
|
+
items = graph.stream(stream_input, stream_mode="messages", config=run_config, context=ctx)
|
|
124
|
+
server_msgs_iter = agent_iter_server_messages(
|
|
125
|
+
items,
|
|
126
|
+
session_id=client_msg.session_id,
|
|
127
|
+
query_msg_id=client_msg.local_msg_id,
|
|
128
|
+
local_msg_id=client_msg.local_msg_id,
|
|
129
|
+
run_id=ctx.run_id,
|
|
130
|
+
log_id=ctx.logid,
|
|
131
|
+
)
|
|
132
|
+
for sm in server_msgs_iter:
|
|
133
|
+
if cancelled.is_set():
|
|
134
|
+
logger.info(f"Producer cancelled during iteration for run_id: {ctx.run_id}")
|
|
135
|
+
cancel_msg = create_message_end_dict(
|
|
136
|
+
code=MESSAGE_END_CODE_CANCELED,
|
|
137
|
+
message="Stream cancelled by upstream",
|
|
138
|
+
session_id=client_msg.session_id,
|
|
139
|
+
query_msg_id=client_msg.local_msg_id,
|
|
140
|
+
log_id=ctx.logid,
|
|
141
|
+
time_cost_ms=int((time.time() - start_time) * 1000),
|
|
142
|
+
reply_id=getattr(sm, 'reply_id', ''),
|
|
143
|
+
sequence_id=last_seq + 1,
|
|
144
|
+
)
|
|
145
|
+
loop.call_soon_threadsafe(q.put_nowait, cancel_msg)
|
|
146
|
+
return
|
|
147
|
+
|
|
148
|
+
if time.time() - start_time > TIMEOUT_SECONDS:
|
|
149
|
+
logger.error(f"Agent execution timeout after {TIMEOUT_SECONDS}s for run_id: {ctx.run_id}")
|
|
150
|
+
timeout_msg = create_message_end_dict(
|
|
151
|
+
code="TIMEOUT",
|
|
152
|
+
message=f"Execution timeout: exceeded {TIMEOUT_SECONDS} seconds",
|
|
153
|
+
session_id=client_msg.session_id,
|
|
154
|
+
query_msg_id=client_msg.local_msg_id,
|
|
155
|
+
log_id=ctx.logid,
|
|
156
|
+
time_cost_ms=int((time.time() - start_time) * 1000),
|
|
157
|
+
reply_id=getattr(sm, 'reply_id', ''),
|
|
158
|
+
sequence_id=last_seq + 1,
|
|
159
|
+
)
|
|
160
|
+
loop.call_soon_threadsafe(q.put_nowait, timeout_msg)
|
|
161
|
+
return
|
|
162
|
+
loop.call_soon_threadsafe(q.put_nowait, sm.dict())
|
|
163
|
+
last_seq = sm.sequence_id
|
|
164
|
+
except Exception as ex:
|
|
165
|
+
if cancelled.is_set():
|
|
166
|
+
logger.info(f"Producer exception after cancel for run_id: {ctx.run_id}, ignoring: {ex}")
|
|
167
|
+
return
|
|
168
|
+
err = classify_error(ex, {"node_name": "astream"})
|
|
169
|
+
end_msg = create_message_error_dict(
|
|
170
|
+
code=str(err.code),
|
|
171
|
+
message=err.message,
|
|
172
|
+
session_id=client_msg.session_id,
|
|
173
|
+
query_msg_id=client_msg.local_msg_id,
|
|
174
|
+
log_id=ctx.logid,
|
|
175
|
+
reply_id="",
|
|
176
|
+
sequence_id=last_seq + 1,
|
|
177
|
+
)
|
|
178
|
+
loop.call_soon_threadsafe(q.put_nowait, end_msg)
|
|
179
|
+
finally:
|
|
180
|
+
loop.call_soon_threadsafe(q.put_nowait, None)
|
|
181
|
+
|
|
182
|
+
threading.Thread(target=lambda: context.run(producer), daemon=True).start()
|
|
183
|
+
|
|
184
|
+
try:
|
|
185
|
+
while True:
|
|
186
|
+
item = await q.get()
|
|
187
|
+
if item is None:
|
|
188
|
+
break
|
|
189
|
+
yield item
|
|
190
|
+
except asyncio.CancelledError:
|
|
191
|
+
logger.info(f"Stream cancelled for run_id: {ctx.run_id}, signaling producer to stop")
|
|
192
|
+
cancelled.set()
|
|
193
|
+
raise
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
class WorkflowStreamRunner(BaseStreamRunner):
|
|
197
|
+
def __init__(self):
|
|
198
|
+
self._node_start_times: Dict[str, float] = {}
|
|
199
|
+
|
|
200
|
+
def _serialize_data(self, data: Any) -> Any:
|
|
201
|
+
if isinstance(data, dict):
|
|
202
|
+
return {k: self._serialize_data(v) for k, v in data.items()}
|
|
203
|
+
elif isinstance(data, (list, tuple)):
|
|
204
|
+
return [self._serialize_data(item) for item in data]
|
|
205
|
+
elif hasattr(data, 'model_dump'):
|
|
206
|
+
return data.model_dump()
|
|
207
|
+
elif hasattr(data, 'dict'):
|
|
208
|
+
return data.dict()
|
|
209
|
+
elif hasattr(data, '__dict__'):
|
|
210
|
+
return {k: self._serialize_data(v) for k, v in data.__dict__.items() if not k.startswith('_')}
|
|
211
|
+
else:
|
|
212
|
+
return data
|
|
213
|
+
|
|
214
|
+
def _build_event(self, event_type: str, ctx: Context, **kwargs) -> Dict[str, Any]:
|
|
215
|
+
result = {
|
|
216
|
+
"type": event_type,
|
|
217
|
+
"timestamp": int(time.time() * 1000),
|
|
218
|
+
"log_id": ctx.logid,
|
|
219
|
+
"run_id": ctx.run_id,
|
|
220
|
+
}
|
|
221
|
+
result.update(kwargs)
|
|
222
|
+
return result
|
|
223
|
+
|
|
224
|
+
def stream(self, payload: Dict[str, Any], graph: CompiledStateGraph, run_config: RunnableConfig, ctx: Context) -> Iterator[Any]:
|
|
225
|
+
run_config["recursion_limit"] = 100
|
|
226
|
+
if "configurable" not in run_config:
|
|
227
|
+
run_config["configurable"] = {}
|
|
228
|
+
run_config["configurable"]["thread_id"] = ctx.run_id
|
|
229
|
+
|
|
230
|
+
t0 = time.time()
|
|
231
|
+
last_ping_time = t0
|
|
232
|
+
node_start_times: Dict[str, float] = {}
|
|
233
|
+
final_output = {}
|
|
234
|
+
seq = 0
|
|
235
|
+
is_debug = run_config.get("configurable", {}).get("workflow_debug", False)
|
|
236
|
+
stream_mode = "debug" if is_debug else "updates"
|
|
237
|
+
|
|
238
|
+
try:
|
|
239
|
+
seq += 1
|
|
240
|
+
yield (seq, self._build_event(WorkflowEventType.WORKFLOW_START, ctx))
|
|
241
|
+
|
|
242
|
+
for event in graph.stream(payload, stream_mode=stream_mode, config=run_config, context=ctx):
|
|
243
|
+
current_time = time.time()
|
|
244
|
+
if current_time - last_ping_time >= PING_INTERVAL_SECONDS:
|
|
245
|
+
seq += 1
|
|
246
|
+
yield (seq, self._build_event(WorkflowEventType.PING, ctx))
|
|
247
|
+
last_ping_time = current_time
|
|
248
|
+
|
|
249
|
+
if not is_debug:
|
|
250
|
+
if isinstance(event, dict):
|
|
251
|
+
logger.info(f"Debug event: {event}")
|
|
252
|
+
for node_name, node_output in event.items():
|
|
253
|
+
final_output = self._serialize_data(node_output) if node_output else {}
|
|
254
|
+
continue
|
|
255
|
+
|
|
256
|
+
event_type = event.get("type", "")
|
|
257
|
+
|
|
258
|
+
if event_type == "task":
|
|
259
|
+
node_name = event.get("payload", {}).get("name", "")
|
|
260
|
+
node_start_times[node_name] = current_time
|
|
261
|
+
|
|
262
|
+
input_data = event.get("payload", {}).get("input", {})
|
|
263
|
+
seq += 1
|
|
264
|
+
yield (seq, self._build_event(
|
|
265
|
+
WorkflowEventType.NODE_START,
|
|
266
|
+
ctx,
|
|
267
|
+
node_name=node_name,
|
|
268
|
+
input=self._serialize_data(input_data),
|
|
269
|
+
))
|
|
270
|
+
|
|
271
|
+
elif event_type == "task_result":
|
|
272
|
+
node_name = event.get("payload", {}).get("name", "")
|
|
273
|
+
result = event.get("payload", {}).get("result")
|
|
274
|
+
|
|
275
|
+
output_data = {}
|
|
276
|
+
if result is not None:
|
|
277
|
+
if isinstance(result, (list, tuple)) and len(result) > 0:
|
|
278
|
+
output_data = self._serialize_data(result[0]) if len(result) == 1 else {"results": [self._serialize_data(r) for r in result]}
|
|
279
|
+
else:
|
|
280
|
+
output_data = self._serialize_data(result)
|
|
281
|
+
|
|
282
|
+
final_output = output_data
|
|
283
|
+
|
|
284
|
+
node_start_time = node_start_times.pop(node_name, current_time)
|
|
285
|
+
time_cost_ms = int((current_time - node_start_time) * 1000)
|
|
286
|
+
|
|
287
|
+
seq += 1
|
|
288
|
+
yield (seq, self._build_event(
|
|
289
|
+
WorkflowEventType.NODE_END,
|
|
290
|
+
ctx,
|
|
291
|
+
node_name=node_name,
|
|
292
|
+
output=output_data,
|
|
293
|
+
time_cost_ms=time_cost_ms,
|
|
294
|
+
))
|
|
295
|
+
|
|
296
|
+
seq += 1
|
|
297
|
+
yield (seq, self._build_event(
|
|
298
|
+
WorkflowEventType.WORKFLOW_END,
|
|
299
|
+
ctx,
|
|
300
|
+
output=final_output,
|
|
301
|
+
time_cost_ms=int((time.time() - t0) * 1000),
|
|
302
|
+
))
|
|
303
|
+
|
|
304
|
+
except asyncio.CancelledError:
|
|
305
|
+
logger.info(f"Workflow stream cancelled for run_id: {ctx.run_id}")
|
|
306
|
+
seq += 1
|
|
307
|
+
yield (seq, self._build_event(WorkflowEventType.ERROR, ctx, code=WorkflowErrorCode.CANCELED, message="Stream execution cancelled"))
|
|
308
|
+
raise
|
|
309
|
+
except Exception as ex:
|
|
310
|
+
err = classify_error(ex, {"node_name": "workflow_stream"})
|
|
311
|
+
seq += 1
|
|
312
|
+
yield (seq, self._build_event(WorkflowEventType.ERROR, ctx, code=str(err.code), error_msg=err.message))
|
|
313
|
+
|
|
314
|
+
async def astream(self, payload: Dict[str, Any], graph: CompiledStateGraph, run_config: RunnableConfig, ctx: Context) -> AsyncIterable[Any]:
|
|
315
|
+
run_config["recursion_limit"] = 100
|
|
316
|
+
if "configurable" not in run_config:
|
|
317
|
+
run_config["configurable"] = {}
|
|
318
|
+
run_config["configurable"]["thread_id"] = ctx.run_id
|
|
319
|
+
|
|
320
|
+
loop = asyncio.get_running_loop()
|
|
321
|
+
q: asyncio.Queue = asyncio.Queue()
|
|
322
|
+
context = contextvars.copy_context()
|
|
323
|
+
start_time = time.time()
|
|
324
|
+
cancelled = threading.Event()
|
|
325
|
+
last_ping_time = [start_time]
|
|
326
|
+
is_debug = run_config.get("configurable", {}).get("workflow_debug", False)
|
|
327
|
+
stream_mode = "debug" if is_debug else "updates"
|
|
328
|
+
logger.info(f"Stream mode: {stream_mode}")
|
|
329
|
+
seq = [0]
|
|
330
|
+
|
|
331
|
+
def producer():
|
|
332
|
+
node_start_times: Dict[str, float] = {}
|
|
333
|
+
final_output = {}
|
|
334
|
+
try:
|
|
335
|
+
if cancelled.is_set():
|
|
336
|
+
logger.info(f"Workflow producer cancelled before start for run_id: {ctx.run_id}")
|
|
337
|
+
return
|
|
338
|
+
|
|
339
|
+
seq[0] += 1
|
|
340
|
+
loop.call_soon_threadsafe(q.put_nowait, (seq[0], self._build_event(WorkflowEventType.WORKFLOW_START, ctx)))
|
|
341
|
+
|
|
342
|
+
for event in graph.stream(payload, stream_mode=stream_mode, config=run_config, context=ctx):
|
|
343
|
+
if cancelled.is_set():
|
|
344
|
+
logger.info(f"Workflow producer cancelled during iteration for run_id: {ctx.run_id}")
|
|
345
|
+
seq[0] += 1
|
|
346
|
+
loop.call_soon_threadsafe(q.put_nowait, (seq[0], self._build_event(WorkflowEventType.ERROR, ctx, code=WorkflowErrorCode.CANCELED, message="Stream cancelled by upstream")))
|
|
347
|
+
return
|
|
348
|
+
|
|
349
|
+
if time.time() - start_time > TIMEOUT_SECONDS:
|
|
350
|
+
logger.error(f"Workflow execution timeout after {TIMEOUT_SECONDS}s for run_id: {ctx.run_id}")
|
|
351
|
+
seq[0] += 1
|
|
352
|
+
loop.call_soon_threadsafe(q.put_nowait, (seq[0], self._build_event(WorkflowEventType.ERROR, ctx, code=WorkflowErrorCode.TIMEOUT, message=f"Execution timeout: exceeded {TIMEOUT_SECONDS} seconds")))
|
|
353
|
+
return
|
|
354
|
+
|
|
355
|
+
current_time = time.time()
|
|
356
|
+
if current_time - last_ping_time[0] >= PING_INTERVAL_SECONDS:
|
|
357
|
+
seq[0] += 1
|
|
358
|
+
loop.call_soon_threadsafe(q.put_nowait, (seq[0], self._build_event(WorkflowEventType.PING, ctx)))
|
|
359
|
+
last_ping_time[0] = current_time
|
|
360
|
+
|
|
361
|
+
if not is_debug:
|
|
362
|
+
if isinstance(event, dict):
|
|
363
|
+
for node_name, node_output in event.items():
|
|
364
|
+
logger.info(f"Node output: {node_name}")
|
|
365
|
+
final_output = self._serialize_data(node_output) if node_output else {}
|
|
366
|
+
continue
|
|
367
|
+
|
|
368
|
+
event_type = event.get("type", "")
|
|
369
|
+
|
|
370
|
+
if event_type == "task":
|
|
371
|
+
node_name = event.get("payload", {}).get("name", "")
|
|
372
|
+
node_start_times[node_name] = current_time
|
|
373
|
+
|
|
374
|
+
input_data = event.get("payload", {}).get("input", {})
|
|
375
|
+
seq[0] += 1
|
|
376
|
+
loop.call_soon_threadsafe(q.put_nowait, (seq[0], self._build_event(
|
|
377
|
+
WorkflowEventType.NODE_START,
|
|
378
|
+
ctx,
|
|
379
|
+
node_name=node_name,
|
|
380
|
+
input=self._serialize_data(input_data),
|
|
381
|
+
)))
|
|
382
|
+
|
|
383
|
+
elif event_type == "task_result":
|
|
384
|
+
node_name = event.get("payload", {}).get("name", "")
|
|
385
|
+
result = event.get("payload", {}).get("result")
|
|
386
|
+
|
|
387
|
+
output_data = {}
|
|
388
|
+
if result is not None:
|
|
389
|
+
if isinstance(result, (list, tuple)) and len(result) > 0:
|
|
390
|
+
output_data = self._serialize_data(result[0]) if len(result) == 1 else {"results": [self._serialize_data(r) for r in result]}
|
|
391
|
+
else:
|
|
392
|
+
output_data = self._serialize_data(result)
|
|
393
|
+
|
|
394
|
+
final_output = output_data
|
|
395
|
+
|
|
396
|
+
node_start_time = node_start_times.pop(node_name, current_time)
|
|
397
|
+
time_cost_ms = int((current_time - node_start_time) * 1000)
|
|
398
|
+
|
|
399
|
+
seq[0] += 1
|
|
400
|
+
loop.call_soon_threadsafe(q.put_nowait, (seq[0], self._build_event(
|
|
401
|
+
WorkflowEventType.NODE_END,
|
|
402
|
+
ctx,
|
|
403
|
+
node_name=node_name,
|
|
404
|
+
output=output_data,
|
|
405
|
+
time_cost_ms=time_cost_ms,
|
|
406
|
+
)))
|
|
407
|
+
|
|
408
|
+
seq[0] += 1
|
|
409
|
+
loop.call_soon_threadsafe(q.put_nowait, (seq[0], self._build_event(
|
|
410
|
+
WorkflowEventType.WORKFLOW_END,
|
|
411
|
+
ctx,
|
|
412
|
+
output=final_output,
|
|
413
|
+
time_cost_ms=int((time.time() - start_time) * 1000),
|
|
414
|
+
)))
|
|
415
|
+
|
|
416
|
+
except Exception as ex:
|
|
417
|
+
if cancelled.is_set():
|
|
418
|
+
logger.info(f"Workflow producer exception after cancel for run_id: {ctx.run_id}, ignoring: {ex}")
|
|
419
|
+
return
|
|
420
|
+
err = classify_error(ex, {"node_name": "workflow_astream"})
|
|
421
|
+
seq[0] += 1
|
|
422
|
+
loop.call_soon_threadsafe(q.put_nowait, (seq[0], self._build_event(WorkflowEventType.ERROR, ctx, code=str(err.code), error_msg=err.message)))
|
|
423
|
+
finally:
|
|
424
|
+
loop.call_soon_threadsafe(q.put_nowait, None)
|
|
425
|
+
|
|
426
|
+
async def ping_sender():
|
|
427
|
+
while not cancelled.is_set():
|
|
428
|
+
await asyncio.sleep(PING_INTERVAL_SECONDS)
|
|
429
|
+
if cancelled.is_set():
|
|
430
|
+
break
|
|
431
|
+
current_time = time.time()
|
|
432
|
+
if current_time - last_ping_time[0] >= PING_INTERVAL_SECONDS:
|
|
433
|
+
seq[0] += 1
|
|
434
|
+
await q.put((seq[0], self._build_event(WorkflowEventType.PING, ctx)))
|
|
435
|
+
last_ping_time[0] = current_time
|
|
436
|
+
|
|
437
|
+
threading.Thread(target=lambda: context.run(producer), daemon=True).start()
|
|
438
|
+
ping_task = asyncio.create_task(ping_sender())
|
|
439
|
+
|
|
440
|
+
try:
|
|
441
|
+
while True:
|
|
442
|
+
item = await q.get()
|
|
443
|
+
if item is None:
|
|
444
|
+
break
|
|
445
|
+
yield item
|
|
446
|
+
except asyncio.CancelledError:
|
|
447
|
+
logger.info(f"Workflow stream cancelled for run_id: {ctx.run_id}, signaling producer to stop")
|
|
448
|
+
cancelled.set()
|
|
449
|
+
raise
|
|
450
|
+
finally:
|
|
451
|
+
cancelled.set()
|
|
452
|
+
ping_task.cancel()
|
|
453
|
+
try:
|
|
454
|
+
await ping_task
|
|
455
|
+
except asyncio.CancelledError:
|
|
456
|
+
pass
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
async def agent_stream_handler(
|
|
460
|
+
payload: Dict[str, Any],
|
|
461
|
+
ctx: Context,
|
|
462
|
+
run_id: str,
|
|
463
|
+
stream_sse_func: Callable,
|
|
464
|
+
sse_event_func: Callable,
|
|
465
|
+
error_classifier: ErrorClassifier,
|
|
466
|
+
register_task_func: Callable[[str, asyncio.Task], None],
|
|
467
|
+
) -> AsyncGenerator[str, None]:
|
|
468
|
+
task = asyncio.current_task()
|
|
469
|
+
if task:
|
|
470
|
+
register_task_func(run_id, task)
|
|
471
|
+
logger.info(f"Registered agent streaming task for run_id: {run_id}")
|
|
472
|
+
|
|
473
|
+
client_msg, _ = to_client_message(payload)
|
|
474
|
+
t0 = time.time()
|
|
475
|
+
|
|
476
|
+
try:
|
|
477
|
+
async for chunk in stream_sse_func(payload, ctx, need_detail=False):
|
|
478
|
+
yield chunk
|
|
479
|
+
except asyncio.CancelledError:
|
|
480
|
+
logger.info(f"Agent stream cancelled for run_id: {run_id}")
|
|
481
|
+
end_msg = create_message_end_dict(
|
|
482
|
+
code=MESSAGE_END_CODE_CANCELED,
|
|
483
|
+
message="Stream cancelled by user",
|
|
484
|
+
session_id=client_msg.session_id,
|
|
485
|
+
query_msg_id=client_msg.local_msg_id,
|
|
486
|
+
log_id=ctx.logid,
|
|
487
|
+
time_cost_ms=int((time.time() - t0) * 1000),
|
|
488
|
+
reply_id="",
|
|
489
|
+
sequence_id=1,
|
|
490
|
+
)
|
|
491
|
+
yield sse_event_func(end_msg)
|
|
492
|
+
raise
|
|
493
|
+
except Exception as ex:
|
|
494
|
+
err = error_classifier.classify(ex, {"node_name": "agent_stream", "run_id": run_id})
|
|
495
|
+
logger.error(
|
|
496
|
+
f"Unexpected error in agent_stream: [{err.code}] {err.message}, "
|
|
497
|
+
f"traceback: {traceback.format_exc()}"
|
|
498
|
+
)
|
|
499
|
+
error_msg = create_message_error_dict(
|
|
500
|
+
code=str(err.code),
|
|
501
|
+
message=str(ex),
|
|
502
|
+
session_id=client_msg.session_id,
|
|
503
|
+
query_msg_id=client_msg.local_msg_id,
|
|
504
|
+
log_id=ctx.logid,
|
|
505
|
+
reply_id="",
|
|
506
|
+
sequence_id=1,
|
|
507
|
+
local_msg_id=client_msg.local_msg_id,
|
|
508
|
+
)
|
|
509
|
+
yield sse_event_func(error_msg)
|
|
510
|
+
|
|
511
|
+
|
|
512
|
+
async def workflow_stream_handler(
|
|
513
|
+
payload: Dict[str, Any],
|
|
514
|
+
ctx: Context,
|
|
515
|
+
run_id: str,
|
|
516
|
+
stream_sse_func: Callable,
|
|
517
|
+
sse_event_func: Callable,
|
|
518
|
+
error_classifier: ErrorClassifier,
|
|
519
|
+
register_task_func: Callable[[str, asyncio.Task], None],
|
|
520
|
+
workflow_debug: bool = False,
|
|
521
|
+
) -> AsyncGenerator[str, None]:
|
|
522
|
+
task = asyncio.current_task()
|
|
523
|
+
if task:
|
|
524
|
+
register_task_func(run_id, task)
|
|
525
|
+
logger.info(f"Registered workflow streaming task for run_id: {run_id}")
|
|
526
|
+
|
|
527
|
+
try:
|
|
528
|
+
async for chunk in stream_sse_func(payload, ctx, need_detail=workflow_debug):
|
|
529
|
+
yield chunk
|
|
530
|
+
except asyncio.CancelledError:
|
|
531
|
+
logger.info(f"Workflow stream cancelled for run_id: {run_id}")
|
|
532
|
+
cancel_event = {
|
|
533
|
+
"type": "error",
|
|
534
|
+
"timestamp": int(time.time() * 1000),
|
|
535
|
+
"log_id": ctx.logid,
|
|
536
|
+
"run_id": run_id,
|
|
537
|
+
"code": "CANCELED",
|
|
538
|
+
"error_msg": "Stream cancelled by user",
|
|
539
|
+
}
|
|
540
|
+
yield sse_event_func(cancel_event)
|
|
541
|
+
raise
|
|
542
|
+
except Exception as ex:
|
|
543
|
+
err = error_classifier.classify(ex, {"node_name": "workflow_stream", "run_id": run_id})
|
|
544
|
+
logger.error(
|
|
545
|
+
f"Unexpected error in workflow_stream: [{err.code}] {err.message}, "
|
|
546
|
+
f"traceback: {traceback.format_exc()}"
|
|
547
|
+
)
|
|
548
|
+
error_event = {
|
|
549
|
+
"type": "error",
|
|
550
|
+
"timestamp": int(time.time() * 1000),
|
|
551
|
+
"log_id": ctx.logid,
|
|
552
|
+
"run_id": run_id,
|
|
553
|
+
"code": str(err.code),
|
|
554
|
+
"error_msg": str(ex),
|
|
555
|
+
}
|
|
556
|
+
yield sse_event_func(error_event)
|
|
557
|
+
|
|
558
|
+
|
|
559
|
+
def get_stream_runner(is_agent: bool) -> BaseStreamRunner:
|
|
560
|
+
if is_agent:
|
|
561
|
+
return AgentStreamRunner()
|
|
562
|
+
else:
|
|
563
|
+
return WorkflowStreamRunner()
|
|
@@ -1,15 +1,16 @@
|
|
|
1
1
|
coze_coding_utils/__init__.py,sha256=OIMKOQLy07Uo5wQkLw3D7j6qRKt4o-smdW-dndYhpHo,37
|
|
2
2
|
coze_coding_utils/error/__init__.py,sha256=SbhsopZ8ZQsbXKZ-GPsw3Fq8AQAOC8W6bZgUZhIOw_k,886
|
|
3
3
|
coze_coding_utils/error/classifier.py,sha256=uXVmufL_sn4w7oNyvrEFXSI_8mCi4mXY353UK5d-d0Y,10028
|
|
4
|
-
coze_coding_utils/error/codes.py,sha256=
|
|
4
|
+
coze_coding_utils/error/codes.py,sha256=IdSRHoWlwaIzfzUswmjT_lGS04_RHaHjSJUbV2DIhEA,17162
|
|
5
5
|
coze_coding_utils/error/exceptions.py,sha256=QjGk56ovGG-2V4gHcTeJq3-3ZIQQ8DF692zgIYcEJxI,17074
|
|
6
|
-
coze_coding_utils/error/patterns.py,sha256=
|
|
6
|
+
coze_coding_utils/error/patterns.py,sha256=_Z_CtsiVng6dQnqWQwYxKrZm_DrLNYL8tw5_oO5I8x8,44871
|
|
7
7
|
coze_coding_utils/error/test_classifier.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
8
8
|
coze_coding_utils/file/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
9
|
-
coze_coding_utils/file/file.py,sha256=
|
|
9
|
+
coze_coding_utils/file/file.py,sha256=fBda18EGSQZ3Xl8OqEaGAb5Rd90_SmhJ1k0jgQk2v7Y,11636
|
|
10
10
|
coze_coding_utils/helper/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
11
11
|
coze_coding_utils/helper/agent_helper.py,sha256=q1ZM30xLXoW-m0NJmJ_Y0M-kUAQCBstG_j7xkqsyRSU,22546
|
|
12
12
|
coze_coding_utils/helper/graph_helper.py,sha256=UNtqqiQNAQ4319qcC1vHiLYIL2eGzvGQRgXu3mgLq8Y,8893
|
|
13
|
+
coze_coding_utils/helper/stream_runner.py,sha256=f66n6QJ3zCakhk7Fe4Vz9vTZ2KJuM9v9UJfqX5S3nDA,24050
|
|
13
14
|
coze_coding_utils/log/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
14
15
|
coze_coding_utils/log/common.py,sha256=mUNkCm68oaPaI6-a5UwLf87AfhrMnVPkEuri16guqKc,168
|
|
15
16
|
coze_coding_utils/log/config.py,sha256=Qkw3JRuGUKJ6CBY7WqHJOFeyCU47cArvUtMsSBifFMo,195
|
|
@@ -31,7 +32,7 @@ coze_coding_utils/openai/types/request.py,sha256=IuNMT2Ce1--_32R30Q2q7Lb2dAwKNy3
|
|
|
31
32
|
coze_coding_utils/openai/types/response.py,sha256=pjHHVR8LSMVFCc3fGzKqXrdoKDIfSCJEfICd_X9Nohc,4808
|
|
32
33
|
coze_coding_utils/runtime_ctx/__init__.py,sha256=4W8VliAYUP1KY2gLJ_YDy2TmcXYVm-PY7XikQD_bFwA,2
|
|
33
34
|
coze_coding_utils/runtime_ctx/context.py,sha256=G8ld-WnQ1pTJe5OOXC_dTbagXj9IxmpRiPM4X_jWW6o,3992
|
|
34
|
-
coze_coding_utils-0.2.
|
|
35
|
-
coze_coding_utils-0.2.
|
|
36
|
-
coze_coding_utils-0.2.
|
|
37
|
-
coze_coding_utils-0.2.
|
|
35
|
+
coze_coding_utils-0.2.3.dist-info/METADATA,sha256=-V3YwNo9i5Z6nlEz92Gy-l_Kg_X4yZ01M4_CtgsEjG4,977
|
|
36
|
+
coze_coding_utils-0.2.3.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
37
|
+
coze_coding_utils-0.2.3.dist-info/licenses/LICENSE,sha256=lzckZhAjHlpSJcWvppoST095IHFpBwKiB2pKcBv7vP4,1078
|
|
38
|
+
coze_coding_utils-0.2.3.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|