sera-2 1.17.0__py3-none-any.whl → 1.18.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,53 @@
1
+ """
2
+ Directed Computing Graph package for sera.
3
+
4
+ This package provides classes and utilities for working with directed computing graphs.
5
+ """
6
+
7
+ # Import type aliases and annotated types
8
+ # Import all classes from submodules
9
+ from ._dcg import DirectedComputingGraph, NodeId, TaskArgs, TaskKey
10
+ from ._edge import DCGEdge
11
+ from ._flow import Flow
12
+ from ._node import ComputeFn, ComputeFnId, DCGNode, PartialFn
13
+ from ._runtime import SKIP, UNSET, ArgValueType, NodeRuntime
14
+
15
+ # Import utility functions from type conversion
16
+ from ._type_conversion import (
17
+ ComposeTypeConversion,
18
+ TypeConversion,
19
+ UnitTypeConversion,
20
+ align_generic_type,
21
+ ground_generic_type,
22
+ is_generic_type,
23
+ patch_get_origin,
24
+ )
25
+
26
+ # Define __all__ to control what gets exported
27
+ __all__ = [
28
+ # Main classes
29
+ "DirectedComputingGraph",
30
+ "DCGNode",
31
+ "DCGEdge",
32
+ "Flow",
33
+ "PartialFn",
34
+ "TypeConversion",
35
+ "NodeRuntime",
36
+ # Enums and special values
37
+ "ArgValueType",
38
+ "UNSET",
39
+ "SKIP",
40
+ # Type aliases and annotations
41
+ "NodeId",
42
+ "TaskKey",
43
+ "TaskArgs",
44
+ "ComputeFnId",
45
+ "ComputeFn",
46
+ "UnitTypeConversion",
47
+ "ComposeTypeConversion",
48
+ # Utility functions
49
+ "patch_get_origin",
50
+ "is_generic_type",
51
+ "align_generic_type",
52
+ "ground_generic_type",
53
+ ]
@@ -0,0 +1,403 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import inspect
5
+ from dataclasses import dataclass
6
+ from enum import Enum
7
+ from typing import (
8
+ Annotated,
9
+ Any,
10
+ Awaitable,
11
+ Callable,
12
+ MutableSequence,
13
+ Optional,
14
+ Sequence,
15
+ )
16
+
17
+ from graph.retworkx import RetworkXStrDiGraph
18
+
19
+ from sera.libs.directed_computing_graph._edge import DCGEdge
20
+ from sera.libs.directed_computing_graph._flow import Flow
21
+ from sera.libs.directed_computing_graph._node import (
22
+ ComputeFn,
23
+ ComputeFnId,
24
+ DCGNode,
25
+ NodeId,
26
+ )
27
+ from sera.libs.directed_computing_graph._runtime import SKIP, NodeRuntime
28
+ from sera.libs.directed_computing_graph._type_conversion import (
29
+ ComposeTypeConversion,
30
+ TypeConversion,
31
+ UnitTypeConversion,
32
+ align_generic_type,
33
+ ground_generic_type,
34
+ is_generic_type,
35
+ )
36
+ from sera.misc import identity
37
+
38
+ TaskKey = Annotated[tuple, "TaskKey"]
39
+ TaskArgs = Annotated[MutableSequence, "TaskArgs"]
40
+
41
+
42
+ class DirectedComputingGraph:
43
+ """
44
+ A Directed Computing Graph (DCG) is a directed graph where nodes represent functions
45
+ and edges represent dependencies between these functions. The graph is used to manage
46
+ the execution of functions in a specific order based on their dependencies.
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ graph: RetworkXStrDiGraph[int, DCGNode, DCGEdge],
52
+ type_service: TypeConversion,
53
+ ):
54
+ self.graph = graph
55
+ self.type_service = type_service
56
+
57
+ @staticmethod
58
+ def from_flows(
59
+ flows: dict[ComputeFnId, Flow | ComputeFn],
60
+ type_conversions: Optional[
61
+ Sequence[UnitTypeConversion | ComposeTypeConversion]
62
+ ] = None,
63
+ strict: bool = True,
64
+ ):
65
+ """Create a computing graph from flow mapping.
66
+
67
+ Args:
68
+ flows: A dictionary mapping identifier to:
69
+ 1. a function
70
+ 2. a flow specifying the upstream functions and the function.
71
+ type_conversions: A list of type conversions to be used for converting the input types.
72
+ strict: If True, we do type checking.
73
+ Returns:
74
+ DirectedComputingGraph: A directed computing graph constructed from the provided flows.
75
+ """
76
+ # add typing conversions
77
+ upd_type_conversions: list[UnitTypeConversion | ComposeTypeConversion] = list(
78
+ type_conversions or []
79
+ )
80
+ type_service = TypeConversion(upd_type_conversions)
81
+
82
+ g: RetworkXStrDiGraph[int, DCGNode, DCGEdge] = RetworkXStrDiGraph(
83
+ check_cycle=False, multigraph=False
84
+ )
85
+
86
+ # create a graph
87
+ for uid, uinfo in flows.items():
88
+ if isinstance(uinfo, Flow):
89
+ actor = uinfo.target
90
+ else:
91
+ actor = uinfo
92
+ g.add_node(DCGNode(uid, actor))
93
+
94
+ # create a graph
95
+ for uid, uinfo in flows.items():
96
+ if isinstance(uinfo, Flow):
97
+ func = uinfo.target
98
+ else:
99
+ func = uinfo
100
+ g.add_node(DCGNode(uid, func))
101
+
102
+ # grounding function that has generic type input and output
103
+ for uid, flow in flows.items():
104
+ if not isinstance(flow, Flow):
105
+ continue
106
+
107
+ u = g.get_node(uid)
108
+ usig = u.signature
109
+ if is_generic_type(usig.return_type) or any(
110
+ is_generic_type(t) for t in usig.argtypes
111
+ ):
112
+ var2type = {}
113
+ for i, t in enumerate(usig.argtypes):
114
+ if is_generic_type(t):
115
+ # align the generic type with the previous return type
116
+ if len(flow.source) <= i and strict:
117
+ raise TypeConversion.UnknownConversion(
118
+ f"Cannot ground the generic type based on upstream actors for actor {uid}"
119
+ )
120
+
121
+ source_return_type = g.get_node(
122
+ flow.source[i]
123
+ ).signature.return_type
124
+
125
+ try:
126
+ usig.argtypes[i], (var, nt) = align_generic_type(
127
+ t, source_return_type
128
+ )
129
+ except Exception as e:
130
+ raise TypeConversion.UnknownConversion(
131
+ f"Cannot align the generic type {t} based on upstream actors for actor {uid}"
132
+ )
133
+ var2type[var] = nt
134
+ if is_generic_type(usig.return_type):
135
+ usig.return_type = ground_generic_type(
136
+ usig.return_type,
137
+ var2type,
138
+ )
139
+
140
+ for uid, flow in flows.items():
141
+ if not isinstance(flow, Flow):
142
+ continue
143
+
144
+ u = g.get_node(uid)
145
+ usig = u.signature
146
+ for idx, sid in enumerate(flow.source):
147
+ s = g.get_node(sid)
148
+ ssig = s.signature
149
+ cast_fn = identity
150
+ try:
151
+ cast_fn = type_service.get_conversion(
152
+ ssig.return_type, usig.argtypes[idx]
153
+ )
154
+ except Exception as e:
155
+ if strict:
156
+ raise TypeConversion.UnknownConversion(
157
+ f"Don't know how to convert output of `{sid}` to input of `{uid}`"
158
+ ) from e
159
+ g.add_edge(
160
+ DCGEdge(
161
+ id=-1,
162
+ source=sid,
163
+ target=uid,
164
+ argindex=idx,
165
+ filter_fn=flow.filter_fn,
166
+ type_conversion=cast_fn,
167
+ )
168
+ )
169
+
170
+ # postprocessing such as type conversion, and args/context
171
+ for u in g.iter_nodes():
172
+ inedges = g.in_edges(u.id)
173
+
174
+ # update the type conversion
175
+ u.type_conversions = [identity] * len(u.signature.argnames)
176
+ for inedge in inedges:
177
+ u.type_conversions[inedge.argindex] = inedge.type_conversion
178
+
179
+ # update the required args and context
180
+ u.required_args = u.signature.argnames[: g.in_degree(u.id)]
181
+ # arguments of a compute function that are not provided by the upstream actors must be provided by the context.
182
+ u.required_context = u.signature.argnames[g.in_degree(u.id) :]
183
+ u.required_context_default_args = {
184
+ k: u.signature.default_args[k]
185
+ for k in u.required_context
186
+ if k in u.signature.default_args
187
+ }
188
+
189
+ return DirectedComputingGraph(g, type_service)
190
+
191
+ def execute(
192
+ self,
193
+ input: dict[ComputeFnId, tuple],
194
+ output: set[str],
195
+ context: Optional[
196
+ dict[str, Callable | Any] | Callable[[], dict[str, Any]]
197
+ ] = None,
198
+ ):
199
+ """
200
+ Execute the directed computing graph with the given input and output specifications.
201
+
202
+ Args:
203
+ input: A dictionary mapping function identifiers to their input arguments.
204
+ output: A set of function identifiers that should be executed.
205
+ context: An optional context that can be a dictionary of functions or a single function.
206
+ """
207
+ assert all(
208
+ isinstance(v, tuple) for v in input.values()
209
+ ), "Input must be a tuple"
210
+
211
+ if context is None:
212
+ context = {}
213
+ elif isinstance(context, Callable):
214
+ context = context()
215
+ else:
216
+ context = {k: v() if callable(v) else v for k, v in context.items()}
217
+
218
+ # This is a quick reactive algorithm, we may be able to do it better.
219
+ # The idea is when all inputs of a function is available, we can execute a function.
220
+ # We assume that the memory is large enough to hold all the functions and their inputs
221
+ # in the memory.
222
+
223
+ # we execute the computing nodes
224
+ # when it's finished, we put the outgoing edges into a stack.
225
+ runtimes: dict[NodeId, NodeRuntime] = {}
226
+
227
+ for u in self.graph.iter_nodes():
228
+ if u.id in input:
229
+ # user provided input should supersede the context
230
+ n_provided_args = len(input[u.id])
231
+ n_consumed_context = n_provided_args - len(u.required_args)
232
+ else:
233
+ n_consumed_context = 0
234
+
235
+ node_context = tuple(
236
+ (
237
+ context[name]
238
+ if name in context
239
+ else u.required_context_default_args[name]
240
+ )
241
+ for name in u.required_context[n_consumed_context:]
242
+ )
243
+
244
+ runtimes[u.id] = NodeRuntime.from_node(self.graph, u, node_context)
245
+ stack: list[NodeId] = []
246
+
247
+ for id, args in input.items():
248
+ runtimes[id].add_task((0,), list(args))
249
+ stack.append(id)
250
+
251
+ return_output = {id: [] for id in output}
252
+
253
+ while len(stack) > 0:
254
+ # pop the one from the stack and execute it.
255
+ id = stack.pop()
256
+ runtime = runtimes[id]
257
+
258
+ # if there is enough data for the node, we can execute it.
259
+ # if it is not, we just skip it and it will be added back to the stack by one of its parents.
260
+ # so we don't miss it.
261
+ if not runtime.has_enough_data():
262
+ continue
263
+
264
+ outedges = self.graph.out_edges(id)
265
+ successors: Sequence[tuple[DCGEdge, DCGNode]] = [
266
+ (edge, self.graph.get_node(edge.target)) for edge in outedges
267
+ ]
268
+
269
+ # run the tasks and pass the output to the successors
270
+ for task_id, task in runtime.tasks.items():
271
+ if any(arg is SKIP for arg in task):
272
+ task_output = SKIP
273
+ else:
274
+ task_output = runtime.execute(task)
275
+
276
+ for outedge, succ in successors:
277
+ runtimes[succ.id].add_task_args(
278
+ task_id,
279
+ id,
280
+ (
281
+ SKIP
282
+ if task_output is SKIP or not outedge.filter(task_output)
283
+ else task_output
284
+ ),
285
+ )
286
+
287
+ if id in output and task_output is not SKIP:
288
+ return_output[id].append(task_output)
289
+
290
+ # retrieve the outgoing nodes and push them into the stack
291
+ for outedge, succ in successors:
292
+ stack.append(succ.id)
293
+
294
+ return return_output
295
+
296
+ async def execute_async(
297
+ self,
298
+ input: dict[ComputeFnId, tuple],
299
+ output: set[str],
300
+ context: Optional[
301
+ dict[str, Callable | Any] | Callable[[], dict[str, Any]]
302
+ ] = None,
303
+ ):
304
+ """
305
+ Asynchronously execute the directed computing graph with the given input and output specifications.
306
+ This method handles both synchronous and asynchronous functions.
307
+
308
+ Args:
309
+ input: A dictionary mapping function identifiers to their input arguments.
310
+ output: A set of function identifiers that should be executed.
311
+ context: An optional context that can be a dictionary of functions or a single function.
312
+ """
313
+ assert all(
314
+ isinstance(v, tuple) for v in input.values()
315
+ ), "Input must be a tuple"
316
+
317
+ if context is None:
318
+ context = {}
319
+ elif isinstance(context, Callable):
320
+ context = context()
321
+ else:
322
+ context = {k: v() if callable(v) else v for k, v in context.items()}
323
+
324
+ # This is a quick reactive algorithm, we may be able to do it better.
325
+ # The idea is when all inputs of a function is available, we can execute a function.
326
+ # We assume that the memory is large enough to hold all the functions and their inputs
327
+ # in the memory.
328
+
329
+ # we execute the computing nodes
330
+ # when it's finished, we put the outgoing edges into a stack.
331
+ runtimes: dict[NodeId, NodeRuntime] = {}
332
+
333
+ for u in self.graph.iter_nodes():
334
+ if u.id in input:
335
+ # user provided input should supersede the context
336
+ n_provided_args = len(input[u.id])
337
+ n_consumed_context = n_provided_args - len(u.required_args)
338
+ else:
339
+ n_consumed_context = 0
340
+
341
+ node_context = tuple(
342
+ (
343
+ context[name]
344
+ if name in context
345
+ else u.required_context_default_args[name]
346
+ )
347
+ for name in u.required_context[n_consumed_context:]
348
+ )
349
+
350
+ runtimes[u.id] = NodeRuntime.from_node(self.graph, u, node_context)
351
+ stack: list[NodeId] = []
352
+
353
+ for id, args in input.items():
354
+ runtimes[id].add_task((0,), list(args))
355
+ stack.append(id)
356
+
357
+ return_output = {id: [] for id in output}
358
+
359
+ while len(stack) > 0:
360
+ # pop the one from the stack and execute it.
361
+ id = stack.pop()
362
+ runtime = runtimes[id]
363
+
364
+ # if there is enough data for the node, we can execute it.
365
+ # if it is not, we just skip it and it will be added back to the stack by one of its parents.
366
+ # so we don't miss it.
367
+ if not runtime.has_enough_data():
368
+ continue
369
+
370
+ outedges = self.graph.out_edges(id)
371
+ successors: Sequence[tuple[DCGEdge, DCGNode]] = [
372
+ (edge, self.graph.get_node(edge.target)) for edge in outedges
373
+ ]
374
+
375
+ # run the tasks and pass the output to the successors
376
+ for task_id, task in runtime.tasks.items():
377
+ if any(arg is SKIP for arg in task):
378
+ task_output = SKIP
379
+ else:
380
+ if runtime.node.signature.is_async:
381
+ task_output = await runtime.execute(task)
382
+ else:
383
+ task_output = runtime.execute(task)
384
+
385
+ for outedge, succ in successors:
386
+ runtimes[succ.id].add_task_args(
387
+ task_id,
388
+ id,
389
+ (
390
+ SKIP
391
+ if task_output is SKIP or not outedge.filter(task_output)
392
+ else task_output
393
+ ),
394
+ )
395
+
396
+ if id in output and task_output is not SKIP:
397
+ return_output[id].append(task_output)
398
+
399
+ # retrieve the outgoing nodes and push them into the stack
400
+ for outedge, succ in successors:
401
+ stack.append(succ.id)
402
+
403
+ return return_output
@@ -0,0 +1,35 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Annotated, Any, Callable, Optional
4
+
5
+ from graph.interface import BaseEdge
6
+
7
+ from sera.libs.directed_computing_graph._node import NodeId
8
+ from sera.libs.directed_computing_graph._type_conversion import UnitTypeConversion
9
+
10
+
11
+ class DCGEdge(BaseEdge[NodeId, int]):
12
+
13
+ def __init__(
14
+ self,
15
+ id: int,
16
+ source: NodeId,
17
+ target: NodeId,
18
+ argindex: int,
19
+ type_conversion: UnitTypeConversion,
20
+ filter_fn: Optional[Callable[[Any], bool]] = None,
21
+ ):
22
+ super().__init__(id, source, target, key=argindex)
23
+ self.argindex = argindex
24
+ self.type_conversion = type_conversion
25
+ self.filter_fn = filter_fn
26
+
27
+ def filter(self, value: Any) -> bool:
28
+ """Filter the value passing through this edge.
29
+
30
+ Returns:
31
+ True if the value should flow through this edge, False to block it.
32
+ """
33
+ if self.filter_fn is not None:
34
+ return self.filter_fn(value)
35
+ return True
@@ -0,0 +1,17 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Callable, Optional
4
+
5
+ from sera.libs.directed_computing_graph._node import ComputeFn, ComputeFnId
6
+
7
+
8
+ class Flow:
9
+ def __init__(
10
+ self,
11
+ source: list[ComputeFnId] | ComputeFnId,
12
+ target: ComputeFn,
13
+ filter_fn: Optional[Callable[[Any], bool]] = None,
14
+ ):
15
+ self.source = [source] if isinstance(source, str) else source
16
+ self.target = target
17
+ self.filter_fn = filter_fn
@@ -0,0 +1,54 @@
1
+ from __future__ import annotations
2
+
3
+ import inspect
4
+ from dataclasses import dataclass
5
+ from typing import Any, Callable, get_args, get_origin, get_type_hints
6
+
7
+ from sera.misc import get_classpath
8
+
9
+
10
+ @dataclass
11
+ class FnSignature:
12
+ return_type: type
13
+ argnames: list[str]
14
+ argtypes: list[type]
15
+ default_args: dict[str, Any] # Added this field to store default values
16
+ is_async: bool = False
17
+
18
+ @staticmethod
19
+ def parse(func: Callable) -> FnSignature:
20
+ sig = get_type_hints(func)
21
+ argnames = list(sig.keys())[:-1]
22
+
23
+ # Get the default values using inspect.signature
24
+ inspect_sig = inspect.signature(func)
25
+ defaults = {}
26
+ for name, param in inspect_sig.parameters.items():
27
+ if param.default is not inspect.Parameter.empty:
28
+ defaults[name] = param.default
29
+
30
+ try:
31
+ return FnSignature(
32
+ sig["return"],
33
+ argnames,
34
+ [sig[arg] for arg in argnames],
35
+ defaults, # Add the default values to the signature
36
+ is_async=inspect.iscoroutinefunction(func),
37
+ )
38
+ except:
39
+ print("Cannot figure out the signature of", func)
40
+ print("The parsed signature is:", sig)
41
+ raise
42
+
43
+
44
+ def type_to_string(_type: type) -> str:
45
+ """Return a fully qualified type name"""
46
+ origin = get_origin(_type)
47
+ if origin is None:
48
+ return get_classpath(_type)
49
+ return (
50
+ get_classpath(origin)
51
+ + "["
52
+ + ", ".join([get_classpath(arg) for arg in get_args(_type)])
53
+ + "]"
54
+ )
@@ -0,0 +1,51 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Annotated, Any, Callable
5
+
6
+ from graph.interface import BaseNode
7
+
8
+ from sera.libs.directed_computing_graph._fn_signature import FnSignature
9
+ from sera.libs.directed_computing_graph._type_conversion import UnitTypeConversion
10
+
11
+
12
+ class PartialFn:
13
+ def __init__(self, fn: Callable, **kwargs):
14
+ self.fn = fn
15
+ self.default_args = kwargs
16
+ self.signature = FnSignature.parse(fn)
17
+
18
+ argnames = set(self.signature.argnames)
19
+ for arg, val in self.default_args.items():
20
+ if arg not in argnames:
21
+ raise Exception(f"Argument {arg} is not in the function signature")
22
+ self.signature.default_args[arg] = val
23
+
24
+ def __call__(self, *args, **kwargs):
25
+ return self.fn(*args, **kwargs)
26
+
27
+
28
+ ComputeFnId = Annotated[str, "ComputeFn Identifier"]
29
+ ComputeFn = PartialFn | Callable
30
+ NodeId = ComputeFnId
31
+
32
+
33
+ class DCGNode(BaseNode[NodeId]):
34
+ id: NodeId
35
+ func: ComputeFn
36
+
37
+ def __init__(self, id: NodeId, func: ComputeFn):
38
+ super().__init__(id)
39
+ self.func = func
40
+ self.signature = self.get_signature(self.func)
41
+ self.type_conversions: list[UnitTypeConversion] = []
42
+ self.required_args: list[str] = []
43
+ self.required_context: list[str] = []
44
+ self.required_context_default_args: dict[str, Any] = {}
45
+
46
+ @staticmethod
47
+ def get_signature(actor: ComputeFn) -> FnSignature:
48
+ if isinstance(actor, PartialFn):
49
+ return actor.signature
50
+ else:
51
+ return FnSignature.parse(actor)
@@ -0,0 +1,110 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from enum import Enum
5
+ from typing import Annotated, Any, MutableSequence, Sequence
6
+
7
+ from graph.retworkx import RetworkXStrDiGraph
8
+
9
+ from sera.libs.directed_computing_graph._edge import DCGEdge
10
+ from sera.libs.directed_computing_graph._node import DCGNode, NodeId
11
+
12
+ TaskKey = Annotated[tuple, "TaskKey"]
13
+ TaskArgs = Annotated[MutableSequence, "TaskArgs"]
14
+
15
+
16
+ class ArgValueType(Enum):
17
+ UNSET = "UNSET"
18
+ SKIP = "SKIP"
19
+
20
+
21
+ UNSET = ArgValueType.UNSET
22
+ SKIP = ArgValueType.SKIP
23
+
24
+
25
+ @dataclass
26
+ class NodeRuntime:
27
+ id: NodeId
28
+ tasks: dict[TaskKey, TaskArgs]
29
+ context: Sequence[Any]
30
+
31
+ graph: RetworkXStrDiGraph[int, DCGNode, DCGEdge]
32
+ node: DCGNode
33
+ indegree: int
34
+ # This is a mapping from parent node id to the index of the argument in the task.
35
+ parent2argindex: dict[str, int]
36
+
37
+ @staticmethod
38
+ def from_node(
39
+ graph: RetworkXStrDiGraph[int, DCGNode, DCGEdge],
40
+ node: DCGNode,
41
+ context: Sequence[Any],
42
+ ) -> NodeRuntime:
43
+ return NodeRuntime(
44
+ id=node.id,
45
+ tasks={},
46
+ context=context,
47
+ graph=graph,
48
+ node=node,
49
+ indegree=graph.in_degree(node.id),
50
+ parent2argindex={
51
+ edge.source: i
52
+ # Map parent node ID to argument index based on sorted in-edge order
53
+ for i, edge in enumerate(
54
+ sorted(graph.in_edges(node.id), key=lambda e: e.id)
55
+ )
56
+ },
57
+ )
58
+
59
+ def add_task(self, key: TaskKey, args: TaskArgs) -> NodeRuntime:
60
+ """
61
+ Add a task to the node runtime.
62
+
63
+ Args:
64
+ key: The key identifying the task.
65
+ args: The arguments for the task.
66
+ Returns:
67
+ NodeRuntime: The updated node runtime with the new task added.
68
+ """
69
+ self.tasks[key] = args
70
+ return self
71
+
72
+ def add_task_args(
73
+ self, key: TaskKey, parent_node: NodeId, argvalue: Any
74
+ ) -> NodeRuntime:
75
+ """
76
+ Add an argument to an existing task.
77
+
78
+ Args:
79
+ key: The key identifying the task.
80
+ parent_node: Identifier of the parent node from which the argument is coming.
81
+ argvalue: The value of the argument to add.
82
+ Returns:
83
+ NodeRuntime: The updated node runtime with the new argument added to the task.
84
+ """
85
+ if key not in self.tasks:
86
+ self.tasks[key] = [UNSET] * self.indegree
87
+ self.tasks[key][self.parent2argindex[parent_node]] = argvalue
88
+ return self
89
+
90
+ def has_enough_data(self) -> bool:
91
+ """
92
+ Check if the node has enough data to execute its tasks.
93
+
94
+ Returns:
95
+ bool: True if the node has enough data, False otherwise.
96
+ """
97
+ return all(
98
+ all(arg is not UNSET for arg in args) for args in self.tasks.values()
99
+ )
100
+
101
+ def execute(self, task: TaskArgs) -> Any:
102
+ """
103
+ Execute a task with the given context.
104
+
105
+ Args:
106
+ task (TaskArgs): The arguments for the task.
107
+ context (dict): The context in which to execute the task.
108
+ """
109
+ norm_args = (self.node.type_conversions[i](a) for i, a in enumerate(task))
110
+ return self.node.func(*norm_args, *self.context)
@@ -0,0 +1,191 @@
1
+ from __future__ import annotations
2
+
3
+ import collections.abc
4
+ import inspect
5
+ from locale import normalize
6
+ from types import UnionType
7
+ from typing import (
8
+ Annotated,
9
+ Any,
10
+ Callable,
11
+ Mapping,
12
+ MutableMapping,
13
+ MutableSequence,
14
+ MutableSet,
15
+ Sequence,
16
+ Set,
17
+ TypeVar,
18
+ Union,
19
+ cast,
20
+ get_args,
21
+ get_origin,
22
+ get_type_hints,
23
+ )
24
+
25
+ from sera.misc import identity
26
+
27
+ UnitTypeConversion = Annotated[
28
+ Callable[[Any], Any], "A function that convert an object of type T1 to T2"
29
+ ]
30
+ ComposeTypeConversion = Annotated[
31
+ Callable[[Any, UnitTypeConversion], Any],
32
+ "A function that convert a generic object of type G[T1] to G[T2]",
33
+ ]
34
+
35
+
36
+ class TypeConversion:
37
+ """Inspired by Rust type conversion traits. This class allows to derive a type conversion function from output of a pipe to input of another pipe."""
38
+
39
+ class UnknownConversion(Exception):
40
+ pass
41
+
42
+ def __init__(
43
+ self, type_casts: Sequence[UnitTypeConversion | ComposeTypeConversion]
44
+ ):
45
+ self.generic_single_type_conversion: dict[type, UnitTypeConversion] = {}
46
+ self.unit_type_conversions: dict[tuple[type, type], UnitTypeConversion] = {}
47
+ self.compose_type_conversion: dict[type, ComposeTypeConversion] = {}
48
+
49
+ for fn in type_casts:
50
+ assert not inspect.iscoroutinefunction(
51
+ fn
52
+ ), "Async conversion functions are not supported"
53
+ sig = get_type_hints(fn)
54
+ if len(sig) == 2:
55
+ fn = cast(UnitTypeConversion, fn)
56
+
57
+ intype = sig[[x for x in sig if x != "return"][0]]
58
+ outtype = sig["return"]
59
+
60
+ intype_origin = get_origin(intype)
61
+ intype_args = get_args(intype)
62
+ if (
63
+ intype_origin is not None
64
+ and len(intype_args) == 1
65
+ and intype_args[0] is outtype
66
+ and isinstance(outtype, TypeVar)
67
+ ):
68
+ # this is a generic conversion G[T] => T
69
+ self.generic_single_type_conversion[intype_origin] = fn
70
+ else:
71
+ self.unit_type_conversions[intype, outtype] = fn
72
+ else:
73
+ assert len(sig) == 3, "Invalid type conversion function"
74
+ fn = cast(ComposeTypeConversion, fn)
75
+
76
+ intype = sig[[x for x in sig if x != "return"][0]]
77
+ outtype = sig["return"]
78
+ intype_origin = get_origin(intype)
79
+ assert intype_origin is not None
80
+ self.compose_type_conversion[intype_origin] = fn
81
+
82
+ def get_conversion(
83
+ self, source_type: type, target_type: type
84
+ ) -> UnitTypeConversion:
85
+ # handle identity conversion
86
+ # happen when source_type = target_type or target_type is Union[source_type, ...]
87
+ if source_type == target_type:
88
+ # source_type is target_type doesn't work with collections.abc.Sequence
89
+ return identity
90
+ if get_origin(target_type) in (Union, UnionType) and source_type in get_args(
91
+ target_type
92
+ ):
93
+ return identity
94
+
95
+ if (source_type, target_type) in self.unit_type_conversions:
96
+ # we already have a unit type conversion function for these types
97
+ return self.unit_type_conversions[source_type, target_type]
98
+
99
+ # check if this is a generic conversion
100
+ intype_origin = get_origin(source_type)
101
+ intype_args = get_args(source_type)
102
+
103
+ if intype_origin is None or len(intype_args) != 1:
104
+ raise TypeConversion.UnknownConversion(
105
+ f"Cannot find conversion from {source_type} to {target_type}"
106
+ )
107
+
108
+ outtype_origin = get_origin(target_type)
109
+ outtype_args = get_args(target_type)
110
+
111
+ if outtype_origin is None:
112
+ # we are converting G[T] => T'
113
+ if (
114
+ target_type is not intype_args[0]
115
+ or intype_origin not in self.generic_single_type_conversion
116
+ ):
117
+ # either T != T' or G is unkknown
118
+ raise TypeConversion.UnknownConversion(
119
+ f"Cannot find conversion from {source_type} to {target_type}"
120
+ )
121
+ return self.generic_single_type_conversion[intype_origin]
122
+
123
+ # we are converting G[T] => G'[T']
124
+ if (
125
+ outtype_origin is not intype_origin
126
+ or intype_origin not in self.compose_type_conversion
127
+ ):
128
+ # either G != G' or G is unknown
129
+ raise TypeConversion.UnknownConversion(
130
+ f"Cannot find conversion from {source_type} to {target_type}"
131
+ )
132
+ # G == G' => T == T'
133
+ compose_func = self.compose_type_conversion[intype_origin]
134
+ func = self.get_conversion(intype_args[0], outtype_args[0])
135
+ return lambda x: compose_func(x, func)
136
+
137
+
138
+ def patch_get_origin(t: type) -> Any:
139
+ """The original get_origin(typing.Sequence) returns collections.abc.Sequence.
140
+ Later comparing typing.Sequence[T] to collections.abc.Sequence[T] aren't equal.
141
+
142
+ This function will return typing.Sequence instead.
143
+ """
144
+ origin = get_origin(t)
145
+ if origin is None:
146
+ return origin
147
+ return {
148
+ collections.abc.Mapping: Mapping,
149
+ collections.abc.Sequence: Sequence,
150
+ collections.abc.MutableSequence: MutableSequence,
151
+ collections.abc.MutableMapping: MutableMapping,
152
+ collections.abc.Set: Set,
153
+ collections.abc.MutableSet: MutableSet,
154
+ }.get(origin, origin)
155
+
156
+
157
+ def is_generic_type(t: type) -> bool:
158
+ return isinstance(t, TypeVar) or any(is_generic_type(a) for a in get_args(t))
159
+
160
+
161
+ def align_generic_type(
162
+ generic_type: type, target_type: type
163
+ ) -> tuple[type, tuple[type, type]]:
164
+ """Return the grounded outer type, and the mapping from the TypeVar to the concrete type"""
165
+ if isinstance(generic_type, TypeVar):
166
+ return target_type, (generic_type, target_type)
167
+
168
+ origin = patch_get_origin(generic_type)
169
+ assert origin is not None
170
+ if origin != patch_get_origin(target_type):
171
+ raise TypeConversion.UnknownConversion(
172
+ f"Cannot ground generic type {generic_type} to {target_type}"
173
+ )
174
+
175
+ if len(get_args(generic_type)) != 1:
176
+ raise NotImplementedError()
177
+
178
+ gt = align_generic_type(get_args(generic_type)[0], get_args(target_type)[0])
179
+ return origin[gt[0]], gt[1]
180
+
181
+
182
+ def ground_generic_type(generic_type: type, var2type: dict[TypeVar, type]) -> type:
183
+ if isinstance(generic_type, TypeVar):
184
+ return var2type[generic_type]
185
+
186
+ origin = get_origin(generic_type)
187
+ if origin is None:
188
+ # nothing to ground
189
+ return generic_type
190
+
191
+ return origin[*(ground_generic_type(t, var2type) for t in get_args(generic_type))]
sera/misc/__init__.py CHANGED
@@ -3,6 +3,7 @@ from sera.misc._utils import (
3
3
  assert_isinstance,
4
4
  assert_not_null,
5
5
  filter_duplication,
6
+ get_classpath,
6
7
  identity,
7
8
  load_data,
8
9
  to_camel_case,
@@ -22,4 +23,5 @@ __all__ = [
22
23
  "File",
23
24
  "load_data",
24
25
  "identity",
26
+ "get_classpath",
25
27
  ]
sera/misc/_utils.py CHANGED
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import re
4
4
  from pathlib import Path
5
- from typing import Any, Callable, Iterable, Optional, TypeVar
5
+ from typing import Any, Callable, Iterable, Optional, Type, TypeVar
6
6
 
7
7
  import serde.csv
8
8
  from sqlalchemy import Engine, text
@@ -10,6 +10,9 @@ from sqlalchemy.orm import Session
10
10
  from tqdm import tqdm
11
11
 
12
12
  T = TypeVar("T")
13
+
14
+ TYPE_ALIASES = {"typing.List": "list", "typing.Dict": "dict", "typing.Set": "set"}
15
+
13
16
  reserved_keywords = {
14
17
  "and",
15
18
  "or",
@@ -153,3 +156,24 @@ def load_data(
153
156
  def identity(x: T) -> T:
154
157
  """Identity function that returns the input unchanged."""
155
158
  return x
159
+
160
+
161
+ def get_classpath(type: Type | Callable) -> str:
162
+ if type.__module__ == "builtins":
163
+ return type.__qualname__
164
+
165
+ if hasattr(type, "__qualname__"):
166
+ return type.__module__ + "." + type.__qualname__
167
+
168
+ # typically a class from the typing module
169
+ if hasattr(type, "_name") and type._name is not None:
170
+ path = type.__module__ + "." + type._name
171
+ if path in TYPE_ALIASES:
172
+ path = TYPE_ALIASES[path]
173
+ elif hasattr(type, "__origin__") and hasattr(type.__origin__, "_name"):
174
+ # found one case which is typing.Union
175
+ path = type.__module__ + "." + type.__origin__._name
176
+ else:
177
+ raise NotImplementedError(type)
178
+
179
+ return path
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: sera-2
3
- Version: 1.17.0
3
+ Version: 1.18.1
4
4
  Summary:
5
5
  Author: Binh Vu
6
6
  Author-email: bvu687@gmail.com
@@ -9,13 +9,15 @@ Classifier: Programming Language :: Python :: 3
9
9
  Classifier: Programming Language :: Python :: 3.12
10
10
  Classifier: Programming Language :: Python :: 3.13
11
11
  Requires-Dist: black (==25.1.0)
12
- Requires-Dist: codegen-2 (>=2.11.1,<3.0.0)
12
+ Requires-Dist: codegen-2 (>=2.12.0,<3.0.0)
13
+ Requires-Dist: graph-wrapper (>=1.7.2,<2.0.0)
13
14
  Requires-Dist: isort (==6.0.1)
14
15
  Requires-Dist: litestar (>=2.15.1,<3.0.0)
15
16
  Requires-Dist: loguru (>=0.7.0,<0.8.0)
16
17
  Requires-Dist: msgspec (>=0.19.0,<0.20.0)
17
- Requires-Dist: serde2 (>=1.9.0,<2.0.0)
18
+ Requires-Dist: serde2 (>=1.9.2,<2.0.0)
18
19
  Requires-Dist: sqlalchemy[asyncio] (>=2.0.41,<3.0.0)
20
+ Requires-Dist: tqdm (>=4.67.1,<5.0.0)
19
21
  Requires-Dist: typer (>=0.12.3,<0.13.0)
20
22
  Project-URL: Repository, https://github.com/binh-vu/sera
21
23
  Description-Content-Type: text/markdown
@@ -8,8 +8,14 @@ sera/libs/api_helper.py,sha256=47y1kcwk3Xd2ZEMnUj_0OwCuUmgwOs5kYrE95BDVUn4,5411
8
8
  sera/libs/api_test_helper.py,sha256=3tRr8sLN4dBSrHgKAHMmyoENI0xh7K_JLel8AvujU7k,1323
9
9
  sera/libs/base_orm.py,sha256=5hOH_diUeaABm3cpE2-9u50VRqG1QW2osPQnvVHIhIA,3365
10
10
  sera/libs/base_service.py,sha256=AX1WoTHte6Z_birkkfagkNE6BrCLTlTjQE4jEsKEaAY,5152
11
- sera/libs/dag/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
- sera/libs/dag/_dag.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
+ sera/libs/directed_computing_graph/__init__.py,sha256=xiF5_I1y9HtQ-cyq02iwkRYgEZvxBB8YIvysCHCLBco,1290
12
+ sera/libs/directed_computing_graph/_dcg.py,sha256=AGTzKVSl-EsSOJlNKPOA1Io7pIxfq0SMXuumq1IExl0,14902
13
+ sera/libs/directed_computing_graph/_edge.py,sha256=iBq6cpLWWyuD99QWTHVEh8naWUJrR4WJJuq5iuCrwHo,1026
14
+ sera/libs/directed_computing_graph/_flow.py,sha256=6v39yKPIDYrQ3KvFqjeAWs88-oQSnDTaED2F3LF2z_I,478
15
+ sera/libs/directed_computing_graph/_fn_signature.py,sha256=73iPUITcRKW0-l6sqjwMSk_FZnJESaKOmUKDGHTOh9Q,1598
16
+ sera/libs/directed_computing_graph/_node.py,sha256=9FsKceW_hq6RYaC7d5YKF5aSXmbAcj-LGakh_GCNgHw,1597
17
+ sera/libs/directed_computing_graph/_runtime.py,sha256=76Ccl1Rj31SkzRJPWFvYNu9ZzUABoeHp5v3tfScekcI,3319
18
+ sera/libs/directed_computing_graph/_type_conversion.py,sha256=_XGvDidOJVmHS4gqdPlhJGzdV34YtNiPF5Kr2nV6ZgE,6806
13
19
  sera/libs/middlewares/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
20
  sera/libs/middlewares/auth.py,sha256=r6aix1ZBwxMd1Jv5hMCTB8a_gFOJQ6egvxIrf3DWEOs,2323
15
21
  sera/libs/middlewares/uscp.py,sha256=H5umW8iEQSCdb_MJ5Im49kxg1E7TpxSg1p2_2A5zI1U,2600
@@ -20,9 +26,9 @@ sera/make/make_python_api.py,sha256=iXGbKQ3IJvsY1ur_fhurr_THFNnH66E3Wl85o0emUbw,
20
26
  sera/make/make_python_model.py,sha256=Nc4vDGgM8icgWBqzNnMgEkLadf5EsZwbbHs3WLW9_co,62778
21
27
  sera/make/make_python_services.py,sha256=0ZpWLwQ7Nwfn8BXAikAB4JRpNknpSJyJgY5b1cjtxV4,2073
22
28
  sera/make/make_typescript_model.py,sha256=1ouYFCeqOlwEzsGBiXUn4VZtLJjJW7GSacdOSlQzhjI,67012
23
- sera/misc/__init__.py,sha256=mPKkik00j3tO_m45VPDJBjm8K85NpymRPl36Kh4hBn8,473
29
+ sera/misc/__init__.py,sha256=Tali_UBtwemETM30a6sP6BbwBMHr3hklPCX0bgiAcbw,513
24
30
  sera/misc/_formatter.py,sha256=aCGYL08l8f3aLODHxSocxBBwkRYEo3K1QzCDEn3suj0,1685
25
- sera/misc/_utils.py,sha256=pGYv8p7m7opiDTLYbsPrhF0YA4WjFff7beMQQZ9NnEs,4095
31
+ sera/misc/_utils.py,sha256=f5mOgDlGh-OVwd6DXou2gTo9eRvJGK_aUT7pM3qzr98,4882
26
32
  sera/models/__init__.py,sha256=vJC5Kzo_N7wd16ocNPy1VvAZDGNiWeiAhWJ4ihATKvA,780
27
33
  sera/models/_class.py,sha256=1J4Bd_LanzhhDWwZFHWGtFYD7lupe_alaB3D02ebNDI,2862
28
34
  sera/models/_collection.py,sha256=ZnQEriKC4X88Zz48Kn1AVZKH-1_l8OgWa-zf2kcQOOE,1414
@@ -36,6 +42,6 @@ sera/models/_parse.py,sha256=ciTLzCkO0q6xA1R_rHbnYJYK3Duo2oh56WeuwxXwJaI,12392
36
42
  sera/models/_property.py,sha256=9yMDxrmbyuF6-29lQjiq163Xzwbk75TlmGBpu0NLpkI,7485
37
43
  sera/models/_schema.py,sha256=VxJEiqgVvbXgcSUK4UW6JnRcggk4nsooVSE6MyXmfNY,1636
38
44
  sera/typing.py,sha256=o_DKfSvs8JpNRQ0kdaTc3BbfdkvibY3uY4tJRt-n2fQ,1023
39
- sera_2-1.17.0.dist-info/METADATA,sha256=aIaXid2dkyX8P9nty-1eFHFBuH0Cpy34vOGDi1wTFkI,852
40
- sera_2-1.17.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
41
- sera_2-1.17.0.dist-info/RECORD,,
45
+ sera_2-1.18.1.dist-info/METADATA,sha256=TATTG19o7HW6O681m1dUdFs92SJ6oqizZUT_vF52zx8,936
46
+ sera_2-1.18.1.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
47
+ sera_2-1.18.1.dist-info/RECORD,,
sera/libs/dag/__init__.py DELETED
File without changes
sera/libs/dag/_dag.py DELETED
File without changes