graphdatascience 1.10a1__tar.gz → 1.11a2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {graphdatascience-1.10a1/graphdatascience.egg-info → graphdatascience-1.11a2}/PKG-INFO +2 -2
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/graph/base_graph_proc_runner.py +4 -4
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/graph/graph_entity_ops_runner.py +43 -6
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/graph/graph_remote_proc_runner.py +0 -1
- graphdatascience-1.11a2/graphdatascience/graph/graph_remote_project_runner.py +41 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/graph_data_science.py +13 -10
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/query_runner/arrow_graph_constructor.py +13 -40
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/query_runner/arrow_query_runner.py +28 -166
- graphdatascience-1.11a2/graphdatascience/query_runner/aura_db_query_runner.py +224 -0
- graphdatascience-1.11a2/graphdatascience/query_runner/gds_arrow_client.py +242 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/query_runner/neo4j_query_runner.py +36 -15
- graphdatascience-1.11a2/graphdatascience/session/__init__.py +16 -0
- graphdatascience-1.11a2/graphdatascience/session/algorithm_category.py +14 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/session/aura_api.py +112 -110
- graphdatascience-1.11a2/graphdatascience/session/aura_api_responses.py +174 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/session/aura_graph_data_science.py +11 -5
- graphdatascience-1.11a2/graphdatascience/session/aurads_sessions.py +202 -0
- graphdatascience-1.11a2/graphdatascience/session/dedicated_sessions.py +147 -0
- graphdatascience-1.11a2/graphdatascience/session/gds_sessions.py +107 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/session/schema.py +0 -3
- graphdatascience-1.11a2/graphdatascience/session/session_info.py +40 -0
- graphdatascience-1.11a2/graphdatascience/session/session_sizes.py +31 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/system/system_endpoints.py +2 -2
- graphdatascience-1.11a2/graphdatascience/version.py +1 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2/graphdatascience.egg-info}/PKG-INFO +2 -2
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience.egg-info/SOURCES.txt +7 -1
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience.egg-info/requires.txt +1 -1
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/requirements/base/base.txt +1 -1
- graphdatascience-1.10a1/graphdatascience/graph/graph_remote_project_runner.py +0 -40
- graphdatascience-1.10a1/graphdatascience/query_runner/aura_db_arrow_query_runner.py +0 -184
- graphdatascience-1.10a1/graphdatascience/session/__init__.py +0 -13
- graphdatascience-1.10a1/graphdatascience/session/gds_sessions.py +0 -240
- graphdatascience-1.10a1/graphdatascience/session/session_sizes.py +0 -33
- graphdatascience-1.10a1/graphdatascience/version.py +0 -1
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/LICENSE +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/MANIFEST.in +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/README.md +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/__init__.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/algo/__init__.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/algo/algo_endpoints.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/algo/algo_proc_runner.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/algo/single_mode_algo_endpoints.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/call_builder.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/call_parameters.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/caller_base.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/endpoints.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/error/__init__.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/error/client_only_endpoint.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/error/cypher_warning_handler.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/error/endpoint_suggester.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/error/gds_not_installed.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/error/illegal_attr_checker.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/error/unable_to_connect.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/error/uncallable_namespace.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/graph/__init__.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/graph/graph_alpha_proc_runner.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/graph/graph_beta_proc_runner.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/graph/graph_create_result.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/graph/graph_cypher_runner.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/graph/graph_endpoints.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/graph/graph_export_runner.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/graph/graph_object.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/graph/graph_proc_runner.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/graph/graph_project_runner.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/graph/graph_sample_runner.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/graph/graph_type_check.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/graph/nx_loader.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/graph/ogb_loader.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/ignored_server_endpoints.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/model/__init__.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/model/graphsage_model.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/model/link_prediction_model.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/model/model.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/model/model_alpha_proc_runner.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/model/model_beta_proc_runner.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/model/model_endpoints.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/model/model_proc_runner.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/model/model_resolver.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/model/node_classification_model.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/model/node_regression_model.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/model/pipeline_model.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/model/simple_rel_embedding_model.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/pipeline/__init__.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/pipeline/classification_training_pipeline.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/pipeline/lp_pipeline_create_runner.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/pipeline/lp_training_pipeline.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/pipeline/nc_pipeline_create_runner.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/pipeline/nc_training_pipeline.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/pipeline/nr_pipeline_create_runner.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/pipeline/nr_training_pipeline.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/pipeline/pipeline_alpha_proc_runner.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/pipeline/pipeline_beta_proc_runner.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/pipeline/pipeline_endpoints.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/pipeline/pipeline_proc_runner.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/pipeline/training_pipeline.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/py.typed +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/query_runner/__init__.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/query_runner/arrow_endpoint_version.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/query_runner/cypher_graph_constructor.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/query_runner/graph_constructor.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/query_runner/query_runner.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/resources/__init__.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/resources/cora/__init__.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/resources/cora/cora_nodes.parquet.gzip +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/resources/cora/cora_rels.parquet.gzip +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/resources/cora/serialize_cora.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/resources/imdb/__init__.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/resources/imdb/imdb_acted_in.parquet.gzip +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/resources/imdb/imdb_actors.parquet.gzip +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/resources/imdb/imdb_directed_in.parquet.gzip +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/resources/imdb/imdb_directors.parquet.gzip +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/resources/imdb/imdb_movies_with_genre.parquet.gzip +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/resources/imdb/imdb_movies_without_genre.parquet.gzip +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/resources/imdb/serialize_imdb.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/resources/karate/__init__.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/resources/karate/karate_club.parquet.gzip +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/resources/lastfm/__init__.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/resources/lastfm/artist_nodes.parquet.gzip +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/resources/lastfm/serialize_lastfm.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/resources/lastfm/user_friend_df_directed.parquet.gzip +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/resources/lastfm/user_listen_artist_rels.parquet.gzip +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/resources/lastfm/user_nodes.parquet.gzip +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/resources/lastfm/user_tag_artist_rels.parquet.gzip +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/server_version/__init__.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/server_version/compatible_with.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/server_version/server_version.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/session/dbms_connection_info.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/session/region_suggester.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/system/__init__.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/system/config_endpoints.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/topological_lp/__init__.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/topological_lp/topological_lp_alpha_runner.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/topological_lp/topological_lp_endpoints.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/utils/__init__.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/utils/util_endpoints.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/utils/util_proc_runner.py +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience.egg-info/dependency_links.txt +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience.egg-info/not-zip-safe +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience.egg-info/top_level.txt +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/pyproject.toml +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/requirements/base/networkx.txt +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/requirements/base/ogb.txt +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/setup.cfg +0 -0
- {graphdatascience-1.10a1 → graphdatascience-1.11a2}/setup.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: graphdatascience
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.11a2
|
|
4
4
|
Summary: A Python client for the Neo4j Graph Data Science (GDS) library
|
|
5
5
|
Home-page: https://neo4j.com/product/graph-data-science/
|
|
6
6
|
Author: Neo4j
|
|
@@ -31,7 +31,7 @@ License-File: LICENSE
|
|
|
31
31
|
Requires-Dist: multimethod<2.0,>=1.0
|
|
32
32
|
Requires-Dist: neo4j<6.0,>=4.4.2
|
|
33
33
|
Requires-Dist: pandas<3.0,>=1.0
|
|
34
|
-
Requires-Dist: pyarrow<
|
|
34
|
+
Requires-Dist: pyarrow<16.0,>=11.0
|
|
35
35
|
Requires-Dist: textdistance<5.0,>=4.0
|
|
36
36
|
Requires-Dist: tqdm<5.0,>=4.0
|
|
37
37
|
Requires-Dist: typing-extensions<5.0,>=4.0
|
{graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/graph/base_graph_proc_runner.py
RENAMED
|
@@ -18,6 +18,7 @@ from .graph_entity_ops_runner import (
|
|
|
18
18
|
GraphElementPropertyRunner,
|
|
19
19
|
GraphLabelRunner,
|
|
20
20
|
GraphNodePropertiesRunner,
|
|
21
|
+
GraphNodePropertyRunner,
|
|
21
22
|
GraphPropertyRunner,
|
|
22
23
|
GraphRelationshipPropertiesRunner,
|
|
23
24
|
GraphRelationshipRunner,
|
|
@@ -379,9 +380,9 @@ class BaseGraphProcRunner(UncallableNamespace, IllegalAttrChecker):
|
|
|
379
380
|
)
|
|
380
381
|
|
|
381
382
|
@property
|
|
382
|
-
def nodeProperty(self) ->
|
|
383
|
+
def nodeProperty(self) -> GraphNodePropertyRunner:
|
|
383
384
|
self._namespace += ".nodeProperty"
|
|
384
|
-
return
|
|
385
|
+
return GraphNodePropertyRunner(self._query_runner, self._namespace, self._server_version)
|
|
385
386
|
|
|
386
387
|
@property
|
|
387
388
|
def nodeProperties(self) -> GraphNodePropertiesRunner:
|
|
@@ -516,8 +517,7 @@ class BaseGraphProcRunner(UncallableNamespace, IllegalAttrChecker):
|
|
|
516
517
|
).squeeze()
|
|
517
518
|
|
|
518
519
|
@multimethod
|
|
519
|
-
def removeNodeProperties(self) -> None:
|
|
520
|
-
...
|
|
520
|
+
def removeNodeProperties(self) -> None: ...
|
|
521
521
|
|
|
522
522
|
@removeNodeProperties.register
|
|
523
523
|
@graph_type_check
|
|
@@ -77,6 +77,26 @@ class GraphElementPropertyRunner(GraphEntityOpsBaseRunner):
|
|
|
77
77
|
return self._handle_properties(G, node_properties, node_labels, config)
|
|
78
78
|
|
|
79
79
|
|
|
80
|
+
class GraphNodePropertyRunner(GraphEntityOpsBaseRunner):
|
|
81
|
+
@compatible_with("stream", min_inclusive=ServerVersion(2, 2, 0))
|
|
82
|
+
@filter_id_func_deprecation_warning()
|
|
83
|
+
def stream(
|
|
84
|
+
self,
|
|
85
|
+
G: Graph,
|
|
86
|
+
node_property: str,
|
|
87
|
+
node_labels: Strings = ["*"],
|
|
88
|
+
db_node_properties: List[str] = [],
|
|
89
|
+
**config: Any,
|
|
90
|
+
) -> DataFrame:
|
|
91
|
+
self._namespace += ".stream"
|
|
92
|
+
|
|
93
|
+
result = self._handle_properties(G, node_property, node_labels, config)
|
|
94
|
+
|
|
95
|
+
return GraphNodePropertiesRunner._process_result(
|
|
96
|
+
self._query_runner, list(node_property), False, db_node_properties, result, config
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
|
|
80
100
|
class GraphNodePropertiesRunner(GraphEntityOpsBaseRunner):
|
|
81
101
|
@compatible_with("stream", min_inclusive=ServerVersion(2, 2, 0))
|
|
82
102
|
@filter_id_func_deprecation_warning()
|
|
@@ -93,6 +113,19 @@ class GraphNodePropertiesRunner(GraphEntityOpsBaseRunner):
|
|
|
93
113
|
|
|
94
114
|
result = self._handle_properties(G, node_properties, node_labels, config)
|
|
95
115
|
|
|
116
|
+
return GraphNodePropertiesRunner._process_result(
|
|
117
|
+
self._query_runner, node_properties, separate_property_columns, db_node_properties, result, config
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
@staticmethod
|
|
121
|
+
def _process_result(
|
|
122
|
+
query_runner: QueryRunner,
|
|
123
|
+
node_properties: List[str],
|
|
124
|
+
separate_property_columns: bool,
|
|
125
|
+
db_node_properties: List[str],
|
|
126
|
+
result: DataFrame,
|
|
127
|
+
config: Dict[str, Any],
|
|
128
|
+
) -> DataFrame:
|
|
96
129
|
# new format was requested, but the query was run via Cypher
|
|
97
130
|
if separate_property_columns and "propertyValue" in result.keys():
|
|
98
131
|
wide_result = result.pivot(index=["nodeId"], columns=["nodeProperty"], values="propertyValue")
|
|
@@ -106,7 +139,7 @@ class GraphNodePropertiesRunner(GraphEntityOpsBaseRunner):
|
|
|
106
139
|
# old format was requested but the query was run via Arrow
|
|
107
140
|
elif not separate_property_columns and "propertyValue" not in result.keys():
|
|
108
141
|
id_vars = ["nodeId", "nodeLabels"] if config.get("listNodeLabels", False) else ["nodeId"]
|
|
109
|
-
result = result.melt(id_vars=id_vars
|
|
142
|
+
result = result.melt(id_vars=id_vars, var_name="nodeProperty", value_name="propertyValue")
|
|
110
143
|
|
|
111
144
|
if db_node_properties:
|
|
112
145
|
duplicate_properties = set(db_node_properties).intersection(set(node_properties))
|
|
@@ -116,16 +149,20 @@ class GraphNodePropertiesRunner(GraphEntityOpsBaseRunner):
|
|
|
116
149
|
)
|
|
117
150
|
|
|
118
151
|
unique_node_ids = result["nodeId"].drop_duplicates().tolist()
|
|
119
|
-
db_properties_df =
|
|
120
|
-
|
|
152
|
+
db_properties_df = query_runner.run_cypher(
|
|
153
|
+
GraphNodePropertiesRunner._build_query(db_node_properties), {"ids": unique_node_ids}
|
|
121
154
|
)
|
|
122
155
|
|
|
123
156
|
if "propertyValue" not in result.keys():
|
|
124
157
|
result = result.join(db_properties_df.set_index("nodeId"), on="nodeId")
|
|
125
158
|
else:
|
|
126
|
-
db_properties_df = db_properties_df.melt(
|
|
127
|
-
|
|
159
|
+
db_properties_df = db_properties_df.melt(
|
|
160
|
+
id_vars=["nodeId"], var_name="nodeProperty", value_name="propertyValue"
|
|
128
161
|
)
|
|
162
|
+
|
|
163
|
+
if "nodeProperty" not in result.keys():
|
|
164
|
+
result["nodeProperty"] = node_properties[0]
|
|
165
|
+
|
|
129
166
|
result = pd.concat([result, db_properties_df])
|
|
130
167
|
|
|
131
168
|
return result
|
|
@@ -140,7 +177,7 @@ class GraphNodePropertiesRunner(GraphEntityOpsBaseRunner):
|
|
|
140
177
|
return reduce(add_property, db_node_properties, query_prefix)
|
|
141
178
|
|
|
142
179
|
@compatible_with("write", min_inclusive=ServerVersion(2, 2, 0))
|
|
143
|
-
def write(self, G: Graph, node_properties:
|
|
180
|
+
def write(self, G: Graph, node_properties: Strings, node_labels: Strings = ["*"], **config: Any) -> "Series[Any]":
|
|
144
181
|
self._namespace += ".write"
|
|
145
182
|
return self._handle_properties(G, node_properties, node_labels, config).squeeze() # type: ignore
|
|
146
183
|
|
|
@@ -5,5 +5,4 @@ from graphdatascience.graph.graph_remote_project_runner import GraphProjectRemot
|
|
|
5
5
|
class GraphRemoteProcRunner(BaseGraphProcRunner):
|
|
6
6
|
@property
|
|
7
7
|
def project(self) -> GraphProjectRemoteRunner:
|
|
8
|
-
self._namespace += ".project.remoteDb"
|
|
9
8
|
return GraphProjectRemoteRunner(self._query_runner, self._namespace, self._server_version)
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import List, Optional
|
|
4
|
+
|
|
5
|
+
from ..error.illegal_attr_checker import IllegalAttrChecker
|
|
6
|
+
from ..query_runner.aura_db_query_runner import AuraDbQueryRunner
|
|
7
|
+
from ..server_version.compatible_with import compatible_with
|
|
8
|
+
from .graph_object import Graph
|
|
9
|
+
from graphdatascience.call_parameters import CallParameters
|
|
10
|
+
from graphdatascience.graph.graph_create_result import GraphCreateResult
|
|
11
|
+
from graphdatascience.server_version.server_version import ServerVersion
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class GraphProjectRemoteRunner(IllegalAttrChecker):
|
|
15
|
+
@compatible_with("project", min_inclusive=ServerVersion(2, 7, 0))
|
|
16
|
+
def __call__(
|
|
17
|
+
self,
|
|
18
|
+
graph_name: str,
|
|
19
|
+
query: str,
|
|
20
|
+
concurrency: int = 4,
|
|
21
|
+
undirected_relationship_types: Optional[List[str]] = None,
|
|
22
|
+
inverse_indexed_relationship_types: Optional[List[str]] = None,
|
|
23
|
+
) -> GraphCreateResult:
|
|
24
|
+
if inverse_indexed_relationship_types is None:
|
|
25
|
+
inverse_indexed_relationship_types = []
|
|
26
|
+
if undirected_relationship_types is None:
|
|
27
|
+
undirected_relationship_types = []
|
|
28
|
+
|
|
29
|
+
params = CallParameters(
|
|
30
|
+
graph_name=graph_name,
|
|
31
|
+
query=query,
|
|
32
|
+
concurrency=concurrency,
|
|
33
|
+
undirected_relationship_types=undirected_relationship_types,
|
|
34
|
+
inverse_indexed_relationship_types=inverse_indexed_relationship_types,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
result = self._query_runner.call_procedure(
|
|
38
|
+
endpoint=AuraDbQueryRunner.GDS_REMOTE_PROJECTION_PROC_NAME,
|
|
39
|
+
params=params,
|
|
40
|
+
).squeeze()
|
|
41
|
+
return GraphCreateResult(Graph(graph_name, self._query_runner, self._server_version), result)
|
|
@@ -23,11 +23,12 @@ class GraphDataScience(DirectEndpoints, UncallableNamespace):
|
|
|
23
23
|
|
|
24
24
|
def __init__(
|
|
25
25
|
self,
|
|
26
|
+
/,
|
|
26
27
|
endpoint: Union[str, Driver, QueryRunner],
|
|
27
28
|
auth: Optional[Tuple[str, str]] = None,
|
|
28
29
|
aura_ds: bool = False,
|
|
29
30
|
database: Optional[str] = None,
|
|
30
|
-
arrow: bool = True,
|
|
31
|
+
arrow: Union[str, bool] = True,
|
|
31
32
|
arrow_disable_server_verification: bool = True,
|
|
32
33
|
arrow_tls_root_certs: Optional[bytes] = None,
|
|
33
34
|
bookmarks: Optional[Any] = None,
|
|
@@ -43,19 +44,20 @@ class GraphDataScience(DirectEndpoints, UncallableNamespace):
|
|
|
43
44
|
A username, password pair for database authentication.
|
|
44
45
|
aura_ds : bool, default False
|
|
45
46
|
A flag that indicates that that the client is used to connect
|
|
46
|
-
to a Neo4j
|
|
47
|
+
to a Neo4j AuraDS instance.
|
|
47
48
|
database: Optional[str], default None
|
|
48
49
|
The Neo4j database to query against.
|
|
49
|
-
arrow : bool, default True
|
|
50
|
-
|
|
51
|
-
|
|
50
|
+
arrow : Union[str, bool], default True
|
|
51
|
+
Arrow connection information. This is either a bool or a string.
|
|
52
|
+
If it is a string, it will be interpreted as a connection URL to a GDS Arrow Server.
|
|
53
|
+
If it is a bool,
|
|
54
|
+
True will make the client discover the connection URI to the GDS Arrow server via the Neo4j endpoint,
|
|
55
|
+
while False will make the client use Bolt for all operations.
|
|
52
56
|
arrow_disable_server_verification : bool, default True
|
|
53
|
-
A flag that
|
|
54
|
-
TLS, that it skips server verification. If this is enabled, all
|
|
55
|
-
other TLS settings are overridden.
|
|
57
|
+
A flag that overrides other TLS settings and disables server verification for TLS connections.
|
|
56
58
|
arrow_tls_root_certs : Optional[bytes], default None
|
|
57
|
-
PEM-encoded certificates that are used for the
|
|
58
|
-
Arrow Flight server.
|
|
59
|
+
PEM-encoded certificates that are used for the connection to the
|
|
60
|
+
GDS Arrow Flight server.
|
|
59
61
|
bookmarks : Optional[Any], default None
|
|
60
62
|
The Neo4j bookmarks to require a certain state before the next query gets executed.
|
|
61
63
|
"""
|
|
@@ -76,6 +78,7 @@ class GraphDataScience(DirectEndpoints, UncallableNamespace):
|
|
|
76
78
|
self._query_runner.encrypted(),
|
|
77
79
|
arrow_disable_server_verification,
|
|
78
80
|
arrow_tls_root_certs,
|
|
81
|
+
None if arrow is True else arrow,
|
|
79
82
|
)
|
|
80
83
|
|
|
81
84
|
super().__init__(self._query_runner, "gds", self._server_version)
|
|
@@ -1,17 +1,17 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import concurrent
|
|
2
|
-
import json
|
|
3
4
|
import math
|
|
4
5
|
import warnings
|
|
5
6
|
from concurrent.futures import ThreadPoolExecutor
|
|
6
|
-
from typing import Any, Dict, List, Optional
|
|
7
|
+
from typing import Any, Dict, List, NoReturn, Optional
|
|
7
8
|
|
|
8
9
|
import numpy
|
|
9
|
-
import pyarrow.flight as flight
|
|
10
10
|
from pandas import DataFrame
|
|
11
11
|
from pyarrow import Table
|
|
12
12
|
from tqdm.auto import tqdm
|
|
13
13
|
|
|
14
|
-
from .
|
|
14
|
+
from .gds_arrow_client import GdsArrowClient
|
|
15
15
|
from .graph_constructor import GraphConstructor
|
|
16
16
|
|
|
17
17
|
|
|
@@ -20,9 +20,8 @@ class ArrowGraphConstructor(GraphConstructor):
|
|
|
20
20
|
self,
|
|
21
21
|
database: str,
|
|
22
22
|
graph_name: str,
|
|
23
|
-
flight_client:
|
|
23
|
+
flight_client: GdsArrowClient,
|
|
24
24
|
concurrency: int,
|
|
25
|
-
arrow_endpoint_version: ArrowEndpointVersion,
|
|
26
25
|
undirected_relationship_types: Optional[List[str]],
|
|
27
26
|
chunk_size: int = 10_000,
|
|
28
27
|
):
|
|
@@ -30,7 +29,6 @@ class ArrowGraphConstructor(GraphConstructor):
|
|
|
30
29
|
self._concurrency = concurrency
|
|
31
30
|
self._graph_name = graph_name
|
|
32
31
|
self._client = flight_client
|
|
33
|
-
self._arrow_endpoint_version = arrow_endpoint_version
|
|
34
32
|
self._undirected_relationship_types = (
|
|
35
33
|
[] if undirected_relationship_types is None else undirected_relationship_types
|
|
36
34
|
)
|
|
@@ -47,20 +45,20 @@ class ArrowGraphConstructor(GraphConstructor):
|
|
|
47
45
|
if self._undirected_relationship_types:
|
|
48
46
|
config["undirected_relationship_types"] = self._undirected_relationship_types
|
|
49
47
|
|
|
50
|
-
self.
|
|
48
|
+
self._client.send_action(
|
|
51
49
|
"CREATE_GRAPH",
|
|
52
50
|
config,
|
|
53
51
|
)
|
|
54
52
|
|
|
55
53
|
self._send_dfs(node_dfs, "node")
|
|
56
54
|
|
|
57
|
-
self.
|
|
55
|
+
self._client.send_action("NODE_LOAD_DONE", {"name": self._graph_name})
|
|
58
56
|
|
|
59
57
|
self._send_dfs(relationship_dfs, "relationship")
|
|
60
58
|
|
|
61
|
-
self.
|
|
59
|
+
self._client.send_action("RELATIONSHIP_LOAD_DONE", {"name": self._graph_name})
|
|
62
60
|
except (Exception, KeyboardInterrupt) as e:
|
|
63
|
-
self.
|
|
61
|
+
self._client.send_action("ABORT", {"name": self._graph_name})
|
|
64
62
|
|
|
65
63
|
raise e
|
|
66
64
|
|
|
@@ -83,31 +81,20 @@ class ArrowGraphConstructor(GraphConstructor):
|
|
|
83
81
|
|
|
84
82
|
return partitioned_dfs
|
|
85
83
|
|
|
86
|
-
def
|
|
87
|
-
action_type = self._versioned_action_type(action_type)
|
|
88
|
-
result = self._client.do_action(flight.Action(action_type, json.dumps(meta_data).encode("utf-8")))
|
|
89
|
-
|
|
90
|
-
# Consume result fully to sanity check and avoid cancelled streams
|
|
91
|
-
collected_result = list(result)
|
|
92
|
-
assert len(collected_result) == 1
|
|
93
|
-
|
|
94
|
-
json.loads(collected_result[0].body.to_pybytes().decode())
|
|
95
|
-
|
|
96
|
-
def _send_df(self, df: DataFrame, entity_type: str, pbar: tqdm) -> None:
|
|
84
|
+
def _send_df(self, df: DataFrame, entity_type: str, pbar: tqdm[NoReturn]) -> None:
|
|
97
85
|
table = Table.from_pandas(df)
|
|
98
86
|
batches = table.to_batches(self._chunk_size)
|
|
99
87
|
flight_descriptor = {"name": self._graph_name, "entity_type": entity_type}
|
|
100
|
-
flight_descriptor = self._versioned_flight_desriptor(flight_descriptor)
|
|
101
88
|
|
|
102
|
-
|
|
103
|
-
upload_descriptor = flight.FlightDescriptor.for_command(json.dumps(flight_descriptor).encode("utf-8"))
|
|
104
|
-
writer, _ = self._client.do_put(upload_descriptor, table.schema)
|
|
89
|
+
writer, _ = self._client.start_put(flight_descriptor, table.schema)
|
|
105
90
|
|
|
106
91
|
with writer:
|
|
107
92
|
# Write table in chunks
|
|
108
93
|
for partition in batches:
|
|
109
94
|
writer.write_batch(partition)
|
|
110
95
|
pbar.update(partition.num_rows)
|
|
96
|
+
# Force a refresh to avoid the progress bar getting stuck at 0%
|
|
97
|
+
pbar.refresh()
|
|
111
98
|
|
|
112
99
|
def _send_dfs(self, dfs: List[DataFrame], entity_type: str) -> None:
|
|
113
100
|
desc = "Uploading Nodes" if entity_type == "node" else "Uploading Relationships"
|
|
@@ -122,17 +109,3 @@ class ArrowGraphConstructor(GraphConstructor):
|
|
|
122
109
|
if not future.exception():
|
|
123
110
|
continue
|
|
124
111
|
raise future.exception() # type: ignore
|
|
125
|
-
|
|
126
|
-
def _versioned_action_type(self, action_type: str) -> str:
|
|
127
|
-
return self._arrow_endpoint_version.prefix() + action_type
|
|
128
|
-
|
|
129
|
-
def _versioned_flight_desriptor(self, flight_descriptor: Dict[str, Any]) -> Dict[str, Any]:
|
|
130
|
-
return (
|
|
131
|
-
flight_descriptor
|
|
132
|
-
if self._arrow_endpoint_version == ArrowEndpointVersion.ALPHA
|
|
133
|
-
else {
|
|
134
|
-
"name": "PUT_MESSAGE",
|
|
135
|
-
"version": ArrowEndpointVersion.V1.version(),
|
|
136
|
-
"body": flight_descriptor,
|
|
137
|
-
}
|
|
138
|
-
)
|
|
@@ -1,21 +1,14 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
import base64
|
|
4
|
-
import json
|
|
5
|
-
import time
|
|
6
3
|
import warnings
|
|
7
4
|
from typing import Any, Dict, List, Optional, Tuple
|
|
8
5
|
|
|
9
|
-
import pyarrow.flight as flight
|
|
10
6
|
from pandas import DataFrame
|
|
11
|
-
from pyarrow import ChunkedArray, Table, chunked_array
|
|
12
|
-
from pyarrow.flight import ClientMiddleware, ClientMiddlewareFactory
|
|
13
|
-
from pyarrow.types import is_dictionary # type: ignore
|
|
14
7
|
|
|
15
8
|
from ..call_parameters import CallParameters
|
|
16
9
|
from ..server_version.server_version import ServerVersion
|
|
17
|
-
from .arrow_endpoint_version import ArrowEndpointVersion
|
|
18
10
|
from .arrow_graph_constructor import ArrowGraphConstructor
|
|
11
|
+
from .gds_arrow_client import GdsArrowClient
|
|
19
12
|
from .graph_constructor import GraphConstructor
|
|
20
13
|
from .query_runner import QueryRunner
|
|
21
14
|
from graphdatascience.server_version.compatible_with import (
|
|
@@ -31,58 +24,31 @@ class ArrowQueryRunner(QueryRunner):
|
|
|
31
24
|
encrypted: bool = False,
|
|
32
25
|
disable_server_verification: bool = False,
|
|
33
26
|
tls_root_certs: Optional[bytes] = None,
|
|
27
|
+
connection_string_override: Optional[str] = None,
|
|
34
28
|
) -> QueryRunner:
|
|
35
|
-
|
|
36
|
-
fallback_query_runner.call_procedure(endpoint="gds.debug.arrow", custom_error=False).squeeze().to_dict()
|
|
37
|
-
)
|
|
38
|
-
server_version = fallback_query_runner.server_version()
|
|
39
|
-
listen_address: str = arrow_info.get("advertisedListenAddress", arrow_info["listenAddress"])
|
|
40
|
-
arrow_endpoint_version = ArrowEndpointVersion.from_arrow_info(arrow_info.get("versions", []))
|
|
41
|
-
|
|
42
|
-
if arrow_info["running"]:
|
|
43
|
-
return ArrowQueryRunner(
|
|
44
|
-
listen_address,
|
|
45
|
-
fallback_query_runner,
|
|
46
|
-
server_version,
|
|
47
|
-
auth,
|
|
48
|
-
encrypted,
|
|
49
|
-
disable_server_verification,
|
|
50
|
-
tls_root_certs,
|
|
51
|
-
arrow_endpoint_version,
|
|
52
|
-
)
|
|
53
|
-
else:
|
|
29
|
+
if not GdsArrowClient.is_arrow_enabled(fallback_query_runner):
|
|
54
30
|
return fallback_query_runner
|
|
55
31
|
|
|
32
|
+
gds_arrow_client = GdsArrowClient.create(
|
|
33
|
+
fallback_query_runner,
|
|
34
|
+
auth,
|
|
35
|
+
encrypted,
|
|
36
|
+
disable_server_verification,
|
|
37
|
+
tls_root_certs,
|
|
38
|
+
connection_string_override,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
return ArrowQueryRunner(gds_arrow_client, fallback_query_runner, fallback_query_runner.server_version())
|
|
42
|
+
|
|
56
43
|
def __init__(
|
|
57
44
|
self,
|
|
58
|
-
|
|
45
|
+
gds_arrow_client: GdsArrowClient,
|
|
59
46
|
fallback_query_runner: QueryRunner,
|
|
60
47
|
server_version: ServerVersion,
|
|
61
|
-
auth: Optional[Tuple[str, str]] = None,
|
|
62
|
-
encrypted: bool = False,
|
|
63
|
-
disable_server_verification: bool = False,
|
|
64
|
-
tls_root_certs: Optional[bytes] = None,
|
|
65
|
-
arrow_endpoint_version: ArrowEndpointVersion = ArrowEndpointVersion.ALPHA,
|
|
66
48
|
):
|
|
67
49
|
self._fallback_query_runner = fallback_query_runner
|
|
50
|
+
self._gds_arrow_client = gds_arrow_client
|
|
68
51
|
self._server_version = server_version
|
|
69
|
-
self._arrow_endpoint_version = arrow_endpoint_version
|
|
70
|
-
|
|
71
|
-
host, port_string = uri.split(":")
|
|
72
|
-
|
|
73
|
-
location = (
|
|
74
|
-
flight.Location.for_grpc_tls(host, int(port_string))
|
|
75
|
-
if encrypted
|
|
76
|
-
else flight.Location.for_grpc_tcp(host, int(port_string))
|
|
77
|
-
)
|
|
78
|
-
|
|
79
|
-
client_options: Dict[str, Any] = {"disable_server_verification": disable_server_verification}
|
|
80
|
-
if auth:
|
|
81
|
-
client_options["middleware"] = [AuthFactory(auth)]
|
|
82
|
-
if tls_root_certs:
|
|
83
|
-
client_options["tls_root_certs"] = tls_root_certs
|
|
84
|
-
|
|
85
|
-
self._flight_client = flight.FlightClient(location, **client_options)
|
|
86
52
|
|
|
87
53
|
def warn_about_deprecation(self, old_endpoint: str, new_endpoint: str) -> None:
|
|
88
54
|
warnings.warn(
|
|
@@ -135,7 +101,7 @@ class ArrowQueryRunner(QueryRunner):
|
|
|
135
101
|
old_endpoint="gds.graph.streamNodeProperty", new_endpoint="gds.graph.nodeProperty.stream"
|
|
136
102
|
)
|
|
137
103
|
|
|
138
|
-
return self.
|
|
104
|
+
return self._gds_arrow_client.get_property(self.database(), graph_name, endpoint, config)
|
|
139
105
|
elif (
|
|
140
106
|
old_endpoint := ("gds.graph.streamNodeProperties" == endpoint)
|
|
141
107
|
) or "gds.graph.nodeProperties.stream" == endpoint:
|
|
@@ -154,7 +120,8 @@ class ArrowQueryRunner(QueryRunner):
|
|
|
154
120
|
self.warn_about_deprecation(
|
|
155
121
|
old_endpoint="gds.graph.streamNodeProperties", new_endpoint="gds.graph.nodeProperties.stream"
|
|
156
122
|
)
|
|
157
|
-
return self.
|
|
123
|
+
return self._gds_arrow_client.get_property(
|
|
124
|
+
self.database(),
|
|
158
125
|
graph_name,
|
|
159
126
|
endpoint,
|
|
160
127
|
config,
|
|
@@ -175,7 +142,8 @@ class ArrowQueryRunner(QueryRunner):
|
|
|
175
142
|
old_endpoint="gds.graph.streamRelationshipProperty",
|
|
176
143
|
new_endpoint="gds.graph.relationshipProperty.stream",
|
|
177
144
|
)
|
|
178
|
-
return self.
|
|
145
|
+
return self._gds_arrow_client.get_property(
|
|
146
|
+
self.database(),
|
|
179
147
|
graph_name,
|
|
180
148
|
endpoint,
|
|
181
149
|
{"relationship_property": property_name, "relationship_types": relationship_types},
|
|
@@ -197,7 +165,8 @@ class ArrowQueryRunner(QueryRunner):
|
|
|
197
165
|
new_endpoint="gds.graph.relationshipProperties.stream",
|
|
198
166
|
)
|
|
199
167
|
|
|
200
|
-
return self.
|
|
168
|
+
return self._gds_arrow_client.get_property(
|
|
169
|
+
self.database(),
|
|
201
170
|
graph_name,
|
|
202
171
|
endpoint,
|
|
203
172
|
{"relationship_properties": property_names, "relationship_types": relationship_types},
|
|
@@ -224,7 +193,9 @@ class ArrowQueryRunner(QueryRunner):
|
|
|
224
193
|
new_endpoint="gds.graph.relationships.stream",
|
|
225
194
|
)
|
|
226
195
|
|
|
227
|
-
return self.
|
|
196
|
+
return self._gds_arrow_client.get_property(
|
|
197
|
+
self.database(), graph_name, endpoint, {"relationship_types": relationship_types}
|
|
198
|
+
)
|
|
228
199
|
|
|
229
200
|
return self._fallback_query_runner.call_procedure(endpoint, params, yields, database, logging, custom_error)
|
|
230
201
|
|
|
@@ -254,52 +225,11 @@ class ArrowQueryRunner(QueryRunner):
|
|
|
254
225
|
|
|
255
226
|
def close(self) -> None:
|
|
256
227
|
self._fallback_query_runner.close()
|
|
257
|
-
|
|
258
|
-
if hasattr(self._flight_client, "close"):
|
|
259
|
-
self._flight_client.close()
|
|
228
|
+
self._gds_arrow_client.close()
|
|
260
229
|
|
|
261
230
|
def fallback_query_runner(self) -> QueryRunner:
|
|
262
231
|
return self._fallback_query_runner
|
|
263
232
|
|
|
264
|
-
def _run_arrow_property_get(self, graph_name: str, procedure_name: str, configuration: Dict[str, Any]) -> DataFrame:
|
|
265
|
-
if not self.database():
|
|
266
|
-
raise ValueError(
|
|
267
|
-
"For this call you must have explicitly specified a valid Neo4j database to execute on, "
|
|
268
|
-
"using `GraphDataScience.set_database`."
|
|
269
|
-
)
|
|
270
|
-
|
|
271
|
-
payload = {
|
|
272
|
-
"database_name": self.database(),
|
|
273
|
-
"graph_name": graph_name,
|
|
274
|
-
"procedure_name": procedure_name,
|
|
275
|
-
"configuration": configuration,
|
|
276
|
-
}
|
|
277
|
-
|
|
278
|
-
if self._arrow_endpoint_version == ArrowEndpointVersion.V1:
|
|
279
|
-
payload = {
|
|
280
|
-
"name": "GET_MESSAGE",
|
|
281
|
-
"version": ArrowEndpointVersion.V1.version(),
|
|
282
|
-
"body": payload,
|
|
283
|
-
}
|
|
284
|
-
|
|
285
|
-
ticket = flight.Ticket(json.dumps(payload).encode("utf-8"))
|
|
286
|
-
get = self._flight_client.do_get(ticket)
|
|
287
|
-
arrow_table = get.read_all()
|
|
288
|
-
|
|
289
|
-
if configuration.get("list_node_labels", False):
|
|
290
|
-
# GDS 2.5 had an inconsistent naming of the node labels column
|
|
291
|
-
new_colum_names = ["nodeLabels" if i == "labels" else i for i in arrow_table.column_names]
|
|
292
|
-
arrow_table = arrow_table.rename_columns(new_colum_names)
|
|
293
|
-
|
|
294
|
-
# Pandas 2.2.0 deprecated an API used by ArrowTable.to_pandas() (< pyarrow 15.0)
|
|
295
|
-
warnings.filterwarnings(
|
|
296
|
-
"ignore",
|
|
297
|
-
category=DeprecationWarning,
|
|
298
|
-
message=r"Passing a BlockManager to DataFrame is deprecated",
|
|
299
|
-
)
|
|
300
|
-
|
|
301
|
-
return self._sanitize_arrow_table(arrow_table).to_pandas() # type: ignore
|
|
302
|
-
|
|
303
233
|
def create_graph_constructor(
|
|
304
234
|
self, graph_name: str, concurrency: int, undirected_relationship_types: Optional[List[str]]
|
|
305
235
|
) -> GraphConstructor:
|
|
@@ -313,75 +243,7 @@ class ArrowQueryRunner(QueryRunner):
|
|
|
313
243
|
return ArrowGraphConstructor(
|
|
314
244
|
database,
|
|
315
245
|
graph_name,
|
|
316
|
-
self.
|
|
246
|
+
self._gds_arrow_client,
|
|
317
247
|
concurrency,
|
|
318
|
-
self._arrow_endpoint_version,
|
|
319
248
|
undirected_relationship_types,
|
|
320
249
|
)
|
|
321
|
-
|
|
322
|
-
def _sanitize_arrow_table(self, arrow_table: Table) -> Table:
|
|
323
|
-
dict_encoded_fields = [
|
|
324
|
-
(idx, field) for idx, field in enumerate(arrow_table.schema) if is_dictionary(field.type)
|
|
325
|
-
]
|
|
326
|
-
for idx, field in dict_encoded_fields:
|
|
327
|
-
try:
|
|
328
|
-
field.type.to_pandas_dtype()
|
|
329
|
-
except NotImplementedError:
|
|
330
|
-
# we need to decode the dictionary column before transforming to pandas
|
|
331
|
-
if isinstance(arrow_table[field.name], ChunkedArray):
|
|
332
|
-
decoded_col = chunked_array([chunk.dictionary_decode() for chunk in arrow_table[field.name].chunks])
|
|
333
|
-
else:
|
|
334
|
-
decoded_col = arrow_table[field.name].dictionary_decode()
|
|
335
|
-
arrow_table = arrow_table.set_column(idx, field.name, decoded_col)
|
|
336
|
-
return arrow_table
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
class AuthFactory(ClientMiddlewareFactory): # type: ignore
|
|
340
|
-
def __init__(self, auth: Tuple[str, str], *args: Any, **kwargs: Any) -> None:
|
|
341
|
-
super().__init__(*args, **kwargs)
|
|
342
|
-
self._auth = auth
|
|
343
|
-
self._token: Optional[str] = None
|
|
344
|
-
self._token_timestamp = 0
|
|
345
|
-
|
|
346
|
-
def start_call(self, info: Any) -> "AuthMiddleware":
|
|
347
|
-
return AuthMiddleware(self)
|
|
348
|
-
|
|
349
|
-
def token(self) -> Optional[str]:
|
|
350
|
-
# check whether the token is older than 10 minutes. If so, reset it.
|
|
351
|
-
if self._token and int(time.time()) - self._token_timestamp > 600:
|
|
352
|
-
self._token = None
|
|
353
|
-
|
|
354
|
-
return self._token
|
|
355
|
-
|
|
356
|
-
def set_token(self, token: str) -> None:
|
|
357
|
-
self._token = token
|
|
358
|
-
self._token_timestamp = int(time.time())
|
|
359
|
-
|
|
360
|
-
@property
|
|
361
|
-
def auth(self) -> Tuple[str, str]:
|
|
362
|
-
return self._auth
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
class AuthMiddleware(ClientMiddleware): # type: ignore
|
|
366
|
-
def __init__(self, factory: AuthFactory, *args: Any, **kwargs: Any) -> None:
|
|
367
|
-
super().__init__(*args, **kwargs)
|
|
368
|
-
self._factory = factory
|
|
369
|
-
|
|
370
|
-
def received_headers(self, headers: Dict[str, Any]) -> None:
|
|
371
|
-
auth_header: str = headers.get("Authorization", None)
|
|
372
|
-
if not auth_header:
|
|
373
|
-
return
|
|
374
|
-
[auth_type, token] = auth_header.split(" ", 1)
|
|
375
|
-
if auth_type == "Bearer":
|
|
376
|
-
self._factory.set_token(token)
|
|
377
|
-
|
|
378
|
-
def sending_headers(self) -> Dict[str, str]:
|
|
379
|
-
token = self._factory.token()
|
|
380
|
-
if not token:
|
|
381
|
-
username, password = self._factory.auth
|
|
382
|
-
auth_token = f"{username}:{password}"
|
|
383
|
-
auth_token = "Basic " + base64.b64encode(auth_token.encode("utf-8")).decode("ASCII")
|
|
384
|
-
# There seems to be a bug, `authorization` must be lower key
|
|
385
|
-
return {"authorization": auth_token}
|
|
386
|
-
else:
|
|
387
|
-
return {"authorization": "Bearer " + token}
|