neo4j-viz 0.4.1__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
neo4j_viz/snowflake.py ADDED
@@ -0,0 +1,344 @@
1
+ from __future__ import annotations
2
+
3
+ from enum import Enum
4
+ from typing import Annotated, Any, Optional
5
+
6
+ from pandas import DataFrame
7
+ from pydantic import (
8
+ AfterValidator,
9
+ BaseModel,
10
+ BeforeValidator,
11
+ )
12
+ from pydantic_core.core_schema import ValidationInfo
13
+ from snowflake.snowpark import Session
14
+ from snowflake.snowpark.exceptions import SnowparkSQLException
15
+ from snowflake.snowpark.types import (
16
+ ArrayType,
17
+ BooleanType,
18
+ ByteType,
19
+ DataType,
20
+ DateType,
21
+ DecimalType,
22
+ DoubleType,
23
+ FloatType,
24
+ GeographyType,
25
+ GeometryType,
26
+ IntegerType,
27
+ LongType,
28
+ MapType,
29
+ ShortType,
30
+ StringType,
31
+ StructField,
32
+ StructType,
33
+ TimestampType,
34
+ TimeType,
35
+ VariantType,
36
+ VectorType,
37
+ )
38
+
39
+ from neo4j_viz import VisualizationGraph
40
+ from neo4j_viz.colors import ColorSpace
41
+ from neo4j_viz.pandas import from_dfs
42
+
43
+
44
+ def _data_type_name(type: DataType) -> str:
45
+ if isinstance(type, StringType):
46
+ return "VARCHAR"
47
+ elif isinstance(type, LongType):
48
+ return "BIGINT"
49
+ elif isinstance(type, IntegerType):
50
+ return "INT"
51
+ elif isinstance(type, DoubleType):
52
+ return "DOUBLE"
53
+ elif isinstance(type, DecimalType):
54
+ return "NUMBER"
55
+ elif isinstance(type, BooleanType):
56
+ return "BOOLEAN"
57
+ elif isinstance(type, ByteType):
58
+ return "TINYINT"
59
+ elif isinstance(type, DateType):
60
+ return "DATE"
61
+ elif isinstance(type, ShortType):
62
+ return "SMALLINT"
63
+ elif isinstance(type, FloatType):
64
+ return "FLOAT"
65
+ elif isinstance(type, ArrayType):
66
+ return "ARRAY"
67
+ elif isinstance(type, VectorType):
68
+ return "VECTOR"
69
+ elif isinstance(type, MapType):
70
+ return "OBJECT"
71
+ elif isinstance(type, TimeType):
72
+ return "TIME"
73
+ elif isinstance(type, TimestampType):
74
+ return "TIMESTAMP"
75
+ elif isinstance(type, VariantType):
76
+ return "VARIANT"
77
+ elif isinstance(type, GeographyType):
78
+ return "GEOGRAPHY"
79
+ elif isinstance(type, GeometryType):
80
+ return "GEOMETRY"
81
+ else:
82
+ # This actually does the job much of the time anyway
83
+ return type.simple_string().upper()
84
+
85
+
86
+ SUPPORTED_ID_TYPES = [_data_type_name(data_type) for data_type in [StringType(), LongType(), IntegerType()]]
87
+
88
+
89
+ def _validate_id_column(schema: StructType, column_name: str, index: int, supported_types: list[str]) -> None:
90
+ if column_name.lower() not in [name.lower() for name in schema.names]:
91
+ raise ValueError(f"Schema must contain a `{column_name}` column")
92
+
93
+ field: StructField = schema.fields[index]
94
+
95
+ if field.name.lower() != column_name.lower():
96
+ raise ValueError(f"Column `{column_name}` must have column index {index}")
97
+
98
+ if _data_type_name(field.datatype) not in supported_types:
99
+ raise ValueError(
100
+ f"Column `{column_name}` has invalid type `{_data_type_name(field.datatype)}`. Expected one of [{', '.join(supported_types)}]"
101
+ )
102
+
103
+
104
+ def _validate_viz_node_table(table: str, info: ValidationInfo) -> str:
105
+ context = info.context
106
+ if context and context["session"] is not None:
107
+ session = context["session"]
108
+ try:
109
+ schema = session.table(table).schema
110
+ _validate_id_column(schema, "nodeId", 0, SUPPORTED_ID_TYPES)
111
+ except SnowparkSQLException as e:
112
+ raise ValueError(f"Table '{table}' does not exist or is not accessible.") from e
113
+ return table
114
+
115
+
116
+ def _validate_viz_relationship_table(
117
+ table: str,
118
+ info: ValidationInfo,
119
+ ) -> str:
120
+ context = info.context
121
+ if context and context["session"] is not None:
122
+ session = context["session"]
123
+ try:
124
+ schema = session.table(table).schema
125
+ _validate_id_column(schema, "sourceNodeId", 0, SUPPORTED_ID_TYPES)
126
+ _validate_id_column(schema, "targetNodeId", 1, SUPPORTED_ID_TYPES)
127
+ except SnowparkSQLException as e:
128
+ raise ValueError(f"Table '{table}' does not exist or is not accessible.") from e
129
+ return table
130
+
131
+
132
+ def _parse_identifier_groups(identifier: str) -> list[str]:
133
+ """
134
+ Parses a table identifier into a list of individual identifier groups.
135
+
136
+ This function handles identifiers that may include double-quoted segments
137
+ and ensures proper validation of the identifier's structure. It raises
138
+ errors for invalid formats, such as unbalanced quotes, invalid characters,
139
+ or improper use of dots.
140
+
141
+ Args:
142
+ identifier (str): The input string identifier to parse.
143
+
144
+ Returns:
145
+ list[str]: A list of parsed identifier groups.
146
+
147
+ Raises:
148
+ ValueError: If the identifier contains:
149
+ - Empty double quotes.
150
+ - Consecutive dots outside of double quotes.
151
+ - Unbalanced double quotes.
152
+ - Invalid characters in unquoted segments.
153
+ - Improper placement of dots around double-quoted segments.
154
+ """
155
+ inside = False # Tracks whether the current character is inside double quotes
156
+ quoted_starts = [] # Stores the start indices of double-quoted segments
157
+ quoted_ends = [] # Stores the end indices of double-quoted segments
158
+ remaining = "" # Stores the unquoted part of the identifier
159
+ previous_is_dot = False # Tracks if the previous character was a dot
160
+
161
+ for i, c in enumerate(identifier):
162
+ if c == '"':
163
+ if not inside:
164
+ quoted_starts.append(i + 1) # Mark the start of a quoted segment
165
+ previous_is_dot = False
166
+ else:
167
+ quoted_ends.append(i) # Mark the end of a quoted segment
168
+ if quoted_ends[-1] - quoted_starts[-1] == 0:
169
+ raise ValueError("Empty double quotes")
170
+ inside = not inside # Toggle the inside state
171
+ else:
172
+ if not inside:
173
+ remaining += c # Append unquoted characters to `remaining`
174
+ if c == ".":
175
+ if previous_is_dot:
176
+ raise ValueError("Not ok to have consecutive dots outside of double quote")
177
+ previous_is_dot = True
178
+ else:
179
+ previous_is_dot = False
180
+
181
+ if len(quoted_starts) != len(quoted_ends):
182
+ raise ValueError("Unbalanced double quotes")
183
+
184
+ for quoted_start in quoted_starts:
185
+ if quoted_start > 1:
186
+ if identifier[quoted_start - 2] != ".":
187
+ raise ValueError("Only dot character may precede before double quoted identifier")
188
+
189
+ for quoted_end in quoted_ends:
190
+ if quoted_end < len(identifier) - 1:
191
+ if identifier[quoted_end + 1] != ".":
192
+ raise ValueError("Only dot character may follow double quoted identifier")
193
+
194
+ words = remaining.split(".") # Split the unquoted part by dots
195
+ for word in words:
196
+ if len(word) == 0:
197
+ continue
198
+ if word.lower()[0] not in "abcdefghijklmnopqrstuvwxyz_":
199
+ raise ValueError(f"Invalid first character in identifier {word}. Only a-z, A-Z, and _ are allowed.")
200
+ if not set(word.lower()).issubset(set("abcdefghijklmnopqrstuvwxyz$_0123456789")):
201
+ raise ValueError(f"Invalid characters in identifier {word}. Only a-z, A-Z, 0-9, _, and $ are allowed.")
202
+
203
+ empty_words_idx = [i for i, w in enumerate(words) if w == ""]
204
+ for i in range(len(quoted_starts)):
205
+ # Replace empty words with their corresponding quoted segments
206
+ words[empty_words_idx[i]] = f'"{identifier[quoted_starts[i] : quoted_ends[i]]}"'
207
+
208
+ return words
209
+
210
+
211
+ def _validate_table_name(table: str) -> str:
212
+ if not isinstance(table, str):
213
+ raise TypeError(f"Table name must be a string, got {type(table).__name__}")
214
+
215
+ try:
216
+ words = _parse_identifier_groups(table)
217
+ except ValueError as e:
218
+ raise ValueError(f"Invalid table name '{table}'. {str(e)}") from e
219
+
220
+ if len(words) not in {1, 3}:
221
+ raise ValueError(
222
+ f"Invalid table name '{table}'. Table names must be in the format '<database>.<schema>.<table>' or '<table>'"
223
+ )
224
+
225
+ return table
226
+
227
+
228
+ Table = Annotated[str, BeforeValidator(_validate_table_name)]
229
+
230
+ VizNodeTable = Annotated[Table, AfterValidator(_validate_viz_node_table)]
231
+ VizRelationshipTable = Annotated[Table, AfterValidator(_validate_viz_relationship_table)]
232
+
233
+
234
+ class Orientation(Enum):
235
+ NATURAL = "natural"
236
+ UNDIRECTED = "undirected"
237
+ REVERSE = "reverse"
238
+
239
+
240
+ def _to_lower(value: str) -> str:
241
+ return value.lower() if value and isinstance(value, str) else value
242
+
243
+
244
+ LowercaseOrientation = Annotated[Orientation, BeforeValidator(_to_lower)]
245
+
246
+
247
+ class VizRelationshipTableConfig(BaseModel, extra="forbid"):
248
+ sourceTable: VizNodeTable
249
+ targetTable: VizNodeTable
250
+ orientation: Optional[LowercaseOrientation] = Orientation.NATURAL
251
+
252
+
253
+ class VizProjectConfig(BaseModel, extra="forbid"):
254
+ defaultTablePrefix: Optional[str] = None
255
+ nodeTables: list[VizNodeTable]
256
+ relationshipTables: dict[VizRelationshipTable, VizRelationshipTableConfig]
257
+
258
+
259
+ def _map_tables(
260
+ session: Session, project_model: VizProjectConfig
261
+ ) -> tuple[list[DataFrame], list[DataFrame], list[str]]:
262
+ offset = 0
263
+ to_internal = {}
264
+ node_dfs = []
265
+ for table in project_model.nodeTables:
266
+ df = session.table(table).to_pandas()
267
+ internal_ids = range(offset, offset + df.shape[0])
268
+ to_internal[table] = df[["NODEID"]].copy()
269
+ to_internal[table]["INTERNALID"] = internal_ids
270
+ offset += df.shape[0]
271
+
272
+ df["SNOWFLAKEID"] = df["NODEID"]
273
+ df["NODEID"] = internal_ids
274
+
275
+ node_dfs.append(df)
276
+
277
+ rel_dfs = []
278
+ rel_table_names = []
279
+ for table, rel_table_config in project_model.relationshipTables.items():
280
+ df = session.table(table).to_pandas()
281
+
282
+ source_table = rel_table_config.sourceTable
283
+ target_table = rel_table_config.targetTable
284
+
285
+ df = df.merge(to_internal[source_table], left_on="SOURCENODEID", right_on="NODEID")
286
+ df.drop(["SOURCENODEID", "NODEID"], axis=1, inplace=True)
287
+ df.rename({"INTERNALID": "SOURCENODEID"}, axis=1, inplace=True)
288
+ df = df.merge(to_internal[target_table], left_on="TARGETNODEID", right_on="NODEID")
289
+ df.drop(["TARGETNODEID", "NODEID"], axis=1, inplace=True)
290
+ df.rename({"INTERNALID": "TARGETNODEID"}, axis=1, inplace=True)
291
+
292
+ if (
293
+ rel_table_config.orientation == Orientation.NATURAL
294
+ or rel_table_config.orientation == Orientation.UNDIRECTED
295
+ ):
296
+ rel_dfs.append(df)
297
+ rel_table_names.append(table)
298
+
299
+ if rel_table_config.orientation == Orientation.REVERSE:
300
+ df_rev = df.rename(columns={"SOURCENODEID": "TARGETNODEID", "TARGETNODEID": "SOURCENODEID"}, copy=False)
301
+ rel_dfs.append(df_rev)
302
+ rel_table_names.append(table)
303
+
304
+ if rel_table_config.orientation == Orientation.UNDIRECTED:
305
+ df_rev = df.rename(columns={"SOURCENODEID": "TARGETNODEID", "TARGETNODEID": "SOURCENODEID"}, copy=True)
306
+ rel_dfs.append(df_rev)
307
+ rel_table_names.append(table)
308
+
309
+ return node_dfs, rel_dfs, rel_table_names
310
+
311
+
312
+ def from_snowflake(
313
+ session: Session,
314
+ project_config: dict[str, Any],
315
+ node_radius_min_max: Optional[tuple[float, float]] = (3, 60),
316
+ ) -> VisualizationGraph:
317
+ project_model = VizProjectConfig.model_validate(project_config, strict=False, context={"session": session})
318
+ node_dfs, rel_dfs, rel_table_names = _map_tables(session, project_model)
319
+
320
+ node_caption_present = False
321
+ for node_df in node_dfs:
322
+ if "CAPTION" in node_df.columns:
323
+ node_caption_present = True
324
+ break
325
+
326
+ if not node_caption_present:
327
+ for i, node_df in enumerate(node_dfs):
328
+ node_df["caption"] = project_model.nodeTables[i].split(".")[-1]
329
+
330
+ rel_caption_present = False
331
+ for rel_df in rel_dfs:
332
+ if "CAPTION" in rel_df.columns:
333
+ rel_caption_present = True
334
+ break
335
+
336
+ if not rel_caption_present:
337
+ for i, rel_df in enumerate(rel_dfs):
338
+ rel_df["caption"] = rel_table_names[i].split(".")[-1]
339
+
340
+ VG = from_dfs(node_dfs, rel_dfs, node_radius_min_max)
341
+
342
+ VG.color_nodes(field="caption", color_space=ColorSpace.DISCRETE)
343
+
344
+ return VG
@@ -5,6 +5,7 @@ from collections.abc import Iterable
5
5
  from typing import Any, Callable, Hashable, Optional, Union
6
6
 
7
7
  from IPython.display import HTML
8
+ from pydantic.alias_generators import to_snake
8
9
  from pydantic_extra_types.color import Color, ColorType
9
10
 
10
11
  from .colors import NEO4J_COLORS_CONTINUOUS, NEO4J_COLORS_DISCRETE, ColorSpace, ColorsType
@@ -277,7 +278,7 @@ class VisualizationGraph:
277
278
  return node.properties.get(attribute)
278
279
  else:
279
280
  assert field is not None
280
- attribute = field
281
+ attribute = to_snake(field)
281
282
 
282
283
  def node_to_attr(node: Node) -> Any:
283
284
  return getattr(node, attribute)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: neo4j-viz
3
- Version: 0.4.1
3
+ Version: 0.5.0
4
4
  Summary: A simple graph visualization tool
5
5
  Author-email: Neo4j <team-gds@neo4j.org>
6
6
  Project-URL: Homepage, https://neo4j.com/
@@ -34,7 +34,7 @@ Requires-Dist: pydantic-extra-types<3,>=2
34
34
  Requires-Dist: enum-tools==0.12.0
35
35
  Provides-Extra: dev
36
36
  Requires-Dist: ruff==0.11.8; extra == "dev"
37
- Requires-Dist: mypy==1.15.0; extra == "dev"
37
+ Requires-Dist: mypy==1.17.1; extra == "dev"
38
38
  Requires-Dist: pytest==8.3.4; extra == "dev"
39
39
  Requires-Dist: selenium==4.32.0; extra == "dev"
40
40
  Requires-Dist: ipykernel==6.29.5; extra == "dev"
@@ -55,6 +55,8 @@ Provides-Extra: gds
55
55
  Requires-Dist: graphdatascience<2,>=1; extra == "gds"
56
56
  Provides-Extra: neo4j
57
57
  Requires-Dist: neo4j; extra == "neo4j"
58
+ Provides-Extra: snowflake
59
+ Requires-Dist: snowflake-snowpark-python<2,>=1; extra == "snowflake"
58
60
  Provides-Extra: notebook
59
61
  Requires-Dist: ipykernel>=6.29.5; extra == "notebook"
60
62
  Requires-Dist: pykernel>=0.1.6; extra == "notebook"
@@ -62,7 +64,7 @@ Requires-Dist: neo4j>=5.26.0; extra == "notebook"
62
64
  Requires-Dist: ipywidgets>=8.0.0; extra == "notebook"
63
65
  Requires-Dist: palettable>=3.3.3; extra == "notebook"
64
66
  Requires-Dist: matplotlib>=3.9.4; extra == "notebook"
65
- Requires-Dist: snowflake-snowpark-python==1.26.0; extra == "notebook"
67
+ Requires-Dist: snowflake-snowpark-python==1.37.0; extra == "notebook"
66
68
 
67
69
  # Graph Visualization for Python by Neo4j
68
70
 
@@ -1,22 +1,23 @@
1
1
  neo4j_viz/__init__.py,sha256=Q-VZlJe3_kAow_-F_-9RsHCQfbOfv5on26YD9ihw27o,504
2
2
  neo4j_viz/colors.py,sha256=IvOCTmCu7WTMna_wNLZ3GrThTwFyIoKtNkmZYDLdYac,6694
3
- neo4j_viz/gds.py,sha256=ux41zwbfBoeH-A4lqzTAYbdM4d-4JwmXv4ooVzlFflI,7595
3
+ neo4j_viz/gds.py,sha256=I6G69KmX7gmEBmfN0vfo9K1_5DAM9orJ3HL2-RPnSsg,8170
4
4
  neo4j_viz/gql_create.py,sha256=K33cT6dOj8eJPGNNJXiXlCfLIzNxTwcW4n_2AG3_zaY,14751
5
5
  neo4j_viz/neo4j.py,sha256=8oNhsEd33wayyNlDi5KirG_vKgvVc3nJgAyvAZuKcNw,7296
6
6
  neo4j_viz/node.py,sha256=MiLoghsn2NLs_iV65NuW7u3iaxP8MTKoNy6La9TdreY,3886
7
7
  neo4j_viz/node_size.py,sha256=c_sMtQSD8eJ_6Y0Kr6ku0LOs9VoEDxfYCUUzUWZ-1Xo,1197
8
8
  neo4j_viz/nvl.py,sha256=ZN3tyWar9ugR88r5N6txW3ThfNEWOt5A1KzrrRnLKwk,5262
9
9
  neo4j_viz/options.py,sha256=eOpiLcIfFvUiPoozyT44F9MHGRkqCfBZFmh0u_6DfwY,6400
10
- neo4j_viz/pandas.py,sha256=x_hD1IAFKKAo-cwgVMXLnavP0zkMIhB28Csw83vztNo,5755
10
+ neo4j_viz/pandas.py,sha256=7ac8kY2GQfLzh64Hn_V7OWdu4UEDG_P5Lb7FWdN24Hk,5776
11
11
  neo4j_viz/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
12
  neo4j_viz/relationship.py,sha256=sRgGjNzlqt6wPmiB3WXBTxR_W_5Z40VofpHUZvkRS1I,4143
13
- neo4j_viz/visualization_graph.py,sha256=Fizmzm8p4p3Fbd3gsiTymkXgI3crpdmr5ipv4gFv3SY,13760
13
+ neo4j_viz/snowflake.py,sha256=9FO7acOHOzLmjMeJxaZW6rbyoJyI08dQEvnbD91LOu8,12176
14
+ neo4j_viz/visualization_graph.py,sha256=aQhItfpPqKlebzldFRAtzamEyw-fxh4C11jO3X_51U4,13817
14
15
  neo4j_viz/resources/icons/screenshot.svg,sha256=Ns9Yi2Iq4lIaiFvzc0pXBmjxt4fcmBO-I4cI8Xiu1HE,311
15
16
  neo4j_viz/resources/icons/zoom-in.svg,sha256=PsO5yFkA1JnGM2QV_qxHKG13qmoR-RrlWARpaXNp5qU,415
16
17
  neo4j_viz/resources/icons/zoom-out.svg,sha256=OQRADAoe2bxbCeFufg6W22nR41q5NlI8QspT9l5pXUw,400
17
- neo4j_viz/resources/nvl_entrypoint/base.js,sha256=jYzVHzrBe32hPUgznTDRBNri3urSJ083GDnFmXOkAvc,1811792
18
+ neo4j_viz/resources/nvl_entrypoint/base.js,sha256=SQm93kmdN6ZIDlXWgtWPUXQBVzYp1Td4UHkRi9-_fjw,1815362
18
19
  neo4j_viz/resources/nvl_entrypoint/styles.css,sha256=JjeTSB9OJT2KMfb8yFUUMLMG7Rzrf3o60hSCD547zTk,1123
19
- neo4j_viz-0.4.1.dist-info/METADATA,sha256=zlnSq2VYXVOmpXmal4gS8LVM64wbb3Zoh4FFARJXKTg,7074
20
- neo4j_viz-0.4.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
21
- neo4j_viz-0.4.1.dist-info/top_level.txt,sha256=jPUM3z8MOtxqDanc2VzqkxG4HJn8aaq4S7rnCFNk_Vs,10
22
- neo4j_viz-0.4.1.dist-info/RECORD,,
20
+ neo4j_viz-0.5.0.dist-info/METADATA,sha256=vBmmB4PcBfuHcj-yypqOWc5IS_Fe8oBeoVBGq1JiLXI,7169
21
+ neo4j_viz-0.5.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
22
+ neo4j_viz-0.5.0.dist-info/top_level.txt,sha256=jPUM3z8MOtxqDanc2VzqkxG4HJn8aaq4S7rnCFNk_Vs,10
23
+ neo4j_viz-0.5.0.dist-info/RECORD,,