lionagi 0.13.7__py3-none-any.whl → 0.14.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,109 @@
1
+ """Task group implementation for structured concurrency."""
2
+
3
+ from collections.abc import Awaitable, Callable
4
+ from types import TracebackType
5
+ from typing import Any, Optional, TypeVar
6
+
7
+ import anyio
8
+
9
+ T = TypeVar("T")
10
+ R = TypeVar("R")
11
+
12
+
13
+ class TaskGroup:
14
+ """A group of tasks that are treated as a unit."""
15
+
16
+ def __init__(self):
17
+ """Initialize a new task group."""
18
+ self._task_group = None
19
+
20
+ async def start_soon(
21
+ self,
22
+ func: Callable[..., Awaitable[Any]],
23
+ *args: Any,
24
+ name: str | None = None,
25
+ ) -> None:
26
+ """Start a new task in this task group.
27
+
28
+ Args:
29
+ func: The coroutine function to call
30
+ *args: Positional arguments to pass to the function
31
+ name: Optional name for the task
32
+
33
+ Note:
34
+ This method does not wait for the task to initialize.
35
+ """
36
+ if self._task_group is None:
37
+ raise RuntimeError("Task group is not active")
38
+ self._task_group.start_soon(func, *args, name=name)
39
+
40
+ async def start(
41
+ self,
42
+ func: Callable[..., Awaitable[R]],
43
+ *args: Any,
44
+ name: str | None = None,
45
+ ) -> R:
46
+ """Start a new task and wait for it to initialize.
47
+
48
+ Args:
49
+ func: The coroutine function to call
50
+ *args: Positional arguments to pass to the function
51
+ name: Optional name for the task
52
+
53
+ Returns:
54
+ The value passed to task_status.started()
55
+
56
+ Note:
57
+ The function must accept a task_status keyword argument and call
58
+ task_status.started() once initialization is complete.
59
+ """
60
+ if self._task_group is None:
61
+ raise RuntimeError("Task group is not active")
62
+ return await self._task_group.start(func, *args, name=name)
63
+
64
+ async def __aenter__(self) -> "TaskGroup":
65
+ """Enter the task group context.
66
+
67
+ Returns:
68
+ The task group instance.
69
+ """
70
+ task_group = anyio.create_task_group()
71
+ self._task_group = await task_group.__aenter__()
72
+ return self
73
+
74
+ async def __aexit__(
75
+ self,
76
+ exc_type: type[BaseException] | None,
77
+ exc_val: BaseException | None,
78
+ exc_tb: TracebackType | None,
79
+ ) -> bool:
80
+ """Exit the task group context.
81
+
82
+ This will wait for all tasks in the group to complete.
83
+ If any task raised an exception, it will be propagated.
84
+ If multiple tasks raised exceptions, they will be combined into an ExceptionGroup.
85
+
86
+ Returns:
87
+ True if the exception was handled, False otherwise.
88
+ """
89
+ if self._task_group is None:
90
+ return False
91
+
92
+ try:
93
+ return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
94
+ finally:
95
+ self._task_group = None
96
+
97
+
98
+ def create_task_group() -> TaskGroup:
99
+ """Create a new task group.
100
+
101
+ Returns:
102
+ A new task group instance.
103
+
104
+ Example:
105
+ async with create_task_group() as tg:
106
+ await tg.start_soon(task1)
107
+ await tg.start_soon(task2)
108
+ """
109
+ return TaskGroup()
@@ -75,6 +75,7 @@ class OperationGraphBuilder:
75
75
  operation: BranchOperations,
76
76
  node_id: str | None = None,
77
77
  depends_on: list[str] | None = None,
78
+ inherit_context: bool = False,
78
79
  **parameters,
79
80
  ) -> str:
80
81
  """
@@ -84,6 +85,8 @@ class OperationGraphBuilder:
84
85
  operation: The branch operation
85
86
  node_id: Optional ID reference for this node
86
87
  depends_on: List of node IDs this depends on
88
+ inherit_context: If True and has dependencies, inherit conversation
89
+ context from primary (first) dependency
87
90
  **parameters: Operation parameters
88
91
 
89
92
  Returns:
@@ -92,6 +95,11 @@ class OperationGraphBuilder:
92
95
  # Create operation node
93
96
  node = Operation(operation=operation, parameters=parameters)
94
97
 
98
+ # Store context inheritance strategy
99
+ if inherit_context and depends_on:
100
+ node.metadata["inherit_context"] = True
101
+ node.metadata["primary_dependency"] = depends_on[0]
102
+
95
103
  self.graph.add_node(node)
96
104
  self._operations[node.id] = node
97
105
 
@@ -126,6 +134,8 @@ class OperationGraphBuilder:
126
134
  source_node_id: str,
127
135
  operation: BranchOperations,
128
136
  strategy: ExpansionStrategy = ExpansionStrategy.CONCURRENT,
137
+ inherit_context: bool = False,
138
+ chain_context: bool = False,
129
139
  **shared_params,
130
140
  ) -> list[str]:
131
141
  """
@@ -139,6 +149,9 @@ class OperationGraphBuilder:
139
149
  source_node_id: ID of node that produced these items
140
150
  operation: Operation to apply to each item
141
151
  strategy: How to organize the expanded operations
152
+ inherit_context: If True, expanded operations inherit context from source
153
+ chain_context: If True and strategy is SEQUENTIAL, each operation
154
+ inherits from the previous (only applies to SEQUENTIAL)
142
155
  **shared_params: Shared parameters for all operations
143
156
 
144
157
  Returns:
@@ -171,6 +184,21 @@ class OperationGraphBuilder:
171
184
  },
172
185
  )
173
186
 
187
+ # Handle context inheritance for expanded operations
188
+ if inherit_context:
189
+ if (
190
+ chain_context
191
+ and strategy == ExpansionStrategy.SEQUENTIAL
192
+ and i > 0
193
+ ):
194
+ # Chain context: inherit from previous expanded operation
195
+ node.metadata["inherit_context"] = True
196
+ node.metadata["primary_dependency"] = new_node_ids[i - 1]
197
+ else:
198
+ # Inherit from source node
199
+ node.metadata["inherit_context"] = True
200
+ node.metadata["primary_dependency"] = source_node_id
201
+
174
202
  self.graph.add_node(node)
175
203
  self._operations[node.id] = node
176
204
  new_node_ids.append(node.id)
@@ -195,7 +223,10 @@ class OperationGraphBuilder:
195
223
  def add_aggregation(
196
224
  self,
197
225
  operation: BranchOperations,
226
+ node_id: str | None = None,
198
227
  source_node_ids: list[str] | None = None,
228
+ inherit_context: bool = False,
229
+ inherit_from_source: int = 0,
199
230
  **parameters,
200
231
  ) -> str:
201
232
  """
@@ -203,7 +234,10 @@ class OperationGraphBuilder:
203
234
 
204
235
  Args:
205
236
  operation: Aggregation operation
237
+ node_id: Optional ID reference for this node
206
238
  source_node_ids: Nodes to aggregate from (defaults to current heads)
239
+ inherit_context: If True, inherit conversation context from one source
240
+ inherit_from_source: Index of source to inherit context from (default: 0)
207
241
  **parameters: Operation parameters
208
242
 
209
243
  Returns:
@@ -226,6 +260,18 @@ class OperationGraphBuilder:
226
260
  metadata={"aggregation": True},
227
261
  )
228
262
 
263
+ # Store node reference if provided
264
+ if node_id:
265
+ node.metadata["reference_id"] = node_id
266
+
267
+ # Store context inheritance for aggregations
268
+ if inherit_context and sources:
269
+ node.metadata["inherit_context"] = True
270
+ # Use the specified source index (bounded by available sources)
271
+ source_idx = min(inherit_from_source, len(sources) - 1)
272
+ node.metadata["primary_dependency"] = sources[source_idx]
273
+ node.metadata["inherit_from_source"] = source_idx
274
+
229
275
  self.graph.add_node(node)
230
276
  self._operations[node.id] = node
231
277