graphdatascience 1.9__tar.gz → 1.10__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {graphdatascience-1.9/graphdatascience.egg-info → graphdatascience-1.10}/PKG-INFO +3 -2
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/__init__.py +1 -1
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/error/endpoint_suggester.py +1 -1
- graphdatascience-1.9/graphdatascience/graph/graph_proc_runner.py → graphdatascience-1.10/graphdatascience/graph/base_graph_proc_runner.py +13 -23
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/graph/graph_entity_ops_runner.py +52 -7
- graphdatascience-1.10/graphdatascience/graph/graph_proc_runner.py +15 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/graph/graph_project_runner.py +0 -29
- graphdatascience-1.10/graphdatascience/graph/graph_remote_proc_runner.py +9 -0
- graphdatascience-1.10/graphdatascience/graph/graph_remote_project_runner.py +38 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/graph/ogb_loader.py +2 -2
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/graph_data_science.py +7 -4
- graphdatascience-1.10/graphdatascience/query_runner/arrow_endpoint_version.py +35 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/query_runner/arrow_graph_constructor.py +25 -2
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/query_runner/arrow_query_runner.py +44 -14
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/query_runner/aura_db_arrow_query_runner.py +1 -1
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/query_runner/neo4j_query_runner.py +36 -15
- graphdatascience-1.10/graphdatascience/session/__init__.py +13 -0
- {graphdatascience-1.9/graphdatascience/gds_session → graphdatascience-1.10/graphdatascience/session}/aura_api.py +77 -13
- {graphdatascience-1.9/graphdatascience/gds_session → graphdatascience-1.10/graphdatascience/session}/aura_graph_data_science.py +2 -2
- {graphdatascience-1.9/graphdatascience/gds_session → graphdatascience-1.10/graphdatascience/session}/dbms_connection_info.py +10 -0
- {graphdatascience-1.9/graphdatascience/gds_session → graphdatascience-1.10/graphdatascience/session}/gds_sessions.py +82 -22
- graphdatascience-1.10/graphdatascience/session/region_suggester.py +17 -0
- {graphdatascience-1.9/graphdatascience/gds_session → graphdatascience-1.10/graphdatascience/session}/schema.py +4 -3
- {graphdatascience-1.9/graphdatascience/gds_session → graphdatascience-1.10/graphdatascience/session}/session_sizes.py +10 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/system/system_endpoints.py +2 -2
- graphdatascience-1.10/graphdatascience/version.py +1 -0
- {graphdatascience-1.9 → graphdatascience-1.10/graphdatascience.egg-info}/PKG-INFO +3 -2
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience.egg-info/SOURCES.txt +12 -7
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience.egg-info/requires.txt +1 -1
- {graphdatascience-1.9 → graphdatascience-1.10}/requirements/base/base.txt +1 -1
- {graphdatascience-1.9 → graphdatascience-1.10}/setup.py +1 -0
- graphdatascience-1.9/graphdatascience/utils/__init__.py +0 -0
- graphdatascience-1.9/graphdatascience/version.py +0 -1
- {graphdatascience-1.9 → graphdatascience-1.10}/LICENSE +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/MANIFEST.in +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/README.md +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/algo/__init__.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/algo/algo_endpoints.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/algo/algo_proc_runner.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/algo/single_mode_algo_endpoints.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/call_builder.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/call_parameters.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/caller_base.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/endpoints.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/error/__init__.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/error/client_only_endpoint.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/error/cypher_warning_handler.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/error/gds_not_installed.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/error/illegal_attr_checker.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/error/unable_to_connect.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/error/uncallable_namespace.py +0 -0
- {graphdatascience-1.9/graphdatascience/gds_session → graphdatascience-1.10/graphdatascience/graph}/__init__.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/graph/graph_alpha_proc_runner.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/graph/graph_beta_proc_runner.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/graph/graph_create_result.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/graph/graph_cypher_runner.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/graph/graph_endpoints.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/graph/graph_export_runner.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/graph/graph_object.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/graph/graph_sample_runner.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/graph/graph_type_check.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/graph/nx_loader.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/ignored_server_endpoints.py +0 -0
- {graphdatascience-1.9/graphdatascience/graph → graphdatascience-1.10/graphdatascience/model}/__init__.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/model/graphsage_model.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/model/link_prediction_model.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/model/model.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/model/model_alpha_proc_runner.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/model/model_beta_proc_runner.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/model/model_endpoints.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/model/model_proc_runner.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/model/model_resolver.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/model/node_classification_model.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/model/node_regression_model.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/model/pipeline_model.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/model/simple_rel_embedding_model.py +0 -0
- {graphdatascience-1.9/graphdatascience/model → graphdatascience-1.10/graphdatascience/pipeline}/__init__.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/pipeline/classification_training_pipeline.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/pipeline/lp_pipeline_create_runner.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/pipeline/lp_training_pipeline.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/pipeline/nc_pipeline_create_runner.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/pipeline/nc_training_pipeline.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/pipeline/nr_pipeline_create_runner.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/pipeline/nr_training_pipeline.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/pipeline/pipeline_alpha_proc_runner.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/pipeline/pipeline_beta_proc_runner.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/pipeline/pipeline_endpoints.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/pipeline/pipeline_proc_runner.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/pipeline/training_pipeline.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/py.typed +0 -0
- {graphdatascience-1.9/graphdatascience/pipeline → graphdatascience-1.10/graphdatascience/query_runner}/__init__.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/query_runner/cypher_graph_constructor.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/query_runner/graph_constructor.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/query_runner/query_runner.py +0 -0
- {graphdatascience-1.9/graphdatascience/query_runner → graphdatascience-1.10/graphdatascience/resources}/__init__.py +0 -0
- {graphdatascience-1.9/graphdatascience/resources → graphdatascience-1.10/graphdatascience/resources/cora}/__init__.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/resources/cora/cora_nodes.parquet.gzip +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/resources/cora/cora_rels.parquet.gzip +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/resources/cora/serialize_cora.py +0 -0
- {graphdatascience-1.9/graphdatascience/resources/cora → graphdatascience-1.10/graphdatascience/resources/imdb}/__init__.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/resources/imdb/imdb_acted_in.parquet.gzip +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/resources/imdb/imdb_actors.parquet.gzip +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/resources/imdb/imdb_directed_in.parquet.gzip +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/resources/imdb/imdb_directors.parquet.gzip +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/resources/imdb/imdb_movies_with_genre.parquet.gzip +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/resources/imdb/imdb_movies_without_genre.parquet.gzip +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/resources/imdb/serialize_imdb.py +0 -0
- {graphdatascience-1.9/graphdatascience/resources/imdb → graphdatascience-1.10/graphdatascience/resources/karate}/__init__.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/resources/karate/karate_club.parquet.gzip +0 -0
- {graphdatascience-1.9/graphdatascience/resources/karate → graphdatascience-1.10/graphdatascience/resources/lastfm}/__init__.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/resources/lastfm/artist_nodes.parquet.gzip +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/resources/lastfm/serialize_lastfm.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/resources/lastfm/user_friend_df_directed.parquet.gzip +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/resources/lastfm/user_listen_artist_rels.parquet.gzip +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/resources/lastfm/user_nodes.parquet.gzip +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/resources/lastfm/user_tag_artist_rels.parquet.gzip +0 -0
- {graphdatascience-1.9/graphdatascience/resources/lastfm → graphdatascience-1.10/graphdatascience/server_version}/__init__.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/server_version/compatible_with.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/server_version/server_version.py +0 -0
- {graphdatascience-1.9/graphdatascience/server_version → graphdatascience-1.10/graphdatascience/system}/__init__.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/system/config_endpoints.py +0 -0
- {graphdatascience-1.9/graphdatascience/system → graphdatascience-1.10/graphdatascience/topological_lp}/__init__.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/topological_lp/topological_lp_alpha_runner.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/topological_lp/topological_lp_endpoints.py +0 -0
- {graphdatascience-1.9/graphdatascience/topological_lp → graphdatascience-1.10/graphdatascience/utils}/__init__.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/utils/util_endpoints.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/utils/util_proc_runner.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience.egg-info/dependency_links.txt +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience.egg-info/not-zip-safe +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience.egg-info/top_level.txt +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/pyproject.toml +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/requirements/base/networkx.txt +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/requirements/base/ogb.txt +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: graphdatascience
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.10
|
|
4
4
|
Summary: A Python client for the Neo4j Graph Data Science (GDS) library
|
|
5
5
|
Home-page: https://neo4j.com/product/graph-data-science/
|
|
6
6
|
Author: Neo4j
|
|
@@ -20,6 +20,7 @@ Classifier: Programming Language :: Python :: 3.8
|
|
|
20
20
|
Classifier: Programming Language :: Python :: 3.9
|
|
21
21
|
Classifier: Programming Language :: Python :: 3.10
|
|
22
22
|
Classifier: Programming Language :: Python :: 3.11
|
|
23
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
23
24
|
Classifier: Topic :: Database
|
|
24
25
|
Classifier: Topic :: Scientific/Engineering
|
|
25
26
|
Classifier: Topic :: Software Development
|
|
@@ -30,7 +31,7 @@ License-File: LICENSE
|
|
|
30
31
|
Requires-Dist: multimethod<2.0,>=1.0
|
|
31
32
|
Requires-Dist: neo4j<6.0,>=4.4.2
|
|
32
33
|
Requires-Dist: pandas<3.0,>=1.0
|
|
33
|
-
Requires-Dist: pyarrow<
|
|
34
|
+
Requires-Dist: pyarrow<16.0,>=11.0
|
|
34
35
|
Requires-Dist: textdistance<5.0,>=4.0
|
|
35
36
|
Requires-Dist: tqdm<5.0,>=4.0
|
|
36
37
|
Requires-Dist: typing-extensions<5.0,>=4.0
|
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
from .gds_session.gds_sessions import GdsSessions
|
|
2
1
|
from .graph.graph_create_result import GraphCreateResult
|
|
3
2
|
from .graph.graph_object import Graph
|
|
4
3
|
from .graph_data_science import GraphDataScience
|
|
@@ -13,6 +12,7 @@ from .pipeline.nc_training_pipeline import NCTrainingPipeline
|
|
|
13
12
|
from .pipeline.nr_training_pipeline import NRTrainingPipeline
|
|
14
13
|
from .query_runner.query_runner import QueryRunner
|
|
15
14
|
from .server_version.server_version import ServerVersion
|
|
15
|
+
from .session.gds_sessions import GdsSessions
|
|
16
16
|
from .version import __version__
|
|
17
17
|
|
|
18
18
|
__all__ = [
|
|
@@ -9,7 +9,7 @@ def generate_suggestive_error_message(requested_endpoint: str, all_endpoints: Li
|
|
|
9
9
|
MIN_SIMILARITY_FOR_SUGGESTION = 0.9
|
|
10
10
|
|
|
11
11
|
closest_endpoint = None
|
|
12
|
-
curr_max_similarity = 0
|
|
12
|
+
curr_max_similarity = 0.0
|
|
13
13
|
for ep in all_endpoints:
|
|
14
14
|
similarity = textdistance.jaro_winkler(requested_endpoint, ep)
|
|
15
15
|
if similarity >= MIN_SIMILARITY_FOR_SUGGESTION:
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import pathlib
|
|
3
3
|
import sys
|
|
4
|
+
import warnings
|
|
4
5
|
from typing import Any, Dict, List, Optional, Union
|
|
5
6
|
|
|
6
7
|
import pandas as pd
|
|
@@ -17,6 +18,7 @@ from .graph_entity_ops_runner import (
|
|
|
17
18
|
GraphElementPropertyRunner,
|
|
18
19
|
GraphLabelRunner,
|
|
19
20
|
GraphNodePropertiesRunner,
|
|
21
|
+
GraphNodePropertyRunner,
|
|
20
22
|
GraphPropertyRunner,
|
|
21
23
|
GraphRelationshipPropertiesRunner,
|
|
22
24
|
GraphRelationshipRunner,
|
|
@@ -24,7 +26,6 @@ from .graph_entity_ops_runner import (
|
|
|
24
26
|
)
|
|
25
27
|
from .graph_export_runner import GraphExportRunner
|
|
26
28
|
from .graph_object import Graph
|
|
27
|
-
from .graph_project_runner import GraphProjectRemoteRunner, GraphProjectRunner
|
|
28
29
|
from .graph_sample_runner import GraphSampleRunner
|
|
29
30
|
from .graph_type_check import (
|
|
30
31
|
from_graph_type_check,
|
|
@@ -34,7 +35,6 @@ from .graph_type_check import (
|
|
|
34
35
|
from .ogb_loader import OGBLLoader, OGBNLoader
|
|
35
36
|
from graphdatascience.call_parameters import CallParameters
|
|
36
37
|
from graphdatascience.graph.graph_create_result import GraphCreateResult
|
|
37
|
-
from graphdatascience.graph.graph_cypher_runner import GraphCypherRunner
|
|
38
38
|
|
|
39
39
|
Strings = Union[str, List[str]]
|
|
40
40
|
|
|
@@ -42,6 +42,15 @@ is_neo4j_4_driver = ServerVersion.from_string(neo4j_driver_version) < ServerVers
|
|
|
42
42
|
|
|
43
43
|
|
|
44
44
|
class BaseGraphProcRunner(UncallableNamespace, IllegalAttrChecker):
|
|
45
|
+
def __init__(self, query_runner: Any, namespace: str, server_version: ServerVersion):
|
|
46
|
+
super().__init__(query_runner, namespace, server_version)
|
|
47
|
+
# Pandas 2.2.0 deprecated an API used by ArrowTable.to_pandas() (< pyarrow 14.0)
|
|
48
|
+
warnings.filterwarnings(
|
|
49
|
+
"ignore",
|
|
50
|
+
category=DeprecationWarning,
|
|
51
|
+
message=r"Passing a BlockManager to DataFrame is deprecated",
|
|
52
|
+
)
|
|
53
|
+
|
|
45
54
|
@staticmethod
|
|
46
55
|
def _path(package: str, resource: str) -> pathlib.Path:
|
|
47
56
|
if sys.version_info >= (3, 9):
|
|
@@ -371,9 +380,9 @@ class BaseGraphProcRunner(UncallableNamespace, IllegalAttrChecker):
|
|
|
371
380
|
)
|
|
372
381
|
|
|
373
382
|
@property
|
|
374
|
-
def nodeProperty(self) ->
|
|
383
|
+
def nodeProperty(self) -> GraphNodePropertyRunner:
|
|
375
384
|
self._namespace += ".nodeProperty"
|
|
376
|
-
return
|
|
385
|
+
return GraphNodePropertyRunner(self._query_runner, self._namespace, self._server_version)
|
|
377
386
|
|
|
378
387
|
@property
|
|
379
388
|
def nodeProperties(self) -> GraphNodePropertiesRunner:
|
|
@@ -558,22 +567,3 @@ class BaseGraphProcRunner(UncallableNamespace, IllegalAttrChecker):
|
|
|
558
567
|
endpoint=self._namespace,
|
|
559
568
|
params=params,
|
|
560
569
|
).squeeze()
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
class GraphProcRunner(BaseGraphProcRunner):
|
|
564
|
-
@property
|
|
565
|
-
def project(self) -> GraphProjectRunner:
|
|
566
|
-
self._namespace += ".project"
|
|
567
|
-
return GraphProjectRunner(self._query_runner, self._namespace, self._server_version)
|
|
568
|
-
|
|
569
|
-
@property
|
|
570
|
-
def cypher(self) -> GraphCypherRunner:
|
|
571
|
-
self._namespace += ".project"
|
|
572
|
-
return GraphCypherRunner(self._query_runner, self._namespace, self._server_version)
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
class GraphRemoteProcRunner(BaseGraphProcRunner):
|
|
576
|
-
@property
|
|
577
|
-
def project(self) -> GraphProjectRemoteRunner:
|
|
578
|
-
self._namespace += ".project.remoteDb"
|
|
579
|
-
return GraphProjectRemoteRunner(self._query_runner, self._namespace, self._server_version)
|
{graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/graph/graph_entity_ops_runner.py
RENAMED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from functools import reduce
|
|
2
2
|
from typing import Any, Dict, List, Type, Union
|
|
3
|
+
from warnings import filterwarnings
|
|
3
4
|
|
|
4
5
|
import pandas as pd
|
|
5
6
|
from pandas import DataFrame, Series
|
|
@@ -26,6 +27,13 @@ class TopologyDataFrame(DataFrame):
|
|
|
26
27
|
return TopologyDataFrame
|
|
27
28
|
|
|
28
29
|
def by_rel_type(self) -> Dict[str, List[List[int]]]:
|
|
30
|
+
# Pandas 2.2.0 deprecated an internal API used by DF.take(indices)
|
|
31
|
+
filterwarnings(
|
|
32
|
+
"ignore",
|
|
33
|
+
category=DeprecationWarning,
|
|
34
|
+
message=r"Passing a BlockManager to TopologyDataFrame is deprecated",
|
|
35
|
+
)
|
|
36
|
+
|
|
29
37
|
gb = self.groupby("relationshipType", observed=True)
|
|
30
38
|
|
|
31
39
|
output = {}
|
|
@@ -69,6 +77,26 @@ class GraphElementPropertyRunner(GraphEntityOpsBaseRunner):
|
|
|
69
77
|
return self._handle_properties(G, node_properties, node_labels, config)
|
|
70
78
|
|
|
71
79
|
|
|
80
|
+
class GraphNodePropertyRunner(GraphEntityOpsBaseRunner):
|
|
81
|
+
@compatible_with("stream", min_inclusive=ServerVersion(2, 2, 0))
|
|
82
|
+
@filter_id_func_deprecation_warning()
|
|
83
|
+
def stream(
|
|
84
|
+
self,
|
|
85
|
+
G: Graph,
|
|
86
|
+
node_property: str,
|
|
87
|
+
node_labels: Strings = ["*"],
|
|
88
|
+
db_node_properties: List[str] = [],
|
|
89
|
+
**config: Any,
|
|
90
|
+
) -> DataFrame:
|
|
91
|
+
self._namespace += ".stream"
|
|
92
|
+
|
|
93
|
+
result = self._handle_properties(G, node_property, node_labels, config)
|
|
94
|
+
|
|
95
|
+
return GraphNodePropertiesRunner._process_result(
|
|
96
|
+
self._query_runner, list(node_property), False, db_node_properties, result, config
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
|
|
72
100
|
class GraphNodePropertiesRunner(GraphEntityOpsBaseRunner):
|
|
73
101
|
@compatible_with("stream", min_inclusive=ServerVersion(2, 2, 0))
|
|
74
102
|
@filter_id_func_deprecation_warning()
|
|
@@ -85,6 +113,19 @@ class GraphNodePropertiesRunner(GraphEntityOpsBaseRunner):
|
|
|
85
113
|
|
|
86
114
|
result = self._handle_properties(G, node_properties, node_labels, config)
|
|
87
115
|
|
|
116
|
+
return GraphNodePropertiesRunner._process_result(
|
|
117
|
+
self._query_runner, node_properties, separate_property_columns, db_node_properties, result, config
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
@staticmethod
|
|
121
|
+
def _process_result(
|
|
122
|
+
query_runner: QueryRunner,
|
|
123
|
+
node_properties: List[str],
|
|
124
|
+
separate_property_columns: bool,
|
|
125
|
+
db_node_properties: List[str],
|
|
126
|
+
result: DataFrame,
|
|
127
|
+
config: Dict[str, Any],
|
|
128
|
+
) -> DataFrame:
|
|
88
129
|
# new format was requested, but the query was run via Cypher
|
|
89
130
|
if separate_property_columns and "propertyValue" in result.keys():
|
|
90
131
|
wide_result = result.pivot(index=["nodeId"], columns=["nodeProperty"], values="propertyValue")
|
|
@@ -98,7 +139,7 @@ class GraphNodePropertiesRunner(GraphEntityOpsBaseRunner):
|
|
|
98
139
|
# old format was requested but the query was run via Arrow
|
|
99
140
|
elif not separate_property_columns and "propertyValue" not in result.keys():
|
|
100
141
|
id_vars = ["nodeId", "nodeLabels"] if config.get("listNodeLabels", False) else ["nodeId"]
|
|
101
|
-
result = result.melt(id_vars=id_vars
|
|
142
|
+
result = result.melt(id_vars=id_vars, var_name="nodeProperty", value_name="propertyValue")
|
|
102
143
|
|
|
103
144
|
if db_node_properties:
|
|
104
145
|
duplicate_properties = set(db_node_properties).intersection(set(node_properties))
|
|
@@ -108,16 +149,20 @@ class GraphNodePropertiesRunner(GraphEntityOpsBaseRunner):
|
|
|
108
149
|
)
|
|
109
150
|
|
|
110
151
|
unique_node_ids = result["nodeId"].drop_duplicates().tolist()
|
|
111
|
-
db_properties_df =
|
|
112
|
-
|
|
152
|
+
db_properties_df = query_runner.run_cypher(
|
|
153
|
+
GraphNodePropertiesRunner._build_query(db_node_properties), {"ids": unique_node_ids}
|
|
113
154
|
)
|
|
114
155
|
|
|
115
156
|
if "propertyValue" not in result.keys():
|
|
116
157
|
result = result.join(db_properties_df.set_index("nodeId"), on="nodeId")
|
|
117
158
|
else:
|
|
118
|
-
db_properties_df = db_properties_df.melt(
|
|
119
|
-
|
|
159
|
+
db_properties_df = db_properties_df.melt(
|
|
160
|
+
id_vars=["nodeId"], var_name="nodeProperty", value_name="propertyValue"
|
|
120
161
|
)
|
|
162
|
+
|
|
163
|
+
if "nodeProperty" not in result.keys():
|
|
164
|
+
result["nodeProperty"] = node_properties[0]
|
|
165
|
+
|
|
121
166
|
result = pd.concat([result, db_properties_df])
|
|
122
167
|
|
|
123
168
|
return result
|
|
@@ -242,13 +287,13 @@ class ToUndirectedRunner(IllegalAttrChecker):
|
|
|
242
287
|
|
|
243
288
|
@graph_type_check
|
|
244
289
|
def __call__(self, G: Graph, relationship_type: str, mutate_relationship_type: str, **config: Any) -> "Series[Any]":
|
|
245
|
-
return self._run_procedure(G, relationship_type, mutate_relationship_type)
|
|
290
|
+
return self._run_procedure(G, relationship_type, mutate_relationship_type, **config)
|
|
246
291
|
|
|
247
292
|
@graph_type_check
|
|
248
293
|
@compatible_with("estimate", min_inclusive=ServerVersion(2, 3, 0))
|
|
249
294
|
def estimate(self, G: Graph, relationship_type: str, mutate_relationship_type: str, **config: Any) -> "Series[Any]":
|
|
250
295
|
self._namespace += ".estimate"
|
|
251
|
-
return self._run_procedure(G, relationship_type, mutate_relationship_type)
|
|
296
|
+
return self._run_procedure(G, relationship_type, mutate_relationship_type, **config)
|
|
252
297
|
|
|
253
298
|
|
|
254
299
|
class GraphRelationshipsRunner(GraphEntityOpsBaseRunner):
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from .graph_project_runner import GraphProjectRunner
|
|
2
|
+
from graphdatascience.graph.base_graph_proc_runner import BaseGraphProcRunner
|
|
3
|
+
from graphdatascience.graph.graph_cypher_runner import GraphCypherRunner
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class GraphProcRunner(BaseGraphProcRunner):
|
|
7
|
+
@property
|
|
8
|
+
def project(self) -> GraphProjectRunner:
|
|
9
|
+
self._namespace += ".project"
|
|
10
|
+
return GraphProjectRunner(self._query_runner, self._namespace, self._server_version)
|
|
11
|
+
|
|
12
|
+
@property
|
|
13
|
+
def cypher(self) -> GraphCypherRunner:
|
|
14
|
+
self._namespace += ".project"
|
|
15
|
+
return GraphCypherRunner(self._query_runner, self._namespace, self._server_version)
|
{graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/graph/graph_project_runner.py
RENAMED
|
@@ -5,13 +5,10 @@ from typing import Any
|
|
|
5
5
|
from pandas import Series
|
|
6
6
|
|
|
7
7
|
from ..error.illegal_attr_checker import IllegalAttrChecker
|
|
8
|
-
from ..gds_session.schema import NODE_PROPERTY_SCHEMA, RELATIONSHIP_PROPERTY_SCHEMA
|
|
9
8
|
from .graph_object import Graph
|
|
10
9
|
from .graph_type_check import from_graph_type_check
|
|
11
10
|
from graphdatascience.call_parameters import CallParameters
|
|
12
11
|
from graphdatascience.graph.graph_create_result import GraphCreateResult
|
|
13
|
-
from graphdatascience.server_version.compatible_with import compatible_with
|
|
14
|
-
from graphdatascience.server_version.server_version import ServerVersion
|
|
15
12
|
|
|
16
13
|
|
|
17
14
|
class GraphProjectRunner(IllegalAttrChecker):
|
|
@@ -73,29 +70,3 @@ class GraphProjectBetaRunner(IllegalAttrChecker):
|
|
|
73
70
|
).squeeze()
|
|
74
71
|
|
|
75
72
|
return GraphCreateResult(Graph(graph_name, self._query_runner, self._server_version), result)
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
class GraphProjectRemoteRunner(IllegalAttrChecker):
|
|
79
|
-
@compatible_with("project", min_inclusive=ServerVersion(2, 6, 0))
|
|
80
|
-
def __call__(self, graph_name: str, query: str, **config: Any) -> GraphCreateResult:
|
|
81
|
-
placeholder = "<>" # host and token will be added by query runner
|
|
82
|
-
self.map_property_types(config)
|
|
83
|
-
params = CallParameters(
|
|
84
|
-
graph_name=graph_name,
|
|
85
|
-
query=query,
|
|
86
|
-
token=placeholder,
|
|
87
|
-
host=placeholder,
|
|
88
|
-
remote_database=self._query_runner.database(),
|
|
89
|
-
config=config,
|
|
90
|
-
)
|
|
91
|
-
result = self._query_runner.call_procedure(
|
|
92
|
-
endpoint=self._namespace,
|
|
93
|
-
params=params,
|
|
94
|
-
).squeeze()
|
|
95
|
-
return GraphCreateResult(Graph(graph_name, self._query_runner, self._server_version), result)
|
|
96
|
-
|
|
97
|
-
@staticmethod
|
|
98
|
-
def map_property_types(config: dict[str, Any]) -> None:
|
|
99
|
-
for key in [NODE_PROPERTY_SCHEMA, RELATIONSHIP_PROPERTY_SCHEMA]:
|
|
100
|
-
if key in config:
|
|
101
|
-
config[key] = {k: v.value for k, v in config[key].items()}
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
from graphdatascience.graph.base_graph_proc_runner import BaseGraphProcRunner
|
|
2
|
+
from graphdatascience.graph.graph_remote_project_runner import GraphProjectRemoteRunner
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class GraphRemoteProcRunner(BaseGraphProcRunner):
|
|
6
|
+
@property
|
|
7
|
+
def project(self) -> GraphProjectRemoteRunner:
|
|
8
|
+
self._namespace += ".project.remoteDb"
|
|
9
|
+
return GraphProjectRemoteRunner(self._query_runner, self._namespace, self._server_version)
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from ..error.illegal_attr_checker import IllegalAttrChecker
|
|
6
|
+
from ..server_version.compatible_with import compatible_with
|
|
7
|
+
from .graph_object import Graph
|
|
8
|
+
from graphdatascience.call_parameters import CallParameters
|
|
9
|
+
from graphdatascience.graph.graph_create_result import GraphCreateResult
|
|
10
|
+
from graphdatascience.server_version.server_version import ServerVersion
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class GraphProjectRemoteRunner(IllegalAttrChecker):
|
|
14
|
+
_SCHEMA_KEYS = ["nodePropertySchema", "relationshipPropertySchema"]
|
|
15
|
+
|
|
16
|
+
@compatible_with("project", min_inclusive=ServerVersion(2, 6, 0))
|
|
17
|
+
def __call__(self, graph_name: str, query: str, **config: Any) -> GraphCreateResult:
|
|
18
|
+
placeholder = "<>" # host and token will be added by query runner
|
|
19
|
+
self.map_property_types(config)
|
|
20
|
+
params = CallParameters(
|
|
21
|
+
graph_name=graph_name,
|
|
22
|
+
query=query,
|
|
23
|
+
token=placeholder,
|
|
24
|
+
host=placeholder,
|
|
25
|
+
remote_database=self._query_runner.database(),
|
|
26
|
+
config=config,
|
|
27
|
+
)
|
|
28
|
+
result = self._query_runner.call_procedure(
|
|
29
|
+
endpoint=self._namespace,
|
|
30
|
+
params=params,
|
|
31
|
+
).squeeze()
|
|
32
|
+
return GraphCreateResult(Graph(graph_name, self._query_runner, self._server_version), result)
|
|
33
|
+
|
|
34
|
+
@staticmethod
|
|
35
|
+
def map_property_types(config: dict[str, Any]) -> None:
|
|
36
|
+
for key in GraphProjectRemoteRunner._SCHEMA_KEYS:
|
|
37
|
+
if key in config:
|
|
38
|
+
config[key] = {k: v.value for k, v in config[key].items()}
|
|
@@ -314,8 +314,8 @@ class OGBLLoader(OGBLoader):
|
|
|
314
314
|
assert source_labels[i] == source_label
|
|
315
315
|
assert target_labels[i] == target_label
|
|
316
316
|
|
|
317
|
-
source_ids[i] += node_id_offsets[edges["head_type"][i]]
|
|
318
|
-
target_ids[i] += node_id_offsets[edges["tail_type"][i]]
|
|
317
|
+
source_ids[i] += node_id_offsets[edges["head_type"][i]] + edges["head"][i]
|
|
318
|
+
target_ids[i] += node_id_offsets[edges["tail_type"][i]] + edges["tail"][i]
|
|
319
319
|
|
|
320
320
|
rel_types.append(f"{edge_type}_{set_type.upper()}")
|
|
321
321
|
|
|
@@ -23,11 +23,12 @@ class GraphDataScience(DirectEndpoints, UncallableNamespace):
|
|
|
23
23
|
|
|
24
24
|
def __init__(
|
|
25
25
|
self,
|
|
26
|
+
/,
|
|
26
27
|
endpoint: Union[str, Driver, QueryRunner],
|
|
27
28
|
auth: Optional[Tuple[str, str]] = None,
|
|
28
29
|
aura_ds: bool = False,
|
|
29
30
|
database: Optional[str] = None,
|
|
30
|
-
arrow: bool = True,
|
|
31
|
+
arrow: Union[str, bool] = True,
|
|
31
32
|
arrow_disable_server_verification: bool = True,
|
|
32
33
|
arrow_tls_root_certs: Optional[bytes] = None,
|
|
33
34
|
bookmarks: Optional[Any] = None,
|
|
@@ -46,9 +47,10 @@ class GraphDataScience(DirectEndpoints, UncallableNamespace):
|
|
|
46
47
|
to a Neo4j Aura instance.
|
|
47
48
|
database: Optional[str], default None
|
|
48
49
|
The Neo4j database to query against.
|
|
49
|
-
arrow : bool, default True
|
|
50
|
-
|
|
51
|
-
for data streaming if it is available on the server.
|
|
50
|
+
arrow : Union[str, bool], default True
|
|
51
|
+
Arrow connection information. Either a flag that indicates whether the client should use Apache Arrow
|
|
52
|
+
for data streaming if it is available on the server. True means discover the connection URI from the server.
|
|
53
|
+
A connection URI (str) can also be provided.
|
|
52
54
|
arrow_disable_server_verification : bool, default True
|
|
53
55
|
A flag that indicates that, if the flight client is connecting with
|
|
54
56
|
TLS, that it skips server verification. If this is enabled, all
|
|
@@ -76,6 +78,7 @@ class GraphDataScience(DirectEndpoints, UncallableNamespace):
|
|
|
76
78
|
self._query_runner.encrypted(),
|
|
77
79
|
arrow_disable_server_verification,
|
|
78
80
|
arrow_tls_root_certs,
|
|
81
|
+
None if arrow is True else arrow,
|
|
79
82
|
)
|
|
80
83
|
|
|
81
84
|
super().__init__(self._query_runner, "gds", self._server_version)
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from enum import Enum
|
|
4
|
+
from typing import List
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ArrowEndpointVersion(Enum):
|
|
8
|
+
ALPHA = ""
|
|
9
|
+
V1 = "v1/"
|
|
10
|
+
|
|
11
|
+
def version(self) -> str:
|
|
12
|
+
return self._name_.lower()
|
|
13
|
+
|
|
14
|
+
def prefix(self) -> str:
|
|
15
|
+
return self._value_
|
|
16
|
+
|
|
17
|
+
@staticmethod
|
|
18
|
+
def from_arrow_info(supported_arrow_versions: List[str]) -> ArrowEndpointVersion:
|
|
19
|
+
# Fallback for pre 2.6.0 servers that do not support versions
|
|
20
|
+
if len(supported_arrow_versions) == 0:
|
|
21
|
+
return ArrowEndpointVersion.ALPHA
|
|
22
|
+
|
|
23
|
+
# If the server supports versioned endpoints, we try v1 first
|
|
24
|
+
if ArrowEndpointVersion.V1.version() in supported_arrow_versions:
|
|
25
|
+
return ArrowEndpointVersion.V1
|
|
26
|
+
|
|
27
|
+
if ArrowEndpointVersion.ALPHA.version() in supported_arrow_versions:
|
|
28
|
+
return ArrowEndpointVersion.ALPHA
|
|
29
|
+
|
|
30
|
+
raise UnsupportedArrowEndpointVersion(supported_arrow_versions)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class UnsupportedArrowEndpointVersion(Exception):
|
|
34
|
+
def __init__(self, server_version: List[str]) -> None:
|
|
35
|
+
super().__init__(self, f"Unsupported Arrow endpoint versions: {server_version}")
|
|
@@ -1,9 +1,11 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import concurrent
|
|
2
4
|
import json
|
|
3
5
|
import math
|
|
4
6
|
import warnings
|
|
5
7
|
from concurrent.futures import ThreadPoolExecutor
|
|
6
|
-
from typing import Any, Dict, List, Optional
|
|
8
|
+
from typing import Any, Dict, List, NoReturn, Optional
|
|
7
9
|
|
|
8
10
|
import numpy
|
|
9
11
|
import pyarrow.flight as flight
|
|
@@ -11,6 +13,7 @@ from pandas import DataFrame
|
|
|
11
13
|
from pyarrow import Table
|
|
12
14
|
from tqdm.auto import tqdm
|
|
13
15
|
|
|
16
|
+
from .arrow_endpoint_version import ArrowEndpointVersion
|
|
14
17
|
from .graph_constructor import GraphConstructor
|
|
15
18
|
|
|
16
19
|
|
|
@@ -21,6 +24,7 @@ class ArrowGraphConstructor(GraphConstructor):
|
|
|
21
24
|
graph_name: str,
|
|
22
25
|
flight_client: flight.FlightClient,
|
|
23
26
|
concurrency: int,
|
|
27
|
+
arrow_endpoint_version: ArrowEndpointVersion,
|
|
24
28
|
undirected_relationship_types: Optional[List[str]],
|
|
25
29
|
chunk_size: int = 10_000,
|
|
26
30
|
):
|
|
@@ -28,6 +32,7 @@ class ArrowGraphConstructor(GraphConstructor):
|
|
|
28
32
|
self._concurrency = concurrency
|
|
29
33
|
self._graph_name = graph_name
|
|
30
34
|
self._client = flight_client
|
|
35
|
+
self._arrow_endpoint_version = arrow_endpoint_version
|
|
31
36
|
self._undirected_relationship_types = (
|
|
32
37
|
[] if undirected_relationship_types is None else undirected_relationship_types
|
|
33
38
|
)
|
|
@@ -81,6 +86,7 @@ class ArrowGraphConstructor(GraphConstructor):
|
|
|
81
86
|
return partitioned_dfs
|
|
82
87
|
|
|
83
88
|
def _send_action(self, action_type: str, meta_data: Dict[str, Any]) -> None:
|
|
89
|
+
action_type = self._versioned_action_type(action_type)
|
|
84
90
|
result = self._client.do_action(flight.Action(action_type, json.dumps(meta_data).encode("utf-8")))
|
|
85
91
|
|
|
86
92
|
# Consume result fully to sanity check and avoid cancelled streams
|
|
@@ -89,10 +95,11 @@ class ArrowGraphConstructor(GraphConstructor):
|
|
|
89
95
|
|
|
90
96
|
json.loads(collected_result[0].body.to_pybytes().decode())
|
|
91
97
|
|
|
92
|
-
def _send_df(self, df: DataFrame, entity_type: str, pbar: tqdm) -> None:
|
|
98
|
+
def _send_df(self, df: DataFrame, entity_type: str, pbar: tqdm[NoReturn]) -> None:
|
|
93
99
|
table = Table.from_pandas(df)
|
|
94
100
|
batches = table.to_batches(self._chunk_size)
|
|
95
101
|
flight_descriptor = {"name": self._graph_name, "entity_type": entity_type}
|
|
102
|
+
flight_descriptor = self._versioned_flight_desriptor(flight_descriptor)
|
|
96
103
|
|
|
97
104
|
# Write schema
|
|
98
105
|
upload_descriptor = flight.FlightDescriptor.for_command(json.dumps(flight_descriptor).encode("utf-8"))
|
|
@@ -103,6 +110,8 @@ class ArrowGraphConstructor(GraphConstructor):
|
|
|
103
110
|
for partition in batches:
|
|
104
111
|
writer.write_batch(partition)
|
|
105
112
|
pbar.update(partition.num_rows)
|
|
113
|
+
# Force a refresh to avoid the progress bar getting stuck at 0%
|
|
114
|
+
pbar.refresh()
|
|
106
115
|
|
|
107
116
|
def _send_dfs(self, dfs: List[DataFrame], entity_type: str) -> None:
|
|
108
117
|
desc = "Uploading Nodes" if entity_type == "node" else "Uploading Relationships"
|
|
@@ -117,3 +126,17 @@ class ArrowGraphConstructor(GraphConstructor):
|
|
|
117
126
|
if not future.exception():
|
|
118
127
|
continue
|
|
119
128
|
raise future.exception() # type: ignore
|
|
129
|
+
|
|
130
|
+
def _versioned_action_type(self, action_type: str) -> str:
|
|
131
|
+
return self._arrow_endpoint_version.prefix() + action_type
|
|
132
|
+
|
|
133
|
+
def _versioned_flight_desriptor(self, flight_descriptor: Dict[str, Any]) -> Dict[str, Any]:
|
|
134
|
+
return (
|
|
135
|
+
flight_descriptor
|
|
136
|
+
if self._arrow_endpoint_version == ArrowEndpointVersion.ALPHA
|
|
137
|
+
else {
|
|
138
|
+
"name": "PUT_MESSAGE",
|
|
139
|
+
"version": ArrowEndpointVersion.V1.version(),
|
|
140
|
+
"body": flight_descriptor,
|
|
141
|
+
}
|
|
142
|
+
)
|
{graphdatascience-1.9 → graphdatascience-1.10}/graphdatascience/query_runner/arrow_query_runner.py
RENAMED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import base64
|
|
2
4
|
import json
|
|
3
5
|
import time
|
|
@@ -5,13 +7,14 @@ import warnings
|
|
|
5
7
|
from typing import Any, Dict, List, Optional, Tuple
|
|
6
8
|
|
|
7
9
|
import pyarrow.flight as flight
|
|
8
|
-
from pandas import DataFrame
|
|
10
|
+
from pandas import DataFrame
|
|
9
11
|
from pyarrow import ChunkedArray, Table, chunked_array
|
|
10
12
|
from pyarrow.flight import ClientMiddleware, ClientMiddlewareFactory
|
|
11
13
|
from pyarrow.types import is_dictionary # type: ignore
|
|
12
14
|
|
|
13
15
|
from ..call_parameters import CallParameters
|
|
14
16
|
from ..server_version.server_version import ServerVersion
|
|
17
|
+
from .arrow_endpoint_version import ArrowEndpointVersion
|
|
15
18
|
from .arrow_graph_constructor import ArrowGraphConstructor
|
|
16
19
|
from .graph_constructor import GraphConstructor
|
|
17
20
|
from .query_runner import QueryRunner
|
|
@@ -28,27 +31,29 @@ class ArrowQueryRunner(QueryRunner):
|
|
|
28
31
|
encrypted: bool = False,
|
|
29
32
|
disable_server_verification: bool = False,
|
|
30
33
|
tls_root_certs: Optional[bytes] = None,
|
|
31
|
-
|
|
34
|
+
connection_string_override: Optional[str] = None,
|
|
35
|
+
) -> QueryRunner:
|
|
36
|
+
arrow_info = (
|
|
37
|
+
fallback_query_runner.call_procedure(endpoint="gds.debug.arrow", custom_error=False).squeeze().to_dict()
|
|
38
|
+
)
|
|
32
39
|
server_version = fallback_query_runner.server_version()
|
|
40
|
+
connection_string: str
|
|
41
|
+
if connection_string_override is not None:
|
|
42
|
+
connection_string = connection_string_override
|
|
43
|
+
else:
|
|
44
|
+
connection_string = arrow_info.get("advertisedListenAddress", arrow_info["listenAddress"])
|
|
45
|
+
arrow_endpoint_version = ArrowEndpointVersion.from_arrow_info(arrow_info.get("versions", []))
|
|
33
46
|
|
|
34
|
-
yield_fields = (
|
|
35
|
-
["running", "listenAddress"]
|
|
36
|
-
if server_version >= ServerVersion(2, 2, 1)
|
|
37
|
-
else ["running", "advertisedListenAddress"]
|
|
38
|
-
)
|
|
39
|
-
arrow_info: "Series[Any]" = fallback_query_runner.call_procedure(
|
|
40
|
-
endpoint="gds.debug.arrow", yields=yield_fields, custom_error=False
|
|
41
|
-
).squeeze()
|
|
42
|
-
listen_address: str = arrow_info.get("advertisedListenAddress", arrow_info["listenAddress"]) # type: ignore
|
|
43
47
|
if arrow_info["running"]:
|
|
44
48
|
return ArrowQueryRunner(
|
|
45
|
-
|
|
49
|
+
connection_string,
|
|
46
50
|
fallback_query_runner,
|
|
47
51
|
server_version,
|
|
48
52
|
auth,
|
|
49
53
|
encrypted,
|
|
50
54
|
disable_server_verification,
|
|
51
55
|
tls_root_certs,
|
|
56
|
+
arrow_endpoint_version,
|
|
52
57
|
)
|
|
53
58
|
else:
|
|
54
59
|
return fallback_query_runner
|
|
@@ -62,9 +67,11 @@ class ArrowQueryRunner(QueryRunner):
|
|
|
62
67
|
encrypted: bool = False,
|
|
63
68
|
disable_server_verification: bool = False,
|
|
64
69
|
tls_root_certs: Optional[bytes] = None,
|
|
70
|
+
arrow_endpoint_version: ArrowEndpointVersion = ArrowEndpointVersion.ALPHA,
|
|
65
71
|
):
|
|
66
72
|
self._fallback_query_runner = fallback_query_runner
|
|
67
73
|
self._server_version = server_version
|
|
74
|
+
self._arrow_endpoint_version = arrow_endpoint_version
|
|
68
75
|
|
|
69
76
|
host, port_string = uri.split(":")
|
|
70
77
|
|
|
@@ -272,8 +279,15 @@ class ArrowQueryRunner(QueryRunner):
|
|
|
272
279
|
"procedure_name": procedure_name,
|
|
273
280
|
"configuration": configuration,
|
|
274
281
|
}
|
|
275
|
-
ticket = flight.Ticket(json.dumps(payload).encode("utf-8"))
|
|
276
282
|
|
|
283
|
+
if self._arrow_endpoint_version == ArrowEndpointVersion.V1:
|
|
284
|
+
payload = {
|
|
285
|
+
"name": "GET_MESSAGE",
|
|
286
|
+
"version": ArrowEndpointVersion.V1.version(),
|
|
287
|
+
"body": payload,
|
|
288
|
+
}
|
|
289
|
+
|
|
290
|
+
ticket = flight.Ticket(json.dumps(payload).encode("utf-8"))
|
|
277
291
|
get = self._flight_client.do_get(ticket)
|
|
278
292
|
arrow_table = get.read_all()
|
|
279
293
|
|
|
@@ -282,6 +296,13 @@ class ArrowQueryRunner(QueryRunner):
|
|
|
282
296
|
new_colum_names = ["nodeLabels" if i == "labels" else i for i in arrow_table.column_names]
|
|
283
297
|
arrow_table = arrow_table.rename_columns(new_colum_names)
|
|
284
298
|
|
|
299
|
+
# Pandas 2.2.0 deprecated an API used by ArrowTable.to_pandas() (< pyarrow 15.0)
|
|
300
|
+
warnings.filterwarnings(
|
|
301
|
+
"ignore",
|
|
302
|
+
category=DeprecationWarning,
|
|
303
|
+
message=r"Passing a BlockManager to DataFrame is deprecated",
|
|
304
|
+
)
|
|
305
|
+
|
|
285
306
|
return self._sanitize_arrow_table(arrow_table).to_pandas() # type: ignore
|
|
286
307
|
|
|
287
308
|
def create_graph_constructor(
|
|
@@ -295,10 +316,19 @@ class ArrowQueryRunner(QueryRunner):
|
|
|
295
316
|
)
|
|
296
317
|
|
|
297
318
|
return ArrowGraphConstructor(
|
|
298
|
-
database,
|
|
319
|
+
database,
|
|
320
|
+
graph_name,
|
|
321
|
+
self._flight_client,
|
|
322
|
+
concurrency,
|
|
323
|
+
self._arrow_endpoint_version,
|
|
324
|
+
undirected_relationship_types,
|
|
299
325
|
)
|
|
300
326
|
|
|
301
327
|
def _sanitize_arrow_table(self, arrow_table: Table) -> Table:
|
|
328
|
+
# empty columns cannot be used to build a chunked_array in pyarrow
|
|
329
|
+
if len(arrow_table) == 0:
|
|
330
|
+
return arrow_table
|
|
331
|
+
|
|
302
332
|
dict_encoded_fields = [
|
|
303
333
|
(idx, field) for idx, field in enumerate(arrow_table.schema) if is_dictionary(field.type)
|
|
304
334
|
]
|
|
@@ -5,7 +5,7 @@ from pyarrow import flight
|
|
|
5
5
|
from pyarrow.flight import ClientMiddleware, ClientMiddlewareFactory
|
|
6
6
|
|
|
7
7
|
from ..call_parameters import CallParameters
|
|
8
|
-
from ..
|
|
8
|
+
from ..session.dbms_connection_info import DbmsConnectionInfo
|
|
9
9
|
from .query_runner import QueryRunner
|
|
10
10
|
from graphdatascience.query_runner.graph_constructor import GraphConstructor
|
|
11
11
|
from graphdatascience.server_version.server_version import ServerVersion
|