neo4j-viz 0.4.1__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
neo4j_viz/gds.py CHANGED
@@ -7,7 +7,6 @@ from uuid import uuid4
7
7
 
8
8
  import pandas as pd
9
9
  from graphdatascience import Graph, GraphDataScience
10
- from pandas import Series
11
10
 
12
11
  from .pandas import _from_dfs
13
12
  from .visualization_graph import VisualizationGraph
@@ -24,22 +23,25 @@ def _fetch_node_dfs(
24
23
  }
25
24
 
26
25
 
27
- def _fetch_rel_df(gds: GraphDataScience, G: Graph) -> pd.DataFrame:
28
- relationship_properties = G.relationship_properties()
29
- assert isinstance(relationship_properties, Series)
26
+ def _fetch_rel_dfs(gds: GraphDataScience, G: Graph) -> list[pd.DataFrame]:
27
+ rel_types = G.relationship_types()
30
28
 
31
- relationship_properties_per_type = relationship_properties.tolist()
32
- property_set: set[str] = set()
33
- for props in relationship_properties_per_type:
34
- if props:
35
- property_set.update(props)
29
+ rel_props = {rel_type: G.relationship_properties(rel_type) for rel_type in rel_types}
36
30
 
37
- if len(property_set) > 0:
38
- return gds.graph.relationshipProperties.stream(
39
- G, relationship_properties=list(property_set), separate_property_columns=True
40
- )
31
+ rel_dfs: list[pd.DataFrame] = []
32
+ # Have to call per stream per relationship type as there was a bug in GDS < 2.21
33
+ for rel_type, props in rel_props.items():
34
+ assert isinstance(props, list)
35
+ if len(props) > 0:
36
+ rel_df = gds.graph.relationshipProperties.stream(
37
+ G, relationship_types=rel_type, relationship_properties=list(props), separate_property_columns=True
38
+ )
39
+ else:
40
+ rel_df = gds.graph.relationships.stream(G, relationship_types=[rel_type])
41
+
42
+ rel_dfs.append(rel_df)
41
43
 
42
- return gds.graph.relationships.stream(G)
44
+ return rel_dfs
43
45
 
44
46
 
45
47
  def from_gds(
@@ -131,7 +133,7 @@ def from_gds(
131
133
  for df in node_dfs.values():
132
134
  df.drop(columns=[property_name], inplace=True)
133
135
 
134
- rel_df = _fetch_rel_df(gds, G_fetched)
136
+ rel_dfs = _fetch_rel_dfs(gds, G_fetched)
135
137
  finally:
136
138
  if G_fetched.name() != G.name():
137
139
  G_fetched.drop()
@@ -146,7 +148,10 @@ def from_gds(
146
148
  if size_property is not None:
147
149
  if "size" in all_actual_node_properties and size_property != "size":
148
150
  node_props_df.rename(columns={"size": "__size"}, inplace=True)
149
- node_props_df.rename(columns={size_property: "size"}, inplace=True)
151
+ if size_property not in additional_node_properties:
152
+ node_props_df.rename(columns={size_property: "size"}, inplace=True)
153
+ else:
154
+ node_props_df["size"] = node_props_df[size_property]
150
155
 
151
156
  for lbl, df in node_dfs.items():
152
157
  if "labels" in all_actual_node_properties:
@@ -161,12 +166,13 @@ def from_gds(
161
166
  if "caption" not in all_actual_node_properties:
162
167
  node_df["caption"] = node_df["labels"].astype(str)
163
168
 
164
- if "caption" not in rel_df.columns:
165
- rel_df["caption"] = rel_df["relationshipType"]
169
+ for rel_df in rel_dfs:
170
+ if "caption" not in rel_df.columns:
171
+ rel_df["caption"] = rel_df["relationshipType"]
166
172
 
167
173
  try:
168
174
  return _from_dfs(
169
- node_df, rel_df, node_radius_min_max=node_radius_min_max, rename_properties={"__size": "size"}, dropna=True
175
+ node_df, rel_dfs, node_radius_min_max=node_radius_min_max, rename_properties={"__size": "size"}, dropna=True
170
176
  )
171
177
  except ValueError as e:
172
178
  err_msg = str(e)
neo4j_viz/pandas.py CHANGED
@@ -74,7 +74,7 @@ def _parse_nodes(
74
74
  has_size = True
75
75
  nodes = []
76
76
  for node_df in node_dfs_iter:
77
- has_size &= "size" in node_df.columns
77
+ has_size &= "size" in [c.lower() for c in node_df.columns]
78
78
  for _, row in node_df.iterrows():
79
79
  if dropna:
80
80
  row = row.dropna(inplace=False)