neo4j-viz 0.4.0__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
neo4j_viz/gds.py CHANGED
@@ -1,44 +1,47 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import warnings
3
4
  from itertools import chain
4
5
  from typing import Optional
5
6
  from uuid import uuid4
6
7
 
7
8
  import pandas as pd
8
9
  from graphdatascience import Graph, GraphDataScience
9
- from pandas import Series
10
10
 
11
11
  from .pandas import _from_dfs
12
12
  from .visualization_graph import VisualizationGraph
13
13
 
14
14
 
15
15
  def _fetch_node_dfs(
16
- gds: GraphDataScience, G: Graph, node_properties: list[str], node_labels: list[str]
16
+ gds: GraphDataScience, G: Graph, node_properties_by_label: dict[str, list[str]], node_labels: list[str]
17
17
  ) -> dict[str, pd.DataFrame]:
18
18
  return {
19
19
  lbl: gds.graph.nodeProperties.stream(
20
- G, node_properties=node_properties, node_labels=[lbl], separate_property_columns=True
20
+ G, node_properties=node_properties_by_label[lbl], node_labels=[lbl], separate_property_columns=True
21
21
  )
22
22
  for lbl in node_labels
23
23
  }
24
24
 
25
25
 
26
- def _fetch_rel_df(gds: GraphDataScience, G: Graph) -> pd.DataFrame:
27
- relationship_properties = G.relationship_properties()
28
- assert isinstance(relationship_properties, Series)
26
+ def _fetch_rel_dfs(gds: GraphDataScience, G: Graph) -> list[pd.DataFrame]:
27
+ rel_types = G.relationship_types()
29
28
 
30
- relationship_properties_per_type = relationship_properties.tolist()
31
- property_set: set[str] = set()
32
- for props in relationship_properties_per_type:
33
- if props:
34
- property_set.update(props)
29
+ rel_props = {rel_type: G.relationship_properties(rel_type) for rel_type in rel_types}
35
30
 
36
- if len(property_set) > 0:
37
- return gds.graph.relationshipProperties.stream(
38
- G, relationship_properties=list(property_set), separate_property_columns=True
39
- )
31
+ rel_dfs: list[pd.DataFrame] = []
32
+ # Have to call per stream per relationship type as there was a bug in GDS < 2.21
33
+ for rel_type, props in rel_props.items():
34
+ assert isinstance(props, list)
35
+ if len(props) > 0:
36
+ rel_df = gds.graph.relationshipProperties.stream(
37
+ G, relationship_types=rel_type, relationship_properties=list(props), separate_property_columns=True
38
+ )
39
+ else:
40
+ rel_df = gds.graph.relationships.stream(G, relationship_types=[rel_type])
41
+
42
+ rel_dfs.append(rel_df)
40
43
 
41
- return gds.graph.relationships.stream(G)
44
+ return rel_dfs
42
45
 
43
46
 
44
47
  def from_gds(
@@ -56,6 +59,7 @@ def from_gds(
56
59
  If the properties are named as the fields of the `Node` class, they will be included as top level fields of the
57
60
  created `Node` objects. Otherwise, they will be included in the `properties` dictionary.
58
61
  Additionally, a new "labels" node property will be added, containing the node labels of the node.
62
+ Similarly for relationships, a new "relationshipType" property will be added.
59
63
 
60
64
  Parameters
61
65
  ----------
@@ -77,27 +81,37 @@ def from_gds(
77
81
  """
78
82
  node_properties_from_gds = G.node_properties()
79
83
  assert isinstance(node_properties_from_gds, pd.Series)
80
- actual_node_properties = list(chain.from_iterable(node_properties_from_gds.to_dict().values()))
84
+ actual_node_properties = node_properties_from_gds.to_dict()
85
+ all_actual_node_properties = list(chain.from_iterable(actual_node_properties.values()))
81
86
 
82
- if size_property is not None and size_property not in actual_node_properties:
83
- raise ValueError(f"There is no node property '{size_property}' in graph '{G.name()}'")
87
+ if size_property is not None:
88
+ if size_property not in all_actual_node_properties:
89
+ raise ValueError(f"There is no node property '{size_property}' in graph '{G.name()}'")
84
90
 
85
91
  if additional_node_properties is None:
86
- additional_node_properties = actual_node_properties
92
+ node_properties_by_label = {k: set(v) for k, v in actual_node_properties.items()}
87
93
  else:
88
94
  for prop in additional_node_properties:
89
- if prop not in actual_node_properties:
95
+ if prop not in all_actual_node_properties:
90
96
  raise ValueError(f"There is no node property '{prop}' in graph '{G.name()}'")
91
97
 
92
- node_properties = set()
93
- if additional_node_properties is not None:
94
- node_properties.update(additional_node_properties)
98
+ node_properties_by_label = {}
99
+ for label, props in actual_node_properties.items():
100
+ node_properties_by_label[label] = {
101
+ prop for prop in actual_node_properties[label] if prop in additional_node_properties
102
+ }
103
+
95
104
  if size_property is not None:
96
- node_properties.add(size_property)
97
- node_properties = list(node_properties)
105
+ for label, props in node_properties_by_label.items():
106
+ props.add(size_property)
107
+
108
+ node_properties_by_label = {k: list(v) for k, v in node_properties_by_label.items()}
98
109
 
99
110
  node_count = G.node_count()
100
111
  if node_count > max_node_count:
112
+ warnings.warn(
113
+ f"The '{G.name()}' projection's node count ({G.node_count()}) exceeds `max_node_count` ({max_node_count}), so subsampling will be applied. Increase `max_node_count` if needed"
114
+ )
101
115
  sampling_ratio = float(max_node_count) / node_count
102
116
  sample_name = f"neo4j-viz_sample_{uuid4()}"
103
117
  G_fetched, _ = gds.graph.sample.rwr(sample_name, G, samplingRatio=sampling_ratio, nodeLabelStratification=True)
@@ -107,14 +121,19 @@ def from_gds(
107
121
  property_name = None
108
122
  try:
109
123
  # Since GDS does not allow us to only fetch node IDs, we add the degree property
110
- # as a temporary property to ensure that we have at least one property to fetch
111
- if len(actual_node_properties) == 0:
124
+ # as a temporary property to ensure that we have at least one property for each label to fetch
125
+ if sum([len(props) == 0 for props in node_properties_by_label.values()]) > 0:
112
126
  property_name = f"neo4j-viz_property_{uuid4()}"
113
127
  gds.degree.mutate(G_fetched, mutateProperty=property_name)
114
- node_properties = [property_name]
128
+ for props in node_properties_by_label.values():
129
+ props.append(property_name)
130
+
131
+ node_dfs = _fetch_node_dfs(gds, G_fetched, node_properties_by_label, G_fetched.node_labels())
132
+ if property_name is not None:
133
+ for df in node_dfs.values():
134
+ df.drop(columns=[property_name], inplace=True)
115
135
 
116
- node_dfs = _fetch_node_dfs(gds, G_fetched, node_properties, G_fetched.node_labels())
117
- rel_df = _fetch_rel_df(gds, G_fetched)
136
+ rel_dfs = _fetch_rel_dfs(gds, G_fetched)
118
137
  finally:
119
138
  if G_fetched.name() != G.name():
120
139
  G_fetched.drop()
@@ -122,32 +141,39 @@ def from_gds(
122
141
  gds.graph.nodeProperties.drop(G_fetched, node_properties=[property_name])
123
142
 
124
143
  for df in node_dfs.values():
125
- df.rename(columns={"nodeId": "id"}, inplace=True)
126
144
  if property_name is not None and property_name in df.columns:
127
145
  df.drop(columns=[property_name], inplace=True)
128
- rel_df.rename(columns={"sourceNodeId": "source", "targetNodeId": "target"}, inplace=True)
129
146
 
130
147
  node_props_df = pd.concat(node_dfs.values(), ignore_index=True, axis=0).drop_duplicates()
131
148
  if size_property is not None:
132
- if "size" in actual_node_properties and size_property != "size":
149
+ if "size" in all_actual_node_properties and size_property != "size":
133
150
  node_props_df.rename(columns={"size": "__size"}, inplace=True)
134
- node_props_df.rename(columns={size_property: "size"}, inplace=True)
151
+ if size_property not in additional_node_properties:
152
+ node_props_df.rename(columns={size_property: "size"}, inplace=True)
153
+ else:
154
+ node_props_df["size"] = node_props_df[size_property]
135
155
 
136
156
  for lbl, df in node_dfs.items():
137
- if "labels" in actual_node_properties:
157
+ if "labels" in all_actual_node_properties:
138
158
  df.rename(columns={"labels": "__labels"}, inplace=True)
139
159
  df["labels"] = lbl
140
160
 
141
- node_labels_df = pd.concat([df[["id", "labels"]] for df in node_dfs.values()], ignore_index=True, axis=0)
142
- node_labels_df = node_labels_df.groupby("id").agg({"labels": list})
161
+ node_labels_df = pd.concat([df[["nodeId", "labels"]] for df in node_dfs.values()], ignore_index=True, axis=0)
162
+ node_labels_df = node_labels_df.groupby("nodeId").agg({"labels": list})
143
163
 
144
- node_df = node_props_df.merge(node_labels_df, on="id")
164
+ node_df = node_props_df.merge(node_labels_df, on="nodeId")
145
165
 
146
- if "caption" not in actual_node_properties:
166
+ if "caption" not in all_actual_node_properties:
147
167
  node_df["caption"] = node_df["labels"].astype(str)
148
168
 
169
+ for rel_df in rel_dfs:
170
+ if "caption" not in rel_df.columns:
171
+ rel_df["caption"] = rel_df["relationshipType"]
172
+
149
173
  try:
150
- return _from_dfs(node_df, rel_df, node_radius_min_max=node_radius_min_max, rename_properties={"__size": "size"})
174
+ return _from_dfs(
175
+ node_df, rel_dfs, node_radius_min_max=node_radius_min_max, rename_properties={"__size": "size"}, dropna=True
176
+ )
151
177
  except ValueError as e:
152
178
  err_msg = str(e)
153
179
  if "column" in err_msg:
neo4j_viz/neo4j.py CHANGED
@@ -1,9 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import warnings
3
4
  from typing import Optional, Union
4
5
 
5
6
  import neo4j.graph
6
- from neo4j import Result
7
+ from neo4j import Driver, Result, RoutingControl
7
8
  from pydantic import BaseModel, ValidationError
8
9
 
9
10
  from neo4j_viz.node import Node
@@ -20,14 +21,15 @@ def _parse_validation_error(e: ValidationError, entity_type: type[BaseModel]) ->
20
21
 
21
22
 
22
23
  def from_neo4j(
23
- result: Union[neo4j.graph.Graph, Result],
24
+ data: Union[neo4j.graph.Graph, Result, Driver],
24
25
  size_property: Optional[str] = None,
25
26
  node_caption: Optional[str] = "labels",
26
27
  relationship_caption: Optional[str] = "type",
27
28
  node_radius_min_max: Optional[tuple[float, float]] = (3, 60),
29
+ row_limit: int = 10_000,
28
30
  ) -> VisualizationGraph:
29
31
  """
30
- Create a VisualizationGraph from a Neo4j Graph or Neo4j Result object.
32
+ Create a VisualizationGraph from a Neo4j `Graph`, Neo4j `Result` or Neo4j `Driver`.
31
33
 
32
34
  All node and relationship properties will be included in the visualization graph.
33
35
  If the properties are named as the fields of the `Node` or `Relationship` classes, they will be included as
@@ -36,8 +38,9 @@ def from_neo4j(
36
38
 
37
39
  Parameters
38
40
  ----------
39
- result : Union[neo4j.graph.Graph, Result]
40
- Query result either in shape of a Graph or result.
41
+ data : Union[neo4j.graph.Graph, neo4j.Result, neo4j.Driver]
42
+ Either a query result in the shape of a `neo4j.graph.Graph` or `neo4j.Result`, or a `neo4j.Driver` in
43
+ which case a simple default query will be executed internally to retrieve the graph data.
41
44
  size_property : str, optional
42
45
  Property to use for node size, by default None.
43
46
  node_caption : str, optional
@@ -47,14 +50,32 @@ def from_neo4j(
47
50
  node_radius_min_max : tuple[float, float], optional
48
51
  Minimum and maximum node radius, by default (3, 60).
49
52
  To avoid tiny or huge nodes in the visualization, the node sizes are scaled to fit in the given range.
53
+ row_limit : int, optional
54
+ Maximum number of rows to return from the query, by default 10_000.
55
+ This is only used if a `neo4j.Driver` is passed as `result` argument, otherwise the limit is ignored.
50
56
  """
51
57
 
52
- if isinstance(result, Result):
53
- graph = result.graph()
54
- elif isinstance(result, neo4j.graph.Graph):
55
- graph = result
58
+ if isinstance(data, Result):
59
+ graph = data.graph()
60
+ elif isinstance(data, neo4j.graph.Graph):
61
+ graph = data
62
+ elif isinstance(data, Driver):
63
+ rel_count = data.execute_query(
64
+ "MATCH ()-[r]->() RETURN count(r) as count",
65
+ routing_=RoutingControl.READ,
66
+ result_transformer_=Result.single,
67
+ ).get("count") # type: ignore[union-attr]
68
+ if rel_count > row_limit:
69
+ warnings.warn(
70
+ f"Database relationship count ({rel_count}) exceeds `row_limit` ({row_limit}), so limiting will be applied. Increase the `row_limit` if needed"
71
+ )
72
+ graph = data.execute_query(
73
+ f"MATCH (n)-[r]->(m) RETURN n,r,m LIMIT {row_limit}",
74
+ routing_=RoutingControl.READ,
75
+ result_transformer_=Result.graph,
76
+ )
56
77
  else:
57
- raise ValueError(f"Invalid input type `{type(result)}`. Expected `neo4j.Graph` or `neo4j.Result`")
78
+ raise ValueError(f"Invalid input type `{type(data)}`. Expected `neo4j.Graph`, `neo4j.Result` or `neo4j.Driver`")
58
79
 
59
80
  all_node_field_aliases = Node.all_validation_aliases()
60
81
  all_rel_field_aliases = Relationship.all_validation_aliases()
neo4j_viz/pandas.py CHANGED
@@ -27,12 +27,19 @@ def _parse_validation_error(e: ValidationError, entity_type: type[BaseModel]) ->
27
27
 
28
28
 
29
29
  def _from_dfs(
30
- node_dfs: Optional[DFS_TYPE],
31
- rel_dfs: DFS_TYPE,
30
+ node_dfs: Optional[DFS_TYPE] = None,
31
+ rel_dfs: Optional[DFS_TYPE] = None,
32
32
  node_radius_min_max: Optional[tuple[float, float]] = (3, 60),
33
33
  rename_properties: Optional[dict[str, str]] = None,
34
+ dropna: bool = False,
34
35
  ) -> VisualizationGraph:
35
- relationships = _parse_relationships(rel_dfs, rename_properties=rename_properties)
36
+ if node_dfs is None and rel_dfs is None:
37
+ raise ValueError("At least one of `node_dfs` or `rel_dfs` must be provided")
38
+
39
+ if rel_dfs is None:
40
+ relationships = []
41
+ else:
42
+ relationships = _parse_relationships(rel_dfs, rename_properties=rename_properties, dropna=dropna)
36
43
 
37
44
  if node_dfs is None:
38
45
  has_size = False
@@ -42,7 +49,7 @@ def _from_dfs(
42
49
  node_ids.add(rel.target)
43
50
  nodes = [Node(id=id) for id in node_ids]
44
51
  else:
45
- nodes, has_size = _parse_nodes(node_dfs, rename_properties=rename_properties)
52
+ nodes, has_size = _parse_nodes(node_dfs, rename_properties=rename_properties, dropna=dropna)
46
53
 
47
54
  VG = VisualizationGraph(nodes=nodes, relationships=relationships)
48
55
 
@@ -52,7 +59,9 @@ def _from_dfs(
52
59
  return VG
53
60
 
54
61
 
55
- def _parse_nodes(node_dfs: DFS_TYPE, rename_properties: Optional[dict[str, str]]) -> tuple[list[Node], bool]:
62
+ def _parse_nodes(
63
+ node_dfs: DFS_TYPE, rename_properties: Optional[dict[str, str]], dropna: bool = False
64
+ ) -> tuple[list[Node], bool]:
56
65
  if isinstance(node_dfs, DataFrame):
57
66
  node_dfs_iter: Iterable[DataFrame] = [node_dfs]
58
67
  elif node_dfs is None:
@@ -65,8 +74,10 @@ def _parse_nodes(node_dfs: DFS_TYPE, rename_properties: Optional[dict[str, str]]
65
74
  has_size = True
66
75
  nodes = []
67
76
  for node_df in node_dfs_iter:
68
- has_size &= "size" in node_df.columns
77
+ has_size &= "size" in [c.lower() for c in node_df.columns]
69
78
  for _, row in node_df.iterrows():
79
+ if dropna:
80
+ row = row.dropna(inplace=False)
70
81
  top_level = {}
71
82
  properties = {}
72
83
  for key, value in row.to_dict().items():
@@ -85,7 +96,9 @@ def _parse_nodes(node_dfs: DFS_TYPE, rename_properties: Optional[dict[str, str]]
85
96
  return nodes, has_size
86
97
 
87
98
 
88
- def _parse_relationships(rel_dfs: DFS_TYPE, rename_properties: Optional[dict[str, str]]) -> list[Relationship]:
99
+ def _parse_relationships(
100
+ rel_dfs: DFS_TYPE, rename_properties: Optional[dict[str, str]], dropna: bool = False
101
+ ) -> list[Relationship]:
89
102
  all_rel_field_aliases = Relationship.all_validation_aliases()
90
103
 
91
104
  if isinstance(rel_dfs, DataFrame):
@@ -96,6 +109,8 @@ def _parse_relationships(rel_dfs: DFS_TYPE, rename_properties: Optional[dict[str
96
109
 
97
110
  for rel_df in rel_dfs_iter:
98
111
  for _, row in rel_df.iterrows():
112
+ if dropna:
113
+ row = row.dropna(inplace=False)
99
114
  top_level = {}
100
115
  properties = {}
101
116
  for key, value in row.to_dict().items():
@@ -115,8 +130,8 @@ def _parse_relationships(rel_dfs: DFS_TYPE, rename_properties: Optional[dict[str
115
130
 
116
131
 
117
132
  def from_dfs(
118
- node_dfs: Optional[DFS_TYPE],
119
- rel_dfs: DFS_TYPE,
133
+ node_dfs: Optional[DFS_TYPE] = None,
134
+ rel_dfs: Optional[DFS_TYPE] = None,
120
135
  node_radius_min_max: Optional[tuple[float, float]] = (3, 60),
121
136
  ) -> VisualizationGraph:
122
137
  """
@@ -128,14 +143,15 @@ def from_dfs(
128
143
 
129
144
  Parameters
130
145
  ----------
131
- node_dfs: Optional[Union[DataFrame, Iterable[DataFrame]]]
146
+ node_dfs: Optional[Union[DataFrame, Iterable[DataFrame]]], optional
132
147
  DataFrame or iterable of DataFrames containing node data.
133
148
  If None, the nodes will be created from the source and target node ids in the rel_dfs.
134
- rel_dfs: Union[DataFrame, Iterable[DataFrame]]
149
+ rel_dfs: Optional[Union[DataFrame, Iterable[DataFrame]]], optional
135
150
  DataFrame or iterable of DataFrames containing relationship data.
151
+ If None, no relationships will be created.
136
152
  node_radius_min_max : tuple[float, float], optional
137
153
  Minimum and maximum node radius.
138
154
  To avoid tiny or huge nodes in the visualization, the node sizes are scaled to fit in the given range.
139
155
  """
140
156
 
141
- return _from_dfs(node_dfs, rel_dfs, node_radius_min_max)
157
+ return _from_dfs(node_dfs, rel_dfs, node_radius_min_max, dropna=False)