streamlet-py 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
streamlet/__init__.py ADDED
@@ -0,0 +1,1233 @@
1
+ import asyncio
2
+ import functools
3
+ import inspect
4
+ import logging
5
+ import time
6
+ from collections.abc import Callable
7
+ from concurrent.futures import ThreadPoolExecutor, as_completed
8
+ from contextvars import ContextVar
9
+ from dataclasses import dataclass
10
+ from typing import Any
11
+
12
+ from dependency_injector import containers, providers
13
+ from dependency_injector.wiring import inject
14
+ from pydantic import ConfigDict, TypeAdapter, ValidationError, validate_call
15
+
16
+ logger = logging.getLogger("streamlet")
17
+
18
+
19
+ @dataclass
20
+ class ParallelResult:
21
+ """Pydantic model for recording parallel execution results and exception stacks."""
22
+
23
+ node_name: str
24
+ success: bool
25
+ result: Any = None
26
+ error: str | None = None
27
+ error_traceback: str | None = None
28
+ execution_time: float | None = None
29
+
30
+
31
+ # ==================== 异常类型体系 ====================
32
+ class StreamletException(Exception):
33
+ """Streamlet框架基础异常类"""
34
+
35
+ retryable = False # 默认框架异常不重试
36
+
37
+ def __init__(
38
+ self, message: str, node_name: str | None = None, **kwargs: Any
39
+ ) -> None:
40
+ self.node_name = node_name
41
+ self.context = kwargs
42
+ super().__init__(message)
43
+
44
+
45
+ class ValidationInputException(StreamletException):
46
+ """参数验证异常 - validate_call前置校验失败"""
47
+
48
+ retryable = False # 参数验证失败不应该重试
49
+
50
+ def __init__(
51
+ self,
52
+ message: str,
53
+ validation_error: Any = None,
54
+ node_name: str | None = None,
55
+ **kwargs: Any,
56
+ ) -> None:
57
+ self.validation_error = validation_error
58
+ super().__init__(message, node_name, **kwargs)
59
+
60
+
61
+ class ValidationOutputException(StreamletException):
62
+ """返回值验证异常 - validate_call返回值校验失败"""
63
+
64
+ retryable = False # 返回值验证失败不应该重试
65
+
66
+ def __init__(
67
+ self,
68
+ message: str,
69
+ validation_error: Any = None,
70
+ node_name: str | None = None,
71
+ **kwargs: Any,
72
+ ) -> None:
73
+ self.validation_error = validation_error
74
+ super().__init__(message, node_name, **kwargs)
75
+
76
+
77
+ class UserBusinessException(StreamletException):
78
+ """用户业务异常基类 - 用户可自定义重试策略"""
79
+
80
+ retryable = True # 默认用户业务异常可重试
81
+
82
+ def __init__(
83
+ self,
84
+ message: str,
85
+ retryable: bool | None = None,
86
+ node_name: str | None = None,
87
+ **kwargs: Any,
88
+ ) -> None:
89
+ # 允许用户在实例化时覆盖重试策略
90
+ if retryable is not None:
91
+ self.retryable = retryable
92
+ super().__init__(message, node_name, **kwargs)
93
+
94
+
95
+ class NodeExecutionException(StreamletException):
96
+ """节点执行异常"""
97
+
98
+ def __init__(
99
+ self,
100
+ message: str,
101
+ node_name: str | None = None,
102
+ original_exception: Exception | None = None,
103
+ **kwargs: Any,
104
+ ) -> None:
105
+ self.original_exception = original_exception
106
+ super().__init__(message, node_name, **kwargs)
107
+
108
+
109
+ class NodeTimeoutException(NodeExecutionException):
110
+ """节点执行超时异常"""
111
+
112
+ def __init__(
113
+ self,
114
+ message: str,
115
+ node_name: str | None = None,
116
+ timeout_seconds: float | None = None,
117
+ **kwargs: Any,
118
+ ) -> None:
119
+ self.timeout_seconds = timeout_seconds
120
+ super().__init__(message, node_name, **kwargs)
121
+
122
+
123
+ class NodeRetryExhaustedException(NodeExecutionException):
124
+ """节点重试次数耗尽异常"""
125
+
126
+ def __init__(
127
+ self,
128
+ message: str,
129
+ node_name: str | None = None,
130
+ retry_count: int | None = None,
131
+ last_exception: Exception | None = None,
132
+ **kwargs: Any,
133
+ ) -> None:
134
+ self.retry_count = retry_count
135
+ self.last_exception = last_exception
136
+ super().__init__(
137
+ message, node_name, original_exception=last_exception, **kwargs
138
+ )
139
+
140
+
141
+ class LoopControlException(StreamletException):
142
+ """循环控制异常基类"""
143
+
144
+ pass
145
+
146
+
147
+ # ==================== 重试装饰器 ====================
148
+
149
+
150
+ class RetryConfig:
151
+ """重试配置类"""
152
+
153
+ def __init__(
154
+ self,
155
+ retry_count: int = 3,
156
+ retry_delay: float = 1.0,
157
+ exception_types: tuple = (Exception,),
158
+ backoff_factor: float = 1.0,
159
+ max_delay: float = 60.0,
160
+ ):
161
+ if retry_count < 0:
162
+ raise ValueError(f"retry_count must be >= 0, got {retry_count}")
163
+ if retry_delay < 0:
164
+ raise ValueError(f"retry_delay must be >= 0, got {retry_delay}")
165
+ if max_delay < 0:
166
+ raise ValueError(f"max_delay must be >= 0, got {max_delay}")
167
+ self.retry_count = retry_count
168
+ self.retry_delay = retry_delay
169
+ self.exception_types = exception_types
170
+ self.backoff_factor = backoff_factor
171
+ self.max_delay = max_delay
172
+
173
+ def should_retry(self, exception: Exception) -> bool:
174
+ """判断是否应该重试 - 优先检查retryable属性,否则使用isinstance检查继承关系"""
175
+ # 如果异常有retryable属性,优先使用
176
+ if hasattr(exception, "retryable"):
177
+ return bool(exception.retryable)
178
+
179
+ # 否则使用isinstance检查异常是否属于指定类型(包括继承关系)
180
+ return isinstance(exception, self.exception_types)
181
+
182
+ def get_delay(self, attempt: int) -> float:
183
+ """计算重试延迟时间(支持指数退避)"""
184
+ delay = self.retry_delay * (self.backoff_factor**attempt)
185
+ return min(delay, self.max_delay)
186
+
187
+
188
+ def _get_func_name(func: Any, fallback_name: str | None = None) -> str:
189
+ """安全获取函数名称"""
190
+ if hasattr(func, "__name__"):
191
+ return str(func.__name__)
192
+ elif hasattr(func, "func") and hasattr(func.func, "__name__"): # partial对象
193
+ return str(func.func.__name__)
194
+ elif hasattr(func, "name"): # Node对象
195
+ return str(func.name)
196
+ elif fallback_name:
197
+ return fallback_name
198
+ else:
199
+ return "unknown_function"
200
+
201
+
202
+ def retry_decorator(
203
+ config: RetryConfig,
204
+ node_name: str | None = None,
205
+ ) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
206
+ """重试装饰器
207
+
208
+ Args:
209
+ config: RetryConfig 配置模型
210
+ node_name: 节点名称,用于异常信息
211
+ """
212
+
213
+ def decorator(func: Callable) -> Callable:
214
+ func_name = node_name or _get_func_name(func)
215
+
216
+ if inspect.iscoroutinefunction(func):
217
+ # 异步函数wrapper
218
+ @functools.wraps(func)
219
+ async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
220
+ for attempt in range(config.retry_count + 1):
221
+ try:
222
+ logger.debug(
223
+ f"执行节点 {func_name},尝试 {attempt + 1}/{config.retry_count + 1}"
224
+ )
225
+ result = await func(*args, **kwargs) # 异步调用
226
+
227
+ if attempt > 0:
228
+ logger.info(
229
+ f"节点 {func_name} 在第 {attempt + 1} 次尝试后成功"
230
+ )
231
+ return result
232
+
233
+ except (KeyboardInterrupt, SystemExit):
234
+ raise
235
+ except Exception as e:
236
+ if not config.should_retry(e):
237
+ # 记录不重试的原因
238
+ logger.debug(
239
+ f"节点 {func_name} 异常不支持重试: {type(e).__name__}: {e}"
240
+ )
241
+ raise # 直接抛出,不封装
242
+
243
+ if attempt == config.retry_count:
244
+ # 记录重试耗尽
245
+ logger.error(
246
+ f"节点 {func_name} 重试 {config.retry_count} 次后仍失败: {type(e).__name__}: {e}"
247
+ )
248
+ raise # 重试耗尽也直接抛出,不封装
249
+
250
+ delay = config.get_delay(attempt)
251
+ logger.warning(
252
+ f"节点 {func_name} 第 {attempt + 1} 次尝试失败: {e},{delay:.2f}秒后重试"
253
+ )
254
+ await asyncio.sleep(delay) # 异步延迟
255
+
256
+ return async_wrapper
257
+ else:
258
+ # 同步函数wrapper
259
+ @functools.wraps(func)
260
+ def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
261
+ for attempt in range(config.retry_count + 1):
262
+ try:
263
+ logger.debug(
264
+ f"执行节点 {func_name},尝试 {attempt + 1}/{config.retry_count + 1}"
265
+ )
266
+ result = func(*args, **kwargs) # 同步调用
267
+
268
+ if attempt > 0:
269
+ logger.info(
270
+ f"节点 {func_name} 在第 {attempt + 1} 次尝试后成功"
271
+ )
272
+ return result
273
+
274
+ except (KeyboardInterrupt, SystemExit):
275
+ raise
276
+ except Exception as e:
277
+ if not config.should_retry(e):
278
+ # 记录不重试的原因
279
+ logger.debug(
280
+ f"节点 {func_name} 异常不支持重试: {type(e).__name__}: {e}"
281
+ )
282
+ raise # 直接抛出,不封装
283
+
284
+ if attempt == config.retry_count:
285
+ # 记录重试耗尽
286
+ logger.error(
287
+ f"节点 {func_name} 重试 {config.retry_count} 次后仍失败: {type(e).__name__}: {e}"
288
+ )
289
+ raise # 重试耗尽也直接抛出,不封装
290
+
291
+ delay = config.get_delay(attempt)
292
+ logger.warning(
293
+ f"节点 {func_name} 第 {attempt + 1} 次尝试失败: {e},{delay:.2f}秒后重试"
294
+ )
295
+ time.sleep(delay) # 同步延迟
296
+
297
+ return sync_wrapper
298
+
299
+ return decorator
300
+
301
+
302
+ # Context variables for asyncio coroutine safety
303
+ _context_state: ContextVar[dict | None] = ContextVar("streamlet_state", default=None)
304
+ _context_context: ContextVar[dict | None] = ContextVar(
305
+ "streamlet_context", default=None
306
+ )
307
+
308
+
309
+ # ==================== 自定义ContextVar Provider ====================
310
+
311
+
312
+ class ContextVarProvider(providers.Provider):
313
+ """自定义Provider类,支持ContextVar的协程安全依赖注入。
314
+
315
+ 这个Provider替代了直接调用ContextVar.get()的方式,
316
+ 提供了正确的dependency-injector集成。
317
+ """
318
+
319
+ def __init__(self, default_factory: Callable[[], Any] = dict):
320
+ """初始化ContextVarProvider。
321
+
322
+ Args:
323
+ default_factory: 创建默认值的工厂函数,默认为dict
324
+ """
325
+ super().__init__()
326
+ self._context_var = ContextVar(f"streamlet_{id(self)}", default=None)
327
+ self._default_factory = default_factory
328
+
329
+ def _provide(self, *args: Any, **kwargs: Any) -> Any:
330
+ """提供协程安全的状态值。
331
+
332
+ Returns:
333
+ ContextVar中的值,如果未设置则返回默认值
334
+ """
335
+ try:
336
+ value = self._context_var.get()
337
+ if value is None:
338
+ # 如果未设置,创建并设置默认值
339
+ value = self._default_factory()
340
+ self._context_var.set(value)
341
+ return value
342
+ except LookupError:
343
+ # 如果ContextVar未初始化,创建默认值
344
+ value = self._default_factory()
345
+ self._context_var.set(value)
346
+ return value
347
+
348
+
349
+ class BaseFlowContext(containers.DeclarativeContainer):
350
+ """Base container for flow context with thread-safe and coroutine-safe dependency injection support."""
351
+
352
+ # Use ThreadLocalSingleton for thread-local state isolation
353
+ # Each thread gets its own state dictionary
354
+ state: providers.Provider = providers.ThreadLocalSingleton(dict)
355
+ context: providers.Provider = providers.ThreadLocalSingleton(dict)
356
+ shared_data: providers.Provider = providers.Singleton(dict)
357
+
358
+ # Coroutine-safe providers using ContextVar for asyncio
359
+ async_state: providers.Provider = ContextVarProvider(dict)
360
+ async_context: providers.Provider = ContextVarProvider(dict)
361
+
362
+
363
+ def custom_validate_call(
364
+ validate_return: bool = True,
365
+ config: ConfigDict | None = None,
366
+ node_name: str | None = None,
367
+ ) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
368
+ """
369
+ 自定义validate_call包装器,使用Pydantic最佳实践区分输入验证和输出验证异常
370
+
371
+ Args:
372
+ validate_return: 是否验证返回值
373
+ config: Pydantic配置
374
+ node_name: 节点名称用于异常信息
375
+
376
+ Returns:
377
+ 装饰器函数
378
+ """
379
+
380
+ def decorator(func: Callable) -> Callable:
381
+ # 获取函数签名
382
+ sig = inspect.signature(func)
383
+
384
+ # 创建输入验证器 - 只验证参数,不验证返回值
385
+ input_validator = validate_call(
386
+ validate_return=False,
387
+ config=config or ConfigDict(arbitrary_types_allowed=True),
388
+ )(func)
389
+
390
+ # 创建返回值验证器(如果需要且有返回值类型注解)
391
+ return_type_adapter = None
392
+ if validate_return and sig.return_annotation != inspect.Signature.empty:
393
+ return_type_adapter = TypeAdapter(sig.return_annotation)
394
+
395
+ # 提取公共逻辑
396
+ func_name = node_name or _get_func_name(func)
397
+
398
+ def create_input_exception(e: ValidationError) -> ValidationInputException:
399
+ return ValidationInputException(
400
+ f"输入参数验证失败: {e}",
401
+ validation_error=e,
402
+ node_name=func_name,
403
+ )
404
+
405
+ def create_output_exception(e: ValidationError) -> ValidationOutputException:
406
+ return ValidationOutputException(
407
+ f"返回值验证失败: {e}",
408
+ validation_error=e,
409
+ node_name=func_name,
410
+ )
411
+
412
+ def validate_result(result: Any) -> Any:
413
+ if return_type_adapter:
414
+ try:
415
+ return_type_adapter.validate_python(result)
416
+ except ValidationError as e:
417
+ raise create_output_exception(e) from e
418
+ return result
419
+
420
+ # 根据函数类型提供对应的wrapper
421
+ if inspect.iscoroutinefunction(func):
422
+
423
+ @functools.wraps(func)
424
+ async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
425
+ try:
426
+ result = await input_validator(*args, **kwargs)
427
+ except ValidationError as e:
428
+ raise create_input_exception(e) from e
429
+ return validate_result(result)
430
+
431
+ return async_wrapper
432
+ else:
433
+
434
+ @functools.wraps(func)
435
+ def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
436
+ try:
437
+ result = input_validator(*args, **kwargs)
438
+ except ValidationError as e:
439
+ raise create_input_exception(e) from e
440
+ return validate_result(result)
441
+
442
+ return sync_wrapper
443
+
444
+ return decorator
445
+
446
+
447
+ class Node:
448
+ """A node in the execution graph that supports fluent interface methods."""
449
+
450
+ def __init__(
451
+ self,
452
+ func: Callable,
453
+ name: str,
454
+ is_start_node: bool = True,
455
+ is_async: bool | None = None,
456
+ ):
457
+ # 配置Pydantic支持任意类型(包括dependency injection的类型)
458
+ self.func = func
459
+ self.name = name
460
+ self.is_start_node = is_start_node
461
+ # 智能检测异步特性:处理Node对象、装饰器等复杂情况
462
+ if is_async is not None:
463
+ # 显式传入,直接使用
464
+ self.is_async = is_async
465
+ elif isinstance(func, Node):
466
+ # func是Node对象,使用其is_async属性
467
+ self.is_async = func.is_async
468
+ else:
469
+ # func是普通函数,使用inspect检测
470
+ self.is_async = inspect.iscoroutinefunction(func)
471
+
472
+ def __call__(self, *args: Any, **kwargs: Any) -> Any:
473
+ """智能异步适配调用:根据函数类型和执行上下文自动处理同步/异步调用。"""
474
+ if self.is_async:
475
+ # 异步函数:根据当前执行上下文智能处理
476
+ try:
477
+ # 检查是否在事件循环中
478
+ asyncio.get_running_loop()
479
+ # 在事件循环中,返回协程对象让调用者await
480
+ return self.func(*args, **kwargs)
481
+ except RuntimeError:
482
+ # 不在事件循环中,创建新事件循环同步执行
483
+ return asyncio.run(self.func(*args, **kwargs))
484
+ else:
485
+ # 同步函数:直接执行
486
+ return self.func(*args, **kwargs)
487
+
488
+ def then(self, next_node: "Node") -> "Node":
489
+ """Chain this node with another node for sequential execution."""
490
+ return sequential_composition(self, next_node)
491
+
492
+ def fan_out_to(
493
+ self,
494
+ nodes: list["Node"],
495
+ executor: str = "thread",
496
+ max_workers: int | None = None,
497
+ ) -> "Node":
498
+ """Fan out to multiple nodes for parallel execution."""
499
+ # Normalize executor type to lowercase for case-insensitive comparison
500
+ executor_lower = executor.lower()
501
+ if executor_lower not in ["thread", "async", "auto"]:
502
+ raise ValueError(
503
+ "Only 'thread', 'async', and 'auto' executors are supported. ProcessPoolExecutor has been removed to resolve pickle serialization issues."
504
+ )
505
+ return parallel_fan_out(self, nodes, executor_lower, max_workers)
506
+
507
+ def fan_in(self, aggregator: "Node") -> "Node":
508
+ """Aggregate results using the specified aggregator node."""
509
+ return parallel_fan_in(self, aggregator)
510
+
511
+ def fan_out_in(
512
+ self,
513
+ targets: list["Node"],
514
+ aggregator: "Node",
515
+ executor: str = "thread",
516
+ max_workers: int | None = None,
517
+ ) -> "Node":
518
+ """Complete fan-out and fan-in operation in one step."""
519
+ # Normalize executor type to lowercase for case-insensitive comparison
520
+ executor_lower = executor.lower()
521
+ if executor_lower not in ["thread", "async", "auto"]:
522
+ raise ValueError(
523
+ "Only 'thread', 'async', and 'auto' executors are supported. ProcessPoolExecutor has been removed to resolve pickle serialization issues."
524
+ )
525
+ return parallel_fan_out_in(
526
+ self, targets, aggregator, executor_lower, max_workers
527
+ )
528
+
529
+ def branch_on(self, conditions: dict[bool, "Node"]) -> "Node":
530
+ """Branch execution based on the boolean output of this node."""
531
+ return conditional_composition(self, conditions)
532
+
533
+ def repeat(self, times: int, stop_on_error: bool = False) -> "Node":
534
+ """重复执行此节点。
535
+
536
+ Args:
537
+ times: 重复次数
538
+ stop_on_error: 遇到错误时是否立即停止
539
+ """
540
+ return repeat_composition(self, times, stop_on_error)
541
+
542
+ def __repr__(self) -> str:
543
+ return f"Node(name='{self.name}')"
544
+
545
+
546
+ def sequential_composition(left: Node, right: Node) -> Node:
547
+ """Sequential execution that combines two nodes with one-time type inference."""
548
+
549
+ # 组合时一次性检测是否有异步节点
550
+ has_async = left.is_async or right.is_async
551
+
552
+ composition_name = f"({left.name} -> {right.name})"
553
+
554
+ if has_async:
555
+ # 如果包含异步节点,创建异步组合函数
556
+ async def async_run(*args: Any, **kwargs: Any) -> Any:
557
+ # 执行左节点,使用Node.__call__智能适配
558
+ left_result = (
559
+ await left(*args, **kwargs) if left.is_async else left(*args, **kwargs)
560
+ )
561
+
562
+ # 执行右节点,使用Node.__call__智能适配
563
+ right_result = (
564
+ await right(left_result) if right.is_async else right(left_result)
565
+ )
566
+
567
+ return right_result
568
+
569
+ return Node(func=async_run, name=composition_name, is_start_node=False)
570
+ else:
571
+ # 如果都是同步节点,创建同步组合函数
572
+ def run(*args: Any, **kwargs: Any) -> Any:
573
+ left_result = left(*args, **kwargs)
574
+ right_result = right(left_result)
575
+ return right_result
576
+
577
+ return Node(func=run, name=composition_name, is_start_node=False)
578
+
579
+
580
+ def _generate_unique_result_key(base_name: str, existing_results: dict) -> str:
581
+ """生成唯一的结果键,避免重复覆盖
582
+
583
+ Args:
584
+ base_name: 基础键名(通常是节点名称)
585
+ existing_results: 现有的结果字典
586
+
587
+ Returns:
588
+ 唯一的键名,如果base_name无冲突则返回原名,否则返回带数字后缀的名称
589
+ """
590
+ if base_name not in existing_results:
591
+ return base_name
592
+
593
+ counter = 1
594
+ unique_key = f"{base_name}[{counter}]"
595
+ while unique_key in existing_results:
596
+ counter += 1
597
+ unique_key = f"{base_name}[{counter}]"
598
+
599
+ return unique_key
600
+
601
+
602
+ # 定义并行任务执行函数
603
+ def execute_target_node(node: Node, input_data: Any) -> ParallelResult:
604
+ """Execute a single target node with the provided input using intelligent async/sync handling."""
605
+ import traceback
606
+
607
+ start_time = time.time()
608
+
609
+ try:
610
+ # 使用Node.__call__智能异步适配
611
+ result = node(input_data)
612
+ execution_time = time.time() - start_time
613
+
614
+ return ParallelResult(
615
+ node_name=node.name,
616
+ success=True,
617
+ result=result,
618
+ execution_time=execution_time,
619
+ )
620
+ except Exception as e:
621
+ execution_time = time.time() - start_time
622
+ error_traceback = traceback.format_exc()
623
+ logger.error(f"Node '{node.name}' failed: {e}")
624
+
625
+ return ParallelResult(
626
+ node_name=node.name,
627
+ success=False,
628
+ error=str(e),
629
+ error_traceback=error_traceback,
630
+ execution_time=execution_time,
631
+ )
632
+
633
+
634
+ def parallel_fan_out(
635
+ source: Node,
636
+ targets: list[Node],
637
+ executor: str = "thread",
638
+ max_workers: int | None = None,
639
+ ) -> Node:
640
+ """
641
+ Simplified parallel fan-out execution with type-based executor selection.
642
+
643
+ Args:
644
+ source: Source node to execute first
645
+ targets: List of target nodes for parallel execution
646
+ executor: 'thread', 'async', or 'auto' for automatic selection
647
+ max_workers: Maximum worker threads (ignored for async)
648
+
649
+ Returns:
650
+ Node that performs parallel fan-out execution
651
+ """
652
+ if not targets:
653
+ raise ValueError("Target nodes list cannot be empty")
654
+
655
+ executor = executor.lower()
656
+
657
+ # Simplified 'auto' executor selection based on one-time type inference
658
+ if executor == "auto":
659
+ all_nodes = [source] + targets
660
+ has_async = any(node.is_async for node in all_nodes)
661
+ executor = "async" if has_async else "thread"
662
+ logger.info(f"Auto-selected executor '{executor}' based on node types")
663
+
664
+ if executor not in ["thread", "async"]:
665
+ raise ValueError("Only 'thread', 'async', and 'auto' executors are supported.")
666
+
667
+ target_names = [t.name for t in targets]
668
+ composition_name = f"({source.name} -> [{', '.join(target_names)}])"
669
+
670
+ if executor == "async":
671
+ # Simplified async version using Node.__call__ smart adaptation
672
+ async def run_async(*args: Any, **kwargs: Any) -> dict[str, ParallelResult]:
673
+ logger.info(f"Executing Async Parallel Fan-Out: {composition_name}")
674
+
675
+ # Execute source node with consistent async handling
676
+ source_result = (
677
+ await source(*args, **kwargs)
678
+ if source.is_async
679
+ else source(*args, **kwargs)
680
+ )
681
+
682
+ # Execute target nodes in parallel
683
+ async def execute_async_target(
684
+ node: Node, input_data: Any
685
+ ) -> ParallelResult:
686
+ start_time = time.time()
687
+ try:
688
+ # Consistent async handling with sequential_composition
689
+ result = (
690
+ await node(input_data) if node.is_async else node(input_data)
691
+ )
692
+
693
+ execution_time = time.time() - start_time
694
+ return ParallelResult(
695
+ node_name=node.name,
696
+ success=True,
697
+ result=result,
698
+ execution_time=execution_time,
699
+ )
700
+ except Exception as e:
701
+ import traceback
702
+
703
+ execution_time = time.time() - start_time
704
+ return ParallelResult(
705
+ node_name=node.name,
706
+ success=False,
707
+ error=str(e),
708
+ error_traceback=traceback.format_exc(),
709
+ execution_time=execution_time,
710
+ )
711
+
712
+ # Create and execute tasks
713
+ tasks = [execute_async_target(node, source_result) for node in targets]
714
+ results = await asyncio.gather(*tasks)
715
+
716
+ # Collect results with unique keys
717
+ parallel_results: dict[str, ParallelResult] = {}
718
+ for result in results:
719
+ result_key = _generate_unique_result_key(
720
+ result.node_name, parallel_results
721
+ )
722
+ parallel_results[result_key] = result
723
+
724
+ logger.info(
725
+ f"Async parallel fan-out completed with {len(parallel_results)} results"
726
+ )
727
+ return parallel_results
728
+
729
+ return Node(func=run_async, name=composition_name)
730
+
731
+ else:
732
+ # Simplified thread version using Node.__call__ smart adaptation
733
+ def run_thread(*args: Any, **kwargs: Any) -> dict[str, ParallelResult]:
734
+ logger.info(f"Executing Thread Parallel Fan-Out: {composition_name}")
735
+
736
+ # Execute source node with smart adaptation
737
+ source_result = source(*args, **kwargs)
738
+
739
+ # Execute target nodes in parallel using ThreadPoolExecutor
740
+ parallel_results: dict[str, ParallelResult] = {}
741
+
742
+ with ThreadPoolExecutor(max_workers=max_workers) as executor_instance:
743
+ # Submit all parallel tasks
744
+ future_to_node = {
745
+ executor_instance.submit(
746
+ execute_target_node, node, source_result
747
+ ): node
748
+ for node in targets
749
+ }
750
+
751
+ # Collect results
752
+ for future in as_completed(future_to_node):
753
+ node = future_to_node[future]
754
+ try:
755
+ parallel_result = future.result()
756
+
757
+ # Generate unique result key
758
+ result_key = _generate_unique_result_key(
759
+ parallel_result.node_name, parallel_results
760
+ )
761
+
762
+ parallel_results[result_key] = parallel_result
763
+ logger.debug(
764
+ f"Collected result from '{result_key}': success={parallel_result.success}"
765
+ )
766
+
767
+ except Exception as e:
768
+ import traceback
769
+
770
+ logger.error(f"Failed to get result from '{node.name}': {e}")
771
+
772
+ # Generate unique result key to avoid key overwriting in exception cases
773
+ error_result_key = _generate_unique_result_key(
774
+ node.name, parallel_results
775
+ )
776
+ parallel_results[error_result_key] = ParallelResult(
777
+ node_name=node.name,
778
+ success=False,
779
+ error=str(e),
780
+ error_traceback=traceback.format_exc(),
781
+ )
782
+
783
+ # Return parallel results
784
+ logger.info(
785
+ f"Thread parallel fan-out completed with {len(parallel_results)} results"
786
+ )
787
+ return parallel_results
788
+
789
+ return Node(func=run_thread, name=composition_name)
790
+
791
+
792
+ def parallel_fan_in(fan_out_node: Node, aggregator: Node) -> Node:
793
+ """
794
+ Simplified parallel fan-in aggregation with intelligent async/sync handling.
795
+
796
+ Args:
797
+ fan_out_node: The fan-out node to execute first
798
+ aggregator: The aggregator node that combines results
799
+
800
+ Returns:
801
+ Node that performs fan-in aggregation
802
+ """
803
+
804
+ # 检测是否包含异步节点,与sequential_composition保持一致
805
+ has_async = fan_out_node.is_async or aggregator.is_async
806
+
807
+ composition_name = f"({fan_out_node.name} -> {aggregator.name})"
808
+
809
+ if has_async:
810
+ # 包含异步节点,创建异步组合函数
811
+ async def async_run(*args: Any, **kwargs: Any) -> Any:
812
+ logger.info(f"Executing Parallel Fan-In (Async): {composition_name}")
813
+
814
+ # 执行fan-out节点,智能适配异步/同步
815
+ fan_out_result = (
816
+ await fan_out_node(*args, **kwargs)
817
+ if fan_out_node.is_async
818
+ else fan_out_node(*args, **kwargs)
819
+ )
820
+
821
+ # 执行聚合器,智能适配异步/同步
822
+ aggregator_result = (
823
+ await aggregator(fan_out_result)
824
+ if aggregator.is_async
825
+ else aggregator(fan_out_result)
826
+ )
827
+
828
+ logger.info("Fan-in aggregation completed successfully")
829
+ return aggregator_result
830
+
831
+ return Node(func=async_run, name=composition_name, is_start_node=False)
832
+ else:
833
+ # 都是同步节点,创建同步组合函数
834
+ def run(*args: Any, **kwargs: Any) -> Any:
835
+ logger.info(f"Executing Parallel Fan-In (Sync): {composition_name}")
836
+
837
+ # 执行fan-out节点,获取并行结果
838
+ parallel_results = fan_out_node(*args, **kwargs)
839
+
840
+ # 将并行结果作为参数传递给聚合器
841
+ aggregator_result = aggregator(parallel_results)
842
+
843
+ logger.info("Fan-in aggregation completed successfully")
844
+ return aggregator_result
845
+
846
+ return Node(func=run, name=composition_name, is_start_node=False)
847
+
848
+
849
+ def parallel_fan_out_in(
850
+ source: Node,
851
+ targets: list[Node],
852
+ aggregator: Node,
853
+ executor: str = "thread",
854
+ max_workers: int | None = None,
855
+ ) -> Node:
856
+ """
857
+ Convenience function that combines fan-out and fan-in into a single operation.
858
+
859
+ Args:
860
+ source: Source node to execute first
861
+ targets: List of target nodes for parallel execution
862
+ aggregator: Aggregator node that combines parallel results
863
+ executor: 'thread', 'async', or 'auto' for automatic selection
864
+ max_workers: Maximum worker threads (ignored for async)
865
+
866
+ Returns:
867
+ Node that performs complete fan-out-in operation
868
+ """
869
+ # Normalize executor type to lowercase for case-insensitive comparison
870
+ executor = executor.lower()
871
+
872
+ # Handle 'auto' executor selection using simplified one-time type inference
873
+ if executor == "auto":
874
+ # Analyze all nodes (source + targets + aggregator) for optimal executor choice
875
+ all_nodes = [source] + targets + [aggregator]
876
+ has_async = any(node.is_async for node in all_nodes)
877
+ executor = "async" if has_async else "thread"
878
+ logger.info(
879
+ f"Auto-selected executor '{executor}' based on node types for fan-out-in"
880
+ )
881
+
882
+ if executor not in ["thread", "async"]:
883
+ raise ValueError(
884
+ "Only 'thread', 'async', and 'auto' executors are supported. ProcessPoolExecutor has been removed to resolve pickle serialization issues."
885
+ )
886
+ # 创建fan-out节点
887
+ fan_out_node = parallel_fan_out(
888
+ source=source, targets=targets, executor=executor, max_workers=max_workers
889
+ )
890
+
891
+ # 创建fan-in节点
892
+ return parallel_fan_in(fan_out_node, aggregator)
893
+
894
+
895
+ def conditional_composition(condition_node: Node, branches: dict[Any, Node]) -> Node:
896
+ """Conditional branching based on boolean output."""
897
+
898
+ # 检测是否有异步节点
899
+ has_async = condition_node.is_async or any(
900
+ branch.is_async for branch in branches.values()
901
+ )
902
+
903
+ branch_names = {k: v.name for k, v in branches.items()}
904
+ composition_name = f"({condition_node.name} ? {branch_names})"
905
+
906
+ if has_async:
907
+ # 异步版本
908
+ async def async_run(*args: Any, **kwargs: Any) -> Any:
909
+ logger.info(f"--- Executing Conditional Branch: {composition_name} ---")
910
+
911
+ # 执行条件节点
912
+ condition_result = (
913
+ await condition_node(*args, **kwargs)
914
+ if condition_node.is_async
915
+ else condition_node(*args, **kwargs)
916
+ )
917
+
918
+ # 执行对应的分支
919
+ if condition_result in branches:
920
+ selected_branch = branches[condition_result]
921
+ logger.info(
922
+ f"Condition is {condition_result}, executing branch: {selected_branch.name}"
923
+ )
924
+ return (
925
+ await selected_branch()
926
+ if selected_branch.is_async
927
+ else selected_branch()
928
+ )
929
+ else:
930
+ msg = f"No branch defined for condition result: {condition_result}"
931
+ logger.error(msg)
932
+ raise ValueError(msg)
933
+
934
+ return Node(
935
+ func=async_run,
936
+ name=composition_name,
937
+ )
938
+ else:
939
+ # 同步版本
940
+ def run(*args: Any, **kwargs: Any) -> Any:
941
+ logger.info(f"--- Executing Conditional Branch: {composition_name} ---")
942
+
943
+ # Execute condition node
944
+ condition_result = condition_node(*args, **kwargs)
945
+
946
+ # Execute the appropriate branch
947
+ if condition_result in branches:
948
+ selected_branch = branches[condition_result]
949
+ logger.info(
950
+ f"Condition is {condition_result}, executing branch: {selected_branch.name}"
951
+ )
952
+ return selected_branch()
953
+ else:
954
+ msg = f"No branch defined for condition result: {condition_result}"
955
+ logger.error(msg)
956
+ raise ValueError(msg)
957
+
958
+ return Node(
959
+ func=run,
960
+ name=composition_name,
961
+ )
962
+
963
+
964
+ def repeat_composition(node: Node, times: int, stop_on_error: bool = False) -> Node:
965
+ """重复执行节点的简化版本。
966
+
967
+ Args:
968
+ node: 要重复执行的节点
969
+ times: 重复次数
970
+ stop_on_error: 遇到错误时是否立即停止
971
+
972
+ Returns:
973
+ 包装后的重复执行节点
974
+ """
975
+
976
+ # 参数前置验证:立即检查参数,fail-fast原则
977
+ if not isinstance(times, int):
978
+ raise TypeError(f"times must be an integer, got {type(times).__name__}")
979
+ if times <= 0:
980
+ raise ValueError("Repeat times must be greater than 0")
981
+ if not isinstance(node, Node):
982
+ raise TypeError("node must be a Node instance")
983
+
984
+ # 创建组合节点的名称
985
+ composition_name = f"({node.name} * {times})"
986
+
987
+ # 检测是否包含异步节点,决定创建同步还是异步执行函数
988
+ if node.is_async:
989
+ # 异步节点:创建异步执行函数
990
+ async def async_run(*args: Any, **kwargs: Any) -> Any:
991
+ logger.info(
992
+ f"--- Executing Async Repeat Composition: {composition_name} ---"
993
+ )
994
+
995
+ last_result = None
996
+ errors = []
997
+
998
+ for i in range(times):
999
+ logger.info(f" - Iteration {i + 1}/{times}")
1000
+
1001
+ try:
1002
+ # 正常执行异步节点
1003
+ if i == 0:
1004
+ result = await node(*args, **kwargs)
1005
+ else:
1006
+ result = await node(last_result)
1007
+
1008
+ # 成功执行
1009
+ last_result = result
1010
+ logger.debug(f"Iteration {i + 1} completed successfully")
1011
+
1012
+ except Exception as e:
1013
+ errors.append(e)
1014
+ logger.error(f"Iteration {i + 1} failed: {e}")
1015
+
1016
+ # 检查是否应该立即停止
1017
+ if stop_on_error:
1018
+ logger.error("Stopping immediately due to stop_on_error=True")
1019
+ # 抛出RepeatStopException,让其被重试机制处理
1020
+ raise LoopControlException(
1021
+ f"Execution stopped due to error at iteration {i + 1}: {e}"
1022
+ ) from e
1023
+
1024
+ # 继续执行,但使用上一次的成功结果
1025
+ logger.info(
1026
+ f"Continuing with last successful result from iteration {i}"
1027
+ )
1028
+
1029
+ # 正常完成
1030
+ logger.info(
1031
+ f"Async repeat composition completed. Iterations: {times}, Errors: {len(errors)}"
1032
+ )
1033
+
1034
+ return last_result
1035
+
1036
+ return Node(func=async_run, name=composition_name, is_start_node=False)
1037
+ else:
1038
+ # 同步节点:创建同步执行函数
1039
+ def run(*args: Any, **kwargs: Any) -> Any:
1040
+ logger.info(
1041
+ f"--- Executing Sync Repeat Composition: {composition_name} ---"
1042
+ )
1043
+
1044
+ last_result = None
1045
+ errors = []
1046
+
1047
+ for i in range(times):
1048
+ logger.info(f" - Iteration {i + 1}/{times}")
1049
+
1050
+ try:
1051
+ # 正常执行同步节点
1052
+ if i == 0:
1053
+ result = node(*args, **kwargs)
1054
+ else:
1055
+ result = node(last_result)
1056
+
1057
+ # 成功执行
1058
+ last_result = result
1059
+ logger.debug(f"Iteration {i + 1} completed successfully")
1060
+
1061
+ except Exception as e:
1062
+ errors.append(e)
1063
+ logger.error(f"Iteration {i + 1} failed: {e}")
1064
+
1065
+ # 检查是否应该立即停止
1066
+ if stop_on_error:
1067
+ logger.error("Stopping immediately due to stop_on_error=True")
1068
+ # 抛出RepeatStopException,让其被重试机制处理
1069
+ raise LoopControlException(
1070
+ f"Execution stopped due to error at iteration {i + 1}: {e}"
1071
+ ) from e
1072
+
1073
+ # 继续执行,但使用上一次的成功结果
1074
+ logger.info(
1075
+ f"Continuing with last successful result from iteration {i}"
1076
+ )
1077
+
1078
+ # 正常完成
1079
+ logger.info(
1080
+ f"Sync repeat composition completed. Iterations: {times}, Errors: {len(errors)}"
1081
+ )
1082
+
1083
+ return last_result
1084
+
1085
+ return Node(func=run, name=composition_name, is_start_node=False)
1086
+
1087
+
1088
+ def node(
1089
+ func: Callable | None = None,
1090
+ *,
1091
+ retry_count: int = 3,
1092
+ name: str | None = None,
1093
+ retry_delay: float = 1.0,
1094
+ exception_types: tuple = (Exception,),
1095
+ backoff_factor: float = 1.0,
1096
+ max_delay: float = 60.0,
1097
+ enable_retry: bool = False,
1098
+ ) -> Node | Callable:
1099
+ """
1100
+ Decorator: Create Node from function with dependency injection, type validation, and retry mechanism.
1101
+
1102
+ **This is the standard and only recommended way to use dependency injection!**
1103
+
1104
+ Usage Examples:
1105
+ ```python
1106
+ # Basic sync node
1107
+ @node
1108
+ def process_data(data: dict, state: dict = Provide[BaseFlowContext.state]) -> dict:
1109
+ result = {"processed": data["value"] * 2}
1110
+ state['last_result'] = result
1111
+ return result
1112
+
1113
+ # Async node with retry
1114
+ @node(retry_count=3, retry_delay=0.5)
1115
+ async def fetch_data(data_id: str) -> dict:
1116
+ # Simulate async data fetching
1117
+ await asyncio.sleep(0.1)
1118
+ return {"data_id": data_id, "value": f"fetched_{data_id}"}
1119
+
1120
+ # Custom retry configuration
1121
+ @node(retry_count=5, retry_delay=2.0, exception_types=(ConnectionError, TimeoutError))
1122
+ def external_service_call(data: dict) -> dict:
1123
+ return call_external_api(data)
1124
+
1125
+ # Sequential composition (async/sync mixing)
1126
+ flow = process_data.then(fetch_api_data).then(external_service_call)
1127
+ result = flow({"value": 10})
1128
+
1129
+ # Parallel execution
1130
+ parallel_flow = process_data.fan_out_to([
1131
+ fetch_api_data,
1132
+ external_service_call
1133
+ ])
1134
+ results = parallel_flow({"value": 10})
1135
+
1136
+ # Fan-out-in pattern
1137
+ @node
1138
+ def aggregate_results(parallel_results: dict) -> str:
1139
+ successful = [r.result for r in parallel_results.values() if r.success]
1140
+ return f"Aggregated: {len(successful)} results"
1141
+
1142
+ complete_flow = process_data.fan_out_in([fetch_api_data, external_service_call], aggregate_results)
1143
+ final_result = complete_flow({"value": 10})
1144
+
1145
+ # Disable retry for specific node
1146
+ @node(enable_retry=False)
1147
+ def no_retry_operation(data: dict) -> dict:
1148
+ return {"immediate": data}
1149
+ ```
1150
+
1151
+ Args:
1152
+ func: Function to be decorated
1153
+ name: Node identifier name
1154
+ retry_count: Maximum retry attempts (default: 3)
1155
+ retry_delay: Base retry delay in seconds (default: 1.0)
1156
+ exception_types: Tuple of exception types to retry (default: (Exception,))
1157
+ backoff_factor: Backoff multiplier for exponential backoff (default: 1.0)
1158
+ max_delay: Maximum delay time in seconds (default: 60.0)
1159
+ enable_retry: Enable/disable retry mechanism (default: True)
1160
+
1161
+ Returns:
1162
+ Node instance or decorator function
1163
+
1164
+ Notes:
1165
+ - Supports both sync and async functions with intelligent retry handling
1166
+ - Async functions use `asyncio.sleep()` for delays, sync functions use `time.sleep()`
1167
+ - Node objects have `is_async` property indicating if they require async execution
1168
+ - Smart async detection works with Node objects, decorators, and plain functions
1169
+ - If using BaseFlowContext dependency injection, @node decorator is required
1170
+ - Container must be wired before usage: `container.wire(modules=[__name__])`
1171
+ - Retry mechanism is enabled by default and works for both sync and async nodes
1172
+ - Sequential composition automatically handles async/sync mixing via Node.__call__()
1173
+ - Parallel execution supports mixed async/sync nodes with auto executor selection
1174
+ """
1175
+ config = RetryConfig(
1176
+ retry_count, retry_delay, exception_types, backoff_factor, max_delay
1177
+ )
1178
+
1179
+ @functools.wraps(Node)
1180
+ def decorator(f: Callable) -> Node:
1181
+ node_name = name or _get_func_name(f, "unnamed_node")
1182
+
1183
+ # 在装饰之前检测原始函数是否为异步函数
1184
+ is_original_async = inspect.iscoroutinefunction(f)
1185
+
1186
+ # 使用functools.reduce应用装饰器链
1187
+ # inject必须在最外层,确保Provide[...]在验证/重试之前被解析为实际值
1188
+ decorators = [
1189
+ custom_validate_call(
1190
+ validate_return=True,
1191
+ config=ConfigDict(arbitrary_types_allowed=True),
1192
+ node_name=node_name,
1193
+ ),
1194
+ ]
1195
+ if enable_retry:
1196
+ decorators.append(retry_decorator(config=config, node_name=node_name))
1197
+ decorators.append(inject)
1198
+
1199
+ decorated_func = functools.reduce(lambda func, deco: deco(func), decorators, f)
1200
+
1201
+ return Node(func=decorated_func, name=node_name, is_async=is_original_async)
1202
+
1203
+ # 支持两种调用方式:@node 和 @node(...)
1204
+ if func is None:
1205
+ # @node(...) 带参数调用
1206
+ return decorator
1207
+ else:
1208
+ # @node 直接调用
1209
+ return decorator(func)
1210
+
1211
+
1212
+ # Export list - 只暴露用户需要的公共接口
1213
+ __all__ = [
1214
+ # 装饰器:创建节点
1215
+ "node",
1216
+ # 验证装饰器:自定义验证异常处理
1217
+ "custom_validate_call",
1218
+ # 上下文:自定义依赖注入
1219
+ "BaseFlowContext",
1220
+ # 并行执行结果模型
1221
+ "ParallelResult",
1222
+ # 异常类
1223
+ "StreamletException",
1224
+ "ValidationInputException",
1225
+ "ValidationOutputException",
1226
+ "UserBusinessException",
1227
+ "NodeExecutionException",
1228
+ "NodeTimeoutException",
1229
+ "NodeRetryExhaustedException",
1230
+ "LoopControlException",
1231
+ # 重试相关
1232
+ "RetryConfig",
1233
+ ]