jaxspec 0.1.4__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,151 @@
1
+ """Helper functions to deal with the graph logic within model building"""
2
+
3
+ import re
4
+
5
+ from collections.abc import Callable
6
+ from uuid import uuid4
7
+
8
+ import networkx as nx
9
+
10
+
11
+ def get_component_names(graph: nx.DiGraph):
12
+ """
13
+ Get the set of component names from the nodes of a graph.
14
+
15
+ Parameters:
16
+ graph: The graph to get the component names from.
17
+ """
18
+ return set(
19
+ data["name"] for _, data in graph.nodes(data=True) if "component" in data.get("type")
20
+ )
21
+
22
+
23
+ def increment_name(name: str, used_names: set):
24
+ """
25
+ Increment the suffix number in a name if it is formated as 'name_1'.
26
+
27
+ Parameters:
28
+ name: The name to increment.
29
+ used_names: The set of names that are already used.
30
+ """
31
+ # Use regex to extract base name and suffix number
32
+ match = re.match(r"(.*?)(?:_(\d+))?$", name)
33
+ base_name = match.group(1)
34
+ suffix = match.group(2)
35
+ if suffix:
36
+ number = int(suffix)
37
+ else:
38
+ number = 1 # Start from 1 if there is no suffix
39
+
40
+ new_name = name
41
+ while new_name in used_names:
42
+ number += 1
43
+ new_name = f"{base_name}_{number}"
44
+
45
+ return new_name
46
+
47
+
48
+ def compose_with_rename(graph_1: nx.DiGraph, graph_2: nx.DiGraph):
49
+ """
50
+ Compose two graphs by updating the 'name' attributes of nodes in graph_2,
51
+ and return the graph joined on the 'out' node.
52
+
53
+ Parameters:
54
+ graph_1: The first graph to compose.
55
+ graph_2: The second graph to compose.
56
+ """
57
+
58
+ # Initialize the set of used names with names from graph_1
59
+ used_names = get_component_names(graph_1)
60
+
61
+ # Update the 'name' attributes in graph_2 to make them unique
62
+ for node, data in graph_2.nodes(data=True):
63
+ if "component" in data.get("type"):
64
+ original_name = data["name"]
65
+ new_name = original_name
66
+
67
+ if new_name in used_names:
68
+ new_name = increment_name(original_name, used_names)
69
+ data["name"] = new_name
70
+ used_names.add(new_name)
71
+
72
+ else:
73
+ used_names.add(new_name)
74
+
75
+ # Now you can safely compose the graphs
76
+ composed_graph = nx.compose(graph_1, graph_2)
77
+
78
+ return composed_graph
79
+
80
+
81
+ def compose(
82
+ graph_1: nx.DiGraph,
83
+ graph_2: nx.DiGraph,
84
+ operation: str = "",
85
+ operation_func: Callable = lambda x, y: None,
86
+ ):
87
+ """
88
+ Compose two graphs by joining the 'out' node of graph_1 and graph_2, and turning
89
+ it to an 'operation' node with the relevant operator and add a new 'out' node.
90
+
91
+ Parameters:
92
+ graph_1: The first graph to compose.
93
+ graph_2: The second graph to compose.
94
+ operation: The string describing the operation to perform.
95
+ operation_func: The callable that performs the operation.
96
+ """
97
+
98
+ combined_graph = compose_with_rename(graph_1, graph_2)
99
+ node_id = str(uuid4())
100
+ graph = nx.relabel_nodes(combined_graph, {"out": node_id})
101
+ nx.set_node_attributes(graph, {node_id: f"{operation}_operation"}, "type")
102
+ nx.set_node_attributes(graph, {node_id: operation_func}, "operator")
103
+
104
+ # Now add the output node and link it to the operation node
105
+ graph.add_node("out", type="out")
106
+ graph.add_edge(node_id, "out")
107
+
108
+ # Compute the new depth of each node
109
+ longest_path = nx.dag_longest_path_length(graph)
110
+
111
+ for node in graph.nodes:
112
+ nx.set_node_attributes(
113
+ graph,
114
+ {node: longest_path - nx.shortest_path_length(graph, node, "out")},
115
+ "depth",
116
+ )
117
+
118
+ return graph
119
+
120
+
121
+ def export_to_mermaid(graph, file=None):
122
+ mermaid_code = "graph LR\n" # LR = left to right
123
+
124
+ # Add nodes
125
+ for node, attributes in graph.nodes(data=True):
126
+ if attributes["type"] == "out":
127
+ mermaid_code += f' {node}("Output")\n'
128
+
129
+ else:
130
+ operation_type, node_type = attributes["type"].split("_")
131
+
132
+ if node_type == "component":
133
+ name, number = attributes["name"].split("_")
134
+ mermaid_code += f' {node}("{name.capitalize()} ({number})")\n'
135
+
136
+ elif node_type == "operation":
137
+ if operation_type == "add":
138
+ mermaid_code += f" {node}{{**+**}}\n"
139
+
140
+ elif operation_type == "mul":
141
+ mermaid_code += f" {node}{{**x**}}\n"
142
+
143
+ # Draw connexion between nodes
144
+ for source, target in graph.edges():
145
+ mermaid_code += f" {source} --> {target}\n"
146
+
147
+ if file is None:
148
+ return mermaid_code
149
+ else:
150
+ with open(file, "w") as f:
151
+ f.write(mermaid_code)