streamlet-py 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
streamlet/__init__.py
ADDED
|
@@ -0,0 +1,1233 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import functools
|
|
3
|
+
import inspect
|
|
4
|
+
import logging
|
|
5
|
+
import time
|
|
6
|
+
from collections.abc import Callable
|
|
7
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
8
|
+
from contextvars import ContextVar
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
from dependency_injector import containers, providers
|
|
13
|
+
from dependency_injector.wiring import inject
|
|
14
|
+
from pydantic import ConfigDict, TypeAdapter, ValidationError, validate_call
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger("streamlet")
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class ParallelResult:
|
|
21
|
+
"""Pydantic model for recording parallel execution results and exception stacks."""
|
|
22
|
+
|
|
23
|
+
node_name: str
|
|
24
|
+
success: bool
|
|
25
|
+
result: Any = None
|
|
26
|
+
error: str | None = None
|
|
27
|
+
error_traceback: str | None = None
|
|
28
|
+
execution_time: float | None = None
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
# ==================== 异常类型体系 ====================
|
|
32
|
+
class StreamletException(Exception):
|
|
33
|
+
"""Streamlet框架基础异常类"""
|
|
34
|
+
|
|
35
|
+
retryable = False # 默认框架异常不重试
|
|
36
|
+
|
|
37
|
+
def __init__(
|
|
38
|
+
self, message: str, node_name: str | None = None, **kwargs: Any
|
|
39
|
+
) -> None:
|
|
40
|
+
self.node_name = node_name
|
|
41
|
+
self.context = kwargs
|
|
42
|
+
super().__init__(message)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class ValidationInputException(StreamletException):
|
|
46
|
+
"""参数验证异常 - validate_call前置校验失败"""
|
|
47
|
+
|
|
48
|
+
retryable = False # 参数验证失败不应该重试
|
|
49
|
+
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
message: str,
|
|
53
|
+
validation_error: Any = None,
|
|
54
|
+
node_name: str | None = None,
|
|
55
|
+
**kwargs: Any,
|
|
56
|
+
) -> None:
|
|
57
|
+
self.validation_error = validation_error
|
|
58
|
+
super().__init__(message, node_name, **kwargs)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class ValidationOutputException(StreamletException):
|
|
62
|
+
"""返回值验证异常 - validate_call返回值校验失败"""
|
|
63
|
+
|
|
64
|
+
retryable = False # 返回值验证失败不应该重试
|
|
65
|
+
|
|
66
|
+
def __init__(
|
|
67
|
+
self,
|
|
68
|
+
message: str,
|
|
69
|
+
validation_error: Any = None,
|
|
70
|
+
node_name: str | None = None,
|
|
71
|
+
**kwargs: Any,
|
|
72
|
+
) -> None:
|
|
73
|
+
self.validation_error = validation_error
|
|
74
|
+
super().__init__(message, node_name, **kwargs)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class UserBusinessException(StreamletException):
|
|
78
|
+
"""用户业务异常基类 - 用户可自定义重试策略"""
|
|
79
|
+
|
|
80
|
+
retryable = True # 默认用户业务异常可重试
|
|
81
|
+
|
|
82
|
+
def __init__(
|
|
83
|
+
self,
|
|
84
|
+
message: str,
|
|
85
|
+
retryable: bool | None = None,
|
|
86
|
+
node_name: str | None = None,
|
|
87
|
+
**kwargs: Any,
|
|
88
|
+
) -> None:
|
|
89
|
+
# 允许用户在实例化时覆盖重试策略
|
|
90
|
+
if retryable is not None:
|
|
91
|
+
self.retryable = retryable
|
|
92
|
+
super().__init__(message, node_name, **kwargs)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class NodeExecutionException(StreamletException):
|
|
96
|
+
"""节点执行异常"""
|
|
97
|
+
|
|
98
|
+
def __init__(
|
|
99
|
+
self,
|
|
100
|
+
message: str,
|
|
101
|
+
node_name: str | None = None,
|
|
102
|
+
original_exception: Exception | None = None,
|
|
103
|
+
**kwargs: Any,
|
|
104
|
+
) -> None:
|
|
105
|
+
self.original_exception = original_exception
|
|
106
|
+
super().__init__(message, node_name, **kwargs)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class NodeTimeoutException(NodeExecutionException):
|
|
110
|
+
"""节点执行超时异常"""
|
|
111
|
+
|
|
112
|
+
def __init__(
|
|
113
|
+
self,
|
|
114
|
+
message: str,
|
|
115
|
+
node_name: str | None = None,
|
|
116
|
+
timeout_seconds: float | None = None,
|
|
117
|
+
**kwargs: Any,
|
|
118
|
+
) -> None:
|
|
119
|
+
self.timeout_seconds = timeout_seconds
|
|
120
|
+
super().__init__(message, node_name, **kwargs)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
class NodeRetryExhaustedException(NodeExecutionException):
|
|
124
|
+
"""节点重试次数耗尽异常"""
|
|
125
|
+
|
|
126
|
+
def __init__(
|
|
127
|
+
self,
|
|
128
|
+
message: str,
|
|
129
|
+
node_name: str | None = None,
|
|
130
|
+
retry_count: int | None = None,
|
|
131
|
+
last_exception: Exception | None = None,
|
|
132
|
+
**kwargs: Any,
|
|
133
|
+
) -> None:
|
|
134
|
+
self.retry_count = retry_count
|
|
135
|
+
self.last_exception = last_exception
|
|
136
|
+
super().__init__(
|
|
137
|
+
message, node_name, original_exception=last_exception, **kwargs
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
class LoopControlException(StreamletException):
|
|
142
|
+
"""循环控制异常基类"""
|
|
143
|
+
|
|
144
|
+
pass
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
# ==================== 重试装饰器 ====================
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
class RetryConfig:
|
|
151
|
+
"""重试配置类"""
|
|
152
|
+
|
|
153
|
+
def __init__(
|
|
154
|
+
self,
|
|
155
|
+
retry_count: int = 3,
|
|
156
|
+
retry_delay: float = 1.0,
|
|
157
|
+
exception_types: tuple = (Exception,),
|
|
158
|
+
backoff_factor: float = 1.0,
|
|
159
|
+
max_delay: float = 60.0,
|
|
160
|
+
):
|
|
161
|
+
if retry_count < 0:
|
|
162
|
+
raise ValueError(f"retry_count must be >= 0, got {retry_count}")
|
|
163
|
+
if retry_delay < 0:
|
|
164
|
+
raise ValueError(f"retry_delay must be >= 0, got {retry_delay}")
|
|
165
|
+
if max_delay < 0:
|
|
166
|
+
raise ValueError(f"max_delay must be >= 0, got {max_delay}")
|
|
167
|
+
self.retry_count = retry_count
|
|
168
|
+
self.retry_delay = retry_delay
|
|
169
|
+
self.exception_types = exception_types
|
|
170
|
+
self.backoff_factor = backoff_factor
|
|
171
|
+
self.max_delay = max_delay
|
|
172
|
+
|
|
173
|
+
def should_retry(self, exception: Exception) -> bool:
|
|
174
|
+
"""判断是否应该重试 - 优先检查retryable属性,否则使用isinstance检查继承关系"""
|
|
175
|
+
# 如果异常有retryable属性,优先使用
|
|
176
|
+
if hasattr(exception, "retryable"):
|
|
177
|
+
return bool(exception.retryable)
|
|
178
|
+
|
|
179
|
+
# 否则使用isinstance检查异常是否属于指定类型(包括继承关系)
|
|
180
|
+
return isinstance(exception, self.exception_types)
|
|
181
|
+
|
|
182
|
+
def get_delay(self, attempt: int) -> float:
|
|
183
|
+
"""计算重试延迟时间(支持指数退避)"""
|
|
184
|
+
delay = self.retry_delay * (self.backoff_factor**attempt)
|
|
185
|
+
return min(delay, self.max_delay)
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def _get_func_name(func: Any, fallback_name: str | None = None) -> str:
|
|
189
|
+
"""安全获取函数名称"""
|
|
190
|
+
if hasattr(func, "__name__"):
|
|
191
|
+
return str(func.__name__)
|
|
192
|
+
elif hasattr(func, "func") and hasattr(func.func, "__name__"): # partial对象
|
|
193
|
+
return str(func.func.__name__)
|
|
194
|
+
elif hasattr(func, "name"): # Node对象
|
|
195
|
+
return str(func.name)
|
|
196
|
+
elif fallback_name:
|
|
197
|
+
return fallback_name
|
|
198
|
+
else:
|
|
199
|
+
return "unknown_function"
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def retry_decorator(
|
|
203
|
+
config: RetryConfig,
|
|
204
|
+
node_name: str | None = None,
|
|
205
|
+
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
|
206
|
+
"""重试装饰器
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
config: RetryConfig 配置模型
|
|
210
|
+
node_name: 节点名称,用于异常信息
|
|
211
|
+
"""
|
|
212
|
+
|
|
213
|
+
def decorator(func: Callable) -> Callable:
|
|
214
|
+
func_name = node_name or _get_func_name(func)
|
|
215
|
+
|
|
216
|
+
if inspect.iscoroutinefunction(func):
|
|
217
|
+
# 异步函数wrapper
|
|
218
|
+
@functools.wraps(func)
|
|
219
|
+
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
220
|
+
for attempt in range(config.retry_count + 1):
|
|
221
|
+
try:
|
|
222
|
+
logger.debug(
|
|
223
|
+
f"执行节点 {func_name},尝试 {attempt + 1}/{config.retry_count + 1}"
|
|
224
|
+
)
|
|
225
|
+
result = await func(*args, **kwargs) # 异步调用
|
|
226
|
+
|
|
227
|
+
if attempt > 0:
|
|
228
|
+
logger.info(
|
|
229
|
+
f"节点 {func_name} 在第 {attempt + 1} 次尝试后成功"
|
|
230
|
+
)
|
|
231
|
+
return result
|
|
232
|
+
|
|
233
|
+
except (KeyboardInterrupt, SystemExit):
|
|
234
|
+
raise
|
|
235
|
+
except Exception as e:
|
|
236
|
+
if not config.should_retry(e):
|
|
237
|
+
# 记录不重试的原因
|
|
238
|
+
logger.debug(
|
|
239
|
+
f"节点 {func_name} 异常不支持重试: {type(e).__name__}: {e}"
|
|
240
|
+
)
|
|
241
|
+
raise # 直接抛出,不封装
|
|
242
|
+
|
|
243
|
+
if attempt == config.retry_count:
|
|
244
|
+
# 记录重试耗尽
|
|
245
|
+
logger.error(
|
|
246
|
+
f"节点 {func_name} 重试 {config.retry_count} 次后仍失败: {type(e).__name__}: {e}"
|
|
247
|
+
)
|
|
248
|
+
raise # 重试耗尽也直接抛出,不封装
|
|
249
|
+
|
|
250
|
+
delay = config.get_delay(attempt)
|
|
251
|
+
logger.warning(
|
|
252
|
+
f"节点 {func_name} 第 {attempt + 1} 次尝试失败: {e},{delay:.2f}秒后重试"
|
|
253
|
+
)
|
|
254
|
+
await asyncio.sleep(delay) # 异步延迟
|
|
255
|
+
|
|
256
|
+
return async_wrapper
|
|
257
|
+
else:
|
|
258
|
+
# 同步函数wrapper
|
|
259
|
+
@functools.wraps(func)
|
|
260
|
+
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
261
|
+
for attempt in range(config.retry_count + 1):
|
|
262
|
+
try:
|
|
263
|
+
logger.debug(
|
|
264
|
+
f"执行节点 {func_name},尝试 {attempt + 1}/{config.retry_count + 1}"
|
|
265
|
+
)
|
|
266
|
+
result = func(*args, **kwargs) # 同步调用
|
|
267
|
+
|
|
268
|
+
if attempt > 0:
|
|
269
|
+
logger.info(
|
|
270
|
+
f"节点 {func_name} 在第 {attempt + 1} 次尝试后成功"
|
|
271
|
+
)
|
|
272
|
+
return result
|
|
273
|
+
|
|
274
|
+
except (KeyboardInterrupt, SystemExit):
|
|
275
|
+
raise
|
|
276
|
+
except Exception as e:
|
|
277
|
+
if not config.should_retry(e):
|
|
278
|
+
# 记录不重试的原因
|
|
279
|
+
logger.debug(
|
|
280
|
+
f"节点 {func_name} 异常不支持重试: {type(e).__name__}: {e}"
|
|
281
|
+
)
|
|
282
|
+
raise # 直接抛出,不封装
|
|
283
|
+
|
|
284
|
+
if attempt == config.retry_count:
|
|
285
|
+
# 记录重试耗尽
|
|
286
|
+
logger.error(
|
|
287
|
+
f"节点 {func_name} 重试 {config.retry_count} 次后仍失败: {type(e).__name__}: {e}"
|
|
288
|
+
)
|
|
289
|
+
raise # 重试耗尽也直接抛出,不封装
|
|
290
|
+
|
|
291
|
+
delay = config.get_delay(attempt)
|
|
292
|
+
logger.warning(
|
|
293
|
+
f"节点 {func_name} 第 {attempt + 1} 次尝试失败: {e},{delay:.2f}秒后重试"
|
|
294
|
+
)
|
|
295
|
+
time.sleep(delay) # 同步延迟
|
|
296
|
+
|
|
297
|
+
return sync_wrapper
|
|
298
|
+
|
|
299
|
+
return decorator
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
# Context variables for asyncio coroutine safety
|
|
303
|
+
_context_state: ContextVar[dict | None] = ContextVar("streamlet_state", default=None)
|
|
304
|
+
_context_context: ContextVar[dict | None] = ContextVar(
|
|
305
|
+
"streamlet_context", default=None
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
# ==================== 自定义ContextVar Provider ====================
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
class ContextVarProvider(providers.Provider):
|
|
313
|
+
"""自定义Provider类,支持ContextVar的协程安全依赖注入。
|
|
314
|
+
|
|
315
|
+
这个Provider替代了直接调用ContextVar.get()的方式,
|
|
316
|
+
提供了正确的dependency-injector集成。
|
|
317
|
+
"""
|
|
318
|
+
|
|
319
|
+
def __init__(self, default_factory: Callable[[], Any] = dict):
|
|
320
|
+
"""初始化ContextVarProvider。
|
|
321
|
+
|
|
322
|
+
Args:
|
|
323
|
+
default_factory: 创建默认值的工厂函数,默认为dict
|
|
324
|
+
"""
|
|
325
|
+
super().__init__()
|
|
326
|
+
self._context_var = ContextVar(f"streamlet_{id(self)}", default=None)
|
|
327
|
+
self._default_factory = default_factory
|
|
328
|
+
|
|
329
|
+
def _provide(self, *args: Any, **kwargs: Any) -> Any:
|
|
330
|
+
"""提供协程安全的状态值。
|
|
331
|
+
|
|
332
|
+
Returns:
|
|
333
|
+
ContextVar中的值,如果未设置则返回默认值
|
|
334
|
+
"""
|
|
335
|
+
try:
|
|
336
|
+
value = self._context_var.get()
|
|
337
|
+
if value is None:
|
|
338
|
+
# 如果未设置,创建并设置默认值
|
|
339
|
+
value = self._default_factory()
|
|
340
|
+
self._context_var.set(value)
|
|
341
|
+
return value
|
|
342
|
+
except LookupError:
|
|
343
|
+
# 如果ContextVar未初始化,创建默认值
|
|
344
|
+
value = self._default_factory()
|
|
345
|
+
self._context_var.set(value)
|
|
346
|
+
return value
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
class BaseFlowContext(containers.DeclarativeContainer):
|
|
350
|
+
"""Base container for flow context with thread-safe and coroutine-safe dependency injection support."""
|
|
351
|
+
|
|
352
|
+
# Use ThreadLocalSingleton for thread-local state isolation
|
|
353
|
+
# Each thread gets its own state dictionary
|
|
354
|
+
state: providers.Provider = providers.ThreadLocalSingleton(dict)
|
|
355
|
+
context: providers.Provider = providers.ThreadLocalSingleton(dict)
|
|
356
|
+
shared_data: providers.Provider = providers.Singleton(dict)
|
|
357
|
+
|
|
358
|
+
# Coroutine-safe providers using ContextVar for asyncio
|
|
359
|
+
async_state: providers.Provider = ContextVarProvider(dict)
|
|
360
|
+
async_context: providers.Provider = ContextVarProvider(dict)
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
def custom_validate_call(
|
|
364
|
+
validate_return: bool = True,
|
|
365
|
+
config: ConfigDict | None = None,
|
|
366
|
+
node_name: str | None = None,
|
|
367
|
+
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
|
368
|
+
"""
|
|
369
|
+
自定义validate_call包装器,使用Pydantic最佳实践区分输入验证和输出验证异常
|
|
370
|
+
|
|
371
|
+
Args:
|
|
372
|
+
validate_return: 是否验证返回值
|
|
373
|
+
config: Pydantic配置
|
|
374
|
+
node_name: 节点名称用于异常信息
|
|
375
|
+
|
|
376
|
+
Returns:
|
|
377
|
+
装饰器函数
|
|
378
|
+
"""
|
|
379
|
+
|
|
380
|
+
def decorator(func: Callable) -> Callable:
|
|
381
|
+
# 获取函数签名
|
|
382
|
+
sig = inspect.signature(func)
|
|
383
|
+
|
|
384
|
+
# 创建输入验证器 - 只验证参数,不验证返回值
|
|
385
|
+
input_validator = validate_call(
|
|
386
|
+
validate_return=False,
|
|
387
|
+
config=config or ConfigDict(arbitrary_types_allowed=True),
|
|
388
|
+
)(func)
|
|
389
|
+
|
|
390
|
+
# 创建返回值验证器(如果需要且有返回值类型注解)
|
|
391
|
+
return_type_adapter = None
|
|
392
|
+
if validate_return and sig.return_annotation != inspect.Signature.empty:
|
|
393
|
+
return_type_adapter = TypeAdapter(sig.return_annotation)
|
|
394
|
+
|
|
395
|
+
# 提取公共逻辑
|
|
396
|
+
func_name = node_name or _get_func_name(func)
|
|
397
|
+
|
|
398
|
+
def create_input_exception(e: ValidationError) -> ValidationInputException:
|
|
399
|
+
return ValidationInputException(
|
|
400
|
+
f"输入参数验证失败: {e}",
|
|
401
|
+
validation_error=e,
|
|
402
|
+
node_name=func_name,
|
|
403
|
+
)
|
|
404
|
+
|
|
405
|
+
def create_output_exception(e: ValidationError) -> ValidationOutputException:
|
|
406
|
+
return ValidationOutputException(
|
|
407
|
+
f"返回值验证失败: {e}",
|
|
408
|
+
validation_error=e,
|
|
409
|
+
node_name=func_name,
|
|
410
|
+
)
|
|
411
|
+
|
|
412
|
+
def validate_result(result: Any) -> Any:
|
|
413
|
+
if return_type_adapter:
|
|
414
|
+
try:
|
|
415
|
+
return_type_adapter.validate_python(result)
|
|
416
|
+
except ValidationError as e:
|
|
417
|
+
raise create_output_exception(e) from e
|
|
418
|
+
return result
|
|
419
|
+
|
|
420
|
+
# 根据函数类型提供对应的wrapper
|
|
421
|
+
if inspect.iscoroutinefunction(func):
|
|
422
|
+
|
|
423
|
+
@functools.wraps(func)
|
|
424
|
+
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
425
|
+
try:
|
|
426
|
+
result = await input_validator(*args, **kwargs)
|
|
427
|
+
except ValidationError as e:
|
|
428
|
+
raise create_input_exception(e) from e
|
|
429
|
+
return validate_result(result)
|
|
430
|
+
|
|
431
|
+
return async_wrapper
|
|
432
|
+
else:
|
|
433
|
+
|
|
434
|
+
@functools.wraps(func)
|
|
435
|
+
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
436
|
+
try:
|
|
437
|
+
result = input_validator(*args, **kwargs)
|
|
438
|
+
except ValidationError as e:
|
|
439
|
+
raise create_input_exception(e) from e
|
|
440
|
+
return validate_result(result)
|
|
441
|
+
|
|
442
|
+
return sync_wrapper
|
|
443
|
+
|
|
444
|
+
return decorator
|
|
445
|
+
|
|
446
|
+
|
|
447
|
+
class Node:
|
|
448
|
+
"""A node in the execution graph that supports fluent interface methods."""
|
|
449
|
+
|
|
450
|
+
def __init__(
|
|
451
|
+
self,
|
|
452
|
+
func: Callable,
|
|
453
|
+
name: str,
|
|
454
|
+
is_start_node: bool = True,
|
|
455
|
+
is_async: bool | None = None,
|
|
456
|
+
):
|
|
457
|
+
# 配置Pydantic支持任意类型(包括dependency injection的类型)
|
|
458
|
+
self.func = func
|
|
459
|
+
self.name = name
|
|
460
|
+
self.is_start_node = is_start_node
|
|
461
|
+
# 智能检测异步特性:处理Node对象、装饰器等复杂情况
|
|
462
|
+
if is_async is not None:
|
|
463
|
+
# 显式传入,直接使用
|
|
464
|
+
self.is_async = is_async
|
|
465
|
+
elif isinstance(func, Node):
|
|
466
|
+
# func是Node对象,使用其is_async属性
|
|
467
|
+
self.is_async = func.is_async
|
|
468
|
+
else:
|
|
469
|
+
# func是普通函数,使用inspect检测
|
|
470
|
+
self.is_async = inspect.iscoroutinefunction(func)
|
|
471
|
+
|
|
472
|
+
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
|
473
|
+
"""智能异步适配调用:根据函数类型和执行上下文自动处理同步/异步调用。"""
|
|
474
|
+
if self.is_async:
|
|
475
|
+
# 异步函数:根据当前执行上下文智能处理
|
|
476
|
+
try:
|
|
477
|
+
# 检查是否在事件循环中
|
|
478
|
+
asyncio.get_running_loop()
|
|
479
|
+
# 在事件循环中,返回协程对象让调用者await
|
|
480
|
+
return self.func(*args, **kwargs)
|
|
481
|
+
except RuntimeError:
|
|
482
|
+
# 不在事件循环中,创建新事件循环同步执行
|
|
483
|
+
return asyncio.run(self.func(*args, **kwargs))
|
|
484
|
+
else:
|
|
485
|
+
# 同步函数:直接执行
|
|
486
|
+
return self.func(*args, **kwargs)
|
|
487
|
+
|
|
488
|
+
def then(self, next_node: "Node") -> "Node":
|
|
489
|
+
"""Chain this node with another node for sequential execution."""
|
|
490
|
+
return sequential_composition(self, next_node)
|
|
491
|
+
|
|
492
|
+
def fan_out_to(
|
|
493
|
+
self,
|
|
494
|
+
nodes: list["Node"],
|
|
495
|
+
executor: str = "thread",
|
|
496
|
+
max_workers: int | None = None,
|
|
497
|
+
) -> "Node":
|
|
498
|
+
"""Fan out to multiple nodes for parallel execution."""
|
|
499
|
+
# Normalize executor type to lowercase for case-insensitive comparison
|
|
500
|
+
executor_lower = executor.lower()
|
|
501
|
+
if executor_lower not in ["thread", "async", "auto"]:
|
|
502
|
+
raise ValueError(
|
|
503
|
+
"Only 'thread', 'async', and 'auto' executors are supported. ProcessPoolExecutor has been removed to resolve pickle serialization issues."
|
|
504
|
+
)
|
|
505
|
+
return parallel_fan_out(self, nodes, executor_lower, max_workers)
|
|
506
|
+
|
|
507
|
+
def fan_in(self, aggregator: "Node") -> "Node":
|
|
508
|
+
"""Aggregate results using the specified aggregator node."""
|
|
509
|
+
return parallel_fan_in(self, aggregator)
|
|
510
|
+
|
|
511
|
+
def fan_out_in(
|
|
512
|
+
self,
|
|
513
|
+
targets: list["Node"],
|
|
514
|
+
aggregator: "Node",
|
|
515
|
+
executor: str = "thread",
|
|
516
|
+
max_workers: int | None = None,
|
|
517
|
+
) -> "Node":
|
|
518
|
+
"""Complete fan-out and fan-in operation in one step."""
|
|
519
|
+
# Normalize executor type to lowercase for case-insensitive comparison
|
|
520
|
+
executor_lower = executor.lower()
|
|
521
|
+
if executor_lower not in ["thread", "async", "auto"]:
|
|
522
|
+
raise ValueError(
|
|
523
|
+
"Only 'thread', 'async', and 'auto' executors are supported. ProcessPoolExecutor has been removed to resolve pickle serialization issues."
|
|
524
|
+
)
|
|
525
|
+
return parallel_fan_out_in(
|
|
526
|
+
self, targets, aggregator, executor_lower, max_workers
|
|
527
|
+
)
|
|
528
|
+
|
|
529
|
+
def branch_on(self, conditions: dict[bool, "Node"]) -> "Node":
|
|
530
|
+
"""Branch execution based on the boolean output of this node."""
|
|
531
|
+
return conditional_composition(self, conditions)
|
|
532
|
+
|
|
533
|
+
def repeat(self, times: int, stop_on_error: bool = False) -> "Node":
|
|
534
|
+
"""重复执行此节点。
|
|
535
|
+
|
|
536
|
+
Args:
|
|
537
|
+
times: 重复次数
|
|
538
|
+
stop_on_error: 遇到错误时是否立即停止
|
|
539
|
+
"""
|
|
540
|
+
return repeat_composition(self, times, stop_on_error)
|
|
541
|
+
|
|
542
|
+
def __repr__(self) -> str:
|
|
543
|
+
return f"Node(name='{self.name}')"
|
|
544
|
+
|
|
545
|
+
|
|
546
|
+
def sequential_composition(left: Node, right: Node) -> Node:
|
|
547
|
+
"""Sequential execution that combines two nodes with one-time type inference."""
|
|
548
|
+
|
|
549
|
+
# 组合时一次性检测是否有异步节点
|
|
550
|
+
has_async = left.is_async or right.is_async
|
|
551
|
+
|
|
552
|
+
composition_name = f"({left.name} -> {right.name})"
|
|
553
|
+
|
|
554
|
+
if has_async:
|
|
555
|
+
# 如果包含异步节点,创建异步组合函数
|
|
556
|
+
async def async_run(*args: Any, **kwargs: Any) -> Any:
|
|
557
|
+
# 执行左节点,使用Node.__call__智能适配
|
|
558
|
+
left_result = (
|
|
559
|
+
await left(*args, **kwargs) if left.is_async else left(*args, **kwargs)
|
|
560
|
+
)
|
|
561
|
+
|
|
562
|
+
# 执行右节点,使用Node.__call__智能适配
|
|
563
|
+
right_result = (
|
|
564
|
+
await right(left_result) if right.is_async else right(left_result)
|
|
565
|
+
)
|
|
566
|
+
|
|
567
|
+
return right_result
|
|
568
|
+
|
|
569
|
+
return Node(func=async_run, name=composition_name, is_start_node=False)
|
|
570
|
+
else:
|
|
571
|
+
# 如果都是同步节点,创建同步组合函数
|
|
572
|
+
def run(*args: Any, **kwargs: Any) -> Any:
|
|
573
|
+
left_result = left(*args, **kwargs)
|
|
574
|
+
right_result = right(left_result)
|
|
575
|
+
return right_result
|
|
576
|
+
|
|
577
|
+
return Node(func=run, name=composition_name, is_start_node=False)
|
|
578
|
+
|
|
579
|
+
|
|
580
|
+
def _generate_unique_result_key(base_name: str, existing_results: dict) -> str:
|
|
581
|
+
"""生成唯一的结果键,避免重复覆盖
|
|
582
|
+
|
|
583
|
+
Args:
|
|
584
|
+
base_name: 基础键名(通常是节点名称)
|
|
585
|
+
existing_results: 现有的结果字典
|
|
586
|
+
|
|
587
|
+
Returns:
|
|
588
|
+
唯一的键名,如果base_name无冲突则返回原名,否则返回带数字后缀的名称
|
|
589
|
+
"""
|
|
590
|
+
if base_name not in existing_results:
|
|
591
|
+
return base_name
|
|
592
|
+
|
|
593
|
+
counter = 1
|
|
594
|
+
unique_key = f"{base_name}[{counter}]"
|
|
595
|
+
while unique_key in existing_results:
|
|
596
|
+
counter += 1
|
|
597
|
+
unique_key = f"{base_name}[{counter}]"
|
|
598
|
+
|
|
599
|
+
return unique_key
|
|
600
|
+
|
|
601
|
+
|
|
602
|
+
# 定义并行任务执行函数
|
|
603
|
+
def execute_target_node(node: Node, input_data: Any) -> ParallelResult:
|
|
604
|
+
"""Execute a single target node with the provided input using intelligent async/sync handling."""
|
|
605
|
+
import traceback
|
|
606
|
+
|
|
607
|
+
start_time = time.time()
|
|
608
|
+
|
|
609
|
+
try:
|
|
610
|
+
# 使用Node.__call__智能异步适配
|
|
611
|
+
result = node(input_data)
|
|
612
|
+
execution_time = time.time() - start_time
|
|
613
|
+
|
|
614
|
+
return ParallelResult(
|
|
615
|
+
node_name=node.name,
|
|
616
|
+
success=True,
|
|
617
|
+
result=result,
|
|
618
|
+
execution_time=execution_time,
|
|
619
|
+
)
|
|
620
|
+
except Exception as e:
|
|
621
|
+
execution_time = time.time() - start_time
|
|
622
|
+
error_traceback = traceback.format_exc()
|
|
623
|
+
logger.error(f"Node '{node.name}' failed: {e}")
|
|
624
|
+
|
|
625
|
+
return ParallelResult(
|
|
626
|
+
node_name=node.name,
|
|
627
|
+
success=False,
|
|
628
|
+
error=str(e),
|
|
629
|
+
error_traceback=error_traceback,
|
|
630
|
+
execution_time=execution_time,
|
|
631
|
+
)
|
|
632
|
+
|
|
633
|
+
|
|
634
|
+
def parallel_fan_out(
|
|
635
|
+
source: Node,
|
|
636
|
+
targets: list[Node],
|
|
637
|
+
executor: str = "thread",
|
|
638
|
+
max_workers: int | None = None,
|
|
639
|
+
) -> Node:
|
|
640
|
+
"""
|
|
641
|
+
Simplified parallel fan-out execution with type-based executor selection.
|
|
642
|
+
|
|
643
|
+
Args:
|
|
644
|
+
source: Source node to execute first
|
|
645
|
+
targets: List of target nodes for parallel execution
|
|
646
|
+
executor: 'thread', 'async', or 'auto' for automatic selection
|
|
647
|
+
max_workers: Maximum worker threads (ignored for async)
|
|
648
|
+
|
|
649
|
+
Returns:
|
|
650
|
+
Node that performs parallel fan-out execution
|
|
651
|
+
"""
|
|
652
|
+
if not targets:
|
|
653
|
+
raise ValueError("Target nodes list cannot be empty")
|
|
654
|
+
|
|
655
|
+
executor = executor.lower()
|
|
656
|
+
|
|
657
|
+
# Simplified 'auto' executor selection based on one-time type inference
|
|
658
|
+
if executor == "auto":
|
|
659
|
+
all_nodes = [source] + targets
|
|
660
|
+
has_async = any(node.is_async for node in all_nodes)
|
|
661
|
+
executor = "async" if has_async else "thread"
|
|
662
|
+
logger.info(f"Auto-selected executor '{executor}' based on node types")
|
|
663
|
+
|
|
664
|
+
if executor not in ["thread", "async"]:
|
|
665
|
+
raise ValueError("Only 'thread', 'async', and 'auto' executors are supported.")
|
|
666
|
+
|
|
667
|
+
target_names = [t.name for t in targets]
|
|
668
|
+
composition_name = f"({source.name} -> [{', '.join(target_names)}])"
|
|
669
|
+
|
|
670
|
+
if executor == "async":
|
|
671
|
+
# Simplified async version using Node.__call__ smart adaptation
|
|
672
|
+
async def run_async(*args: Any, **kwargs: Any) -> dict[str, ParallelResult]:
|
|
673
|
+
logger.info(f"Executing Async Parallel Fan-Out: {composition_name}")
|
|
674
|
+
|
|
675
|
+
# Execute source node with consistent async handling
|
|
676
|
+
source_result = (
|
|
677
|
+
await source(*args, **kwargs)
|
|
678
|
+
if source.is_async
|
|
679
|
+
else source(*args, **kwargs)
|
|
680
|
+
)
|
|
681
|
+
|
|
682
|
+
# Execute target nodes in parallel
|
|
683
|
+
async def execute_async_target(
|
|
684
|
+
node: Node, input_data: Any
|
|
685
|
+
) -> ParallelResult:
|
|
686
|
+
start_time = time.time()
|
|
687
|
+
try:
|
|
688
|
+
# Consistent async handling with sequential_composition
|
|
689
|
+
result = (
|
|
690
|
+
await node(input_data) if node.is_async else node(input_data)
|
|
691
|
+
)
|
|
692
|
+
|
|
693
|
+
execution_time = time.time() - start_time
|
|
694
|
+
return ParallelResult(
|
|
695
|
+
node_name=node.name,
|
|
696
|
+
success=True,
|
|
697
|
+
result=result,
|
|
698
|
+
execution_time=execution_time,
|
|
699
|
+
)
|
|
700
|
+
except Exception as e:
|
|
701
|
+
import traceback
|
|
702
|
+
|
|
703
|
+
execution_time = time.time() - start_time
|
|
704
|
+
return ParallelResult(
|
|
705
|
+
node_name=node.name,
|
|
706
|
+
success=False,
|
|
707
|
+
error=str(e),
|
|
708
|
+
error_traceback=traceback.format_exc(),
|
|
709
|
+
execution_time=execution_time,
|
|
710
|
+
)
|
|
711
|
+
|
|
712
|
+
# Create and execute tasks
|
|
713
|
+
tasks = [execute_async_target(node, source_result) for node in targets]
|
|
714
|
+
results = await asyncio.gather(*tasks)
|
|
715
|
+
|
|
716
|
+
# Collect results with unique keys
|
|
717
|
+
parallel_results: dict[str, ParallelResult] = {}
|
|
718
|
+
for result in results:
|
|
719
|
+
result_key = _generate_unique_result_key(
|
|
720
|
+
result.node_name, parallel_results
|
|
721
|
+
)
|
|
722
|
+
parallel_results[result_key] = result
|
|
723
|
+
|
|
724
|
+
logger.info(
|
|
725
|
+
f"Async parallel fan-out completed with {len(parallel_results)} results"
|
|
726
|
+
)
|
|
727
|
+
return parallel_results
|
|
728
|
+
|
|
729
|
+
return Node(func=run_async, name=composition_name)
|
|
730
|
+
|
|
731
|
+
else:
|
|
732
|
+
# Simplified thread version using Node.__call__ smart adaptation
|
|
733
|
+
def run_thread(*args: Any, **kwargs: Any) -> dict[str, ParallelResult]:
|
|
734
|
+
logger.info(f"Executing Thread Parallel Fan-Out: {composition_name}")
|
|
735
|
+
|
|
736
|
+
# Execute source node with smart adaptation
|
|
737
|
+
source_result = source(*args, **kwargs)
|
|
738
|
+
|
|
739
|
+
# Execute target nodes in parallel using ThreadPoolExecutor
|
|
740
|
+
parallel_results: dict[str, ParallelResult] = {}
|
|
741
|
+
|
|
742
|
+
with ThreadPoolExecutor(max_workers=max_workers) as executor_instance:
|
|
743
|
+
# Submit all parallel tasks
|
|
744
|
+
future_to_node = {
|
|
745
|
+
executor_instance.submit(
|
|
746
|
+
execute_target_node, node, source_result
|
|
747
|
+
): node
|
|
748
|
+
for node in targets
|
|
749
|
+
}
|
|
750
|
+
|
|
751
|
+
# Collect results
|
|
752
|
+
for future in as_completed(future_to_node):
|
|
753
|
+
node = future_to_node[future]
|
|
754
|
+
try:
|
|
755
|
+
parallel_result = future.result()
|
|
756
|
+
|
|
757
|
+
# Generate unique result key
|
|
758
|
+
result_key = _generate_unique_result_key(
|
|
759
|
+
parallel_result.node_name, parallel_results
|
|
760
|
+
)
|
|
761
|
+
|
|
762
|
+
parallel_results[result_key] = parallel_result
|
|
763
|
+
logger.debug(
|
|
764
|
+
f"Collected result from '{result_key}': success={parallel_result.success}"
|
|
765
|
+
)
|
|
766
|
+
|
|
767
|
+
except Exception as e:
|
|
768
|
+
import traceback
|
|
769
|
+
|
|
770
|
+
logger.error(f"Failed to get result from '{node.name}': {e}")
|
|
771
|
+
|
|
772
|
+
# Generate unique result key to avoid key overwriting in exception cases
|
|
773
|
+
error_result_key = _generate_unique_result_key(
|
|
774
|
+
node.name, parallel_results
|
|
775
|
+
)
|
|
776
|
+
parallel_results[error_result_key] = ParallelResult(
|
|
777
|
+
node_name=node.name,
|
|
778
|
+
success=False,
|
|
779
|
+
error=str(e),
|
|
780
|
+
error_traceback=traceback.format_exc(),
|
|
781
|
+
)
|
|
782
|
+
|
|
783
|
+
# Return parallel results
|
|
784
|
+
logger.info(
|
|
785
|
+
f"Thread parallel fan-out completed with {len(parallel_results)} results"
|
|
786
|
+
)
|
|
787
|
+
return parallel_results
|
|
788
|
+
|
|
789
|
+
return Node(func=run_thread, name=composition_name)
|
|
790
|
+
|
|
791
|
+
|
|
792
|
+
def parallel_fan_in(fan_out_node: Node, aggregator: Node) -> Node:
|
|
793
|
+
"""
|
|
794
|
+
Simplified parallel fan-in aggregation with intelligent async/sync handling.
|
|
795
|
+
|
|
796
|
+
Args:
|
|
797
|
+
fan_out_node: The fan-out node to execute first
|
|
798
|
+
aggregator: The aggregator node that combines results
|
|
799
|
+
|
|
800
|
+
Returns:
|
|
801
|
+
Node that performs fan-in aggregation
|
|
802
|
+
"""
|
|
803
|
+
|
|
804
|
+
# 检测是否包含异步节点,与sequential_composition保持一致
|
|
805
|
+
has_async = fan_out_node.is_async or aggregator.is_async
|
|
806
|
+
|
|
807
|
+
composition_name = f"({fan_out_node.name} -> {aggregator.name})"
|
|
808
|
+
|
|
809
|
+
if has_async:
|
|
810
|
+
# 包含异步节点,创建异步组合函数
|
|
811
|
+
async def async_run(*args: Any, **kwargs: Any) -> Any:
|
|
812
|
+
logger.info(f"Executing Parallel Fan-In (Async): {composition_name}")
|
|
813
|
+
|
|
814
|
+
# 执行fan-out节点,智能适配异步/同步
|
|
815
|
+
fan_out_result = (
|
|
816
|
+
await fan_out_node(*args, **kwargs)
|
|
817
|
+
if fan_out_node.is_async
|
|
818
|
+
else fan_out_node(*args, **kwargs)
|
|
819
|
+
)
|
|
820
|
+
|
|
821
|
+
# 执行聚合器,智能适配异步/同步
|
|
822
|
+
aggregator_result = (
|
|
823
|
+
await aggregator(fan_out_result)
|
|
824
|
+
if aggregator.is_async
|
|
825
|
+
else aggregator(fan_out_result)
|
|
826
|
+
)
|
|
827
|
+
|
|
828
|
+
logger.info("Fan-in aggregation completed successfully")
|
|
829
|
+
return aggregator_result
|
|
830
|
+
|
|
831
|
+
return Node(func=async_run, name=composition_name, is_start_node=False)
|
|
832
|
+
else:
|
|
833
|
+
# 都是同步节点,创建同步组合函数
|
|
834
|
+
def run(*args: Any, **kwargs: Any) -> Any:
|
|
835
|
+
logger.info(f"Executing Parallel Fan-In (Sync): {composition_name}")
|
|
836
|
+
|
|
837
|
+
# 执行fan-out节点,获取并行结果
|
|
838
|
+
parallel_results = fan_out_node(*args, **kwargs)
|
|
839
|
+
|
|
840
|
+
# 将并行结果作为参数传递给聚合器
|
|
841
|
+
aggregator_result = aggregator(parallel_results)
|
|
842
|
+
|
|
843
|
+
logger.info("Fan-in aggregation completed successfully")
|
|
844
|
+
return aggregator_result
|
|
845
|
+
|
|
846
|
+
return Node(func=run, name=composition_name, is_start_node=False)
|
|
847
|
+
|
|
848
|
+
|
|
849
|
+
def parallel_fan_out_in(
|
|
850
|
+
source: Node,
|
|
851
|
+
targets: list[Node],
|
|
852
|
+
aggregator: Node,
|
|
853
|
+
executor: str = "thread",
|
|
854
|
+
max_workers: int | None = None,
|
|
855
|
+
) -> Node:
|
|
856
|
+
"""
|
|
857
|
+
Convenience function that combines fan-out and fan-in into a single operation.
|
|
858
|
+
|
|
859
|
+
Args:
|
|
860
|
+
source: Source node to execute first
|
|
861
|
+
targets: List of target nodes for parallel execution
|
|
862
|
+
aggregator: Aggregator node that combines parallel results
|
|
863
|
+
executor: 'thread', 'async', or 'auto' for automatic selection
|
|
864
|
+
max_workers: Maximum worker threads (ignored for async)
|
|
865
|
+
|
|
866
|
+
Returns:
|
|
867
|
+
Node that performs complete fan-out-in operation
|
|
868
|
+
"""
|
|
869
|
+
# Normalize executor type to lowercase for case-insensitive comparison
|
|
870
|
+
executor = executor.lower()
|
|
871
|
+
|
|
872
|
+
# Handle 'auto' executor selection using simplified one-time type inference
|
|
873
|
+
if executor == "auto":
|
|
874
|
+
# Analyze all nodes (source + targets + aggregator) for optimal executor choice
|
|
875
|
+
all_nodes = [source] + targets + [aggregator]
|
|
876
|
+
has_async = any(node.is_async for node in all_nodes)
|
|
877
|
+
executor = "async" if has_async else "thread"
|
|
878
|
+
logger.info(
|
|
879
|
+
f"Auto-selected executor '{executor}' based on node types for fan-out-in"
|
|
880
|
+
)
|
|
881
|
+
|
|
882
|
+
if executor not in ["thread", "async"]:
|
|
883
|
+
raise ValueError(
|
|
884
|
+
"Only 'thread', 'async', and 'auto' executors are supported. ProcessPoolExecutor has been removed to resolve pickle serialization issues."
|
|
885
|
+
)
|
|
886
|
+
# 创建fan-out节点
|
|
887
|
+
fan_out_node = parallel_fan_out(
|
|
888
|
+
source=source, targets=targets, executor=executor, max_workers=max_workers
|
|
889
|
+
)
|
|
890
|
+
|
|
891
|
+
# 创建fan-in节点
|
|
892
|
+
return parallel_fan_in(fan_out_node, aggregator)
|
|
893
|
+
|
|
894
|
+
|
|
895
|
+
def conditional_composition(condition_node: Node, branches: dict[Any, Node]) -> Node:
|
|
896
|
+
"""Conditional branching based on boolean output."""
|
|
897
|
+
|
|
898
|
+
# 检测是否有异步节点
|
|
899
|
+
has_async = condition_node.is_async or any(
|
|
900
|
+
branch.is_async for branch in branches.values()
|
|
901
|
+
)
|
|
902
|
+
|
|
903
|
+
branch_names = {k: v.name for k, v in branches.items()}
|
|
904
|
+
composition_name = f"({condition_node.name} ? {branch_names})"
|
|
905
|
+
|
|
906
|
+
if has_async:
|
|
907
|
+
# 异步版本
|
|
908
|
+
async def async_run(*args: Any, **kwargs: Any) -> Any:
|
|
909
|
+
logger.info(f"--- Executing Conditional Branch: {composition_name} ---")
|
|
910
|
+
|
|
911
|
+
# 执行条件节点
|
|
912
|
+
condition_result = (
|
|
913
|
+
await condition_node(*args, **kwargs)
|
|
914
|
+
if condition_node.is_async
|
|
915
|
+
else condition_node(*args, **kwargs)
|
|
916
|
+
)
|
|
917
|
+
|
|
918
|
+
# 执行对应的分支
|
|
919
|
+
if condition_result in branches:
|
|
920
|
+
selected_branch = branches[condition_result]
|
|
921
|
+
logger.info(
|
|
922
|
+
f"Condition is {condition_result}, executing branch: {selected_branch.name}"
|
|
923
|
+
)
|
|
924
|
+
return (
|
|
925
|
+
await selected_branch()
|
|
926
|
+
if selected_branch.is_async
|
|
927
|
+
else selected_branch()
|
|
928
|
+
)
|
|
929
|
+
else:
|
|
930
|
+
msg = f"No branch defined for condition result: {condition_result}"
|
|
931
|
+
logger.error(msg)
|
|
932
|
+
raise ValueError(msg)
|
|
933
|
+
|
|
934
|
+
return Node(
|
|
935
|
+
func=async_run,
|
|
936
|
+
name=composition_name,
|
|
937
|
+
)
|
|
938
|
+
else:
|
|
939
|
+
# 同步版本
|
|
940
|
+
def run(*args: Any, **kwargs: Any) -> Any:
|
|
941
|
+
logger.info(f"--- Executing Conditional Branch: {composition_name} ---")
|
|
942
|
+
|
|
943
|
+
# Execute condition node
|
|
944
|
+
condition_result = condition_node(*args, **kwargs)
|
|
945
|
+
|
|
946
|
+
# Execute the appropriate branch
|
|
947
|
+
if condition_result in branches:
|
|
948
|
+
selected_branch = branches[condition_result]
|
|
949
|
+
logger.info(
|
|
950
|
+
f"Condition is {condition_result}, executing branch: {selected_branch.name}"
|
|
951
|
+
)
|
|
952
|
+
return selected_branch()
|
|
953
|
+
else:
|
|
954
|
+
msg = f"No branch defined for condition result: {condition_result}"
|
|
955
|
+
logger.error(msg)
|
|
956
|
+
raise ValueError(msg)
|
|
957
|
+
|
|
958
|
+
return Node(
|
|
959
|
+
func=run,
|
|
960
|
+
name=composition_name,
|
|
961
|
+
)
|
|
962
|
+
|
|
963
|
+
|
|
964
|
+
def repeat_composition(node: Node, times: int, stop_on_error: bool = False) -> Node:
|
|
965
|
+
"""重复执行节点的简化版本。
|
|
966
|
+
|
|
967
|
+
Args:
|
|
968
|
+
node: 要重复执行的节点
|
|
969
|
+
times: 重复次数
|
|
970
|
+
stop_on_error: 遇到错误时是否立即停止
|
|
971
|
+
|
|
972
|
+
Returns:
|
|
973
|
+
包装后的重复执行节点
|
|
974
|
+
"""
|
|
975
|
+
|
|
976
|
+
# 参数前置验证:立即检查参数,fail-fast原则
|
|
977
|
+
if not isinstance(times, int):
|
|
978
|
+
raise TypeError(f"times must be an integer, got {type(times).__name__}")
|
|
979
|
+
if times <= 0:
|
|
980
|
+
raise ValueError("Repeat times must be greater than 0")
|
|
981
|
+
if not isinstance(node, Node):
|
|
982
|
+
raise TypeError("node must be a Node instance")
|
|
983
|
+
|
|
984
|
+
# 创建组合节点的名称
|
|
985
|
+
composition_name = f"({node.name} * {times})"
|
|
986
|
+
|
|
987
|
+
# 检测是否包含异步节点,决定创建同步还是异步执行函数
|
|
988
|
+
if node.is_async:
|
|
989
|
+
# 异步节点:创建异步执行函数
|
|
990
|
+
async def async_run(*args: Any, **kwargs: Any) -> Any:
|
|
991
|
+
logger.info(
|
|
992
|
+
f"--- Executing Async Repeat Composition: {composition_name} ---"
|
|
993
|
+
)
|
|
994
|
+
|
|
995
|
+
last_result = None
|
|
996
|
+
errors = []
|
|
997
|
+
|
|
998
|
+
for i in range(times):
|
|
999
|
+
logger.info(f" - Iteration {i + 1}/{times}")
|
|
1000
|
+
|
|
1001
|
+
try:
|
|
1002
|
+
# 正常执行异步节点
|
|
1003
|
+
if i == 0:
|
|
1004
|
+
result = await node(*args, **kwargs)
|
|
1005
|
+
else:
|
|
1006
|
+
result = await node(last_result)
|
|
1007
|
+
|
|
1008
|
+
# 成功执行
|
|
1009
|
+
last_result = result
|
|
1010
|
+
logger.debug(f"Iteration {i + 1} completed successfully")
|
|
1011
|
+
|
|
1012
|
+
except Exception as e:
|
|
1013
|
+
errors.append(e)
|
|
1014
|
+
logger.error(f"Iteration {i + 1} failed: {e}")
|
|
1015
|
+
|
|
1016
|
+
# 检查是否应该立即停止
|
|
1017
|
+
if stop_on_error:
|
|
1018
|
+
logger.error("Stopping immediately due to stop_on_error=True")
|
|
1019
|
+
# 抛出RepeatStopException,让其被重试机制处理
|
|
1020
|
+
raise LoopControlException(
|
|
1021
|
+
f"Execution stopped due to error at iteration {i + 1}: {e}"
|
|
1022
|
+
) from e
|
|
1023
|
+
|
|
1024
|
+
# 继续执行,但使用上一次的成功结果
|
|
1025
|
+
logger.info(
|
|
1026
|
+
f"Continuing with last successful result from iteration {i}"
|
|
1027
|
+
)
|
|
1028
|
+
|
|
1029
|
+
# 正常完成
|
|
1030
|
+
logger.info(
|
|
1031
|
+
f"Async repeat composition completed. Iterations: {times}, Errors: {len(errors)}"
|
|
1032
|
+
)
|
|
1033
|
+
|
|
1034
|
+
return last_result
|
|
1035
|
+
|
|
1036
|
+
return Node(func=async_run, name=composition_name, is_start_node=False)
|
|
1037
|
+
else:
|
|
1038
|
+
# 同步节点:创建同步执行函数
|
|
1039
|
+
def run(*args: Any, **kwargs: Any) -> Any:
|
|
1040
|
+
logger.info(
|
|
1041
|
+
f"--- Executing Sync Repeat Composition: {composition_name} ---"
|
|
1042
|
+
)
|
|
1043
|
+
|
|
1044
|
+
last_result = None
|
|
1045
|
+
errors = []
|
|
1046
|
+
|
|
1047
|
+
for i in range(times):
|
|
1048
|
+
logger.info(f" - Iteration {i + 1}/{times}")
|
|
1049
|
+
|
|
1050
|
+
try:
|
|
1051
|
+
# 正常执行同步节点
|
|
1052
|
+
if i == 0:
|
|
1053
|
+
result = node(*args, **kwargs)
|
|
1054
|
+
else:
|
|
1055
|
+
result = node(last_result)
|
|
1056
|
+
|
|
1057
|
+
# 成功执行
|
|
1058
|
+
last_result = result
|
|
1059
|
+
logger.debug(f"Iteration {i + 1} completed successfully")
|
|
1060
|
+
|
|
1061
|
+
except Exception as e:
|
|
1062
|
+
errors.append(e)
|
|
1063
|
+
logger.error(f"Iteration {i + 1} failed: {e}")
|
|
1064
|
+
|
|
1065
|
+
# 检查是否应该立即停止
|
|
1066
|
+
if stop_on_error:
|
|
1067
|
+
logger.error("Stopping immediately due to stop_on_error=True")
|
|
1068
|
+
# 抛出RepeatStopException,让其被重试机制处理
|
|
1069
|
+
raise LoopControlException(
|
|
1070
|
+
f"Execution stopped due to error at iteration {i + 1}: {e}"
|
|
1071
|
+
) from e
|
|
1072
|
+
|
|
1073
|
+
# 继续执行,但使用上一次的成功结果
|
|
1074
|
+
logger.info(
|
|
1075
|
+
f"Continuing with last successful result from iteration {i}"
|
|
1076
|
+
)
|
|
1077
|
+
|
|
1078
|
+
# 正常完成
|
|
1079
|
+
logger.info(
|
|
1080
|
+
f"Sync repeat composition completed. Iterations: {times}, Errors: {len(errors)}"
|
|
1081
|
+
)
|
|
1082
|
+
|
|
1083
|
+
return last_result
|
|
1084
|
+
|
|
1085
|
+
return Node(func=run, name=composition_name, is_start_node=False)
|
|
1086
|
+
|
|
1087
|
+
|
|
1088
|
+
def node(
|
|
1089
|
+
func: Callable | None = None,
|
|
1090
|
+
*,
|
|
1091
|
+
retry_count: int = 3,
|
|
1092
|
+
name: str | None = None,
|
|
1093
|
+
retry_delay: float = 1.0,
|
|
1094
|
+
exception_types: tuple = (Exception,),
|
|
1095
|
+
backoff_factor: float = 1.0,
|
|
1096
|
+
max_delay: float = 60.0,
|
|
1097
|
+
enable_retry: bool = False,
|
|
1098
|
+
) -> Node | Callable:
|
|
1099
|
+
"""
|
|
1100
|
+
Decorator: Create Node from function with dependency injection, type validation, and retry mechanism.
|
|
1101
|
+
|
|
1102
|
+
**This is the standard and only recommended way to use dependency injection!**
|
|
1103
|
+
|
|
1104
|
+
Usage Examples:
|
|
1105
|
+
```python
|
|
1106
|
+
# Basic sync node
|
|
1107
|
+
@node
|
|
1108
|
+
def process_data(data: dict, state: dict = Provide[BaseFlowContext.state]) -> dict:
|
|
1109
|
+
result = {"processed": data["value"] * 2}
|
|
1110
|
+
state['last_result'] = result
|
|
1111
|
+
return result
|
|
1112
|
+
|
|
1113
|
+
# Async node with retry
|
|
1114
|
+
@node(retry_count=3, retry_delay=0.5)
|
|
1115
|
+
async def fetch_data(data_id: str) -> dict:
|
|
1116
|
+
# Simulate async data fetching
|
|
1117
|
+
await asyncio.sleep(0.1)
|
|
1118
|
+
return {"data_id": data_id, "value": f"fetched_{data_id}"}
|
|
1119
|
+
|
|
1120
|
+
# Custom retry configuration
|
|
1121
|
+
@node(retry_count=5, retry_delay=2.0, exception_types=(ConnectionError, TimeoutError))
|
|
1122
|
+
def external_service_call(data: dict) -> dict:
|
|
1123
|
+
return call_external_api(data)
|
|
1124
|
+
|
|
1125
|
+
# Sequential composition (async/sync mixing)
|
|
1126
|
+
flow = process_data.then(fetch_api_data).then(external_service_call)
|
|
1127
|
+
result = flow({"value": 10})
|
|
1128
|
+
|
|
1129
|
+
# Parallel execution
|
|
1130
|
+
parallel_flow = process_data.fan_out_to([
|
|
1131
|
+
fetch_api_data,
|
|
1132
|
+
external_service_call
|
|
1133
|
+
])
|
|
1134
|
+
results = parallel_flow({"value": 10})
|
|
1135
|
+
|
|
1136
|
+
# Fan-out-in pattern
|
|
1137
|
+
@node
|
|
1138
|
+
def aggregate_results(parallel_results: dict) -> str:
|
|
1139
|
+
successful = [r.result for r in parallel_results.values() if r.success]
|
|
1140
|
+
return f"Aggregated: {len(successful)} results"
|
|
1141
|
+
|
|
1142
|
+
complete_flow = process_data.fan_out_in([fetch_api_data, external_service_call], aggregate_results)
|
|
1143
|
+
final_result = complete_flow({"value": 10})
|
|
1144
|
+
|
|
1145
|
+
# Disable retry for specific node
|
|
1146
|
+
@node(enable_retry=False)
|
|
1147
|
+
def no_retry_operation(data: dict) -> dict:
|
|
1148
|
+
return {"immediate": data}
|
|
1149
|
+
```
|
|
1150
|
+
|
|
1151
|
+
Args:
|
|
1152
|
+
func: Function to be decorated
|
|
1153
|
+
name: Node identifier name
|
|
1154
|
+
retry_count: Maximum retry attempts (default: 3)
|
|
1155
|
+
retry_delay: Base retry delay in seconds (default: 1.0)
|
|
1156
|
+
exception_types: Tuple of exception types to retry (default: (Exception,))
|
|
1157
|
+
backoff_factor: Backoff multiplier for exponential backoff (default: 1.0)
|
|
1158
|
+
max_delay: Maximum delay time in seconds (default: 60.0)
|
|
1159
|
+
enable_retry: Enable/disable retry mechanism (default: True)
|
|
1160
|
+
|
|
1161
|
+
Returns:
|
|
1162
|
+
Node instance or decorator function
|
|
1163
|
+
|
|
1164
|
+
Notes:
|
|
1165
|
+
- Supports both sync and async functions with intelligent retry handling
|
|
1166
|
+
- Async functions use `asyncio.sleep()` for delays, sync functions use `time.sleep()`
|
|
1167
|
+
- Node objects have `is_async` property indicating if they require async execution
|
|
1168
|
+
- Smart async detection works with Node objects, decorators, and plain functions
|
|
1169
|
+
- If using BaseFlowContext dependency injection, @node decorator is required
|
|
1170
|
+
- Container must be wired before usage: `container.wire(modules=[__name__])`
|
|
1171
|
+
- Retry mechanism is enabled by default and works for both sync and async nodes
|
|
1172
|
+
- Sequential composition automatically handles async/sync mixing via Node.__call__()
|
|
1173
|
+
- Parallel execution supports mixed async/sync nodes with auto executor selection
|
|
1174
|
+
"""
|
|
1175
|
+
config = RetryConfig(
|
|
1176
|
+
retry_count, retry_delay, exception_types, backoff_factor, max_delay
|
|
1177
|
+
)
|
|
1178
|
+
|
|
1179
|
+
@functools.wraps(Node)
|
|
1180
|
+
def decorator(f: Callable) -> Node:
|
|
1181
|
+
node_name = name or _get_func_name(f, "unnamed_node")
|
|
1182
|
+
|
|
1183
|
+
# 在装饰之前检测原始函数是否为异步函数
|
|
1184
|
+
is_original_async = inspect.iscoroutinefunction(f)
|
|
1185
|
+
|
|
1186
|
+
# 使用functools.reduce应用装饰器链
|
|
1187
|
+
# inject必须在最外层,确保Provide[...]在验证/重试之前被解析为实际值
|
|
1188
|
+
decorators = [
|
|
1189
|
+
custom_validate_call(
|
|
1190
|
+
validate_return=True,
|
|
1191
|
+
config=ConfigDict(arbitrary_types_allowed=True),
|
|
1192
|
+
node_name=node_name,
|
|
1193
|
+
),
|
|
1194
|
+
]
|
|
1195
|
+
if enable_retry:
|
|
1196
|
+
decorators.append(retry_decorator(config=config, node_name=node_name))
|
|
1197
|
+
decorators.append(inject)
|
|
1198
|
+
|
|
1199
|
+
decorated_func = functools.reduce(lambda func, deco: deco(func), decorators, f)
|
|
1200
|
+
|
|
1201
|
+
return Node(func=decorated_func, name=node_name, is_async=is_original_async)
|
|
1202
|
+
|
|
1203
|
+
# 支持两种调用方式:@node 和 @node(...)
|
|
1204
|
+
if func is None:
|
|
1205
|
+
# @node(...) 带参数调用
|
|
1206
|
+
return decorator
|
|
1207
|
+
else:
|
|
1208
|
+
# @node 直接调用
|
|
1209
|
+
return decorator(func)
|
|
1210
|
+
|
|
1211
|
+
|
|
1212
|
+
# Export list - 只暴露用户需要的公共接口
|
|
1213
|
+
__all__ = [
|
|
1214
|
+
# 装饰器:创建节点
|
|
1215
|
+
"node",
|
|
1216
|
+
# 验证装饰器:自定义验证异常处理
|
|
1217
|
+
"custom_validate_call",
|
|
1218
|
+
# 上下文:自定义依赖注入
|
|
1219
|
+
"BaseFlowContext",
|
|
1220
|
+
# 并行执行结果模型
|
|
1221
|
+
"ParallelResult",
|
|
1222
|
+
# 异常类
|
|
1223
|
+
"StreamletException",
|
|
1224
|
+
"ValidationInputException",
|
|
1225
|
+
"ValidationOutputException",
|
|
1226
|
+
"UserBusinessException",
|
|
1227
|
+
"NodeExecutionException",
|
|
1228
|
+
"NodeTimeoutException",
|
|
1229
|
+
"NodeRetryExhaustedException",
|
|
1230
|
+
"LoopControlException",
|
|
1231
|
+
# 重试相关
|
|
1232
|
+
"RetryConfig",
|
|
1233
|
+
]
|