spization 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. spization/__init__.py +3 -0
  2. spization/__internals/__init__.py +0 -0
  3. spization/__internals/general/__init__.py +53 -0
  4. spization/__internals/graph/__init__.py +27 -0
  5. spization/__internals/graph/add_nodes_and_edges.py +20 -0
  6. spization/__internals/graph/longest_path_lengths_from_source.py +39 -0
  7. spization/__internals/graph/lowest_common_ancestor.py +17 -0
  8. spization/__internals/graph/properties.py +34 -0
  9. spization/__internals/graph/sinks.py +7 -0
  10. spization/__internals/graph/sources.py +7 -0
  11. spization/__internals/graph/strata_sort.py +10 -0
  12. spization/__internals/sp/__init__.py +14 -0
  13. spization/__internals/sp/cbc_decomposition.py +93 -0
  14. spization/__internals/sp/inverse_line_graph.py +63 -0
  15. spization/algorithms/__init__.py +12 -0
  16. spization/algorithms/flexible_sync.py +208 -0
  17. spization/algorithms/naive_strata_sync.py +37 -0
  18. spization/algorithms/pure_node_dup.py +96 -0
  19. spization/algorithms/spanish_strata_sync.py +155 -0
  20. spization/benchmarking/benchmarking.py +252 -0
  21. spization/benchmarking/cost_modelling.py +72 -0
  22. spization/benchmarking/graphs.py +259 -0
  23. spization/modular_decomposition/__init__.py +24 -0
  24. spization/modular_decomposition/directed/directed_md.py +104 -0
  25. spization/modular_decomposition/directed/directed_quotient_graph.py +113 -0
  26. spization/modular_decomposition/directed/objects.py +101 -0
  27. spization/modular_decomposition/through_modular_decomposition.py +53 -0
  28. spization/modular_decomposition/undirected/objects.py +105 -0
  29. spization/modular_decomposition/undirected/undirected_md.py +53 -0
  30. spization/modular_decomposition/undirected/undirected_md_naive.py +55 -0
  31. spization/modular_decomposition/undirected/undirected_quotient_graph.py +112 -0
  32. spization/modular_decomposition/utils.py +74 -0
  33. spization/objects/__init__.py +32 -0
  34. spization/objects/edges.py +4 -0
  35. spization/objects/nodes.py +13 -0
  36. spization/objects/splits.py +140 -0
  37. spization/utils/__init__.py +55 -0
  38. spization/utils/bsp_to_sp.py +25 -0
  39. spization/utils/compositions.py +62 -0
  40. spization/utils/critical_path_cost.py +86 -0
  41. spization/utils/dependencies_are_maintained.py +29 -0
  42. spization/utils/get_ancestors.py +54 -0
  43. spization/utils/get_node_counter.py +20 -0
  44. spization/utils/get_nodes.py +19 -0
  45. spization/utils/has_no_duplicate_nodes.py +6 -0
  46. spization/utils/is_empty.py +7 -0
  47. spization/utils/normalize.py +41 -0
  48. spization/utils/random_sp.py +27 -0
  49. spization/utils/replace_node.py +26 -0
  50. spization/utils/sp_to_bsp.py +43 -0
  51. spization/utils/sp_to_spg.py +26 -0
  52. spization/utils/spg_to_sp.py +116 -0
  53. spization/utils/ttspg_to_spg.py +25 -0
  54. spization/utils/work_cost.py +51 -0
  55. spization-1.0.0.dist-info/METADATA +66 -0
  56. spization-1.0.0.dist-info/RECORD +60 -0
  57. spization-1.0.0.dist-info/WHEEL +5 -0
  58. spization-1.0.0.dist-info/entry_points.txt +2 -0
  59. spization-1.0.0.dist-info/licenses/LICENSE +0 -0
  60. spization-1.0.0.dist-info/top_level.txt +1 -0
spization/__init__.py ADDED
@@ -0,0 +1,3 @@
1
+ from spization import algorithms, modular_decomposition, objects, utils
2
+
3
+ __all__ = ["algorithms", "objects", "utils", "modular_decomposition"]
File without changes
@@ -0,0 +1,53 @@
1
+ from functools import reduce
2
+ from itertools import chain
3
+ from typing import Callable, Iterable
4
+
5
+
6
+ def get_only[T](container: Iterable[T]) -> T:
7
+ c = list(container)
8
+ if len(c) != 1:
9
+ raise ValueError(f"Container must only have 1 item, has {len(c)}")
10
+ return c[0]
11
+
12
+
13
+ def get_any[T](container: Iterable[T]) -> T:
14
+ return next(iter(container))
15
+
16
+
17
+ def must[T](value: T | None) -> T:
18
+ if value is None:
19
+ raise ValueError("Used must() on a None value.")
20
+ return value
21
+
22
+
23
+ def flatmap[T, U](
24
+ func: Callable[[T], Iterable[U]], iterable: Iterable[T]
25
+ ) -> Iterable[U]:
26
+ return chain.from_iterable(map(func, iterable))
27
+
28
+
29
+ def are_all_equal[T](iterable: Iterable[T]) -> bool:
30
+ iterator = iter(iterable)
31
+ try:
32
+ first = next(iterator)
33
+ except StopIteration:
34
+ return True
35
+ return all(first == x for x in iterator)
36
+
37
+
38
+ def are_all_disjoint[T](iterable: Iterable[set[T] | frozenset[T]]) -> bool:
39
+ sets = list(iterable)
40
+ if not sets:
41
+ return True
42
+ union = reduce(lambda x, y: x.union(y), sets)
43
+ return len(union) == sum(len(s) for s in sets)
44
+
45
+
46
+ __all__ = [
47
+ "get_any",
48
+ "get_only",
49
+ "must",
50
+ "flatmap",
51
+ "are_all_equal",
52
+ "are_all_disjoint",
53
+ ]
@@ -0,0 +1,27 @@
1
+ from .add_nodes_and_edges import add_edges, add_node, add_nodes
2
+ from .longest_path_lengths_from_source import longest_path_lengths_from_source
3
+ from .lowest_common_ancestor import lowest_common_ancestor
4
+ from .properties import (
5
+ is_2_terminal_dag,
6
+ is_compatible_graph,
7
+ is_single_sourced,
8
+ is_transitively_closed_dag,
9
+ )
10
+ from .sinks import sinks
11
+ from .sources import sources
12
+ from .strata_sort import strata_sort
13
+
14
+ __all__ = [
15
+ "longest_path_lengths_from_source",
16
+ "lowest_common_ancestor",
17
+ "is_2_terminal_dag",
18
+ "is_compatible_graph",
19
+ "is_single_sourced",
20
+ "is_transitively_closed_dag",
21
+ "sinks",
22
+ "sources",
23
+ "strata_sort",
24
+ "add_edges",
25
+ "add_node",
26
+ "add_nodes",
27
+ ]
@@ -0,0 +1,20 @@
1
+ from typing import Iterable
2
+
3
+ from networkx import DiGraph
4
+
5
+ from spization.objects import DiEdge, Node
6
+
7
+
8
+ def add_node(g: DiGraph) -> Node:
9
+ n = max(g.nodes(), default=-1) + 1
10
+ g.add_node(n)
11
+ return n
12
+
13
+
14
+ def add_nodes(g: DiGraph, n: int) -> list[Node]:
15
+ return [add_node(g) for _ in range(n)]
16
+
17
+
18
+ def add_edges(g: DiGraph, edges: Iterable[DiEdge]) -> None:
19
+ for edge in edges:
20
+ g.add_edge(edge[0], edge[1])
@@ -0,0 +1,39 @@
1
+ import networkx as nx
2
+ from multimethod import multimethod
3
+ from networkx import DiGraph
4
+
5
+ from spization.__internals.general import get_only
6
+ from spization.objects import Node
7
+
8
+ from .properties import is_single_sourced
9
+ from .sources import sources
10
+
11
+
12
+ @multimethod
13
+ def longest_path_lengths_from_source(g: DiGraph) -> dict[Node, int]:
14
+ assert is_single_sourced(g)
15
+ dist: dict[Node, int] = dict.fromkeys(g.nodes, -1)
16
+ root: Node = get_only(sources(g))
17
+ dist[root] = 0
18
+ topo_order = nx.topological_sort(g)
19
+ for n in topo_order:
20
+ if n == root:
21
+ continue
22
+ dist[n] = 1 + max(dist[p] for p in g.predecessors(n))
23
+ return dist
24
+
25
+
26
+ @multimethod
27
+ def longest_path_lengths_from_source(
28
+ g: DiGraph, cost_map: dict[Node, int | float]
29
+ ) -> dict[Node, int | float]:
30
+ assert is_single_sourced(g)
31
+ dist: dict[Node, int | float] = dict.fromkeys(g.nodes, -1)
32
+ root: Node = get_only(sources(g))
33
+ dist[root] = cost_map[root]
34
+ topo_order = nx.topological_sort(g)
35
+ for n in topo_order:
36
+ if n == root:
37
+ continue
38
+ dist[n] = cost_map[n] + max(dist[p] for p in g.predecessors(n))
39
+ return dist
@@ -0,0 +1,17 @@
1
+ from typing import Optional
2
+
3
+ import networkx as nx
4
+ from networkx import DiGraph
5
+
6
+ from spization.__internals.general import get_any
7
+ from spization.objects import Node
8
+
9
+
10
+ def lowest_common_ancestor(g: DiGraph, nodes: set[Node]) -> Optional[Node]:
11
+ assert all(n in g.nodes() for n in nodes)
12
+ lca: Optional[Node] = get_any(nodes)
13
+ for n in nodes:
14
+ lca = nx.lowest_common_ancestor(g, lca, n)
15
+ if lca is None:
16
+ return lca
17
+ return lca
@@ -0,0 +1,34 @@
1
+ import networkx as nx
2
+ from networkx import DiGraph
3
+
4
+ from spization.objects import Node
5
+
6
+ from .sinks import sinks
7
+ from .sources import sources
8
+
9
+
10
+ def is_2_terminal_dag(g: DiGraph) -> bool:
11
+ if not nx.is_directed_acyclic_graph(g):
12
+ return False
13
+
14
+ return len(sources(g)) == 1 and len(sinks(g)) == 1
15
+
16
+
17
+ def is_compatible_graph(g: DiGraph) -> bool:
18
+ return all(isinstance(node, Node) for node in g.nodes())
19
+
20
+
21
+ def is_single_sourced(g: DiGraph) -> bool:
22
+ return len(sources(g)) == 1
23
+
24
+
25
+ def is_transitively_closed(g: DiGraph) -> bool:
26
+ for node in g.nodes():
27
+ for descendant in nx.descendants(g, node):
28
+ if not g.has_edge(node, descendant):
29
+ return False
30
+ return True
31
+
32
+
33
+ def is_transitively_closed_dag(g: DiGraph) -> bool:
34
+ return nx.is_directed_acyclic_graph(g) and is_transitively_closed(g)
@@ -0,0 +1,7 @@
1
+ from networkx import DiGraph
2
+
3
+ from spization.objects import Node
4
+
5
+
6
+ def sinks(g: DiGraph) -> set[Node]:
7
+ return {node for node, out_degree in g.out_degree() if out_degree == 0}
@@ -0,0 +1,7 @@
1
+ from networkx import DiGraph
2
+
3
+ from spization.objects import Node
4
+
5
+
6
+ def sources(g: DiGraph) -> set[Node]:
7
+ return {node for node, in_degree in g.in_degree() if in_degree == 0}
@@ -0,0 +1,10 @@
1
+ from networkx import DiGraph
2
+
3
+ from spization.objects import Node
4
+
5
+ from .longest_path_lengths_from_source import longest_path_lengths_from_source
6
+
7
+
8
+ def strata_sort(g: DiGraph) -> list[Node]:
9
+ depth_map: dict[Node, int] = longest_path_lengths_from_source(g)
10
+ return sorted(depth_map.keys(), key=lambda node: depth_map[node])
@@ -0,0 +1,14 @@
1
+ from .cbc_decomposition import (
2
+ BipartiteComponent,
3
+ CompleteBipartiteCompositeDecomposition,
4
+ cbc_decomposition,
5
+ )
6
+ from .inverse_line_graph import InverseLineGraphResult, inverse_line_graph
7
+
8
+ __all__ = [
9
+ "inverse_line_graph",
10
+ "InverseLineGraphResult",
11
+ "cbc_decomposition",
12
+ "CompleteBipartiteCompositeDecomposition",
13
+ "BipartiteComponent",
14
+ ]
@@ -0,0 +1,93 @@
1
+ from collections import deque
2
+ from dataclasses import dataclass
3
+ from typing import Optional
4
+
5
+ from networkx import DiGraph
6
+
7
+ from spization.__internals.general import get_only
8
+ from spization.__internals.graph import sinks, sources
9
+ from spization.objects import DiEdge, Node
10
+
11
+
12
+ @dataclass(frozen=True)
13
+ class BipartiteComponent:
14
+ head_nodes: frozenset[Node]
15
+ tail_nodes: frozenset[Node]
16
+
17
+
18
+ CompleteBipartiteCompositeDecomposition = set[BipartiteComponent]
19
+
20
+
21
+ def is_complete_bipartite_digraph(g: DiGraph, head: frozenset[Node]) -> bool:
22
+ sinks = set(g.nodes) - head
23
+ for source in head:
24
+ for sink in sinks:
25
+ if not g.has_edge(source, sink):
26
+ return False
27
+ return True
28
+
29
+
30
+ def cbc_decomposition(g: DiGraph) -> Optional[CompleteBipartiteCompositeDecomposition]:
31
+ edges_to_process = deque(sorted(g.edges()))
32
+
33
+ already_in_a_head: set[Node] = set()
34
+ already_in_a_tail: set[Node] = set()
35
+ already_processed: set[DiEdge] = set()
36
+ result: CompleteBipartiteCompositeDecomposition = set()
37
+
38
+ while edges_to_process:
39
+ e = edges_to_process.pop()
40
+ if e in already_processed:
41
+ continue
42
+
43
+ head = frozenset(g.predecessors(e[1]))
44
+ tail = frozenset(g.successors(e[0]))
45
+
46
+ if head & tail:
47
+ return None
48
+
49
+ from_head_to_tail = {(u, v) for u in head for v in tail if g.has_edge(u, v)}
50
+
51
+ subgraph = g.subgraph(head | tail)
52
+
53
+ if not is_complete_bipartite_digraph(subgraph, head):
54
+ return None
55
+
56
+ for u, v in subgraph.edges():
57
+ if (u, v) not in from_head_to_tail:
58
+ return None
59
+
60
+ out_edges = {(u, v) for u in head for v in g.successors(u)}
61
+ if out_edges != from_head_to_tail:
62
+ return None
63
+
64
+ in_edges = {(u, v) for v in tail for u in g.predecessors(v)}
65
+ if in_edges != from_head_to_tail:
66
+ return None
67
+
68
+ result.add(BipartiteComponent(head, tail))
69
+
70
+ already_processed |= from_head_to_tail
71
+ already_in_a_head.update(head)
72
+ already_in_a_tail.update(tail)
73
+
74
+ assert already_in_a_head == set(g.nodes) - sinks(g)
75
+ assert already_in_a_tail == set(g.nodes) - sources(g)
76
+
77
+ return result
78
+
79
+
80
+ def get_component_containing_node_in_head(
81
+ cbc: CompleteBipartiteCompositeDecomposition, n: Node
82
+ ) -> Optional[BipartiteComponent]:
83
+ found: set[BipartiteComponent] = set(filter(lambda bc: n in bc.head_nodes, cbc))
84
+ assert len(found) <= 1
85
+ return get_only(found) if found else None
86
+
87
+
88
+ def get_component_containing_node_in_tail(
89
+ cbc: CompleteBipartiteCompositeDecomposition, n: Node
90
+ ) -> Optional[BipartiteComponent]:
91
+ found: set[BipartiteComponent] = set(filter(lambda bc: n in bc.tail_nodes, cbc))
92
+ assert len(found) <= 1
93
+ return get_only(found) if found else None
@@ -0,0 +1,63 @@
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+ from bidict import bidict
5
+ from networkx import DiGraph, MultiDiGraph
6
+
7
+ from spization.__internals.graph import add_node, sinks, sources
8
+ from spization.__internals.sp.cbc_decomposition import (
9
+ BipartiteComponent,
10
+ cbc_decomposition,
11
+ get_component_containing_node_in_head,
12
+ get_component_containing_node_in_tail,
13
+ )
14
+ from spization.objects import MultiDiEdge, Node, SerialParallelDecomposition
15
+
16
+
17
+ @dataclass
18
+ class InverseLineGraphResult:
19
+ graph: MultiDiGraph
20
+ inverse_edge_to_line_node_map: bidict[MultiDiEdge, SerialParallelDecomposition]
21
+
22
+
23
+ def inverse_line_graph(g: DiGraph) -> Optional[InverseLineGraphResult]:
24
+ cbc_decomp = cbc_decomposition(g)
25
+ if cbc_decomp is None:
26
+ return None
27
+
28
+ result_graph = MultiDiGraph()
29
+ alpha: Node = add_node(result_graph)
30
+ omega: Node = add_node(result_graph)
31
+
32
+ component_nodes = bidict(
33
+ {bi_comp: add_node(result_graph) for bi_comp in cbc_decomp}
34
+ )
35
+
36
+ def h(n: Node) -> BipartiteComponent:
37
+ cmp = get_component_containing_node_in_head(cbc_decomp, n)
38
+ assert cmp is not None
39
+ return cmp
40
+
41
+ def t(n: Node) -> BipartiteComponent:
42
+ cmp = get_component_containing_node_in_tail(cbc_decomp, n)
43
+ assert cmp is not None
44
+ return cmp
45
+
46
+ srcs = sources(g)
47
+ snks = sinks(g)
48
+
49
+ def src_for_node(v: Node) -> Node:
50
+ return alpha if v in srcs else component_nodes[t(v)]
51
+
52
+ def dst_for_node(v: Node) -> Node:
53
+ return omega if v in snks else component_nodes[h(v)]
54
+
55
+ inverse_edge_to_line_node: bidict[MultiDiEdge, SerialParallelDecomposition] = bidict()
56
+
57
+ for v in g.nodes:
58
+ src, dst = src_for_node(v), dst_for_node(v)
59
+ idx = result_graph.add_edge(src, dst)
60
+ edge: MultiDiEdge = (src, dst, idx)
61
+ inverse_edge_to_line_node[edge] = v
62
+
63
+ return InverseLineGraphResult(result_graph, inverse_edge_to_line_node)
@@ -0,0 +1,12 @@
1
+ from .flexible_sync import flexible_sync
2
+ from .naive_strata_sync import naive_strata_sync
3
+ from .pure_node_dup import pure_node_dup, tree_pure_node_dup
4
+ from .spanish_strata_sync import spanish_strata_sync
5
+
6
+ __all__ = [
7
+ "naive_strata_sync",
8
+ "pure_node_dup",
9
+ "tree_pure_node_dup",
10
+ "spanish_strata_sync",
11
+ "flexible_sync",
12
+ ]
@@ -0,0 +1,208 @@
1
+ import networkx as nx
2
+ from networkx import DiGraph
3
+
4
+ from spization.__internals.general import get_only, must
5
+ from spization.__internals.graph import (
6
+ is_compatible_graph,
7
+ is_single_sourced,
8
+ longest_path_lengths_from_source,
9
+ lowest_common_ancestor,
10
+ sources,
11
+ )
12
+ from spization.objects import (
13
+ Node,
14
+ NodeRole,
15
+ SerialParallelDecomposition,
16
+ get_initial_node_role_map,
17
+ )
18
+ from spization.utils import (
19
+ contract_out_nodes_of_role,
20
+ critical_path_cost,
21
+ dependencies_are_maintained,
22
+ get_critical_path_cost_map,
23
+ spg_to_sp,
24
+ )
25
+
26
+
27
+ def get_component(SP: DiGraph, nodes: set[Node]) -> set[Node]:
28
+ parents = set().union(*[SP.predecessors(node) for node in nodes])
29
+ children = set().union(*[nx.descendants(SP, p) for p in parents])
30
+ other_parents = set().union(*[SP.predecessors(c) for c in children])
31
+ return parents | children | other_parents
32
+
33
+
34
+ def get_forest(
35
+ SP: DiGraph, handle: Node, component: set[Node], node_roles: dict[Node, NodeRole]
36
+ ) -> set[Node]:
37
+ subtrees = [
38
+ (set(nx.descendants(SP, node)) | {node}) for node in SP.successors(handle)
39
+ ]
40
+ subtrees = [subtree for subtree in subtrees if subtree & component]
41
+ forest = set().union(*subtrees) | {handle}
42
+ forest = {node for node in forest if node_roles[node] != NodeRole.SYNC}
43
+ return forest
44
+
45
+
46
+ def get_up_and_down(
47
+ nodes: set[Node],
48
+ SP: DiGraph,
49
+ forest: set[Node],
50
+ cost_map: dict[Node, float],
51
+ node_roles: dict[Node, NodeRole],
52
+ ) -> tuple[set[Node], set[Node], set[Node], set[Node]]:
53
+ SP = contract_out_nodes_of_role(SP, NodeRole.SYNC, node_roles)
54
+
55
+ base_down = set(nodes)
56
+ base_up = set().union(*[nx.ancestors(SP, node) for node in nodes]) & forest
57
+ assignable_nodes = forest - (base_up | base_down)
58
+ critical_path_cost_map = get_critical_path_cost_map(SP.subgraph(forest), cost_map)
59
+
60
+ def get_partitions():
61
+ bipartitions = set()
62
+ bipartitions.add((frozenset(base_up), frozenset(base_down | assignable_nodes)))
63
+ for node in assignable_nodes:
64
+ reference_cost = critical_path_cost_map[node]
65
+ up = base_up | {
66
+ node
67
+ for node in assignable_nodes
68
+ if critical_path_cost_map[node] <= reference_cost
69
+ }
70
+ down = (base_down | assignable_nodes) - up
71
+ bipartitions.add((frozenset(up), frozenset(down)))
72
+ return bipartitions
73
+
74
+ bipartitions = get_partitions()
75
+
76
+ def is_valid_bipartition(up: set[Node], down: set[Node]) -> bool:
77
+ for node in nodes:
78
+ if node not in down:
79
+ return False
80
+
81
+ for node in SP.nodes():
82
+ if node in down:
83
+ if any(child in up for child in SP.successors(node)):
84
+ return False
85
+ parents = SP.predecessors(node)
86
+ if any(p in forest and p not in up and p not in down for p in parents):
87
+ return False
88
+ return True
89
+
90
+ valid_partitions = [
91
+ (up, down) for up, down in bipartitions if is_valid_bipartition(up, down)
92
+ ]
93
+
94
+ assert valid_partitions
95
+
96
+ def partition_cost(
97
+ partition: tuple[set[Node], set[Node]],
98
+ ) -> tuple[float, float, float]:
99
+ up, down = partition
100
+ up_cost = critical_path_cost(SP.subgraph(up), cost_map)
101
+ down_cost = critical_path_cost(SP.subgraph(down), cost_map)
102
+ return (up_cost + down_cost, down_cost, len(down))
103
+
104
+ best_up, best_down = min(valid_partitions, key=partition_cost)
105
+ up_subgraph = SP.subgraph(best_up)
106
+ up_frontier = {node for node in best_up if up_subgraph.out_degree(node) == 0}
107
+
108
+ down_subgraph = SP.subgraph(best_down)
109
+ down_frontier = {node for node in best_down if down_subgraph.in_degree(node) == 0}
110
+
111
+ return best_up, best_down, up_frontier, down_frontier
112
+
113
+
114
+ def edges_to_remove(
115
+ SP: DiGraph, up: set[Node], down: set[Node], node_roles: dict[Node, NodeRole]
116
+ ) -> set[tuple[Node, Node]]:
117
+ to_remove = set()
118
+ for u in up:
119
+ for v in SP.successors(u):
120
+ if v in down:
121
+ to_remove.add((u, v))
122
+ for node in list(SP.nodes()):
123
+ if (
124
+ node_roles[node] == NodeRole.SYNC
125
+ and all(p in up for p in SP.predecessors(node))
126
+ and all(s in down for s in SP.successors(node))
127
+ ):
128
+ SP.remove_node(node)
129
+ return to_remove
130
+
131
+
132
+ def edges_to_add(up: set[Node], down: set[Node], sync: Node) -> set[tuple[Node, Node]]:
133
+ to_add: set[tuple[Node, Node]] = set()
134
+ for u in up:
135
+ to_add.add((u, sync))
136
+ for d in down:
137
+ to_add.add((sync, d))
138
+ return to_add
139
+
140
+
141
+ def get_next_nodes(SP: DiGraph, g: DiGraph, cost_map: dict[Node, float]) -> set[Node]:
142
+ sp_longest_paths = longest_path_lengths_from_source(SP, cost_map)
143
+
144
+ candidate_nodes: set[Node] = {
145
+ node
146
+ for node in g.nodes()
147
+ if node not in SP.nodes()
148
+ and all(parent in SP.nodes() for parent in g.predecessors(node))
149
+ }
150
+
151
+ assert candidate_nodes
152
+
153
+ critical_path_costs = {}
154
+ for node in candidate_nodes:
155
+ parent_costs = {sp_longest_paths[parent] for parent in g.predecessors(node)}
156
+ critical_path_costs[node] = cost_map[node] + max(parent_costs)
157
+
158
+ ref_node = min(
159
+ critical_path_costs.keys(),
160
+ key=lambda node: (critical_path_costs.get(node), node),
161
+ )
162
+
163
+ nodes = {ref_node}
164
+ for node in candidate_nodes:
165
+ if g.predecessors(node) == g.predecessors(ref_node):
166
+ nodes.add(node)
167
+ return nodes
168
+
169
+
170
+ def flexible_sync(
171
+ g: DiGraph, cost_map: dict[Node, float]
172
+ ) -> SerialParallelDecomposition:
173
+ assert is_single_sourced(g) and is_compatible_graph(g)
174
+ g = nx.transitive_reduction(g)
175
+ node_roles = get_initial_node_role_map(g.nodes)
176
+ next_sync = max(g.nodes(), default=-1) + 1
177
+ SP = DiGraph()
178
+ cost_map = cost_map.copy()
179
+ root: Node = get_only(sources(g))
180
+ SP.add_node(root)
181
+ node = root
182
+ while not set(g.nodes).issubset(SP.nodes):
183
+ nodes = get_next_nodes(SP, g, cost_map)
184
+ SP.add_nodes_from(nodes)
185
+ for node in nodes:
186
+ SP.add_edges_from(g.in_edges(node))
187
+ SP = nx.transitive_reduction(
188
+ SP
189
+ )
190
+ component: set[Node] = get_component(SP, nodes)
191
+ handle = must(lowest_common_ancestor(SP, component))
192
+ forest: set[Node] = get_forest(SP, handle, component, node_roles)
193
+ up, down, up_frontier, down_frontier = get_up_and_down(
194
+ nodes, SP, forest, cost_map, node_roles
195
+ )
196
+
197
+ sync = next_sync
198
+ next_sync += 1
199
+ SP.add_node(sync)
200
+ node_roles[sync] = NodeRole.SYNC
201
+ cost_map[sync] = 0
202
+ SP.remove_edges_from(edges_to_remove(SP, up, down, node_roles))
203
+ SP.add_edges_from(edges_to_add(up_frontier, down_frontier, sync))
204
+ SP = contract_out_nodes_of_role(SP, NodeRole.SYNC, node_roles)
205
+ decomp = spg_to_sp(SP)
206
+ assert decomp is not None
207
+ assert dependencies_are_maintained(g, decomp)
208
+ return decomp
@@ -0,0 +1,37 @@
1
+ from collections import defaultdict
2
+
3
+ from networkx import DiGraph
4
+
5
+ from spization.__internals.graph import (
6
+ is_2_terminal_dag,
7
+ is_compatible_graph,
8
+ longest_path_lengths_from_source,
9
+ )
10
+ from spization.objects import (
11
+ Node,
12
+ SerialParallelDecomposition,
13
+ )
14
+ from spization.utils import (
15
+ normalize,
16
+ sp_parallel_composition,
17
+ sp_serial_composition,
18
+ )
19
+
20
+
21
+ def naive_strata_sync(g: DiGraph) -> SerialParallelDecomposition:
22
+ assert is_2_terminal_dag(g) and is_compatible_graph(g)
23
+
24
+ longest_path_lengths: dict[Node, int] = longest_path_lengths_from_source(g)
25
+
26
+ groups = defaultdict(list)
27
+ for node, length in longest_path_lengths.items():
28
+ groups[length].append(node)
29
+
30
+ sp: SerialParallelDecomposition = sp_serial_composition(
31
+ [
32
+ sp_parallel_composition(list(group))
33
+ for _, group in sorted(groups.items(), key=lambda t: t[0])
34
+ ]
35
+ )
36
+ sp = normalize(sp)
37
+ return sp