neo4j-viz 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
neo4j_viz/gds.py CHANGED
@@ -8,16 +8,26 @@ from uuid import uuid4
8
8
  import pandas as pd
9
9
  from graphdatascience import Graph, GraphDataScience
10
10
 
11
+ from neo4j_viz.colors import NEO4J_COLORS_DISCRETE, ColorSpace
12
+
11
13
  from .pandas import _from_dfs
12
14
  from .visualization_graph import VisualizationGraph
13
15
 
14
16
 
15
17
  def _fetch_node_dfs(
16
- gds: GraphDataScience, G: Graph, node_properties_by_label: dict[str, list[str]], node_labels: list[str]
18
+ gds: GraphDataScience,
19
+ G: Graph,
20
+ node_properties_by_label: dict[str, list[str]],
21
+ node_labels: list[str],
22
+ additional_db_node_properties: list[str],
17
23
  ) -> dict[str, pd.DataFrame]:
18
24
  return {
19
25
  lbl: gds.graph.nodeProperties.stream(
20
- G, node_properties=node_properties_by_label[lbl], node_labels=[lbl], separate_property_columns=True
26
+ G,
27
+ node_properties=node_properties_by_label[lbl],
28
+ node_labels=[lbl],
29
+ separate_property_columns=True,
30
+ db_node_properties=additional_db_node_properties,
21
31
  )
22
32
  for lbl in node_labels
23
33
  }
@@ -47,17 +57,20 @@ def _fetch_rel_dfs(gds: GraphDataScience, G: Graph) -> list[pd.DataFrame]:
47
57
  def from_gds(
48
58
  gds: GraphDataScience,
49
59
  G: Graph,
50
- size_property: Optional[str] = None,
51
- additional_node_properties: Optional[list[str]] = None,
52
- node_radius_min_max: Optional[tuple[float, float]] = (3, 60),
60
+ node_properties: Optional[list[str]] = None,
61
+ db_node_properties: Optional[list[str]] = None,
53
62
  max_node_count: int = 10_000,
54
63
  ) -> VisualizationGraph:
55
64
  """
56
65
  Create a VisualizationGraph from a GraphDataScience object and a Graph object.
57
66
 
58
- All `additional_node_properties` will be included in the visualization graph.
59
- If the properties are named as the fields of the `Node` class, they will be included as top level fields of the
60
- created `Node` objects. Otherwise, they will be included in the `properties` dictionary.
67
+ By default:
68
+
69
+ * the caption of a node will be based on its `labels`.
70
+ * the caption of a relationship will be based on its `relationshipType`.
71
+ * the color of nodes will be set based on their label, unless there are more than 12 unique labels.
72
+
73
+ All `node_properties` and `db_node_properties` will be included in the visualization graph under the `properties` field.
61
74
  Additionally, a new "labels" node property will be added, containing the node labels of the node.
62
75
  Similarly for relationships, a new "relationshipType" property will be added.
63
76
 
@@ -67,45 +80,36 @@ def from_gds(
67
80
  GraphDataScience object.
68
81
  G : Graph
69
82
  Graph object.
70
- size_property : str, optional
71
- Property to use for node size, by default None.
72
- additional_node_properties : list[str], optional
83
+ node_properties : list[str], optional
73
84
  Additional properties to include in the visualization node, by default None which means that all node
74
- properties will be fetched.
75
- node_radius_min_max : tuple[float, float], optional
76
- Minimum and maximum node radius, by default (3, 60).
77
- To avoid tiny or huge nodes in the visualization, the node sizes are scaled to fit in the given range.
85
+ properties from the Graph will be fetched.
86
+ db_node_properties : list[str], optional
87
+ Additional node properties to fetch from the database, by default None. Only works if the graph was projected from the database.
78
88
  max_node_count : int, optional
79
89
  The maximum number of nodes to fetch from the graph. The graph will be sampled using random walk with restarts
80
90
  if its node count exceeds this number.
81
91
  """
92
+ if db_node_properties is None:
93
+ db_node_properties = []
94
+
82
95
  node_properties_from_gds = G.node_properties()
83
96
  assert isinstance(node_properties_from_gds, pd.Series)
84
97
  actual_node_properties: dict[str, list[str]] = cast(dict[str, list[str]], node_properties_from_gds.to_dict())
85
98
  all_actual_node_properties = list(chain.from_iterable(actual_node_properties.values()))
86
99
 
87
- if size_property is not None:
88
- if size_property not in all_actual_node_properties:
89
- raise ValueError(f"There is no node property '{size_property}' in graph '{G.name()}'")
90
-
91
100
  node_properties_by_label_sets: dict[str, set[str]] = dict()
92
- if additional_node_properties is None:
101
+ if node_properties is None:
93
102
  node_properties_by_label_sets = {k: set(v) for k, v in actual_node_properties.items()}
94
103
  else:
95
- for prop in additional_node_properties:
104
+ for prop in node_properties:
96
105
  if prop not in all_actual_node_properties:
97
106
  raise ValueError(f"There is no node property '{prop}' in graph '{G.name()}'")
98
107
 
99
108
  for label, props in actual_node_properties.items():
100
109
  node_properties_by_label_sets[label] = {
101
- prop for prop in actual_node_properties[label] if prop in additional_node_properties
110
+ prop for prop in actual_node_properties[label] if prop in node_properties
102
111
  }
103
112
 
104
- if size_property is not None:
105
- # For some reason mypy are unable to understand that this is dict[str, set[str]]
106
- for label, props in node_properties_by_label_sets.items(): # type: ignore
107
- props.add(size_property) # type: ignore
108
-
109
113
  node_properties_by_label = {k: list(v) for k, v in node_properties_by_label_sets.items()}
110
114
 
111
115
  node_count = G.node_count()
@@ -129,7 +133,9 @@ def from_gds(
129
133
  for props in node_properties_by_label.values():
130
134
  props.append(property_name)
131
135
 
132
- node_dfs = _fetch_node_dfs(gds, G_fetched, node_properties_by_label, G_fetched.node_labels())
136
+ node_dfs = _fetch_node_dfs(
137
+ gds, G_fetched, node_properties_by_label, G_fetched.node_labels(), db_node_properties
138
+ )
133
139
  if property_name is not None:
134
140
  for df in node_dfs.values():
135
141
  df.drop(columns=[property_name], inplace=True)
@@ -145,14 +151,7 @@ def from_gds(
145
151
  if property_name is not None and property_name in df.columns:
146
152
  df.drop(columns=[property_name], inplace=True)
147
153
 
148
- node_props_df = pd.concat(node_dfs.values(), ignore_index=True, axis=0).drop_duplicates()
149
- if size_property is not None:
150
- if "size" in all_actual_node_properties and size_property != "size":
151
- node_props_df.rename(columns={"size": "__size"}, inplace=True)
152
- if additional_node_properties is not None and size_property not in additional_node_properties:
153
- node_props_df.rename(columns={size_property: "size"}, inplace=True)
154
- else:
155
- node_props_df["size"] = node_props_df[size_property]
154
+ node_props_df = pd.concat(node_dfs.values(), ignore_index=True, axis=0).drop_duplicates(subset=["nodeId"])
156
155
 
157
156
  for lbl, df in node_dfs.items():
158
157
  if "labels" in all_actual_node_properties:
@@ -164,22 +163,22 @@ def from_gds(
164
163
 
165
164
  node_df = node_props_df.merge(node_labels_df, on="nodeId")
166
165
 
167
- if "caption" not in all_actual_node_properties:
168
- node_df["caption"] = node_df["labels"].astype(str)
166
+ try:
167
+ VG = _from_dfs(node_df, rel_dfs, dropna=True)
169
168
 
170
- for rel_df in rel_dfs:
171
- if "caption" not in rel_df.columns:
172
- rel_df["caption"] = rel_df["relationshipType"]
169
+ for node in VG.nodes:
170
+ node.caption = ":".join([label for label in node.properties["labels"]])
171
+ for rel in VG.relationships:
172
+ rel.caption = rel.properties.get("relationshipType")
173
173
 
174
- try:
175
- return _from_dfs(
176
- node_df, rel_dfs, node_radius_min_max=node_radius_min_max, rename_properties={"__size": "size"}, dropna=True
177
- )
174
+ number_of_colors = node_df["labels"].drop_duplicates().count()
175
+ if number_of_colors <= len(NEO4J_COLORS_DISCRETE):
176
+ VG.color_nodes(property="labels", color_space=ColorSpace.DISCRETE)
177
+
178
+ return VG
178
179
  except ValueError as e:
179
180
  err_msg = str(e)
180
181
  if "column" in err_msg:
181
182
  err_msg = err_msg.replace("column", "property")
182
- if ("'size'" in err_msg) and (size_property is not None):
183
- err_msg = err_msg.replace("'size'", f"'{size_property}'")
184
183
  raise ValueError(err_msg)
185
184
  raise e
neo4j_viz/gql_create.py CHANGED
@@ -5,6 +5,7 @@ from typing import Any, Optional
5
5
  from pydantic import BaseModel, ValidationError
6
6
 
7
7
  from neo4j_viz import Node, Relationship, VisualizationGraph
8
+ from neo4j_viz.colors import NEO4J_COLORS_DISCRETE, ColorSpace
8
9
 
9
10
 
10
11
  def _parse_value(value_str: str) -> Any:
@@ -91,10 +92,7 @@ def _parse_value(value_str: str) -> Any:
91
92
  return value_str.strip("'\"")
92
93
 
93
94
 
94
- def _parse_prop_str(
95
- query: str, prop_str: str, prop_start: int, top_level_keys: set[str]
96
- ) -> tuple[dict[str, Any], dict[str, Any]]:
97
- top_level: dict[str, Any] = {}
95
+ def _parse_prop_str(query: str, prop_str: str, prop_start: int) -> dict[str, Any]:
98
96
  props: dict[str, Any] = {}
99
97
  depth = 0
100
98
  in_string = None
@@ -115,10 +113,7 @@ def _parse_prop_str(
115
113
  k, v = pair.split(":", 1)
116
114
  k = k.strip().strip("'\"")
117
115
 
118
- if k in top_level_keys:
119
- top_level[k] = _parse_value(v)
120
- else:
121
- props[k] = _parse_value(v)
116
+ props[k] = _parse_value(v)
122
117
 
123
118
  start_idx = i + 1
124
119
  else:
@@ -133,17 +128,12 @@ def _parse_prop_str(
133
128
  k, v = pair.split(":", 1)
134
129
  k = k.strip().strip("'\"")
135
130
 
136
- if k in top_level_keys:
137
- top_level[k] = _parse_value(v)
138
- else:
139
- props[k] = _parse_value(v)
131
+ props[k] = _parse_value(v)
140
132
 
141
- return top_level, props
133
+ return props
142
134
 
143
135
 
144
- def _parse_labels_and_props(
145
- query: str, s: str, top_level_keys: set[str]
146
- ) -> tuple[Optional[str], dict[str, Any], dict[str, Any]]:
136
+ def _parse_labels_and_props(query: str, s: str) -> tuple[Optional[str], dict[str, Any]]:
147
137
  prop_match = re.search(r"\{(.*)\}", s)
148
138
  prop_str = ""
149
139
  if prop_match:
@@ -155,9 +145,8 @@ def _parse_labels_and_props(
155
145
  final_alias = raw_alias if raw_alias else None
156
146
 
157
147
  if prop_str:
158
- top_level, props = _parse_prop_str(query, prop_str, prop_start, top_level_keys)
148
+ props = _parse_prop_str(query, prop_str, prop_start)
159
149
  else:
160
- top_level = {}
161
150
  props = {}
162
151
 
163
152
  label_list = [lbl.strip() for lbl in alias_labels[1:]]
@@ -165,7 +154,7 @@ def _parse_labels_and_props(
165
154
  props["__labels"] = props["labels"]
166
155
  props["labels"] = sorted(label_list)
167
156
 
168
- return final_alias, top_level, props
157
+ return final_alias, props
169
158
 
170
159
 
171
160
  def _get_snippet(q: str, idx: int, context: int = 15) -> str:
@@ -175,21 +164,20 @@ def _get_snippet(q: str, idx: int, context: int = 15) -> str:
175
164
  return q[start:end].replace("\n", " ")
176
165
 
177
166
 
178
- def from_gql_create(
179
- query: str,
180
- size_property: Optional[str] = None,
181
- node_caption: Optional[str] = "labels",
182
- relationship_caption: Optional[str] = "type",
183
- node_radius_min_max: Optional[tuple[float, float]] = (3, 60),
184
- ) -> VisualizationGraph:
167
+ def from_gql_create(query: str) -> VisualizationGraph:
185
168
  """
186
169
  Parse a GQL CREATE query and return a VisualizationGraph object representing the graph it creates.
187
170
 
188
171
  All node and relationship properties will be included in the visualization graph.
189
- If the properties are named as the fields of the `Node` or `Relationship` classes, they will be included as
190
- top level fields of the respective objects. Otherwise, they will be included in the `properties` dictionary.
172
+ All properties of nodes and relationships will be included in the `properties` dictionary of the respective objects.
191
173
  Additionally, a "labels" property will be added for nodes and a "type" property for relationships.
192
174
 
175
+ By default:
176
+
177
+ * the caption of a node will be based on its `labels`.
178
+ * the caption of a relationship will be based on its `type`.
179
+ * the color of nodes will be set based on their label, unless there are more than 12 unique labels.
180
+
193
181
  Please note that this function is not a full GQL parser, it only handles CREATE queries that do not contain
194
182
  other clauses like MATCH, WHERE, RETURN, etc, or any Cypher function calls.
195
183
  It also does not handle all possible GQL syntax, but it should work for most common cases.
@@ -199,15 +187,6 @@ def from_gql_create(
199
187
  ----------
200
188
  query : str
201
189
  The GQL CREATE query to parse
202
- size_property : str, optional
203
- Property to use for node size, by default None.
204
- node_caption : str, optional
205
- Property to use as the node caption, by default the node labels will be used.
206
- relationship_caption : str, optional
207
- Property to use as the relationship caption, by default the relationship type will be used.
208
- node_radius_min_max : tuple[float, float], optional
209
- Minimum and maximum node radius, by default (3, 60).
210
- To avoid tiny or huge nodes in the visualization, the node sizes are scaled to fit in the given range.
211
190
  """
212
191
 
213
192
  query = query.strip()
@@ -251,19 +230,9 @@ def from_gql_create(
251
230
  node_pattern = re.compile(r"^\(([^)]*)\)$")
252
231
  rel_pattern = re.compile(r"^\(([^)]*)\)-\s*\[\s*:(\w+)\s*(\{[^}]*\})?\s*\]->\(([^)]*)\)$")
253
232
 
254
- node_top_level_keys = Node.all_validation_aliases(exempted_fields=["id"])
255
- rel_top_level_keys = Relationship.all_validation_aliases(exempted_fields=["id", "source", "target"])
256
-
257
233
  def _parse_validation_error(e: ValidationError, entity_type: type[BaseModel]) -> None:
258
234
  for err in e.errors():
259
235
  loc = err["loc"][0]
260
- if (loc == "size") and size_property is not None:
261
- loc = size_property
262
- if loc == "caption":
263
- if (entity_type == Node) and (node_caption is not None):
264
- loc = node_caption
265
- elif (entity_type == Relationship) and (relationship_caption is not None):
266
- loc = relationship_caption
267
236
  raise ValueError(
268
237
  f"Error for {entity_type.__name__.lower()} property '{loc}' with provided input '{err['input']}'. Reason: {err['msg']}"
269
238
  )
@@ -277,14 +246,14 @@ def from_gql_create(
277
246
  node_m = node_pattern.match(part)
278
247
  if node_m:
279
248
  alias_labels_props = node_m.group(1).strip()
280
- alias, top_level, props = _parse_labels_and_props(query, alias_labels_props, node_top_level_keys)
249
+ alias, props = _parse_labels_and_props(query, alias_labels_props)
281
250
  if not alias:
282
251
  alias = f"_anon_{anonymous_count}"
283
252
  anonymous_count += 1
284
253
  if alias not in alias_to_id:
285
254
  alias_to_id[alias] = str(uuid.uuid4())
286
255
  try:
287
- nodes.append(Node(id=alias_to_id[alias], **top_level, properties=props))
256
+ nodes.append(Node(id=alias_to_id[alias], properties=props))
288
257
  except ValidationError as e:
289
258
  _parse_validation_error(e, Node)
290
259
 
@@ -296,14 +265,14 @@ def from_gql_create(
296
265
  right_node = rel_m.group(4).strip()
297
266
 
298
267
  # Parse left node pattern
299
- left_alias, left_top_level, left_props = _parse_labels_and_props(query, left_node, node_top_level_keys)
268
+ left_alias, left_props = _parse_labels_and_props(query, left_node)
300
269
  if not left_alias:
301
270
  left_alias = f"_anon_{anonymous_count}"
302
271
  anonymous_count += 1
303
272
  if left_alias not in alias_to_id:
304
273
  alias_to_id[left_alias] = str(uuid.uuid4())
305
274
  try:
306
- nodes.append(Node(id=alias_to_id[left_alias], **left_top_level, properties=left_props))
275
+ nodes.append(Node(id=alias_to_id[left_alias], properties=left_props))
307
276
  except ValidationError as e:
308
277
  _parse_validation_error(e, Node)
309
278
  elif left_alias not in alias_to_id:
@@ -311,14 +280,14 @@ def from_gql_create(
311
280
  raise ValueError(f"Relationship references unknown node alias: '{left_alias}' near: `{snippet}`.")
312
281
 
313
282
  # Parse right node pattern
314
- right_alias, right_top_level, right_props = _parse_labels_and_props(query, right_node, node_top_level_keys)
283
+ right_alias, right_props = _parse_labels_and_props(query, right_node)
315
284
  if not right_alias:
316
285
  right_alias = f"_anon_{anonymous_count}"
317
286
  anonymous_count += 1
318
287
  if right_alias not in alias_to_id:
319
288
  alias_to_id[right_alias] = str(uuid.uuid4())
320
289
  try:
321
- nodes.append(Node(id=alias_to_id[right_alias], **right_top_level, properties=right_props))
290
+ nodes.append(Node(id=alias_to_id[right_alias], properties=right_props))
322
291
  except ValidationError as e:
323
292
  _parse_validation_error(e, Node)
324
293
  elif right_alias not in alias_to_id:
@@ -331,9 +300,8 @@ def from_gql_create(
331
300
  if rel_props_str:
332
301
  inner_str = rel_props_str.strip("{}").strip()
333
302
  prop_start = query.index(inner_str, query.index(inner_str))
334
- top_level, props = _parse_prop_str(query, inner_str, prop_start, rel_top_level_keys)
303
+ props = _parse_prop_str(query, inner_str, prop_start)
335
304
  else:
336
- top_level = {}
337
305
  props = {}
338
306
  if "type" in props:
339
307
  props["__type"] = props["type"]
@@ -345,7 +313,6 @@ def from_gql_create(
345
313
  id=rel_id,
346
314
  source=alias_to_id[left_alias],
347
315
  target=alias_to_id[right_alias],
348
- **top_level,
349
316
  properties=props,
350
317
  )
351
318
  )
@@ -357,29 +324,15 @@ def from_gql_create(
357
324
  snippet = part[:30]
358
325
  raise ValueError(f"Invalid element in CREATE near: `{snippet}`.")
359
326
 
360
- if size_property is not None:
361
- for node in nodes:
362
- node.size = node.properties.get(size_property)
363
- if node_caption is not None:
364
- for node in nodes:
365
- if node_caption == "labels":
366
- if len(node.properties["labels"]) > 0:
367
- node.caption = ":".join([label for label in node.properties["labels"]])
368
- else:
369
- node.caption = str(node.properties.get(node_caption))
370
- if relationship_caption is not None:
371
- for rel in relationships:
372
- if relationship_caption == "type":
373
- rel.caption = rel.properties["type"]
374
- else:
375
- rel.caption = str(rel.properties.get(relationship_caption))
376
-
377
327
  VG = VisualizationGraph(nodes=nodes, relationships=relationships)
378
- if (node_radius_min_max is not None) and (size_property is not None):
379
- try:
380
- VG.resize_nodes(node_radius_min_max=node_radius_min_max)
381
- except TypeError:
382
- loc = "size" if size_property is None else size_property
383
- raise ValueError(f"Error for node property '{loc}'. Reason: must be a numerical value")
328
+
329
+ for node in VG.nodes:
330
+ node.caption = ":".join([label for label in node.properties["labels"]])
331
+ for rel in VG.relationships:
332
+ rel.caption = rel.properties.get("type")
333
+
334
+ number_of_colors = len({str(n.properties.get("labels")) for n in VG.nodes})
335
+ if number_of_colors <= len(NEO4J_COLORS_DISCRETE):
336
+ VG.color_nodes(property="labels", color_space=ColorSpace.DISCRETE)
384
337
 
385
338
  return VG
neo4j_viz/neo4j.py CHANGED
@@ -7,6 +7,7 @@ import neo4j.graph
7
7
  from neo4j import Driver, Result, RoutingControl
8
8
  from pydantic import BaseModel, ValidationError
9
9
 
10
+ from neo4j_viz.colors import NEO4J_COLORS_DISCRETE, ColorSpace
10
11
  from neo4j_viz.node import Node
11
12
  from neo4j_viz.relationship import Relationship
12
13
  from neo4j_viz.visualization_graph import VisualizationGraph
@@ -22,18 +23,18 @@ def _parse_validation_error(e: ValidationError, entity_type: type[BaseModel]) ->
22
23
 
23
24
  def from_neo4j(
24
25
  data: Union[neo4j.graph.Graph, Result, Driver],
25
- size_property: Optional[str] = None,
26
- node_caption: Optional[str] = "labels",
27
- relationship_caption: Optional[str] = "type",
28
- node_radius_min_max: Optional[tuple[float, float]] = (3, 60),
29
26
  row_limit: int = 10_000,
30
27
  ) -> VisualizationGraph:
31
28
  """
32
29
  Create a VisualizationGraph from a Neo4j `Graph`, Neo4j `Result` or Neo4j `Driver`.
33
30
 
34
- All node and relationship properties will be included in the visualization graph.
35
- If the properties are named as the fields of the `Node` or `Relationship` classes, they will be included as
36
- top level fields of the respective objects. Otherwise, they will be included in the `properties` dictionary.
31
+ By default:
32
+
33
+ * the caption of a node will be based on its `labels`.
34
+ * the caption of a relationship will be based on its `type`.
35
+ * the color of nodes will be set based on their label, unless there are more than 12 unique labels.
36
+
37
+ All node and relationship properties will be included in the visualization graph under the `properties` field.
37
38
  Additionally, a "labels" property will be added for nodes and a "type" property for relationships.
38
39
 
39
40
  Parameters
@@ -41,15 +42,6 @@ def from_neo4j(
41
42
  data : Union[neo4j.graph.Graph, neo4j.Result, neo4j.Driver]
42
43
  Either a query result in the shape of a `neo4j.graph.Graph` or `neo4j.Result`, or a `neo4j.Driver` in
43
44
  which case a simple default query will be executed internally to retrieve the graph data.
44
- size_property : str, optional
45
- Property to use for node size, by default None.
46
- node_caption : str, optional
47
- Property to use as the node caption, by default the node labels will be used.
48
- relationship_caption : str, optional
49
- Property to use as the relationship caption, by default the relationship type will be used.
50
- node_radius_min_max : tuple[float, float], optional
51
- Minimum and maximum node radius, by default (3, 60).
52
- To avoid tiny or huge nodes in the visualization, the node sizes are scaled to fit in the given range.
53
45
  row_limit : int, optional
54
46
  Maximum number of rows to return from the query, by default 10_000.
55
47
  This is only used if a `neo4j.Driver` is passed as `result` argument, otherwise the limit is ignored.
@@ -77,117 +69,62 @@ def from_neo4j(
77
69
  else:
78
70
  raise ValueError(f"Invalid input type `{type(data)}`. Expected `neo4j.Graph`, `neo4j.Result` or `neo4j.Driver`")
79
71
 
80
- all_node_field_aliases = Node.all_validation_aliases()
81
- all_rel_field_aliases = Relationship.all_validation_aliases()
82
-
83
- try:
84
- nodes = [
85
- _map_node(node, all_node_field_aliases, size_property, caption_property=node_caption)
86
- for node in graph.nodes
87
- ]
88
- except ValueError as e:
89
- err_msg = str(e)
90
- if ("'size'" in err_msg) and (size_property is not None):
91
- err_msg = err_msg.replace("'size'", f"'{size_property}'")
92
- elif ("'caption'" in err_msg) and (node_caption is not None):
93
- err_msg = err_msg.replace("'caption'", f"'{node_caption}'")
94
- raise ValueError(err_msg)
72
+ nodes = [_map_node(node) for node in graph.nodes]
95
73
 
96
74
  relationships = []
97
- try:
98
- for rel in graph.relationships:
99
- mapped_rel = _map_relationship(rel, all_rel_field_aliases, caption_property=relationship_caption)
100
- if mapped_rel:
101
- relationships.append(mapped_rel)
102
- except ValueError as e:
103
- err_msg = str(e)
104
- if ("'caption'" in err_msg) and (relationship_caption is not None):
105
- err_msg = err_msg.replace("'caption'", f"'{relationship_caption}'")
106
- raise ValueError(err_msg)
75
+
76
+ for rel in graph.relationships:
77
+ mapped_rel = _map_relationship(rel)
78
+ if mapped_rel:
79
+ relationships.append(mapped_rel)
107
80
 
108
81
  VG = VisualizationGraph(nodes, relationships)
109
82
 
110
- if (node_radius_min_max is not None) and (size_property is not None):
111
- VG.resize_nodes(node_radius_min_max=node_radius_min_max)
83
+ for node in VG.nodes:
84
+ node.caption = ":".join(node.properties["labels"])
85
+ for r in VG.relationships:
86
+ r.caption = r.properties["type"]
87
+
88
+ number_of_colors = len({n.caption for n in VG.nodes})
89
+ if number_of_colors <= len(NEO4J_COLORS_DISCRETE):
90
+ VG.color_nodes(field="caption", color_space=ColorSpace.DISCRETE, colors=NEO4J_COLORS_DISCRETE)
112
91
 
113
92
  return VG
114
93
 
115
94
 
116
95
  def _map_node(
117
96
  node: neo4j.graph.Node,
118
- all_node_field_aliases: set[str],
119
- size_property: Optional[str],
120
- caption_property: Optional[str],
121
97
  ) -> Node:
122
- top_level_fields = {"id": node.element_id}
123
-
124
- if size_property:
125
- top_level_fields["size"] = node.get(size_property)
126
-
127
98
  labels = sorted([label for label in node.labels])
128
- if caption_property:
129
- if caption_property == "labels":
130
- if len(labels) > 0:
131
- top_level_fields["caption"] = ":".join([label for label in labels])
132
- else:
133
- top_level_fields["caption"] = str(node.get(caption_property))
134
-
135
- properties = {}
136
- for prop, value in node.items():
137
- if prop not in all_node_field_aliases:
138
- properties[prop] = value
139
- continue
140
99
 
141
- if prop in top_level_fields:
142
- properties[prop] = value
143
- continue
144
-
145
- top_level_fields[prop] = value
100
+ properties = {prop: value for prop, value in node.items()}
146
101
 
147
102
  if "labels" in properties:
148
103
  properties["__labels"] = properties["labels"]
149
104
  properties["labels"] = labels
150
105
 
151
106
  try:
152
- viz_node = Node(**top_level_fields, properties=properties)
107
+ viz_node = Node(id=node.element_id, properties=properties)
153
108
  except ValidationError as e:
154
109
  _parse_validation_error(e, Node)
155
110
 
156
111
  return viz_node
157
112
 
158
113
 
159
- def _map_relationship(
160
- rel: neo4j.graph.Relationship, all_rel_field_aliases: set[str], caption_property: Optional[str]
161
- ) -> Optional[Relationship]:
114
+ def _map_relationship(rel: neo4j.graph.Relationship) -> Optional[Relationship]:
162
115
  if rel.start_node is None or rel.end_node is None:
163
116
  return None
164
117
 
165
- top_level_fields = {"id": rel.element_id, "source": rel.start_node.element_id, "target": rel.end_node.element_id}
166
-
167
- if caption_property:
168
- if caption_property == "type":
169
- top_level_fields["caption"] = rel.type
170
- else:
171
- top_level_fields["caption"] = str(rel.get(caption_property))
172
-
173
- properties = {}
174
- for prop, value in rel.items():
175
- if prop not in all_rel_field_aliases:
176
- properties[prop] = value
177
- continue
178
-
179
- if prop in top_level_fields:
180
- properties[prop] = value
181
- continue
182
-
183
- top_level_fields[prop] = value
118
+ properties = {prop: value for prop, value in rel.items()}
184
119
 
185
120
  if "type" in properties:
186
121
  properties["__type"] = properties["type"]
187
122
  properties["type"] = rel.type
188
123
 
189
124
  try:
190
- viz_rel = Relationship(**top_level_fields, properties=properties)
125
+ viz_rel = Relationship(
126
+ id=rel.element_id, source=rel.start_node.element_id, target=rel.end_node.element_id, properties=properties
127
+ )
191
128
  except ValidationError as e:
192
129
  _parse_validation_error(e, Relationship)
193
130
 
neo4j_viz/node.py CHANGED
@@ -30,6 +30,7 @@ class Node(
30
30
  validation_alias=create_aliases,
31
31
  serialization_alias=lambda field_name: to_camel(field_name),
32
32
  ),
33
+ validate_assignment=True,
33
34
  ):
34
35
  """
35
36
  A node in a graph to visualize.
@@ -90,10 +91,8 @@ class Node(
90
91
  return self.model_dump(exclude_none=True, by_alias=True)
91
92
 
92
93
  @staticmethod
93
- def all_validation_aliases(exempted_fields: Optional[list[str]] = None) -> set[str]:
94
- if exempted_fields is None:
95
- exempted_fields = []
96
-
97
- by_field = [v.validation_alias.choices for k, v in Node.model_fields.items() if k not in exempted_fields] # type: ignore
94
+ def basic_fields_validation_aliases() -> set[str]:
95
+ mandatory_fields = ["id"]
96
+ by_field = [v.validation_alias.choices for k, v in Node.model_fields.items() if k in mandatory_fields] # type: ignore
98
97
 
99
98
  return {str(alias) for aliases in by_field for alias in aliases}
neo4j_viz/options.py CHANGED
@@ -22,6 +22,9 @@ class CaptionAlignment(str, Enum):
22
22
  @enum_tools.documentation.document_enum
23
23
  class Layout(str, Enum):
24
24
  FORCE_DIRECTED = "forcedirected"
25
+ """
26
+ The force-directed layout uses a physics simulation to position the nodes.
27
+ """
25
28
  HIERARCHICAL = "hierarchical"
26
29
  """
27
30
  The nodes are then arranged by the directionality of their relationships