Graphinate 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,227 @@
1
+ import operator
2
+ from collections import Counter
3
+ from collections.abc import Callable, Hashable, Mapping
4
+ from typing import Any, Union
5
+
6
+ import networkx as nx
7
+ from loguru import logger
8
+ from mappingtools.transformers import simplify
9
+ from networkx.classes.reportviews import EdgeDataView, EdgeView, NodeDataView, NodeView
10
+
11
+ from .. import color
12
+ from ..enums import GraphType
13
+ from ..modeling import GraphModel, Multiplicity
14
+ from ..tools import utcnow
15
+ from ..typing import NodeTypeAbsoluteId, UniverseNode
16
+ from ._builder import Builder
17
+
18
+
19
+ class NetworkxBuilder(Builder):
20
+ """Build a NetworkX Graph"""
21
+
22
+ def __init__(self, model: GraphModel, graph_type: GraphType = GraphType.Graph):
23
+ super().__init__(model, graph_type)
24
+ self._graph: nx.Graph | None = None
25
+
26
+ def _initialize_graph(self):
27
+ """Initialize an empty NetworkX graph with metadata and default attributes."""
28
+ self._graph: nx.Graph = self.graph_type.value(name=self.model.name,
29
+ node_types=Counter(),
30
+ edge_types=Counter())
31
+
32
+ def _graph_edges(self, data, default=None):
33
+ params = {'data': data, 'default': default}
34
+
35
+ if isinstance(self._graph, nx.MultiGraph):
36
+ params['keys'] = True
37
+
38
+ return self._graph.edges(**params)
39
+
40
+ def _populate_node_type(self, node_type: Union[Hashable, UniverseNode] = UniverseNode, **kwargs):
41
+ for parent_node_type, child_node_types in self.model.node_children_types(node_type).items():
42
+ for child_node_type in child_node_types:
43
+ node_type_absolute_id = (parent_node_type, child_node_type)
44
+ self._populate_nodes(node_type_absolute_id, **kwargs)
45
+
46
+ @staticmethod
47
+ def _parent_node_id(node_type_absolute_id: NodeTypeAbsoluteId, **kwargs: Any):
48
+ if node_type_absolute_id[0] is UniverseNode:
49
+ return UniverseNode
50
+
51
+ ids = []
52
+ for k, v in kwargs.items():
53
+ if k[:-3] == node_type_absolute_id[1]:
54
+ break
55
+ ids.append(v)
56
+
57
+ return tuple(ids)
58
+
59
+ def _populate_nodes(self, node_type_absolute_id: NodeTypeAbsoluteId, **kwargs: Any):
60
+ """Populate graph nodes based on the provided model and ID."""
61
+ for node_model in self.model.node_models[node_type_absolute_id]:
62
+ unique = node_model.uniqueness
63
+ for node in node_model.generator(**kwargs):
64
+ parent_node_id = self._parent_node_id(node_type_absolute_id, **kwargs)
65
+ node_lineage = (*parent_node_id, node.key) if parent_node_id is not UniverseNode else (node.key,)
66
+ node_id = (node.key,) if unique else node_lineage
67
+
68
+ label = node.key
69
+ if node_model.label is not None:
70
+ label = node_model.label(node.value) if callable(node_model.label) else node_model.label
71
+
72
+ node_type = node.__class__.__name__.lower()
73
+ if node_type == 'tuple':
74
+ node_type = node_model.type.lower()
75
+
76
+ if node_id in self._graph:
77
+ logger.debug("Updating node. ID: {}, Label: {}", node_id, label)
78
+
79
+ match node_model.multiplicity:
80
+ case Multiplicity.ADD:
81
+ self._graph.nodes[node_id]['value'] = [self._graph.nodes[node_id]['value'] + node.value]
82
+ case Multiplicity.ALL:
83
+ self._graph.nodes[node_id]['value'].append(node.value)
84
+ case Multiplicity.FIRST:
85
+ ...
86
+ case Multiplicity.LAST:
87
+ self._graph.nodes[node_id]['value'] = [node.value]
88
+
89
+ self._graph.nodes[node_id]['magnitude'] += 1
90
+ self._graph.nodes[node_id]['updated'] = utcnow()
91
+ else:
92
+ logger.debug("Adding node. ID: {}, Label: {}", node_id, label)
93
+ self._graph.add_node(node_id,
94
+ label=label,
95
+ type=node_type,
96
+ value=[node.value],
97
+ magnitude=1,
98
+ lineage=list(node_lineage),
99
+ created=utcnow())
100
+
101
+ self._graph.graph['node_types'].update({node_type: 1})
102
+
103
+ if node_model.parent_type is not UniverseNode:
104
+ logger.debug("Adding edge. Source: {}, Target: {}", parent_node_id, node_id)
105
+ self._graph.add_edge(parent_node_id,
106
+ node_id,
107
+ created=utcnow())
108
+
109
+ new_kwargs = kwargs.copy()
110
+ new_kwargs[f"{node_type}_id"] = node.key
111
+ self._populate_node_type(node_model.type, **new_kwargs)
112
+
113
+ def _populate_edges(self, **kwargs: Any):
114
+ """Populate graph edges based on defined connections."""
115
+ for edge_model, edge_generators in self.model.edge_generators.items():
116
+ for edge_generator in edge_generators:
117
+ for edge in edge_generator(**kwargs):
118
+ edge_id = ((edge.source,), (edge.target,))
119
+ edge_label = edge.label(edge_id) if callable(edge.label) else edge.label
120
+ edge_weight = edge.weight or 1.0
121
+ edge_type = edge.type.lower()
122
+ logger.debug("Adding edge. Source: {}, Target: {}", *edge_id)
123
+
124
+ if isinstance(self._graph, nx.MultiGraph) or edge_id not in self._graph.edges:
125
+ self._graph.add_edge(*edge_id,
126
+ label=edge_label,
127
+ type=edge_type,
128
+ value=[edge.value],
129
+ weight=edge_weight,
130
+ created=utcnow())
131
+ self._graph.graph['edge_types'].update({edge_type: 1})
132
+ else:
133
+ self._graph.edges[edge_id]['value'].append(edge.value)
134
+ self._graph.edges[edge_id]['weight'] += edge_weight
135
+ self._graph.edges[edge_id]['updated'] = utcnow()
136
+
137
+ @staticmethod
138
+ def _rectified_values(name: str, default: Any, elements: Callable[
139
+ [str, Any], NodeView[Any] | EdgeView[Any] | NodeDataView[Any] | EdgeDataView], k: Callable[[Any], Any],
140
+ v: Callable[[Any], Any]) -> dict:
141
+ if callable(default):
142
+ elem = elements(data=name, default=None)
143
+ return {k(e): default(k(e))
144
+ for e in elem
145
+ if (v(e) is None if isinstance(elem, NodeDataView) else v(e) is not None)}
146
+ elif isinstance(default, dict):
147
+ return default
148
+ elif default:
149
+ return {k(e): v(e) for e in elements(data=name, default=default) if v(e) == default}
150
+ else: # default is None or empty collection
151
+ return {k(e): k(e) for e in elements(data=name, default=default) if v(e) is default}
152
+
153
+ def _rectify_node_attributes(self, **defaults):
154
+ for name, default in defaults.items():
155
+ if values := self._rectified_values(
156
+ name,
157
+ default,
158
+ self._graph.nodes,
159
+ operator.itemgetter(0),
160
+ operator.itemgetter(1),
161
+ ):
162
+ nx.set_node_attributes(self._graph, values=values, name=name)
163
+
164
+ if default_type := defaults.get('type'):
165
+ type_count = sum(1 for n, d in self._graph.nodes(data='type') if d == default_type)
166
+ if type_count:
167
+ self._graph.graph['node_types'].update({default_type: type_count})
168
+
169
+ def _rectify_edge_attributes(self, **defaults):
170
+ for name, default in defaults.items():
171
+ if values := self._rectified_values(
172
+ name,
173
+ default,
174
+ self._graph_edges,
175
+ lambda x: tuple(x[:-1]),
176
+ lambda x: x[-1]
177
+ ):
178
+ nx.set_edge_attributes(self._graph, values=values, name=name)
179
+
180
+ if default_type := defaults.get('type'):
181
+ type_count = sum(1 for *_, d in self._graph_edges(data='type') if d == default_type)
182
+ if type_count:
183
+ self._graph.graph['edge_types'].update({default_type: type_count})
184
+
185
+ def _finalize_graph(self, **node_attributes):
186
+ self._rectify_node_attributes(**node_attributes)
187
+
188
+ if 'color' not in node_attributes:
189
+ self._rectify_node_attributes(color=color.node_color_mapping(self._graph))
190
+
191
+ self._rectify_edge_attributes(**self.default_edge_attributes)
192
+
193
+ for counter_name in ('node_types', 'edge_types'):
194
+ counter = self._graph.graph[counter_name]
195
+ self._graph.graph[counter_name] = simplify(counter)
196
+
197
+ self._graph.graph['created'] = utcnow()
198
+
199
+ def _rectify_model(self, node_attributes: Mapping):
200
+ default_type = node_attributes.get('type')
201
+ default_label = node_attributes.get('label')
202
+ self.model.rectify(_type=default_type, parent_type=default_type, label=default_label)
203
+
204
+ def _build_graph(self, node_attributes: Mapping, **kwargs: Any):
205
+ self._initialize_graph()
206
+ self._populate_node_type(**kwargs)
207
+ self._populate_edges(**kwargs)
208
+ self._finalize_graph(**node_attributes)
209
+
210
+ def build(self, **kwargs: Any) -> nx.Graph:
211
+ """Build a NetworkX graph representation.
212
+
213
+ Args:
214
+ **kwargs:
215
+
216
+ Returns:
217
+ NetworkX Graph
218
+ """
219
+ super().build(**kwargs)
220
+
221
+ default_node_attributes = dict(**self.default_node_attributes)
222
+ if 'default_node_attributes' in kwargs:
223
+ default_node_attributes.update(kwargs.pop('default_node_attributes') or {})
224
+
225
+ self._rectify_model(default_node_attributes)
226
+ self._build_graph(default_node_attributes, **kwargs)
227
+ return self._graph
graphinate/cli.py ADDED
@@ -0,0 +1,123 @@
1
+ import importlib
2
+ import json
3
+ from pathlib import Path
4
+ from types import ModuleType
5
+ from typing import Any
6
+
7
+ import click
8
+ from strawberry import Schema
9
+
10
+ from . import GraphModel, builders, graphql
11
+ from .renderers.graphql import DEFAULT_PORT
12
+
13
+
14
+ def _get_kwargs(ctx: click.Context) -> dict:
15
+ return dict([item.strip('--').split('=') for item in ctx.args if item.startswith("--")]) # NOSONAR
16
+
17
+
18
+ def import_from_string(import_str: str) -> GraphModel:
19
+ """Import an object from a string reference {module-name}:{variable-name}
20
+ For example, if `model: GraphModel = GraphModel(...)` is a variable defined in an app.py file,
21
+ then the reference would be app:model.
22
+ """
23
+
24
+ if not isinstance(import_str, str):
25
+ raise ImportFromStringError(f"{import_str} is not a string")
26
+
27
+ module_name, _, attrs_names_str = import_str.partition(':')
28
+ if not module_name or not attrs_names_str:
29
+ message = f"Import string '{import_str}' must be in format '<module>:<attribute>'."
30
+ raise ImportFromStringError(message)
31
+
32
+ try:
33
+ module: ModuleType = importlib.import_module(module_name)
34
+ except ModuleNotFoundError as exc:
35
+ if exc.name != module_name:
36
+ raise exc from None
37
+ message = f"Could not import module '{module_name}'."
38
+ raise ImportFromStringError(message) from exc
39
+
40
+ instance_candidate: ModuleType | GraphModel = module
41
+ try:
42
+ for attr_name in attrs_names_str.split('.'):
43
+ instance_candidate = getattr(instance_candidate, attr_name)
44
+ except AttributeError as e:
45
+ message = f"Attribute '{attrs_names_str}' not found in import string reference '{import_str}'."
46
+ raise ImportFromStringError(message) from e
47
+
48
+ if isinstance(instance_candidate, GraphModel):
49
+ return instance_candidate
50
+ else:
51
+ raise ImportFromStringError(f"GraphModel instance cannot be determined from reference '{import_str}'")
52
+
53
+
54
+ class ImportFromStringError(Exception):
55
+ pass
56
+
57
+
58
+ class GraphModelType(click.ParamType):
59
+ name = "MODEL"
60
+
61
+ def convert(self,
62
+ value: Any,
63
+ param: click.Parameter | None,
64
+ ctx: click.Context) -> GraphModel: # type: ignore[override]
65
+ if isinstance(value, GraphModel):
66
+ return value
67
+
68
+ try:
69
+ return import_from_string(value)
70
+ except Exception as e:
71
+ self.fail(str(e))
72
+
73
+
74
+ model_option = click.option('-m', '--model',
75
+ type=GraphModelType(),
76
+ help="A GraphModel instance reference {module-name}:{GraphModel-instance-variable-name}"
77
+ " For example given a var `model=GraphModel()` defined in app.py file, then the"
78
+ " reference would be app:model")
79
+
80
+
81
+ @click.group()
82
+ @click.pass_context
83
+ def cli(ctx: click.Context) -> None:
84
+ ctx.ensure_object(dict)
85
+
86
+
87
+ @cli.command()
88
+ @model_option
89
+ @click.pass_context
90
+ def save(ctx: click.Context, model: GraphModel) -> None:
91
+ file_path = Path(f"{model.name}.d3_graph.json")
92
+
93
+ if file_path.is_absolute():
94
+ raise click.ClickException("Please provide a relative file path for saving the graph.")
95
+
96
+ if file_path.parent != Path('.'):
97
+ raise click.ClickException("Saving to subdirectories is not supported. Please provide a file name only.")
98
+
99
+ if file_path.exists():
100
+ click.confirm(f"The file '{file_path}' already exists. Do you want to overwrite it?", abort=True)
101
+
102
+ kwargs = _get_kwargs(ctx)
103
+ with open(file_path, mode='w') as fp:
104
+ graph = builders.D3Builder(model, **kwargs).build()
105
+ json.dump(graph, fp=fp, default=str, **kwargs)
106
+
107
+
108
+ @cli.command()
109
+ @model_option
110
+ @click.option('-p', '--port', type=int, default=DEFAULT_PORT, help='Port number.')
111
+ @click.option('-b', '--browse', type=bool, default=False, help='Open server address in browser.')
112
+ @click.pass_context
113
+ def server(ctx: click.Context, model: GraphModel, port: int, browse: bool) -> None:
114
+ message = """
115
+ ██████╗ ██████╗ █████╗ ██████╗ ██╗ ██╗██╗███╗ ██╗ █████╗ ████████╗███████╗
116
+ ██╔════╝ ██╔══██╗██╔══██╗██╔══██╗██║ ██║██║████╗ ██║██╔══██╗╚══██╔══╝██╔════╝
117
+ ██║ ███╗██████╔╝███████║██████╔╝███████║██║██╔██╗ ██║███████║ ██║ █████╗
118
+ ██║ ██║██╔══██╗██╔══██║██╔═══╝ ██╔══██║██║██║╚██╗██║██╔══██║ ██║ ██╔══╝
119
+ ╚██████╔╝██║ ██║██║ ██║██║ ██║ ██║██║██║ ╚████║██║ ██║ ██║ ███████╗
120
+ ╚═════╝ ╚═╝ ╚═╝╚═╝ ╚═╝╚═╝ ╚═╝ ╚═╝╚═╝╚═╝ ╚═══╝╚═╝ ╚═╝ ╚═╝ ╚══════╝"""
121
+ click.echo(message)
122
+ schema: Schema = builders.GraphQLBuilder(model).build()
123
+ graphql.server(schema, port=port, browse=browse, **_get_kwargs(ctx))
graphinate/color.py ADDED
@@ -0,0 +1,100 @@
1
+ import functools
2
+ from collections.abc import Mapping, Sequence
3
+ from typing import Union
4
+
5
+ import matplotlib as mpl
6
+ import networkx as nx
7
+ import numpy as np
8
+
9
+
10
+ @functools.lru_cache
11
+ def node_color_mapping(graph: nx.Graph, cmap: Union[str, mpl.colors.Colormap] = "tab20") -> Mapping:
12
+ """Map node types to RGBA colors based on a colormap.
13
+ Args:
14
+ graph: nx.Graph - The input graph for which node colors need to be mapped.
15
+ cmap: Union[str, mpl.colors.Colormap], optional - The colormap used to map values to RGBA colors.
16
+ Default is "tab20".
17
+ Returns:
18
+ Mapping - A dictionary mapping nodes to their corresponding RGBA colors based on the colormap.
19
+
20
+ .. note::
21
+ The graph should have a 'node_types' attribute containing the types of nodes.
22
+ The colormap can be specified as a string or a matplotlib colormap object.
23
+ """
24
+ if not graph.nodes:
25
+ return {}
26
+
27
+ node_type_keys = graph.graph.get('node_types', {}).keys()
28
+
29
+ if len(node_type_keys) > 1 and 'node' in node_type_keys:
30
+ # Create a new list of keys, preserving order, but excluding 'node'
31
+ final_keys = [k for k in node_type_keys if k != 'node']
32
+ else:
33
+ final_keys = list(node_type_keys)
34
+
35
+ type_lookup = {t: i for i, t in enumerate(final_keys)}
36
+
37
+ color_values_ndarray = np.fromiter(
38
+ (type_lookup.get(graph.nodes[node].get('type'), 0) for node in graph.nodes),
39
+ dtype=int,
40
+ count=len(graph),
41
+ )
42
+ if len(color_values_ndarray) > 1:
43
+ low, high = color_values_ndarray.min(), color_values_ndarray.max()
44
+ else:
45
+ low = high = 0
46
+
47
+ norm = mpl.colors.Normalize(vmin=low, vmax=high, clip=True)
48
+ mapper = mpl.cm.ScalarMappable(norm=norm, cmap=cmap)
49
+ colors = mapper.to_rgba(color_values_ndarray).tolist()
50
+
51
+ color_mapping = dict(zip(graph.nodes, colors))
52
+ return color_mapping
53
+
54
+
55
+ def color_hex(color: Union[str, Sequence[Union[float, int]]]) -> Union[str, Sequence[Union[float, int]]]:
56
+ """Get HEX color code
57
+
58
+ Args:
59
+ color: input color
60
+ Returns:
61
+ Color HEX code
62
+
63
+ .. note::
64
+ If the input is a tuple or list, it should contain either three floats (0-1) or three ints (0-255).
65
+ The function will convert these to a HEX color code.
66
+ """
67
+ if isinstance(color, (tuple, list)):
68
+ rgb = color[:3]
69
+
70
+ if all(isinstance(c, float) and 0 <= c <= 1 for c in rgb):
71
+ rgb = tuple(int(c * 255) for c in rgb)
72
+ elif all(isinstance(c, int) and 0 <= c <= 255 for c in rgb):
73
+ rgb = tuple(rgb)
74
+ else:
75
+ msg = "Input values should either be a float between 0 and 1 or an int between 0 and 255"
76
+ raise ValueError(msg)
77
+
78
+ r, g, b = rgb
79
+ return f'#{r:02x}{g:02x}{b:02x}'
80
+
81
+ else:
82
+ return color
83
+
84
+
85
+ def convert_colors_to_hex(graph: nx.Graph, color: str = 'color') -> None:
86
+ """Convert all color labels in the graph to hexadecimal format.
87
+
88
+ Args:
89
+ graph (nx.Graph): The input graph with node attributes.
90
+ color (str): The attribute name for the color. Default is 'color'.
91
+
92
+ Returns:
93
+ None: The function modifies the graph in place.
94
+
95
+ .. note::
96
+ This function assumes that the color attribute is present in the node data.
97
+ """
98
+
99
+ color_values = {node: color_hex(data[color]) for node, data in graph.nodes(data=True) if color in data}
100
+ nx.set_node_attributes(graph, values=color_values, name=color)
@@ -0,0 +1,4 @@
1
+ DEFAULT_NODE_DELIMITER = ' ∋ '
2
+ DEFAULT_EDGE_DELIMITER = ' ⟹ '
3
+
4
+ __all__ = ['DEFAULT_EDGE_DELIMITER', 'DEFAULT_NODE_DELIMITER']
@@ -0,0 +1,94 @@
1
+ import ast
2
+ import base64
3
+ import decimal
4
+ import math
5
+ from types import MappingProxyType
6
+ from typing import Any, Union
7
+
8
+ import strawberry
9
+
10
+ from .constants import DEFAULT_EDGE_DELIMITER, DEFAULT_NODE_DELIMITER
11
+
12
+ __all__ = [
13
+ 'InfNumber',
14
+ 'decode_edge_id',
15
+ 'decode_id',
16
+ 'edge_label_converter',
17
+ 'encode_edge_id',
18
+ 'encode_id',
19
+ 'infnum_to_value',
20
+ 'label_converter',
21
+ 'node_label_converter',
22
+ 'value_to_infnum',
23
+ ]
24
+
25
+ InfNumber = Union[float, int, decimal.Decimal]
26
+
27
+ INFINITY_MAPPING: MappingProxyType[str, InfNumber] = MappingProxyType({
28
+ 'Infinity': math.inf,
29
+ '+Infinity': math.inf,
30
+ '-Infinity': -math.inf
31
+ })
32
+
33
+ MATH_INF_MAPPING: MappingProxyType[InfNumber, str] = MappingProxyType({
34
+ math.inf: 'Infinity',
35
+ -math.inf: '-Infinity'
36
+ })
37
+
38
+
39
+ def value_to_infnum(value: str | InfNumber) -> InfNumber:
40
+ return INFINITY_MAPPING.get(value, value)
41
+
42
+
43
+ def infnum_to_value(value: InfNumber) -> InfNumber | str:
44
+ return MATH_INF_MAPPING.get(value, value)
45
+
46
+
47
+ def label_converter(value: Any, delimiter: str) -> str | None:
48
+ if value is not None:
49
+ return delimiter.join(str(v) for v in value) if isinstance(value, tuple) else str(value)
50
+ return value
51
+
52
+
53
+ def node_label_converter(value: Any) -> str | None:
54
+ return label_converter(value, delimiter=DEFAULT_NODE_DELIMITER)
55
+
56
+
57
+ def edge_label_converter(value: Any) -> str | None:
58
+ return label_converter(tuple(node_label_converter(n) for n in value), delimiter=DEFAULT_EDGE_DELIMITER)
59
+
60
+
61
+ def encode(value: Any, encoding: str = 'utf-8') -> str:
62
+ obj_s: str = repr(value)
63
+ obj_b: bytes = obj_s.encode(encoding)
64
+ enc_b: bytes = base64.urlsafe_b64encode(obj_b)
65
+ enc_s: str = enc_b.decode(encoding)
66
+ return enc_s
67
+
68
+
69
+ def decode(value: str, encoding: str = 'utf-8') -> Any:
70
+ enc_b: bytes = value.encode(encoding)
71
+ obj_b: bytes = base64.urlsafe_b64decode(enc_b)
72
+ obj_s: str = obj_b.decode(encoding)
73
+ obj: Any = ast.literal_eval(obj_s)
74
+ return obj
75
+
76
+
77
+ def encode_id(graph_node_id: tuple,
78
+ encoding: str = 'utf-8') -> str:
79
+ return encode(graph_node_id, encoding)
80
+
81
+
82
+ def decode_id(graphql_node_id: strawberry.ID,
83
+ encoding: str = 'utf-8') -> tuple[str, ...]:
84
+ return decode(graphql_node_id, encoding)
85
+
86
+
87
+ def encode_edge_id(edge: tuple, encoding: str = 'utf-8') -> str:
88
+ encoded_edge = tuple(encode_id(n, encoding) for n in edge)
89
+ return encode_id(encoded_edge, encoding)
90
+
91
+
92
+ def decode_edge_id(graphql_edge_id: strawberry.ID, encoding: str = 'utf-8') -> tuple:
93
+ encoded_edge: tuple = decode_id(graphql_edge_id, encoding)
94
+ return tuple(decode_id(enc_node) for enc_node in encoded_edge)
graphinate/enums.py ADDED
@@ -0,0 +1,44 @@
1
+ from enum import Enum
2
+
3
+ import networkx as nx
4
+ from typing_extensions import Self
5
+
6
+
7
+ class GraphType(Enum):
8
+ """Graph Types
9
+
10
+ The choice of graph class depends on the structure of the graph you want to represent.
11
+
12
+ | **Graph Type** | **Type** | **Self-loops allowed** | **Parallel edges allowed** |
13
+ |----------------|------------|:----------------------:|:--------------------------:|
14
+ | Graph | Undirected | Yes | No |
15
+ | DiGraph | Directed | Yes | No |
16
+ | MultiGraph | Undirected | Yes | Yes |
17
+ | MultiDiGraph | Directed | Yes | Yes |
18
+
19
+ See more here: [NetworkX Reference](https://networkx.org/documentation/stable/reference/classes)
20
+ """
21
+
22
+ Graph = nx.Graph
23
+ DiGraph = nx.DiGraph
24
+ MultiDiGraph = nx.MultiDiGraph
25
+ MultiGraph = nx.MultiGraph
26
+
27
+ @classmethod
28
+ def of(cls, graph: nx.Graph) -> Self:
29
+ """Determine the graph type based on structure and properties.
30
+
31
+ Args:
32
+ graph (nx.Graph): A NetworkX graph object.
33
+
34
+ Returns:
35
+ GraphType: An instance of this Enum matching the input graph.
36
+ """
37
+ if graph.is_directed() and graph.is_multigraph():
38
+ return cls.MultiDiGraph
39
+ elif graph.is_directed():
40
+ return cls.DiGraph
41
+ elif graph.is_multigraph():
42
+ return cls.MultiGraph
43
+ else:
44
+ return cls.Graph