Graphinate 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,521 @@
1
+ import functools
2
+ import importlib
3
+ import inspect
4
+ import json
5
+ import math
6
+ import operator
7
+ from collections.abc import Callable
8
+ from datetime import datetime
9
+ from enum import Enum, EnumType
10
+ from typing import Any, Optional
11
+
12
+ import inflect
13
+ import networkx as nx
14
+ import strawberry
15
+ from strawberry.extensions import ParserCache, QueryDepthLimiter, ValidationCache
16
+ from strawberry.types.base import StrawberryType
17
+
18
+ from .. import color, converters
19
+ from ..converters import (
20
+ decode_edge_id,
21
+ decode_id,
22
+ edge_label_converter,
23
+ encode_edge_id,
24
+ encode_id,
25
+ node_label_converter,
26
+ )
27
+ from ..enums import GraphType
28
+ from ..modeling import GraphModel
29
+ from ._networkx import NetworkxBuilder
30
+
31
+
32
+ class GraphQLBuilder(NetworkxBuilder):
33
+ """Builds a GraphQL Schema"""
34
+
35
+ # region - Strawberry Types
36
+
37
+ InfNumber = strawberry.scalar(
38
+ converters.InfNumber,
39
+ description='Integer, Decimal or Float including Infinity and -Infinity',
40
+ serialize=converters.infnum_to_value,
41
+ parse_value=converters.value_to_infnum,
42
+ )
43
+
44
+ @strawberry.type
45
+ class Measure:
46
+ name: str
47
+ value: 'GraphQLBuilder.InfNumber'
48
+
49
+ @strawberry.interface
50
+ class GraphElement:
51
+ id: strawberry.ID
52
+ type: str
53
+ label: str
54
+ value: list[strawberry.scalars.JSON] | None
55
+ color: str | None = None
56
+ created: datetime | None
57
+ updated: datetime | None
58
+
59
+ @strawberry.enum
60
+ class GraphNodeType(Enum):
61
+ ... # pragma: no cover
62
+
63
+ @strawberry.interface(description="Represents a Graph Node")
64
+ class GraphNode(GraphElement):
65
+ node_id: strawberry.ID
66
+ magnitude: int
67
+ lineage: str
68
+
69
+ @strawberry.field()
70
+ def neighbors(self,
71
+ type: 'GraphQLBuilder.GraphNodeType | None' = None,
72
+ children: bool = False) -> list[Optional['GraphQLBuilder.GraphNode']]:
73
+ ... # pragma: no cover
74
+
75
+ @strawberry.field()
76
+ def edges(self) -> list[Optional['GraphQLBuilder.GraphEdge']]:
77
+ ... # pragma: no cover
78
+
79
+ @strawberry.type(description="Represents a Graph Edge")
80
+ class GraphEdge(GraphElement):
81
+ source: 'GraphQLBuilder.GraphNode'
82
+ target: 'GraphQLBuilder.GraphNode'
83
+ weight: float
84
+
85
+ @strawberry.type
86
+ class Graph:
87
+ nx_graph: strawberry.Private[nx.Graph]
88
+
89
+ @strawberry.field()
90
+ def radius(self) -> 'GraphQLBuilder.InfNumber':
91
+ return nx.radius(self.nx_graph) if nx.is_connected(self.nx_graph) else math.inf
92
+
93
+ @strawberry.field()
94
+ def diameter(self) -> 'GraphQLBuilder.InfNumber':
95
+ return nx.diameter(self.nx_graph) if nx.is_connected(self.nx_graph) else math.inf
96
+
97
+ @strawberry.field()
98
+ def name(self) -> str:
99
+ return self.nx_graph.graph['name']
100
+
101
+ @strawberry.field()
102
+ def node_type_counts(self) -> list['GraphQLBuilder.Measure']:
103
+ return [GraphQLBuilder.Measure(name=t, value=c) for t, c in self.nx_graph.graph['node_types'].items()]
104
+
105
+ @strawberry.field()
106
+ def edge_type_counts(self) -> list['GraphQLBuilder.Measure']:
107
+ return [GraphQLBuilder.Measure(name=t, value=c) for t, c in self.nx_graph.graph['edge_types'].items()]
108
+
109
+ @strawberry.field()
110
+ def node_count(self) -> int:
111
+ return self.nx_graph.number_of_nodes()
112
+
113
+ @strawberry.field()
114
+ def edge_count(self) -> int:
115
+ return self.nx_graph.number_of_edges()
116
+
117
+ @strawberry.field()
118
+ def order(self) -> int:
119
+ return self.nx_graph.order()
120
+
121
+ @strawberry.field()
122
+ def size(self) -> int:
123
+ return self.nx_graph.size(weight='weight')
124
+
125
+ # @strawberry.field()
126
+ # def girth(self) -> int:
127
+ # return min(len(cycle) for cycle in nx.simple_cycles(self.graph))
128
+
129
+ @strawberry.field()
130
+ def average_degree(self) -> float:
131
+ return self.nx_graph.number_of_nodes() and (
132
+ 1.0 * sum(d for _, d in self.nx_graph.degree()) / self.nx_graph.number_of_nodes())
133
+
134
+ @strawberry.field()
135
+ def hash(self) -> str:
136
+ return nx.weisfeiler_lehman_graph_hash(self.nx_graph)
137
+
138
+ @strawberry.field()
139
+ def created(self) -> datetime:
140
+ return self.nx_graph.graph['created']
141
+
142
+ @strawberry.enum(description="""
143
+ See NetworkX documentation for explanations:
144
+ https://networkx.org/documentation/stable/reference/index.html
145
+ """)
146
+ class GraphMeasure(Enum):
147
+ is_empty = 'is_empty'
148
+ is_directed = 'is_directed'
149
+ is_weighted = 'is_weighted'
150
+ is_negatively_weighted = 'is_negatively_weighted'
151
+ is_planar = 'is_planar'
152
+ is_regular = 'is_regular'
153
+ is_bipartite = 'is_bipartite'
154
+ is_chordal = 'is_chordal'
155
+ is_eulerian = 'is_eulerian'
156
+ is_semieulerian = 'is_semieulerian'
157
+ has_eulerian_path = 'has_eulerian_path'
158
+ has_bridges = 'has_bridges'
159
+ is_asteroidal_triple_free = 'is_at_free'
160
+ is_directed_acyclic_graph = 'is_directed_acyclic_graph'
161
+ is_aperiodic = 'is_aperiodic'
162
+ is_distance_regular = 'is_distance_regular'
163
+ is_strongly_regular = 'is_strongly_regular'
164
+ is_threshold_graph = ('networkx.algorithms.threshold', 'is_threshold_graph')
165
+ is_connected = 'is_connected'
166
+ is_biconnected = 'is_biconnected'
167
+ is_strongly_connected = 'is_strongly_connected'
168
+ is_weakly_connected = 'is_weakly_connected'
169
+ is_semiconnected = 'is_semiconnected'
170
+ is_attracting_component = 'is_attracting_component'
171
+ is_tournament = ('networkx.algorithms.tournament', 'is_tournament')
172
+ is_tree = 'is_tree'
173
+ is_forest = 'is_forest'
174
+ is_arborescence = 'is_arborescence'
175
+ is_branching = 'is_branching'
176
+ is_triad = 'is_triad'
177
+ radius = 'radius'
178
+ diameter = 'diameter'
179
+ density = 'density'
180
+ number_of_isolates = 'number_of_isolates'
181
+ number_connected_components = 'number_connected_components'
182
+ number_strongly_connected_components = 'number_strongly_connected_components'
183
+ number_weakly_connected_components = ' number_weakly_connected_components'
184
+ number_attracting_components = 'number_attracting_components'
185
+ node_connectivity = 'node_connectivity'
186
+ transitivity = 'transitivity'
187
+ average_clustering = 'average_clustering'
188
+ chordal_graph_treewidth = 'chordal_graph_treewidth'
189
+ degree_assortativity_coefficient = 'degree_assortativity_coefficient'
190
+ degree_pearson_correlation_coefficient = 'degree_pearson_correlation_coefficient'
191
+ local_efficiency = 'local_efficiency'
192
+ global_efficiency = 'global_efficiency'
193
+ flow_hierarchy = 'flow_hierarchy'
194
+ average_shortest_path_length = 'average_shortest_path_length'
195
+ overall_reciprocity = 'overall_reciprocity'
196
+ wiener_index = 'wiener_index'
197
+
198
+ # endregion - Strawberry Types
199
+
200
+ def __init__(self, model: GraphModel, graph_type: GraphType = GraphType.Graph):
201
+ super().__init__(model, graph_type)
202
+ self._node_value_graphql_type_supplier: Callable[[str], StrawberryType | None] | None = None
203
+
204
+ @staticmethod
205
+ def add_field_resolver(class_dict: dict, field_name: str, resolver: Callable, graphql_type: Any | None = None):
206
+ class_dict[field_name] = strawberry.field(resolver=resolver, graphql_type=graphql_type)
207
+ class_dict['__annotations__'][field_name] = inspect.getfullargspec(resolver).annotations['return']
208
+
209
+ @staticmethod
210
+ def _graph_node(node_class: type['GraphQLBuilder.GraphNode'],
211
+ node: tuple,
212
+ node_data: dict) -> 'GraphQLBuilder.GraphNode':
213
+ kwargs = {
214
+ 'id': encode_id(node),
215
+ 'node_id': str(node),
216
+ 'type': node_data['type'],
217
+ 'label': node_data.get('label', node_label_converter(node)),
218
+ 'value': node_data['value'],
219
+ 'magnitude': node_data.get('magnitude', 1),
220
+ 'lineage': str(node_data['lineage']),
221
+ 'color': color.color_hex(node_data['color']),
222
+ 'created': node_data.get('created'),
223
+ 'updated': node_data.get('updated')
224
+ }
225
+
226
+ return node_class(**kwargs)
227
+
228
+ def _graph_edge(self, edge: tuple, edge_data: dict):
229
+ graphql_types = self._graphql_types
230
+ nodes_with_data = ((n, self._graph.nodes[n]) for n in edge)
231
+ nodes_args = ((graphql_types.get(d.get('type'), ), n, d) for n, d in nodes_with_data)
232
+ source, target = tuple(self._graph_node(*args) for args in nodes_args)
233
+
234
+ return GraphQLBuilder.GraphEdge(
235
+ id=encode_edge_id(edge),
236
+ source=source,
237
+ target=target,
238
+ type=edge_data.get('type', ''),
239
+ label=edge_data.get('label', edge_label_converter(edge)),
240
+ value=[json.dumps(v, default=str) for v in edge_data['value']],
241
+ weight=edge_data.get('weight', 1.0),
242
+ color=color.color_hex(edge_data.get('color')),
243
+ created=edge_data.get('created'),
244
+ updated=edge_data.get('updated')
245
+ )
246
+
247
+ @staticmethod
248
+ def _graphql_type(name: str, type_class: type['GraphQLBuilder.GraphNode']) -> type['GraphQLBuilder.GraphNode']:
249
+ capitalized_name = name.capitalize()
250
+ return strawberry.type(
251
+ type_class,
252
+ name=f"{capitalized_name}{'' if name.lower().endswith('node') else 'Node'}",
253
+ description=f"Represents a {capitalized_name} Graph Node"
254
+ )
255
+
256
+ @staticmethod
257
+ def _graphql_enum(name: str, values: list[str]) -> EnumType:
258
+ return strawberry.enum(
259
+ Enum(name, {v: v for v in values}),
260
+ name=name,
261
+ description=f"{name} Enumeration"
262
+ )
263
+
264
+ @classmethod
265
+ @functools.lru_cache
266
+ def _children_types(cls, model: GraphModel, node_type: str):
267
+ return model.node_children_types(node_type).get(node_type, [])
268
+
269
+ def _populate_graph_node_type_enum(self, node_types: list[str]):
270
+ from strawberry.types.enum import EnumValue
271
+
272
+ for v in node_types:
273
+ self.GraphNodeType._member_names_.append(v)
274
+ self.GraphNodeType._member_map_[v] = v
275
+ self.GraphNodeType._value2member_map_[v] = v
276
+
277
+ self.GraphNodeType.__strawberry_definition__.values.append(
278
+ EnumValue(
279
+ name=v,
280
+ value=v,
281
+ description=f"Graph Node Type: {v}"
282
+ )
283
+ )
284
+
285
+ @property
286
+ @functools.lru_cache
287
+ def _graphql_types(self) -> dict[str, type['GraphQLBuilder.GraphNode']]:
288
+ node_types = list(self._graph.graph['node_types'].keys())
289
+
290
+ self._populate_graph_node_type_enum(node_types)
291
+
292
+ def neighbors_resolver():
293
+ graph = self._graph
294
+
295
+ children_types = set(self._children_types(self.model, node_type))
296
+
297
+ def node_neighbors(self,
298
+ type: 'GraphQLBuilder.GraphNodeType | None' = None,
299
+ children: bool = False) -> list['GraphQLBuilder.GraphNode']:
300
+ node = decode_id(self.id)
301
+ items = (GraphQLBuilder._graph_node(graphql_types[d['type']], n, d)
302
+ for n, d in graph.nodes(data=True)
303
+ if n in graph.neighbors(node))
304
+
305
+ if type is not None:
306
+ items = (item for item in items if item.type == type)
307
+
308
+ if children and children_types:
309
+ items = (item for item in items if item.type in children_types)
310
+
311
+ items = list(items)
312
+ return items
313
+
314
+ return node_neighbors
315
+
316
+ def edges_resolver():
317
+ graph: nx.Graph = self._graph
318
+ graph_edge = self._graph_edge
319
+
320
+ def node_edges(self) -> list[GraphQLBuilder.GraphEdge | None]:
321
+ node = decode_id(self.id)
322
+ return [graph_edge((source, target), data) for source, target, data in graph.edges(node, data=True)]
323
+
324
+ return node_edges
325
+
326
+ # Create classes for nodes according to their type
327
+ graphql_types: dict[str, type[GraphQLBuilder.GraphNode]] = {}
328
+ for node_type in node_types:
329
+ class_name = node_type.capitalize()
330
+ bases = (GraphQLBuilder.GraphNode,)
331
+ class_dict = {
332
+ '__doc__': f"A {class_name} Graph Node",
333
+ '__annotations__': {}
334
+ }
335
+
336
+ if (
337
+ self._node_value_graphql_type_supplier is not None
338
+ and (value_graphql_type := self._node_value_graphql_type_supplier(node_type) is not None)
339
+ ):
340
+ class_dict['value'] = list[value_graphql_type]
341
+
342
+ self.add_field_resolver(class_dict, 'neighbors', neighbors_resolver())
343
+ self.add_field_resolver(class_dict, 'edges', edges_resolver())
344
+
345
+ # noinspection PyTypeChecker
346
+ graphql_type: type[GraphQLBuilder.GraphNode] = type(class_name, bases, class_dict)
347
+ graphql_types[node_type] = GraphQLBuilder._graphql_type(node_type, graphql_type)
348
+
349
+ return graphql_types
350
+
351
+ def _graphql_query(self): # noqa: C901
352
+ # inflect engine to generate Plurals when needed
353
+ inflection = inflect.engine()
354
+
355
+ # local reference to instance fields used to "inject" into dynamically generated class methods
356
+ def get_graph():
357
+ return self._graph
358
+
359
+ graphql_types = self._graphql_types
360
+
361
+ # region - Defining GraphQL Query Class dict
362
+ query_class_dict = {'__annotations__': {}}
363
+
364
+ # region - Defining GraphQL Query Class dict - graph field
365
+ def graphql_graph(self) -> GraphQLBuilder.Graph:
366
+ return GraphQLBuilder.Graph(nx_graph=get_graph())
367
+
368
+ self.add_field_resolver(query_class_dict, 'graph', graphql_graph)
369
+
370
+ # endregion
371
+
372
+ # region - Defining GraphQL Query Class dict - nodes field
373
+ def graph_nodes_resolver(
374
+ graphql_type: type[GraphQLBuilder.GraphNode] | None = None,
375
+ node_type: str | None = None
376
+ ) -> Callable[[strawberry.ID | None], list[GraphQLBuilder.GraphNode]]:
377
+
378
+ def graph_nodes(self,
379
+ node_id: strawberry.ID | None = strawberry.UNSET) -> list[GraphQLBuilder.GraphNode]:
380
+
381
+ decoded_node_id = node_id and decode_id(node_id)
382
+
383
+ graph = get_graph()
384
+
385
+ if graphql_type:
386
+ nodes = (GraphQLBuilder._graph_node(graphql_type, n, d)
387
+ for n, d in graph.nodes(data=True))
388
+ else:
389
+ nodes = (GraphQLBuilder._graph_node(graphql_types.get(d['type']), n, d)
390
+ for n, d in graph.nodes(data=True))
391
+
392
+ def filter_node(node):
393
+ output = True
394
+ if node_type:
395
+ output = node.type.lower() == node_type
396
+
397
+ if decoded_node_id:
398
+ output = output and (decode_id(node.id) == decoded_node_id)
399
+
400
+ return output
401
+
402
+ items = [node for node in nodes if filter_node(node)]
403
+
404
+ return items
405
+
406
+ return graph_nodes
407
+
408
+ self.add_field_resolver(query_class_dict, 'nodes', graph_nodes_resolver())
409
+
410
+ # endregion
411
+
412
+ # region - Defining GraphQL Query Class dict - edges field
413
+ def graph_edges_resolver() -> Callable[[strawberry.ID | None], list[GraphQLBuilder.GraphEdge]]:
414
+
415
+ graph_edge = self._graph_edge
416
+
417
+ def graph_edges(self,
418
+ edge_id: strawberry.ID | None = strawberry.UNSET) -> list[GraphQLBuilder.GraphEdge]:
419
+ decoded_edge_id = edge_id and decode_edge_id(edge_id)
420
+
421
+ graph = get_graph()
422
+
423
+ edges = (graph_edge((source, target), data) for source, target, data in graph.edges(data=True))
424
+
425
+ def filter_edge(edge):
426
+ output = True
427
+ if decoded_edge_id:
428
+ output = decode_edge_id(edge.id) == decoded_edge_id
429
+
430
+ return output
431
+
432
+ return [edge for edge in edges if filter_edge(edge)]
433
+
434
+ return graph_edges
435
+
436
+ self.add_field_resolver(query_class_dict, 'edges', graph_edges_resolver())
437
+ # endregion
438
+
439
+ # region - Defining GraphQL Query Class dict - fields for GraphQL types implementing 'GraphNode' interface
440
+ for node_type, graphql_type in self._graphql_types.items():
441
+ field_name = inflection.plural(node_type)
442
+ resolver = graph_nodes_resolver(graphql_type, node_type)
443
+ self.add_field_resolver(query_class_dict, field_name, resolver)
444
+
445
+ # endregion
446
+
447
+ # region - Defining GraphQL Query Class dict - field measure for 'GraphMeasure' GraphQL type
448
+
449
+ def graph_measure(self, measure: GraphQLBuilder.GraphMeasure) -> GraphQLBuilder.Measure:
450
+
451
+ graph = get_graph()
452
+
453
+ if isinstance(measure.value, str):
454
+ method = measure.value
455
+ module = nx
456
+ else:
457
+ method = measure.value[1]
458
+ module = importlib.import_module(measure.value[0])
459
+
460
+ value_getter = operator.attrgetter(method)(module)
461
+ value = float(value_getter(graph))
462
+ return GraphQLBuilder.Measure(name=measure.name, value=value)
463
+
464
+ # query_class_dict['measure'] = strawberry.field(resolver=graph_measure)
465
+ # query_class_dict['__annotations__']['measure'] = float
466
+ self.add_field_resolver(query_class_dict, 'measure', graph_measure)
467
+ # endregion
468
+
469
+ # region - Defining GraphQL Query Class dict - create Query Class and Query Type
470
+ query_class = type('Query', (), query_class_dict)
471
+ query_graphql_type = strawberry.type(query_class, name='Query')
472
+ # endregion
473
+
474
+ # endregion - Defining GraphQL Query Class dict
475
+
476
+ return query_graphql_type
477
+
478
+ def _graphql_mutation(self):
479
+
480
+ refresh_graph = functools.partial(super().build, **self._cached_build_kwargs)
481
+
482
+ @strawberry.type
483
+ class Mutation:
484
+
485
+ @strawberry.mutation
486
+ def refresh(self) -> bool:
487
+ refresh_graph()
488
+ return True
489
+
490
+ return Mutation
491
+
492
+ def schema(self) -> strawberry.Schema:
493
+ # define and return Schema
494
+ return strawberry.Schema(
495
+ query=self._graphql_query(),
496
+ mutation=self._graphql_mutation(),
497
+ types=self._graphql_types.values(),
498
+ extensions=[
499
+ ParserCache(maxsize=100),
500
+ QueryDepthLimiter(max_depth=10),
501
+ ValidationCache(maxsize=100)
502
+ ]
503
+ )
504
+
505
+ def build(self,
506
+ node_value_graphql_type_supplier: Callable[[str], StrawberryType | None] | None = None,
507
+ **kwargs: Any) -> strawberry.Schema:
508
+ """
509
+
510
+ Args:
511
+ node_value_graphql_type_supplier: Callable[[str], StrawberryType]]
512
+ **kwargs:
513
+
514
+ Returns:
515
+ Strawberry GraphQL Schema
516
+ """
517
+ super().build(**kwargs)
518
+
519
+ self._node_value_graphql_type_supplier = node_value_graphql_type_supplier
520
+
521
+ return self.schema()
@@ -0,0 +1,45 @@
1
+ from typing import Any
2
+
3
+ import networkx_mermaid as nxm
4
+
5
+ from .. import color
6
+ from ..enums import GraphType
7
+ from ..modeling import GraphModel
8
+ from ._networkx import NetworkxBuilder
9
+
10
+
11
+ class MermaidBuilder(NetworkxBuilder):
12
+ """Build a Mermaid Graph"""
13
+
14
+ def __init__(self, model: GraphModel, graph_type: GraphType = GraphType.Graph):
15
+ super().__init__(model, graph_type)
16
+
17
+ def build(self,
18
+ orientation: nxm.DiagramOrientation = nxm.DiagramOrientation.LEFT_RIGHT,
19
+ node_shape: nxm.DiagramNodeShape = nxm.DiagramNodeShape.DEFAULT,
20
+ title: str | None = None,
21
+ with_edge_labels: bool = False,
22
+ **kwargs: Any) -> nxm.typing.MermaidDiagram:
23
+ """
24
+ Build a Mermaid Graph
25
+
26
+ Args:
27
+ orientation : Orientation, optional
28
+ The orientation of the graph, by default Orientation.LEFT_RIGHT.
29
+ node_shape : NodeShape, optional
30
+ The shape of the nodes, by default NodeShape.DEFAULT.
31
+ title: str, optional
32
+ The title of the graph (default: None).
33
+ If None, the graph name will be used if available.
34
+ Supplying and empty string will remove the title.
35
+ with_edge_labels:
36
+ Whether to include edge labels, by default False.
37
+ **kwargs: additional inputs to the node and edge generator functions
38
+
39
+ Returns:
40
+ Mermaid Graph
41
+ """
42
+ super().build(**kwargs)
43
+ color.convert_colors_to_hex(self._graph)
44
+ nxm_builder = nxm.DiagramBuilder(orientation=orientation, node_shape=node_shape)
45
+ return nxm_builder.build(self._graph, title=title, with_edge_labels=with_edge_labels)