nodebpy 0.3.1__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,512 @@
1
+ # SPDX-License-Identifier: GPL-2.0-or-later
2
+
3
+ # https://link.springer.com/chapter/10.1007/3-540-36151-0_26
4
+ # https://doi.org/10.1016/j.jvlc.2013.11.005
5
+ # https://doi.org/10.1007/978-3-642-11805-0_14
6
+ # https://link.springer.com/chapter/10.1007/978-3-540-31843-9_22
7
+ # https://doi.org/10.7155/jgaa.00088
8
+
9
+ from __future__ import annotations
10
+
11
+ import random
12
+ from collections import defaultdict
13
+ from collections.abc import Callable, Collection, Iterable, Iterator, Sequence
14
+ from dataclasses import replace
15
+ from functools import cache
16
+ from itertools import chain, pairwise
17
+ from math import inf
18
+ from operator import itemgetter
19
+ from statistics import fmean
20
+ from typing import TypeAlias, cast
21
+
22
+ import networkx as nx
23
+
24
+ from .. import config
25
+ from .graph import FROM_SOCKET, TO_SOCKET, Cluster, Kind, Node, Socket, socket_graph
26
+
27
+ # -------------------------------------------------------------------
28
+
29
+ _MixedGraph: TypeAlias = 'nx.DiGraph[Node | Cluster]'
30
+
31
+
32
+ def get_col_nesting_trees(
33
+ columns: Sequence[Collection[Node]],
34
+ T: _MixedGraph,
35
+ ) -> list[_MixedGraph]:
36
+ trees = []
37
+ for col in columns:
38
+ LT = nx.DiGraph()
39
+ edges = nx.edge_bfs(T, col, orientation='reverse')
40
+ LT.add_edges_from([e[:2] for e in edges])
41
+ trees.append(LT)
42
+
43
+ return trees
44
+
45
+
46
+ def expand_multi_inputs(G: nx.MultiDiGraph[Node]) -> None:
47
+ H = socket_graph(G)
48
+ reroutes = {v for v in H if v.owner.is_reroute}
49
+ for v in {s.owner for s in config.multi_input_sort_ids}:
50
+ if v not in G:
51
+ continue
52
+ inputs = sorted({e[2] for e in G.in_edges(v, data=TO_SOCKET)}, key=lambda s: s.idx)
53
+ i = inputs[0].idx
54
+ for socket in inputs:
55
+ if socket not in config.multi_input_sort_ids:
56
+ if i != socket.idx:
57
+ d = next(d for *_, d in G.in_edges(v, data=True) if d[TO_SOCKET] == socket)
58
+ d[TO_SOCKET] = replace(socket, idx=i)
59
+ i += 1
60
+ continue
61
+
62
+ sort_ids = config.multi_input_sort_ids[socket]
63
+ SH = H.subgraph({i[0] for i in sort_ids} | {socket} | reroutes)
64
+ seen = set()
65
+ for base_from_socket in sorted(sort_ids, key=itemgetter(1), reverse=True):
66
+ from_socket = next(
67
+ s for s, t in nx.edge_dfs(SH, base_from_socket) if t == socket and s not in seen)
68
+ d = next(
69
+ d for d in G[from_socket.owner][v].values()
70
+ if d[TO_SOCKET] == socket and d[FROM_SOCKET] == from_socket)
71
+ d[TO_SOCKET] = replace(socket, idx=i)
72
+ seen.add(from_socket)
73
+ i += 1
74
+
75
+
76
+ @cache
77
+ def reflexive_transitive_closure(LT: _MixedGraph) -> _MixedGraph:
78
+ return cast(_MixedGraph, nx.transitive_closure(LT, reflexive=True))
79
+
80
+
81
+ @cache
82
+ def topologically_sorted_clusters(LT: _MixedGraph) -> list[Cluster]:
83
+ return [h for h in nx.topological_sort(LT) if h.type == Kind.CLUSTER]
84
+
85
+
86
+ def crossing_reduction_graph(
87
+ h: Cluster,
88
+ LT: _MixedGraph,
89
+ G: nx.MultiDiGraph[Node],
90
+ ) -> nx.MultiDiGraph[Node | Cluster]:
91
+ G_h = nx.MultiDiGraph()
92
+ G_h.add_nodes_from(LT[h])
93
+ TC = reflexive_transitive_closure(LT)
94
+ for s, t, k, d in G.in_edges(TC[h], data=True, keys=True): # type: ignore
95
+ c = next(c for c in TC.pred[t] if c in LT[h])
96
+
97
+ input_k = TO_SOCKET
98
+ output_k = FROM_SOCKET
99
+ if d[output_k].owner != s:
100
+ input_k, output_k = output_k, input_k
101
+
102
+ if (s, c, k) in G_h.edges and G_h.edges[s, c, k][output_k] == d[output_k]:
103
+ G_h.edges[s, c, k]['weight'] += 1
104
+ continue
105
+
106
+ to_socket = d[input_k] if c.type != Kind.CLUSTER else replace(d[input_k], owner=c, idx=0)
107
+ G_h.add_edge(s, c, weight=1, from_socket=d[output_k], to_socket=to_socket)
108
+
109
+ return G_h
110
+
111
+
112
+ _BALANCING_FAC = 1
113
+
114
+
115
+ class _CrossingReductionGraph:
116
+ graph: nx.MultiDiGraph[Node | Cluster]
117
+
118
+ fixed_LT: _MixedGraph
119
+ free_LT: _MixedGraph
120
+
121
+ fixed_col: list[Node]
122
+ free_col: list[Node]
123
+
124
+ expanded_fixed_col: list[Node]
125
+ reduced_free_col: list[Node | Cluster]
126
+
127
+ fixed_sockets: dict[Node, list[Socket]]
128
+ free_sockets: dict[Node | Cluster, list[Socket]]
129
+
130
+ border_pairs: dict[tuple[Node, Node], list[Node]]
131
+ constrained_clusters: list[Cluster]
132
+
133
+ N: list[Socket]
134
+ S: list[Socket]
135
+ bipartite_edges: list[tuple[Socket, Socket, int]]
136
+
137
+ __slots__ = tuple(__annotations__)
138
+
139
+ def _insert_border_edges(self, is_forwards: bool) -> None:
140
+ self.border_pairs = {}
141
+ free_clusters = {v for v in self.reduced_free_col if v.type == Kind.CLUSTER}
142
+ for c in free_clusters & self.fixed_LT.nodes:
143
+ upper_v = Node(type=Kind.VERTICAL_BORDER)
144
+ lower_v = Node(type=Kind.VERTICAL_BORDER)
145
+ self.expanded_fixed_col.extend((upper_v, lower_v))
146
+
147
+ fac = 1 + len((nx.descendants(self.free_LT, c) & self.fixed_LT.nodes))
148
+ for border_v in upper_v, lower_v:
149
+ self.graph.add_edge(
150
+ border_v,
151
+ c,
152
+ weight=(0.5 * _BALANCING_FAC) * fac,
153
+ from_socket=Socket(border_v, 0, is_forwards),
154
+ to_socket=Socket(c, 0, not is_forwards), # type: ignore
155
+ )
156
+
157
+ bordered_nodes = [
158
+ v for v in nx.descendants(self.fixed_LT, c) if v.type != Kind.CLUSTER]
159
+ self.border_pairs[upper_v, lower_v] = bordered_nodes
160
+
161
+ def _add_bipartite_edges(self) -> None:
162
+ edges = [(d[FROM_SOCKET], d[TO_SOCKET], d) for *_, d in self.graph.edges.data()]
163
+
164
+ if not edges:
165
+ self.N = []
166
+ self.S = []
167
+ self.bipartite_edges = []
168
+ return
169
+
170
+ B = nx.DiGraph()
171
+ B.add_edges_from(edges)
172
+
173
+ N, S = map(set, zip(*B.edges))
174
+ if len(S) > len(N):
175
+ N, S = S, N
176
+ B = nx.reverse_view(B)
177
+
178
+ self.N = sorted(N, key=lambda d: d.idx)
179
+ self.S = sorted(S, key=lambda d: d.idx)
180
+ self.bipartite_edges = list(B.edges.data('weight'))
181
+
182
+ def __init__(
183
+ self,
184
+ G: nx.MultiDiGraph[Node],
185
+ h: Cluster,
186
+ fixed_LT: _MixedGraph,
187
+ free_LT: _MixedGraph,
188
+ is_forwards: bool,
189
+ ) -> None:
190
+ G_h = crossing_reduction_graph(h, free_LT, G)
191
+ self.graph = G_h
192
+
193
+ self.fixed_LT = fixed_LT
194
+ self.free_LT = free_LT
195
+
196
+ fixed_col = next(v.col for v in fixed_LT if v.type != Kind.CLUSTER)
197
+ self.fixed_col = fixed_col
198
+ self.free_col = next(v.col for v in free_LT if v.type != Kind.CLUSTER)
199
+
200
+ G_h.add_nodes_from(fixed_col)
201
+
202
+ self.expanded_fixed_col = fixed_col.copy()
203
+ pos = lambda v: v.col.index(v) if v.type != Kind.CLUSTER else inf
204
+ self.reduced_free_col = sorted(free_LT[h], key=pos)
205
+
206
+ self._insert_border_edges(is_forwards)
207
+
208
+ self.fixed_sockets = {}
209
+ for u in self.expanded_fixed_col:
210
+ if sockets := {e[2] for e in G_h.out_edges(u, data=FROM_SOCKET)}:
211
+ self.fixed_sockets[u] = sorted(sockets, key=lambda d: d.idx, reverse=is_forwards)
212
+
213
+ self.free_sockets = {}
214
+ for v in self.reduced_free_col:
215
+ self.free_sockets[v] = [e[2] for e in G_h.in_edges(v, data=FROM_SOCKET)]
216
+
217
+ self.constrained_clusters = [
218
+ cast(Cluster, v) for v in self.reduced_free_col if v in fixed_LT]
219
+
220
+ self._add_bipartite_edges()
221
+
222
+
223
+ def crossing_reduction_items(
224
+ trees: Iterable[_MixedGraph],
225
+ G: nx.MultiDiGraph[Node],
226
+ is_forwards: bool,
227
+ ) -> list[list[_CrossingReductionGraph]]:
228
+ items = []
229
+ for fixed_LT, free_LT in pairwise(trees):
230
+ crossing_reduction_graphs = [
231
+ _CrossingReductionGraph(G, h, fixed_LT, free_LT, is_forwards)
232
+ for h in topologically_sorted_clusters(free_LT)]
233
+ items.append(crossing_reduction_graphs)
234
+
235
+ return items
236
+
237
+
238
+ # -------------------------------------------------------------------
239
+
240
+
241
+ def sort_expanded_fixed_col(H: _CrossingReductionGraph) -> None:
242
+ pos: dict[Node, float] = {v: i for i, v in enumerate(H.fixed_col)}
243
+
244
+ for (upper_v, lower_v), bordered_nodes in H.border_pairs.items():
245
+ positions = [pos[v] for v in bordered_nodes]
246
+ pos[upper_v] = min(positions) - 0.1
247
+ pos[lower_v] = max(positions) + 0.1
248
+
249
+ H.expanded_fixed_col.sort(key=pos.get) # type: ignore
250
+
251
+
252
+ def calc_socket_ranks(H: _CrossingReductionGraph, is_forwards: bool) -> None:
253
+ for v, sockets in H.fixed_sockets.items():
254
+ incr = 1 / (len(sockets) + 1)
255
+ rank = H.expanded_fixed_col.index(v) + 1
256
+ if is_forwards:
257
+ incr = -incr
258
+
259
+ for socket in sockets:
260
+ rank += incr
261
+ v.cr.socket_ranks[socket] = rank
262
+
263
+
264
+ def random_perturbation() -> float:
265
+ random_amount = random.uniform(-1, 1)
266
+ return random.uniform(0, 1) * random_amount - random_amount / 2
267
+
268
+
269
+ def calc_barycenters(H: _CrossingReductionGraph) -> None:
270
+ for w in H.reduced_free_col:
271
+ if sockets := H.free_sockets[w]:
272
+ w.cr.barycenter = (
273
+ fmean([s.owner.cr.socket_ranks[s] for s in sockets]) + random_perturbation())
274
+
275
+
276
+ def get_barycenter(v: Node | Cluster) -> float:
277
+ barycenter = v.cr.barycenter
278
+ assert barycenter is not None
279
+ return barycenter
280
+
281
+
282
+ def fill_in_unknown_barycenters(col: list[Node | Cluster], is_first_sweep: bool) -> None:
283
+ if is_first_sweep:
284
+ max_b = max([b for v in col if (b := v.cr.barycenter) is not None], default=0) + 2
285
+ for v in col:
286
+ if v.cr.barycenter is None:
287
+ v.cr.barycenter = random.uniform(0, 1) * max_b - 1 + random_perturbation()
288
+ return
289
+
290
+ for i, v in enumerate(col):
291
+ if v.cr.barycenter is not None:
292
+ continue
293
+
294
+ prev_b = get_barycenter(col[i - 1]) if i != 0 else 0
295
+ next_b = next((b for w in col[i + 1:] if (b := w.cr.barycenter) is not None), prev_b + 1)
296
+ v.cr.barycenter = (prev_b + next_b) / 2 + random_perturbation()
297
+
298
+
299
+ def find_violated_constraint(GC: _MixedGraph) -> tuple[Node | Cluster, Node | Cluster] | None:
300
+ active = [v for v in GC if GC[v] and not GC.pred[v]]
301
+ incoming_constraints = defaultdict(list)
302
+ while active:
303
+ v = active.pop(0)
304
+
305
+ for c in incoming_constraints[v]:
306
+ if c[0].cr.barycenter >= v.cr.barycenter:
307
+ return c
308
+
309
+ for t in GC[v]:
310
+ incoming_constraints[t].insert(0, (v, t))
311
+ if len(incoming_constraints[t]) == GC.in_degree[t]:
312
+ active.append(t)
313
+
314
+ return None
315
+
316
+
317
+ def handle_constraints(H: _CrossingReductionGraph) -> None:
318
+
319
+ # Optimization: don't pass constraints to `nx.DiGraph` constructor
320
+ GC = nx.DiGraph()
321
+ GC.add_edges_from(pairwise(H.constrained_clusters))
322
+
323
+ unconstrained = set(H.reduced_free_col) - GC.nodes
324
+ L = {v: [v] for v in H.reduced_free_col}
325
+
326
+ deg = {v: H.graph.degree[v] for v in GC}
327
+ while c := find_violated_constraint(GC):
328
+ v_c = Node(type=Kind.DUMMY)
329
+ s, t = c
330
+
331
+ deg[v_c] = deg[s] + deg[t]
332
+ assert s.cr.barycenter and t.cr.barycenter
333
+ if deg[v_c] > 0:
334
+ v_c.cr.barycenter = (s.cr.barycenter * deg[s] + t.cr.barycenter * deg[t]) / deg[v_c]
335
+ else:
336
+ v_c.cr.barycenter = (s.cr.barycenter + t.cr.barycenter) / 2
337
+
338
+ L[v_c] = L[s] + L[t]
339
+
340
+ nx.relabel_nodes(GC, {s: v_c, t: v_c}, copy=False)
341
+ if (v_c, v_c) in GC.edges:
342
+ GC.remove_edge(v_c, v_c)
343
+
344
+ if v_c not in GC:
345
+ unconstrained.add(v_c)
346
+
347
+ groups = sorted(unconstrained | GC.nodes, key=get_barycenter)
348
+ for i, v in enumerate(chain(*[L[v] for v in groups])):
349
+ v.cr.barycenter = i
350
+
351
+
352
+ def get_cross_count(H: _CrossingReductionGraph) -> int:
353
+ edges = H.bipartite_edges
354
+
355
+ if not edges:
356
+ return 0
357
+
358
+ reduced_free_col = set(H.reduced_free_col)
359
+
360
+ def pos(s: Socket) -> float:
361
+ v = s.owner
362
+ if v in reduced_free_col:
363
+ return v.cr.barycenter # type: ignore
364
+ else:
365
+ return H.expanded_fixed_col.index(v)
366
+
367
+ H.N.sort(key=pos)
368
+ H.S.sort(key=pos)
369
+
370
+ south_indicies = {k: i for i, k in enumerate(H.S)}
371
+ north_indicies = {k: i for i, k in enumerate(H.N)}
372
+
373
+ edges.sort(key=lambda e: south_indicies[e[1]])
374
+ edges.sort(key=lambda e: north_indicies[e[0]])
375
+
376
+ first_idx = 1
377
+ while first_idx < len(H.S):
378
+ first_idx *= 2
379
+
380
+ tree = [0] * (2 * first_idx - 1)
381
+ first_idx -= 1
382
+
383
+ cross_weight = 0
384
+ for _, v, weight in edges:
385
+ idx = south_indicies[v] + first_idx
386
+ tree[idx] += weight
387
+ weight_sum = 0
388
+ while idx > 0:
389
+ if idx % 2 == 1:
390
+ weight_sum += tree[idx + 1]
391
+
392
+ idx = (idx - 1) // 2
393
+ tree[idx] += weight
394
+
395
+ cross_weight += weight * weight_sum
396
+
397
+ return cross_weight
398
+
399
+
400
+ def get_new_col_order(v: Node | Cluster, LT: _MixedGraph) -> Iterator[Node]:
401
+ if v.type == Kind.CLUSTER:
402
+ for w in sorted(LT[v], key=get_barycenter):
403
+ yield from get_new_col_order(w, LT)
404
+ else:
405
+ yield v
406
+
407
+
408
+ @cache
409
+ def non_cluster_descendant(T: _MixedGraph, c: Cluster) -> Node:
410
+ return next(v for _, v in nx.bfs_edges(T, c) if v.type != Kind.CLUSTER)
411
+
412
+
413
+ def sort_reduced_free_columns(items: Iterable[Sequence[_CrossingReductionGraph]]) -> None:
414
+ for crossing_reduction_graphs in items:
415
+
416
+ def pos(v: Node | Cluster) -> int:
417
+ w = non_cluster_descendant(H.free_LT, v) if v.type == Kind.CLUSTER else v
418
+ return H.free_col.index(w)
419
+
420
+ for H in crossing_reduction_graphs:
421
+ H.reduced_free_col.sort(key=pos)
422
+
423
+
424
+ # -------------------------------------------------------------------
425
+
426
+
427
+ def minimized_cross_count(
428
+ columns: Sequence[list[Node]],
429
+ forward_items: list[list[_CrossingReductionGraph]],
430
+ backward_items: list[list[_CrossingReductionGraph]],
431
+ T: _MixedGraph,
432
+ ) -> float:
433
+ cross_count = inf
434
+ is_forwards = random.choice((True, False))
435
+ is_first_sweep = True
436
+ while True:
437
+ for v in T:
438
+ v.cr.reset()
439
+
440
+ if cross_count == 0:
441
+ return 0
442
+
443
+ is_forwards = not is_forwards
444
+ old_cross_count = cross_count
445
+ cross_count = 0
446
+
447
+ items = forward_items if is_forwards else backward_items
448
+ for i, crossing_reduction_graphs in enumerate(items):
449
+ if i == 0:
450
+ clusters = {
451
+ c: j
452
+ for j, v in enumerate(crossing_reduction_graphs[0].fixed_col)
453
+ for c in nx.ancestors(T, v)}
454
+ key = cast(Callable[[Cluster], int], clusters.get)
455
+ else:
456
+ key = get_barycenter
457
+
458
+ for H in crossing_reduction_graphs:
459
+ H.constrained_clusters.sort(key=key)
460
+ sort_expanded_fixed_col(H)
461
+
462
+ calc_socket_ranks(H, is_forwards)
463
+ calc_barycenters(H)
464
+ fill_in_unknown_barycenters(H.reduced_free_col, is_first_sweep)
465
+ handle_constraints(H)
466
+
467
+ cross_count += get_cross_count(H)
468
+
469
+ root = topologically_sorted_clusters(H.free_LT)[0]
470
+ new_order = tuple(get_new_col_order(root, H.free_LT))
471
+ H.free_col.sort(key=new_order.index)
472
+
473
+ if old_cross_count > cross_count:
474
+ sort_reduced_free_columns(forward_items + backward_items)
475
+ best_columns = [c.copy() for c in columns]
476
+ is_first_sweep = False
477
+ else:
478
+ for col, best_col in zip(columns, best_columns):
479
+ col.sort(key=best_col.index)
480
+ break
481
+
482
+ return old_cross_count
483
+
484
+
485
+ def minimize_crossings(G: nx.MultiDiGraph[Node], T: _MixedGraph) -> None:
486
+ columns = G.graph['columns']
487
+ trees = get_col_nesting_trees(columns, T)
488
+ G_ = G.copy()
489
+
490
+ expand_multi_inputs(G_)
491
+
492
+ forward_items = crossing_reduction_items(trees, G_, True)
493
+
494
+ G__ = cast('nx.MultiDiGraph[Node]', nx.reverse_view(G_)) # type: ignore
495
+ backward_items = crossing_reduction_items(reversed(trees), G__, False)
496
+
497
+ # -------------------------------------------------------------------
498
+
499
+ random.seed(0)
500
+ best_cross_count = inf
501
+ best_columns = [c.copy() for c in columns]
502
+ for _ in range(config.SETTINGS.iterations):
503
+ cross_count = minimized_cross_count(columns, forward_items, backward_items, T)
504
+ if cross_count < best_cross_count:
505
+ best_cross_count = cross_count
506
+ best_columns = [c.copy() for c in columns]
507
+ if best_cross_count == 0:
508
+ break
509
+ else:
510
+ for col, best_col in zip(columns, best_columns):
511
+ col.sort(key=best_col.index)
512
+ sort_reduced_free_columns(forward_items + backward_items)