Graphinate 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- graphinate/__init__.py +18 -0
- graphinate/__main__.py +4 -0
- graphinate/builders/__init__.py +55 -0
- graphinate/builders/_builder.py +58 -0
- graphinate/builders/_d3.py +61 -0
- graphinate/builders/_graphql.py +521 -0
- graphinate/builders/_mermaid.py +45 -0
- graphinate/builders/_networkx.py +227 -0
- graphinate/cli.py +123 -0
- graphinate/color.py +100 -0
- graphinate/constants.py +4 -0
- graphinate/converters.py +94 -0
- graphinate/enums.py +44 -0
- graphinate/modeling.py +337 -0
- graphinate/renderers/__init__.py +5 -0
- graphinate/renderers/graphql.py +111 -0
- graphinate/renderers/matplotlib.py +82 -0
- graphinate/server/__init__.py +0 -0
- graphinate/server/starlette/__init__.py +31 -0
- graphinate/server/starlette/views.py +17 -0
- graphinate/server/web/__init__.py +25 -0
- graphinate/server/web/elements/index.html +23 -0
- graphinate/server/web/graphiql/index.html +160 -0
- graphinate/server/web/rapidoc/index.html +17 -0
- graphinate/server/web/static/images/logo-128.png +0 -0
- graphinate/server/web/static/images/logo.svg +50 -0
- graphinate/server/web/static/images/network_graph.png +0 -0
- graphinate/server/web/viewer/index.html +719 -0
- graphinate/server/web/voyager/index.html +55 -0
- graphinate/tools.py +7 -0
- graphinate/typing.py +83 -0
- graphinate-0.12.0.dist-info/METADATA +284 -0
- graphinate-0.12.0.dist-info/RECORD +36 -0
- graphinate-0.12.0.dist-info/WHEEL +4 -0
- graphinate-0.12.0.dist-info/entry_points.txt +2 -0
- graphinate-0.12.0.dist-info/licenses/LICENSE +165 -0
|
@@ -0,0 +1,227 @@
|
|
|
1
|
+
import operator
|
|
2
|
+
from collections import Counter
|
|
3
|
+
from collections.abc import Callable, Hashable, Mapping
|
|
4
|
+
from typing import Any, Union
|
|
5
|
+
|
|
6
|
+
import networkx as nx
|
|
7
|
+
from loguru import logger
|
|
8
|
+
from mappingtools.transformers import simplify
|
|
9
|
+
from networkx.classes.reportviews import EdgeDataView, EdgeView, NodeDataView, NodeView
|
|
10
|
+
|
|
11
|
+
from .. import color
|
|
12
|
+
from ..enums import GraphType
|
|
13
|
+
from ..modeling import GraphModel, Multiplicity
|
|
14
|
+
from ..tools import utcnow
|
|
15
|
+
from ..typing import NodeTypeAbsoluteId, UniverseNode
|
|
16
|
+
from ._builder import Builder
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class NetworkxBuilder(Builder):
|
|
20
|
+
"""Build a NetworkX Graph"""
|
|
21
|
+
|
|
22
|
+
def __init__(self, model: GraphModel, graph_type: GraphType = GraphType.Graph):
|
|
23
|
+
super().__init__(model, graph_type)
|
|
24
|
+
self._graph: nx.Graph | None = None
|
|
25
|
+
|
|
26
|
+
def _initialize_graph(self):
|
|
27
|
+
"""Initialize an empty NetworkX graph with metadata and default attributes."""
|
|
28
|
+
self._graph: nx.Graph = self.graph_type.value(name=self.model.name,
|
|
29
|
+
node_types=Counter(),
|
|
30
|
+
edge_types=Counter())
|
|
31
|
+
|
|
32
|
+
def _graph_edges(self, data, default=None):
|
|
33
|
+
params = {'data': data, 'default': default}
|
|
34
|
+
|
|
35
|
+
if isinstance(self._graph, nx.MultiGraph):
|
|
36
|
+
params['keys'] = True
|
|
37
|
+
|
|
38
|
+
return self._graph.edges(**params)
|
|
39
|
+
|
|
40
|
+
def _populate_node_type(self, node_type: Union[Hashable, UniverseNode] = UniverseNode, **kwargs):
|
|
41
|
+
for parent_node_type, child_node_types in self.model.node_children_types(node_type).items():
|
|
42
|
+
for child_node_type in child_node_types:
|
|
43
|
+
node_type_absolute_id = (parent_node_type, child_node_type)
|
|
44
|
+
self._populate_nodes(node_type_absolute_id, **kwargs)
|
|
45
|
+
|
|
46
|
+
@staticmethod
|
|
47
|
+
def _parent_node_id(node_type_absolute_id: NodeTypeAbsoluteId, **kwargs: Any):
|
|
48
|
+
if node_type_absolute_id[0] is UniverseNode:
|
|
49
|
+
return UniverseNode
|
|
50
|
+
|
|
51
|
+
ids = []
|
|
52
|
+
for k, v in kwargs.items():
|
|
53
|
+
if k[:-3] == node_type_absolute_id[1]:
|
|
54
|
+
break
|
|
55
|
+
ids.append(v)
|
|
56
|
+
|
|
57
|
+
return tuple(ids)
|
|
58
|
+
|
|
59
|
+
def _populate_nodes(self, node_type_absolute_id: NodeTypeAbsoluteId, **kwargs: Any):
|
|
60
|
+
"""Populate graph nodes based on the provided model and ID."""
|
|
61
|
+
for node_model in self.model.node_models[node_type_absolute_id]:
|
|
62
|
+
unique = node_model.uniqueness
|
|
63
|
+
for node in node_model.generator(**kwargs):
|
|
64
|
+
parent_node_id = self._parent_node_id(node_type_absolute_id, **kwargs)
|
|
65
|
+
node_lineage = (*parent_node_id, node.key) if parent_node_id is not UniverseNode else (node.key,)
|
|
66
|
+
node_id = (node.key,) if unique else node_lineage
|
|
67
|
+
|
|
68
|
+
label = node.key
|
|
69
|
+
if node_model.label is not None:
|
|
70
|
+
label = node_model.label(node.value) if callable(node_model.label) else node_model.label
|
|
71
|
+
|
|
72
|
+
node_type = node.__class__.__name__.lower()
|
|
73
|
+
if node_type == 'tuple':
|
|
74
|
+
node_type = node_model.type.lower()
|
|
75
|
+
|
|
76
|
+
if node_id in self._graph:
|
|
77
|
+
logger.debug("Updating node. ID: {}, Label: {}", node_id, label)
|
|
78
|
+
|
|
79
|
+
match node_model.multiplicity:
|
|
80
|
+
case Multiplicity.ADD:
|
|
81
|
+
self._graph.nodes[node_id]['value'] = [self._graph.nodes[node_id]['value'] + node.value]
|
|
82
|
+
case Multiplicity.ALL:
|
|
83
|
+
self._graph.nodes[node_id]['value'].append(node.value)
|
|
84
|
+
case Multiplicity.FIRST:
|
|
85
|
+
...
|
|
86
|
+
case Multiplicity.LAST:
|
|
87
|
+
self._graph.nodes[node_id]['value'] = [node.value]
|
|
88
|
+
|
|
89
|
+
self._graph.nodes[node_id]['magnitude'] += 1
|
|
90
|
+
self._graph.nodes[node_id]['updated'] = utcnow()
|
|
91
|
+
else:
|
|
92
|
+
logger.debug("Adding node. ID: {}, Label: {}", node_id, label)
|
|
93
|
+
self._graph.add_node(node_id,
|
|
94
|
+
label=label,
|
|
95
|
+
type=node_type,
|
|
96
|
+
value=[node.value],
|
|
97
|
+
magnitude=1,
|
|
98
|
+
lineage=list(node_lineage),
|
|
99
|
+
created=utcnow())
|
|
100
|
+
|
|
101
|
+
self._graph.graph['node_types'].update({node_type: 1})
|
|
102
|
+
|
|
103
|
+
if node_model.parent_type is not UniverseNode:
|
|
104
|
+
logger.debug("Adding edge. Source: {}, Target: {}", parent_node_id, node_id)
|
|
105
|
+
self._graph.add_edge(parent_node_id,
|
|
106
|
+
node_id,
|
|
107
|
+
created=utcnow())
|
|
108
|
+
|
|
109
|
+
new_kwargs = kwargs.copy()
|
|
110
|
+
new_kwargs[f"{node_type}_id"] = node.key
|
|
111
|
+
self._populate_node_type(node_model.type, **new_kwargs)
|
|
112
|
+
|
|
113
|
+
def _populate_edges(self, **kwargs: Any):
|
|
114
|
+
"""Populate graph edges based on defined connections."""
|
|
115
|
+
for edge_model, edge_generators in self.model.edge_generators.items():
|
|
116
|
+
for edge_generator in edge_generators:
|
|
117
|
+
for edge in edge_generator(**kwargs):
|
|
118
|
+
edge_id = ((edge.source,), (edge.target,))
|
|
119
|
+
edge_label = edge.label(edge_id) if callable(edge.label) else edge.label
|
|
120
|
+
edge_weight = edge.weight or 1.0
|
|
121
|
+
edge_type = edge.type.lower()
|
|
122
|
+
logger.debug("Adding edge. Source: {}, Target: {}", *edge_id)
|
|
123
|
+
|
|
124
|
+
if isinstance(self._graph, nx.MultiGraph) or edge_id not in self._graph.edges:
|
|
125
|
+
self._graph.add_edge(*edge_id,
|
|
126
|
+
label=edge_label,
|
|
127
|
+
type=edge_type,
|
|
128
|
+
value=[edge.value],
|
|
129
|
+
weight=edge_weight,
|
|
130
|
+
created=utcnow())
|
|
131
|
+
self._graph.graph['edge_types'].update({edge_type: 1})
|
|
132
|
+
else:
|
|
133
|
+
self._graph.edges[edge_id]['value'].append(edge.value)
|
|
134
|
+
self._graph.edges[edge_id]['weight'] += edge_weight
|
|
135
|
+
self._graph.edges[edge_id]['updated'] = utcnow()
|
|
136
|
+
|
|
137
|
+
@staticmethod
|
|
138
|
+
def _rectified_values(name: str, default: Any, elements: Callable[
|
|
139
|
+
[str, Any], NodeView[Any] | EdgeView[Any] | NodeDataView[Any] | EdgeDataView], k: Callable[[Any], Any],
|
|
140
|
+
v: Callable[[Any], Any]) -> dict:
|
|
141
|
+
if callable(default):
|
|
142
|
+
elem = elements(data=name, default=None)
|
|
143
|
+
return {k(e): default(k(e))
|
|
144
|
+
for e in elem
|
|
145
|
+
if (v(e) is None if isinstance(elem, NodeDataView) else v(e) is not None)}
|
|
146
|
+
elif isinstance(default, dict):
|
|
147
|
+
return default
|
|
148
|
+
elif default:
|
|
149
|
+
return {k(e): v(e) for e in elements(data=name, default=default) if v(e) == default}
|
|
150
|
+
else: # default is None or empty collection
|
|
151
|
+
return {k(e): k(e) for e in elements(data=name, default=default) if v(e) is default}
|
|
152
|
+
|
|
153
|
+
def _rectify_node_attributes(self, **defaults):
|
|
154
|
+
for name, default in defaults.items():
|
|
155
|
+
if values := self._rectified_values(
|
|
156
|
+
name,
|
|
157
|
+
default,
|
|
158
|
+
self._graph.nodes,
|
|
159
|
+
operator.itemgetter(0),
|
|
160
|
+
operator.itemgetter(1),
|
|
161
|
+
):
|
|
162
|
+
nx.set_node_attributes(self._graph, values=values, name=name)
|
|
163
|
+
|
|
164
|
+
if default_type := defaults.get('type'):
|
|
165
|
+
type_count = sum(1 for n, d in self._graph.nodes(data='type') if d == default_type)
|
|
166
|
+
if type_count:
|
|
167
|
+
self._graph.graph['node_types'].update({default_type: type_count})
|
|
168
|
+
|
|
169
|
+
def _rectify_edge_attributes(self, **defaults):
|
|
170
|
+
for name, default in defaults.items():
|
|
171
|
+
if values := self._rectified_values(
|
|
172
|
+
name,
|
|
173
|
+
default,
|
|
174
|
+
self._graph_edges,
|
|
175
|
+
lambda x: tuple(x[:-1]),
|
|
176
|
+
lambda x: x[-1]
|
|
177
|
+
):
|
|
178
|
+
nx.set_edge_attributes(self._graph, values=values, name=name)
|
|
179
|
+
|
|
180
|
+
if default_type := defaults.get('type'):
|
|
181
|
+
type_count = sum(1 for *_, d in self._graph_edges(data='type') if d == default_type)
|
|
182
|
+
if type_count:
|
|
183
|
+
self._graph.graph['edge_types'].update({default_type: type_count})
|
|
184
|
+
|
|
185
|
+
def _finalize_graph(self, **node_attributes):
|
|
186
|
+
self._rectify_node_attributes(**node_attributes)
|
|
187
|
+
|
|
188
|
+
if 'color' not in node_attributes:
|
|
189
|
+
self._rectify_node_attributes(color=color.node_color_mapping(self._graph))
|
|
190
|
+
|
|
191
|
+
self._rectify_edge_attributes(**self.default_edge_attributes)
|
|
192
|
+
|
|
193
|
+
for counter_name in ('node_types', 'edge_types'):
|
|
194
|
+
counter = self._graph.graph[counter_name]
|
|
195
|
+
self._graph.graph[counter_name] = simplify(counter)
|
|
196
|
+
|
|
197
|
+
self._graph.graph['created'] = utcnow()
|
|
198
|
+
|
|
199
|
+
def _rectify_model(self, node_attributes: Mapping):
|
|
200
|
+
default_type = node_attributes.get('type')
|
|
201
|
+
default_label = node_attributes.get('label')
|
|
202
|
+
self.model.rectify(_type=default_type, parent_type=default_type, label=default_label)
|
|
203
|
+
|
|
204
|
+
def _build_graph(self, node_attributes: Mapping, **kwargs: Any):
|
|
205
|
+
self._initialize_graph()
|
|
206
|
+
self._populate_node_type(**kwargs)
|
|
207
|
+
self._populate_edges(**kwargs)
|
|
208
|
+
self._finalize_graph(**node_attributes)
|
|
209
|
+
|
|
210
|
+
def build(self, **kwargs: Any) -> nx.Graph:
|
|
211
|
+
"""Build a NetworkX graph representation.
|
|
212
|
+
|
|
213
|
+
Args:
|
|
214
|
+
**kwargs:
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
NetworkX Graph
|
|
218
|
+
"""
|
|
219
|
+
super().build(**kwargs)
|
|
220
|
+
|
|
221
|
+
default_node_attributes = dict(**self.default_node_attributes)
|
|
222
|
+
if 'default_node_attributes' in kwargs:
|
|
223
|
+
default_node_attributes.update(kwargs.pop('default_node_attributes') or {})
|
|
224
|
+
|
|
225
|
+
self._rectify_model(default_node_attributes)
|
|
226
|
+
self._build_graph(default_node_attributes, **kwargs)
|
|
227
|
+
return self._graph
|
graphinate/cli.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
import importlib
|
|
2
|
+
import json
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from types import ModuleType
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import click
|
|
8
|
+
from strawberry import Schema
|
|
9
|
+
|
|
10
|
+
from . import GraphModel, builders, graphql
|
|
11
|
+
from .renderers.graphql import DEFAULT_PORT
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _get_kwargs(ctx: click.Context) -> dict:
|
|
15
|
+
return dict([item.strip('--').split('=') for item in ctx.args if item.startswith("--")]) # NOSONAR
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def import_from_string(import_str: str) -> GraphModel:
|
|
19
|
+
"""Import an object from a string reference {module-name}:{variable-name}
|
|
20
|
+
For example, if `model: GraphModel = GraphModel(...)` is a variable defined in an app.py file,
|
|
21
|
+
then the reference would be app:model.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
if not isinstance(import_str, str):
|
|
25
|
+
raise ImportFromStringError(f"{import_str} is not a string")
|
|
26
|
+
|
|
27
|
+
module_name, _, attrs_names_str = import_str.partition(':')
|
|
28
|
+
if not module_name or not attrs_names_str:
|
|
29
|
+
message = f"Import string '{import_str}' must be in format '<module>:<attribute>'."
|
|
30
|
+
raise ImportFromStringError(message)
|
|
31
|
+
|
|
32
|
+
try:
|
|
33
|
+
module: ModuleType = importlib.import_module(module_name)
|
|
34
|
+
except ModuleNotFoundError as exc:
|
|
35
|
+
if exc.name != module_name:
|
|
36
|
+
raise exc from None
|
|
37
|
+
message = f"Could not import module '{module_name}'."
|
|
38
|
+
raise ImportFromStringError(message) from exc
|
|
39
|
+
|
|
40
|
+
instance_candidate: ModuleType | GraphModel = module
|
|
41
|
+
try:
|
|
42
|
+
for attr_name in attrs_names_str.split('.'):
|
|
43
|
+
instance_candidate = getattr(instance_candidate, attr_name)
|
|
44
|
+
except AttributeError as e:
|
|
45
|
+
message = f"Attribute '{attrs_names_str}' not found in import string reference '{import_str}'."
|
|
46
|
+
raise ImportFromStringError(message) from e
|
|
47
|
+
|
|
48
|
+
if isinstance(instance_candidate, GraphModel):
|
|
49
|
+
return instance_candidate
|
|
50
|
+
else:
|
|
51
|
+
raise ImportFromStringError(f"GraphModel instance cannot be determined from reference '{import_str}'")
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class ImportFromStringError(Exception):
|
|
55
|
+
pass
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class GraphModelType(click.ParamType):
|
|
59
|
+
name = "MODEL"
|
|
60
|
+
|
|
61
|
+
def convert(self,
|
|
62
|
+
value: Any,
|
|
63
|
+
param: click.Parameter | None,
|
|
64
|
+
ctx: click.Context) -> GraphModel: # type: ignore[override]
|
|
65
|
+
if isinstance(value, GraphModel):
|
|
66
|
+
return value
|
|
67
|
+
|
|
68
|
+
try:
|
|
69
|
+
return import_from_string(value)
|
|
70
|
+
except Exception as e:
|
|
71
|
+
self.fail(str(e))
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
model_option = click.option('-m', '--model',
|
|
75
|
+
type=GraphModelType(),
|
|
76
|
+
help="A GraphModel instance reference {module-name}:{GraphModel-instance-variable-name}"
|
|
77
|
+
" For example given a var `model=GraphModel()` defined in app.py file, then the"
|
|
78
|
+
" reference would be app:model")
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@click.group()
|
|
82
|
+
@click.pass_context
|
|
83
|
+
def cli(ctx: click.Context) -> None:
|
|
84
|
+
ctx.ensure_object(dict)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@cli.command()
|
|
88
|
+
@model_option
|
|
89
|
+
@click.pass_context
|
|
90
|
+
def save(ctx: click.Context, model: GraphModel) -> None:
|
|
91
|
+
file_path = Path(f"{model.name}.d3_graph.json")
|
|
92
|
+
|
|
93
|
+
if file_path.is_absolute():
|
|
94
|
+
raise click.ClickException("Please provide a relative file path for saving the graph.")
|
|
95
|
+
|
|
96
|
+
if file_path.parent != Path('.'):
|
|
97
|
+
raise click.ClickException("Saving to subdirectories is not supported. Please provide a file name only.")
|
|
98
|
+
|
|
99
|
+
if file_path.exists():
|
|
100
|
+
click.confirm(f"The file '{file_path}' already exists. Do you want to overwrite it?", abort=True)
|
|
101
|
+
|
|
102
|
+
kwargs = _get_kwargs(ctx)
|
|
103
|
+
with open(file_path, mode='w') as fp:
|
|
104
|
+
graph = builders.D3Builder(model, **kwargs).build()
|
|
105
|
+
json.dump(graph, fp=fp, default=str, **kwargs)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
@cli.command()
|
|
109
|
+
@model_option
|
|
110
|
+
@click.option('-p', '--port', type=int, default=DEFAULT_PORT, help='Port number.')
|
|
111
|
+
@click.option('-b', '--browse', type=bool, default=False, help='Open server address in browser.')
|
|
112
|
+
@click.pass_context
|
|
113
|
+
def server(ctx: click.Context, model: GraphModel, port: int, browse: bool) -> None:
|
|
114
|
+
message = """
|
|
115
|
+
██████╗ ██████╗ █████╗ ██████╗ ██╗ ██╗██╗███╗ ██╗ █████╗ ████████╗███████╗
|
|
116
|
+
██╔════╝ ██╔══██╗██╔══██╗██╔══██╗██║ ██║██║████╗ ██║██╔══██╗╚══██╔══╝██╔════╝
|
|
117
|
+
██║ ███╗██████╔╝███████║██████╔╝███████║██║██╔██╗ ██║███████║ ██║ █████╗
|
|
118
|
+
██║ ██║██╔══██╗██╔══██║██╔═══╝ ██╔══██║██║██║╚██╗██║██╔══██║ ██║ ██╔══╝
|
|
119
|
+
╚██████╔╝██║ ██║██║ ██║██║ ██║ ██║██║██║ ╚████║██║ ██║ ██║ ███████╗
|
|
120
|
+
╚═════╝ ╚═╝ ╚═╝╚═╝ ╚═╝╚═╝ ╚═╝ ╚═╝╚═╝╚═╝ ╚═══╝╚═╝ ╚═╝ ╚═╝ ╚══════╝"""
|
|
121
|
+
click.echo(message)
|
|
122
|
+
schema: Schema = builders.GraphQLBuilder(model).build()
|
|
123
|
+
graphql.server(schema, port=port, browse=browse, **_get_kwargs(ctx))
|
graphinate/color.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
from collections.abc import Mapping, Sequence
|
|
3
|
+
from typing import Union
|
|
4
|
+
|
|
5
|
+
import matplotlib as mpl
|
|
6
|
+
import networkx as nx
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@functools.lru_cache
|
|
11
|
+
def node_color_mapping(graph: nx.Graph, cmap: Union[str, mpl.colors.Colormap] = "tab20") -> Mapping:
|
|
12
|
+
"""Map node types to RGBA colors based on a colormap.
|
|
13
|
+
Args:
|
|
14
|
+
graph: nx.Graph - The input graph for which node colors need to be mapped.
|
|
15
|
+
cmap: Union[str, mpl.colors.Colormap], optional - The colormap used to map values to RGBA colors.
|
|
16
|
+
Default is "tab20".
|
|
17
|
+
Returns:
|
|
18
|
+
Mapping - A dictionary mapping nodes to their corresponding RGBA colors based on the colormap.
|
|
19
|
+
|
|
20
|
+
.. note::
|
|
21
|
+
The graph should have a 'node_types' attribute containing the types of nodes.
|
|
22
|
+
The colormap can be specified as a string or a matplotlib colormap object.
|
|
23
|
+
"""
|
|
24
|
+
if not graph.nodes:
|
|
25
|
+
return {}
|
|
26
|
+
|
|
27
|
+
node_type_keys = graph.graph.get('node_types', {}).keys()
|
|
28
|
+
|
|
29
|
+
if len(node_type_keys) > 1 and 'node' in node_type_keys:
|
|
30
|
+
# Create a new list of keys, preserving order, but excluding 'node'
|
|
31
|
+
final_keys = [k for k in node_type_keys if k != 'node']
|
|
32
|
+
else:
|
|
33
|
+
final_keys = list(node_type_keys)
|
|
34
|
+
|
|
35
|
+
type_lookup = {t: i for i, t in enumerate(final_keys)}
|
|
36
|
+
|
|
37
|
+
color_values_ndarray = np.fromiter(
|
|
38
|
+
(type_lookup.get(graph.nodes[node].get('type'), 0) for node in graph.nodes),
|
|
39
|
+
dtype=int,
|
|
40
|
+
count=len(graph),
|
|
41
|
+
)
|
|
42
|
+
if len(color_values_ndarray) > 1:
|
|
43
|
+
low, high = color_values_ndarray.min(), color_values_ndarray.max()
|
|
44
|
+
else:
|
|
45
|
+
low = high = 0
|
|
46
|
+
|
|
47
|
+
norm = mpl.colors.Normalize(vmin=low, vmax=high, clip=True)
|
|
48
|
+
mapper = mpl.cm.ScalarMappable(norm=norm, cmap=cmap)
|
|
49
|
+
colors = mapper.to_rgba(color_values_ndarray).tolist()
|
|
50
|
+
|
|
51
|
+
color_mapping = dict(zip(graph.nodes, colors))
|
|
52
|
+
return color_mapping
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def color_hex(color: Union[str, Sequence[Union[float, int]]]) -> Union[str, Sequence[Union[float, int]]]:
|
|
56
|
+
"""Get HEX color code
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
color: input color
|
|
60
|
+
Returns:
|
|
61
|
+
Color HEX code
|
|
62
|
+
|
|
63
|
+
.. note::
|
|
64
|
+
If the input is a tuple or list, it should contain either three floats (0-1) or three ints (0-255).
|
|
65
|
+
The function will convert these to a HEX color code.
|
|
66
|
+
"""
|
|
67
|
+
if isinstance(color, (tuple, list)):
|
|
68
|
+
rgb = color[:3]
|
|
69
|
+
|
|
70
|
+
if all(isinstance(c, float) and 0 <= c <= 1 for c in rgb):
|
|
71
|
+
rgb = tuple(int(c * 255) for c in rgb)
|
|
72
|
+
elif all(isinstance(c, int) and 0 <= c <= 255 for c in rgb):
|
|
73
|
+
rgb = tuple(rgb)
|
|
74
|
+
else:
|
|
75
|
+
msg = "Input values should either be a float between 0 and 1 or an int between 0 and 255"
|
|
76
|
+
raise ValueError(msg)
|
|
77
|
+
|
|
78
|
+
r, g, b = rgb
|
|
79
|
+
return f'#{r:02x}{g:02x}{b:02x}'
|
|
80
|
+
|
|
81
|
+
else:
|
|
82
|
+
return color
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def convert_colors_to_hex(graph: nx.Graph, color: str = 'color') -> None:
|
|
86
|
+
"""Convert all color labels in the graph to hexadecimal format.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
graph (nx.Graph): The input graph with node attributes.
|
|
90
|
+
color (str): The attribute name for the color. Default is 'color'.
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
None: The function modifies the graph in place.
|
|
94
|
+
|
|
95
|
+
.. note::
|
|
96
|
+
This function assumes that the color attribute is present in the node data.
|
|
97
|
+
"""
|
|
98
|
+
|
|
99
|
+
color_values = {node: color_hex(data[color]) for node, data in graph.nodes(data=True) if color in data}
|
|
100
|
+
nx.set_node_attributes(graph, values=color_values, name=color)
|
graphinate/constants.py
ADDED
graphinate/converters.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
import base64
|
|
3
|
+
import decimal
|
|
4
|
+
import math
|
|
5
|
+
from types import MappingProxyType
|
|
6
|
+
from typing import Any, Union
|
|
7
|
+
|
|
8
|
+
import strawberry
|
|
9
|
+
|
|
10
|
+
from .constants import DEFAULT_EDGE_DELIMITER, DEFAULT_NODE_DELIMITER
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
'InfNumber',
|
|
14
|
+
'decode_edge_id',
|
|
15
|
+
'decode_id',
|
|
16
|
+
'edge_label_converter',
|
|
17
|
+
'encode_edge_id',
|
|
18
|
+
'encode_id',
|
|
19
|
+
'infnum_to_value',
|
|
20
|
+
'label_converter',
|
|
21
|
+
'node_label_converter',
|
|
22
|
+
'value_to_infnum',
|
|
23
|
+
]
|
|
24
|
+
|
|
25
|
+
InfNumber = Union[float, int, decimal.Decimal]
|
|
26
|
+
|
|
27
|
+
INFINITY_MAPPING: MappingProxyType[str, InfNumber] = MappingProxyType({
|
|
28
|
+
'Infinity': math.inf,
|
|
29
|
+
'+Infinity': math.inf,
|
|
30
|
+
'-Infinity': -math.inf
|
|
31
|
+
})
|
|
32
|
+
|
|
33
|
+
MATH_INF_MAPPING: MappingProxyType[InfNumber, str] = MappingProxyType({
|
|
34
|
+
math.inf: 'Infinity',
|
|
35
|
+
-math.inf: '-Infinity'
|
|
36
|
+
})
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def value_to_infnum(value: str | InfNumber) -> InfNumber:
|
|
40
|
+
return INFINITY_MAPPING.get(value, value)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def infnum_to_value(value: InfNumber) -> InfNumber | str:
|
|
44
|
+
return MATH_INF_MAPPING.get(value, value)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def label_converter(value: Any, delimiter: str) -> str | None:
|
|
48
|
+
if value is not None:
|
|
49
|
+
return delimiter.join(str(v) for v in value) if isinstance(value, tuple) else str(value)
|
|
50
|
+
return value
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def node_label_converter(value: Any) -> str | None:
|
|
54
|
+
return label_converter(value, delimiter=DEFAULT_NODE_DELIMITER)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def edge_label_converter(value: Any) -> str | None:
|
|
58
|
+
return label_converter(tuple(node_label_converter(n) for n in value), delimiter=DEFAULT_EDGE_DELIMITER)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def encode(value: Any, encoding: str = 'utf-8') -> str:
|
|
62
|
+
obj_s: str = repr(value)
|
|
63
|
+
obj_b: bytes = obj_s.encode(encoding)
|
|
64
|
+
enc_b: bytes = base64.urlsafe_b64encode(obj_b)
|
|
65
|
+
enc_s: str = enc_b.decode(encoding)
|
|
66
|
+
return enc_s
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def decode(value: str, encoding: str = 'utf-8') -> Any:
|
|
70
|
+
enc_b: bytes = value.encode(encoding)
|
|
71
|
+
obj_b: bytes = base64.urlsafe_b64decode(enc_b)
|
|
72
|
+
obj_s: str = obj_b.decode(encoding)
|
|
73
|
+
obj: Any = ast.literal_eval(obj_s)
|
|
74
|
+
return obj
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def encode_id(graph_node_id: tuple,
|
|
78
|
+
encoding: str = 'utf-8') -> str:
|
|
79
|
+
return encode(graph_node_id, encoding)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def decode_id(graphql_node_id: strawberry.ID,
|
|
83
|
+
encoding: str = 'utf-8') -> tuple[str, ...]:
|
|
84
|
+
return decode(graphql_node_id, encoding)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def encode_edge_id(edge: tuple, encoding: str = 'utf-8') -> str:
|
|
88
|
+
encoded_edge = tuple(encode_id(n, encoding) for n in edge)
|
|
89
|
+
return encode_id(encoded_edge, encoding)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def decode_edge_id(graphql_edge_id: strawberry.ID, encoding: str = 'utf-8') -> tuple:
|
|
93
|
+
encoded_edge: tuple = decode_id(graphql_edge_id, encoding)
|
|
94
|
+
return tuple(decode_id(enc_node) for enc_node in encoded_edge)
|
graphinate/enums.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
|
|
3
|
+
import networkx as nx
|
|
4
|
+
from typing_extensions import Self
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class GraphType(Enum):
|
|
8
|
+
"""Graph Types
|
|
9
|
+
|
|
10
|
+
The choice of graph class depends on the structure of the graph you want to represent.
|
|
11
|
+
|
|
12
|
+
| **Graph Type** | **Type** | **Self-loops allowed** | **Parallel edges allowed** |
|
|
13
|
+
|----------------|------------|:----------------------:|:--------------------------:|
|
|
14
|
+
| Graph | Undirected | Yes | No |
|
|
15
|
+
| DiGraph | Directed | Yes | No |
|
|
16
|
+
| MultiGraph | Undirected | Yes | Yes |
|
|
17
|
+
| MultiDiGraph | Directed | Yes | Yes |
|
|
18
|
+
|
|
19
|
+
See more here: [NetworkX Reference](https://networkx.org/documentation/stable/reference/classes)
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
Graph = nx.Graph
|
|
23
|
+
DiGraph = nx.DiGraph
|
|
24
|
+
MultiDiGraph = nx.MultiDiGraph
|
|
25
|
+
MultiGraph = nx.MultiGraph
|
|
26
|
+
|
|
27
|
+
@classmethod
|
|
28
|
+
def of(cls, graph: nx.Graph) -> Self:
|
|
29
|
+
"""Determine the graph type based on structure and properties.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
graph (nx.Graph): A NetworkX graph object.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
GraphType: An instance of this Enum matching the input graph.
|
|
36
|
+
"""
|
|
37
|
+
if graph.is_directed() and graph.is_multigraph():
|
|
38
|
+
return cls.MultiDiGraph
|
|
39
|
+
elif graph.is_directed():
|
|
40
|
+
return cls.DiGraph
|
|
41
|
+
elif graph.is_multigraph():
|
|
42
|
+
return cls.MultiGraph
|
|
43
|
+
else:
|
|
44
|
+
return cls.Graph
|