graphdatascience 1.8__tar.gz → 1.9__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {graphdatascience-1.8/graphdatascience.egg-info → graphdatascience-1.9}/PKG-INFO +4 -2
- {graphdatascience-1.8 → graphdatascience-1.9}/README.md +1 -0
- graphdatascience-1.9/graphdatascience/__init__.py +36 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/algo/algo_proc_runner.py +3 -9
- graphdatascience-1.9/graphdatascience/call_parameters.py +8 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/caller_base.py +5 -1
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/endpoints.py +1 -6
- graphdatascience-1.9/graphdatascience/gds_session/aura_api.py +236 -0
- graphdatascience-1.9/graphdatascience/gds_session/aura_graph_data_science.py +181 -0
- graphdatascience-1.9/graphdatascience/gds_session/dbms_connection_info.py +14 -0
- graphdatascience-1.9/graphdatascience/gds_session/gds_sessions.py +178 -0
- graphdatascience-1.9/graphdatascience/gds_session/schema.py +12 -0
- graphdatascience-1.9/graphdatascience/gds_session/session_sizes.py +23 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/graph/graph_alpha_proc_runner.py +6 -8
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/graph/graph_beta_proc_runner.py +8 -8
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/graph/graph_create_result.py +4 -4
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/graph/graph_cypher_runner.py +1 -1
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/graph/graph_endpoints.py +0 -7
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/graph/graph_entity_ops_runner.py +77 -79
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/graph/graph_export_runner.py +5 -8
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/graph/graph_object.py +16 -13
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/graph/graph_proc_runner.py +110 -83
- graphdatascience-1.9/graphdatascience/graph/graph_project_runner.py +101 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/graph/graph_sample_runner.py +19 -21
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/graph_data_science.py +36 -96
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/model/graphsage_model.py +8 -5
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/model/link_prediction_model.py +2 -2
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/model/model.py +26 -16
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/model/model_alpha_proc_runner.py +15 -20
- graphdatascience-1.9/graphdatascience/model/model_beta_proc_runner.py +31 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/model/model_proc_runner.py +23 -33
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/model/node_classification_model.py +8 -5
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/model/node_regression_model.py +2 -2
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/model/simple_rel_embedding_model.py +44 -96
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/pipeline/classification_training_pipeline.py +17 -12
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/pipeline/lp_pipeline_create_runner.py +3 -3
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/pipeline/lp_training_pipeline.py +8 -9
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/pipeline/nc_pipeline_create_runner.py +3 -3
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/pipeline/nc_training_pipeline.py +6 -5
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/pipeline/nr_pipeline_create_runner.py +3 -3
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/pipeline/nr_training_pipeline.py +12 -11
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/pipeline/pipeline_proc_runner.py +10 -14
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/pipeline/training_pipeline.py +44 -38
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/query_runner/arrow_query_runner.py +104 -26
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/query_runner/aura_db_arrow_query_runner.py +67 -41
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/query_runner/cypher_graph_constructor.py +3 -3
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/query_runner/neo4j_query_runner.py +103 -9
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/query_runner/query_runner.py +25 -5
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/server_version/server_version.py +5 -6
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/system/config_endpoints.py +5 -8
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/system/system_endpoints.py +23 -37
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/topological_lp/topological_lp_alpha_runner.py +2 -2
- graphdatascience-1.9/graphdatascience/utils/__init__.py +0 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/utils/util_endpoints.py +11 -13
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/utils/util_proc_runner.py +6 -6
- graphdatascience-1.9/graphdatascience/version.py +1 -0
- {graphdatascience-1.8 → graphdatascience-1.9/graphdatascience.egg-info}/PKG-INFO +4 -2
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience.egg-info/SOURCES.txt +8 -1
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience.egg-info/requires.txt +2 -1
- {graphdatascience-1.8 → graphdatascience-1.9}/requirements/base/base.txt +2 -1
- graphdatascience-1.8/graphdatascience/__init__.py +0 -5
- graphdatascience-1.8/graphdatascience/graph/graph_alpha_project_runner.py +0 -17
- graphdatascience-1.8/graphdatascience/graph/graph_project_runner.py +0 -65
- graphdatascience-1.8/graphdatascience/model/model_beta_proc_runner.py +0 -37
- graphdatascience-1.8/graphdatascience/version.py +0 -1
- {graphdatascience-1.8 → graphdatascience-1.9}/LICENSE +0 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/MANIFEST.in +0 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/algo/__init__.py +0 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/algo/algo_endpoints.py +0 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/algo/single_mode_algo_endpoints.py +0 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/call_builder.py +0 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/error/__init__.py +0 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/error/client_only_endpoint.py +0 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/error/cypher_warning_handler.py +0 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/error/endpoint_suggester.py +0 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/error/gds_not_installed.py +0 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/error/illegal_attr_checker.py +0 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/error/unable_to_connect.py +0 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/error/uncallable_namespace.py +0 -0
- {graphdatascience-1.8/graphdatascience/graph → graphdatascience-1.9/graphdatascience/gds_session}/__init__.py +0 -0
- {graphdatascience-1.8/graphdatascience/model → graphdatascience-1.9/graphdatascience/graph}/__init__.py +0 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/graph/graph_type_check.py +0 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/graph/nx_loader.py +0 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/graph/ogb_loader.py +0 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/ignored_server_endpoints.py +0 -0
- {graphdatascience-1.8/graphdatascience/pipeline → graphdatascience-1.9/graphdatascience/model}/__init__.py +0 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/model/model_endpoints.py +0 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/model/model_resolver.py +0 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/model/pipeline_model.py +0 -0
- {graphdatascience-1.8/graphdatascience/query_runner → graphdatascience-1.9/graphdatascience/pipeline}/__init__.py +0 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/pipeline/pipeline_alpha_proc_runner.py +0 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/pipeline/pipeline_beta_proc_runner.py +0 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/pipeline/pipeline_endpoints.py +0 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/py.typed +0 -0
- {graphdatascience-1.8/graphdatascience/resources → graphdatascience-1.9/graphdatascience/query_runner}/__init__.py +0 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/query_runner/arrow_graph_constructor.py +0 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/query_runner/graph_constructor.py +0 -0
- {graphdatascience-1.8/graphdatascience/resources/cora → graphdatascience-1.9/graphdatascience/resources}/__init__.py +0 -0
- {graphdatascience-1.8/graphdatascience/resources/imdb → graphdatascience-1.9/graphdatascience/resources/cora}/__init__.py +0 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/resources/cora/cora_nodes.parquet.gzip +0 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/resources/cora/cora_rels.parquet.gzip +0 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/resources/cora/serialize_cora.py +0 -0
- {graphdatascience-1.8/graphdatascience/resources/karate → graphdatascience-1.9/graphdatascience/resources/imdb}/__init__.py +0 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/resources/imdb/imdb_acted_in.parquet.gzip +0 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/resources/imdb/imdb_actors.parquet.gzip +0 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/resources/imdb/imdb_directed_in.parquet.gzip +0 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/resources/imdb/imdb_directors.parquet.gzip +0 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/resources/imdb/imdb_movies_with_genre.parquet.gzip +0 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/resources/imdb/imdb_movies_without_genre.parquet.gzip +0 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/resources/imdb/serialize_imdb.py +0 -0
- {graphdatascience-1.8/graphdatascience/resources/lastfm → graphdatascience-1.9/graphdatascience/resources/karate}/__init__.py +0 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/resources/karate/karate_club.parquet.gzip +0 -0
- {graphdatascience-1.8/graphdatascience/server_version → graphdatascience-1.9/graphdatascience/resources/lastfm}/__init__.py +0 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/resources/lastfm/artist_nodes.parquet.gzip +0 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/resources/lastfm/serialize_lastfm.py +0 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/resources/lastfm/user_friend_df_directed.parquet.gzip +0 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/resources/lastfm/user_listen_artist_rels.parquet.gzip +0 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/resources/lastfm/user_nodes.parquet.gzip +0 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/resources/lastfm/user_tag_artist_rels.parquet.gzip +0 -0
- {graphdatascience-1.8/graphdatascience/system → graphdatascience-1.9/graphdatascience/server_version}/__init__.py +0 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/server_version/compatible_with.py +0 -0
- {graphdatascience-1.8/graphdatascience/topological_lp → graphdatascience-1.9/graphdatascience/system}/__init__.py +0 -0
- {graphdatascience-1.8/graphdatascience/utils → graphdatascience-1.9/graphdatascience/topological_lp}/__init__.py +0 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/topological_lp/topological_lp_endpoints.py +0 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience.egg-info/dependency_links.txt +0 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience.egg-info/not-zip-safe +0 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience.egg-info/top_level.txt +0 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/pyproject.toml +0 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/requirements/base/networkx.txt +0 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/requirements/base/ogb.txt +0 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/setup.cfg +0 -0
- {graphdatascience-1.8 → graphdatascience-1.9}/setup.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: graphdatascience
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.9
|
|
4
4
|
Summary: A Python client for the Neo4j Graph Data Science (GDS) library
|
|
5
5
|
Home-page: https://neo4j.com/product/graph-data-science/
|
|
6
6
|
Author: Neo4j
|
|
@@ -30,10 +30,11 @@ License-File: LICENSE
|
|
|
30
30
|
Requires-Dist: multimethod<2.0,>=1.0
|
|
31
31
|
Requires-Dist: neo4j<6.0,>=4.4.2
|
|
32
32
|
Requires-Dist: pandas<3.0,>=1.0
|
|
33
|
-
Requires-Dist: pyarrow<
|
|
33
|
+
Requires-Dist: pyarrow<15.0,>=10.0
|
|
34
34
|
Requires-Dist: textdistance<5.0,>=4.0
|
|
35
35
|
Requires-Dist: tqdm<5.0,>=4.0
|
|
36
36
|
Requires-Dist: typing-extensions<5.0,>=4.0
|
|
37
|
+
Requires-Dist: requests
|
|
37
38
|
Provides-Extra: ogb
|
|
38
39
|
Requires-Dist: ogb<2.0,>=1.0; extra == "ogb"
|
|
39
40
|
Provides-Extra: networkx
|
|
@@ -125,6 +126,7 @@ Full end-to-end examples in Jupyter ready-to-run notebooks can be found in the [
|
|
|
125
126
|
* [Sampling, Export and Integration with PyG example](examples/import-sample-export-gnn.ipynb)
|
|
126
127
|
* [Load data to a projected graph via graph construction](examples/load-data-via-graph-construction.ipynb)
|
|
127
128
|
* [Heterogeneous Node Classification with HashGNN and Autotuning](https://github.com/neo4j/graph-data-science-client/tree/main/examples/heterogeneous-node-classification-with-hashgnn.ipynb)
|
|
129
|
+
* [Perform inference using pre-trained KGE models](examples/kge-predict-transe-pyg-train.ipynb)
|
|
128
130
|
|
|
129
131
|
|
|
130
132
|
## Documentation
|
|
@@ -84,6 +84,7 @@ Full end-to-end examples in Jupyter ready-to-run notebooks can be found in the [
|
|
|
84
84
|
* [Sampling, Export and Integration with PyG example](examples/import-sample-export-gnn.ipynb)
|
|
85
85
|
* [Load data to a projected graph via graph construction](examples/load-data-via-graph-construction.ipynb)
|
|
86
86
|
* [Heterogeneous Node Classification with HashGNN and Autotuning](https://github.com/neo4j/graph-data-science-client/tree/main/examples/heterogeneous-node-classification-with-hashgnn.ipynb)
|
|
87
|
+
* [Perform inference using pre-trained KGE models](examples/kge-predict-transe-pyg-train.ipynb)
|
|
87
88
|
|
|
88
89
|
|
|
89
90
|
## Documentation
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
from .gds_session.gds_sessions import GdsSessions
|
|
2
|
+
from .graph.graph_create_result import GraphCreateResult
|
|
3
|
+
from .graph.graph_object import Graph
|
|
4
|
+
from .graph_data_science import GraphDataScience
|
|
5
|
+
from .model.graphsage_model import GraphSageModel
|
|
6
|
+
from .model.link_prediction_model import LinkFeature, LPModel
|
|
7
|
+
from .model.node_classification_model import NCModel
|
|
8
|
+
from .model.node_regression_model import NRModel
|
|
9
|
+
from .model.pipeline_model import NodePropertyStep
|
|
10
|
+
from .model.simple_rel_embedding_model import SimpleRelEmbeddingModel
|
|
11
|
+
from .pipeline.lp_training_pipeline import LPTrainingPipeline
|
|
12
|
+
from .pipeline.nc_training_pipeline import NCTrainingPipeline
|
|
13
|
+
from .pipeline.nr_training_pipeline import NRTrainingPipeline
|
|
14
|
+
from .query_runner.query_runner import QueryRunner
|
|
15
|
+
from .server_version.server_version import ServerVersion
|
|
16
|
+
from .version import __version__
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
"GraphDataScience",
|
|
20
|
+
"GdsSessions",
|
|
21
|
+
"QueryRunner",
|
|
22
|
+
"__version__",
|
|
23
|
+
"ServerVersion",
|
|
24
|
+
"Graph",
|
|
25
|
+
"GraphCreateResult",
|
|
26
|
+
"LPTrainingPipeline",
|
|
27
|
+
"NCTrainingPipeline",
|
|
28
|
+
"NRTrainingPipeline",
|
|
29
|
+
"NodePropertyStep",
|
|
30
|
+
"LinkFeature",
|
|
31
|
+
"LPModel",
|
|
32
|
+
"NCModel",
|
|
33
|
+
"NRModel",
|
|
34
|
+
"GraphSageModel",
|
|
35
|
+
"SimpleRelEmbeddingModel",
|
|
36
|
+
]
|
|
@@ -7,21 +7,15 @@ from ..error.illegal_attr_checker import IllegalAttrChecker
|
|
|
7
7
|
from ..graph.graph_object import Graph
|
|
8
8
|
from ..graph.graph_type_check import graph_type_check
|
|
9
9
|
from ..model.graphsage_model import GraphSageModel
|
|
10
|
+
from graphdatascience.call_parameters import CallParameters
|
|
10
11
|
|
|
11
12
|
|
|
12
13
|
class AlgoProcRunner(IllegalAttrChecker, ABC):
|
|
13
14
|
@graph_type_check
|
|
14
15
|
def _run_procedure(self, G: Graph, config: Dict[str, Any], with_logging: bool = True) -> DataFrame:
|
|
15
|
-
|
|
16
|
+
params = CallParameters(graph_name=G.name(), config=config)
|
|
16
17
|
|
|
17
|
-
|
|
18
|
-
params["graph_name"] = G.name()
|
|
19
|
-
params["config"] = config
|
|
20
|
-
|
|
21
|
-
if with_logging:
|
|
22
|
-
return self._query_runner.run_query_with_logging(query, params)
|
|
23
|
-
else:
|
|
24
|
-
return self._query_runner.run_query(query, params)
|
|
18
|
+
return self._query_runner.call_procedure(endpoint=self._namespace, params=params, logging=with_logging)
|
|
25
19
|
|
|
26
20
|
@graph_type_check
|
|
27
21
|
def estimate(self, G: Graph, **config: Any) -> "Series[Any]":
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
from typing import Any, OrderedDict
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class CallParameters(OrderedDict[str, Any]):
|
|
5
|
+
# since Python 3.6 also initializing through CallParameters(**kwargs) is order preserving
|
|
6
|
+
|
|
7
|
+
def placeholder_str(self) -> str:
|
|
8
|
+
return ", ".join([f"${k}" for k in self.keys()])
|
|
@@ -13,7 +13,11 @@ class CallerBase(ABC):
|
|
|
13
13
|
self._server_version = server_version
|
|
14
14
|
|
|
15
15
|
def _raise_suggestive_error_message(self, requested_endpoint: str) -> NoReturn:
|
|
16
|
-
list_result = self._query_runner.
|
|
16
|
+
list_result = self._query_runner.call_procedure(
|
|
17
|
+
endpoint="gds.list",
|
|
18
|
+
yields=["name"],
|
|
19
|
+
custom_error=False,
|
|
20
|
+
)
|
|
17
21
|
all_endpoints = list_result["name"].tolist()
|
|
18
22
|
|
|
19
23
|
raise SyntaxError(generate_suggestive_error_message(requested_endpoint, all_endpoints))
|
|
@@ -3,11 +3,7 @@ from .algo.single_mode_algo_endpoints import (
|
|
|
3
3
|
SingleModeAlphaAlgoEndpoints,
|
|
4
4
|
)
|
|
5
5
|
from .call_builder import IndirectAlphaCallBuilder, IndirectBetaCallBuilder
|
|
6
|
-
from .graph.graph_endpoints import
|
|
7
|
-
GraphAlphaEndpoints,
|
|
8
|
-
GraphBetaEndpoints,
|
|
9
|
-
GraphEndpoints,
|
|
10
|
-
)
|
|
6
|
+
from .graph.graph_endpoints import GraphAlphaEndpoints, GraphBetaEndpoints
|
|
11
7
|
from .model.model_endpoints import (
|
|
12
8
|
ModelAlphaEndpoints,
|
|
13
9
|
ModelBetaEndpoints,
|
|
@@ -39,7 +35,6 @@ class DirectEndpoints(
|
|
|
39
35
|
SingleModeAlgoEndpoints,
|
|
40
36
|
DirectSystemEndpoints,
|
|
41
37
|
DirectUtilEndpoints,
|
|
42
|
-
GraphEndpoints,
|
|
43
38
|
PipelineEndpoints,
|
|
44
39
|
ModelEndpoints,
|
|
45
40
|
ConfigEndpoints,
|
|
@@ -0,0 +1,236 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import dataclasses
|
|
4
|
+
import logging
|
|
5
|
+
import os
|
|
6
|
+
import time
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from typing import Any, List, Optional
|
|
9
|
+
from urllib.parse import urlparse
|
|
10
|
+
|
|
11
|
+
import requests as req
|
|
12
|
+
from requests import HTTPError
|
|
13
|
+
|
|
14
|
+
from graphdatascience.version import __version__
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass(repr=True)
|
|
18
|
+
class InstanceDetails:
|
|
19
|
+
id: str
|
|
20
|
+
name: str
|
|
21
|
+
tenant_id: str
|
|
22
|
+
cloud_provider: str
|
|
23
|
+
|
|
24
|
+
@classmethod
|
|
25
|
+
def fromJson(cls, json: dict[str, Any]) -> InstanceDetails:
|
|
26
|
+
return cls(
|
|
27
|
+
id=json["id"],
|
|
28
|
+
name=json["name"],
|
|
29
|
+
tenant_id=json["tenant_id"],
|
|
30
|
+
cloud_provider=json["cloud_provider"],
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass(repr=True)
|
|
35
|
+
class InstanceSpecificDetails(InstanceDetails):
|
|
36
|
+
status: str
|
|
37
|
+
connection_url: str
|
|
38
|
+
memory: str
|
|
39
|
+
type: str
|
|
40
|
+
region: str
|
|
41
|
+
|
|
42
|
+
@classmethod
|
|
43
|
+
def fromJson(cls, json: dict[str, Any]) -> InstanceSpecificDetails:
|
|
44
|
+
return cls(
|
|
45
|
+
id=json["id"],
|
|
46
|
+
name=json["name"],
|
|
47
|
+
tenant_id=json["tenant_id"],
|
|
48
|
+
cloud_provider=json["cloud_provider"],
|
|
49
|
+
status=json["status"],
|
|
50
|
+
connection_url=json.get("connection_url", ""),
|
|
51
|
+
memory=json.get("memory", ""),
|
|
52
|
+
type=json["type"],
|
|
53
|
+
region=json["region"],
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@dataclass(repr=True)
|
|
58
|
+
class InstanceCreateDetails:
|
|
59
|
+
id: str
|
|
60
|
+
username: str
|
|
61
|
+
password: str
|
|
62
|
+
connection_url: str
|
|
63
|
+
|
|
64
|
+
@classmethod
|
|
65
|
+
def from_json(cls, json: dict[str, Any]) -> InstanceCreateDetails:
|
|
66
|
+
fields = dataclasses.fields(cls)
|
|
67
|
+
if any(f.name not in json for f in fields):
|
|
68
|
+
raise RuntimeError(f"Missing required field. Expected `{[f.name for f in fields]}` but got `{json}`")
|
|
69
|
+
|
|
70
|
+
return cls(**{f.name: json[f.name] for f in fields})
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class AuraApi:
|
|
74
|
+
class AuraAuthToken:
|
|
75
|
+
access_token: str
|
|
76
|
+
expires_in: int
|
|
77
|
+
token_type: str
|
|
78
|
+
|
|
79
|
+
def __init__(self, json: dict[str, Any]) -> None:
|
|
80
|
+
self.access_token = json["access_token"]
|
|
81
|
+
expires_in: int = json["expires_in"]
|
|
82
|
+
self.expires_at = int(time.time()) + expires_in
|
|
83
|
+
self.token_type = json["token_type"]
|
|
84
|
+
|
|
85
|
+
def is_expired(self) -> bool:
|
|
86
|
+
return self.expires_at >= int(time.time())
|
|
87
|
+
|
|
88
|
+
def __init__(self, client_id: str, client_secret: str, tenant_id: Optional[str] = None) -> None:
|
|
89
|
+
self._dev_env = os.environ.get("AURA_ENV")
|
|
90
|
+
self._base_uri = "https://api.neo4j.io" if not self._dev_env else f"https://api-{self._dev_env}.neo4j-dev.io"
|
|
91
|
+
self._credentials = (client_id, client_secret)
|
|
92
|
+
self._token: Optional[AuraApi.AuraAuthToken] = None
|
|
93
|
+
self._logger = logging.getLogger()
|
|
94
|
+
self._tenant_id = tenant_id if tenant_id else self._get_tenant_id()
|
|
95
|
+
|
|
96
|
+
@staticmethod
|
|
97
|
+
def extract_id(uri: str) -> str:
|
|
98
|
+
host = urlparse(uri).hostname
|
|
99
|
+
|
|
100
|
+
if not host:
|
|
101
|
+
raise RuntimeError(f"Could not parse the uri `{uri}`.")
|
|
102
|
+
|
|
103
|
+
return host.split(".")[0].split("-")[0]
|
|
104
|
+
|
|
105
|
+
def create_instance(self, name: str, memory: str, cloud_provider: str, region: str) -> InstanceCreateDetails:
|
|
106
|
+
# TODO should give more control here
|
|
107
|
+
data = {
|
|
108
|
+
"name": name,
|
|
109
|
+
"memory": memory,
|
|
110
|
+
"version": "5",
|
|
111
|
+
"region": region,
|
|
112
|
+
# TODO should be figured out from the tenant details in the future
|
|
113
|
+
"type": self._instance_type(),
|
|
114
|
+
"tenant_id": self._tenant_id,
|
|
115
|
+
"cloud_provider": cloud_provider,
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
response = req.post(
|
|
119
|
+
f"{self._base_uri}/v1/instances",
|
|
120
|
+
json=data,
|
|
121
|
+
headers=self._build_header(),
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
try:
|
|
125
|
+
response.raise_for_status()
|
|
126
|
+
except HTTPError as e:
|
|
127
|
+
print(response.json())
|
|
128
|
+
raise e
|
|
129
|
+
|
|
130
|
+
return InstanceCreateDetails.from_json(response.json()["data"])
|
|
131
|
+
|
|
132
|
+
def delete_instance(self, instance_id: str) -> Optional[InstanceSpecificDetails]:
|
|
133
|
+
response = req.delete(
|
|
134
|
+
f"{self._base_uri}/v1/instances/{instance_id}",
|
|
135
|
+
headers=self._build_header(),
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
if response.status_code == 404:
|
|
139
|
+
return None
|
|
140
|
+
|
|
141
|
+
response.raise_for_status()
|
|
142
|
+
|
|
143
|
+
return InstanceSpecificDetails.fromJson(response.json()["data"])
|
|
144
|
+
|
|
145
|
+
def list_instances(self) -> List[InstanceDetails]:
|
|
146
|
+
response = req.get(
|
|
147
|
+
f"{self._base_uri}/v1/instances",
|
|
148
|
+
headers=self._build_header(),
|
|
149
|
+
params={"tenantId": self._tenant_id},
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
response.raise_for_status()
|
|
153
|
+
|
|
154
|
+
raw_data = response.json()["data"]
|
|
155
|
+
|
|
156
|
+
return [InstanceDetails.fromJson(i) for i in raw_data]
|
|
157
|
+
|
|
158
|
+
def list_instance(self, instance_id: str) -> Optional[InstanceSpecificDetails]:
|
|
159
|
+
response = req.get(
|
|
160
|
+
f"{self._base_uri}/v1/instances/{instance_id}",
|
|
161
|
+
headers=self._build_header(),
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
if response.status_code == 404:
|
|
165
|
+
return None
|
|
166
|
+
|
|
167
|
+
response.raise_for_status()
|
|
168
|
+
|
|
169
|
+
raw_data = response.json()["data"]
|
|
170
|
+
|
|
171
|
+
return InstanceSpecificDetails.fromJson(raw_data)
|
|
172
|
+
|
|
173
|
+
def wait_for_instance_running(
|
|
174
|
+
self, instance_id: str, sleep_time: float = 0.2, max_sleep_time: float = 300
|
|
175
|
+
) -> Optional[str]:
|
|
176
|
+
waited_time = 0.0
|
|
177
|
+
while waited_time <= max_sleep_time:
|
|
178
|
+
instance = self.list_instance(instance_id)
|
|
179
|
+
if instance is None:
|
|
180
|
+
return "Instance is not found -- please retry"
|
|
181
|
+
elif instance.status in ["deleting", "destroying"]:
|
|
182
|
+
return "Instance is being deleted"
|
|
183
|
+
elif instance.status == "running":
|
|
184
|
+
return None
|
|
185
|
+
else:
|
|
186
|
+
self._logger.debug(
|
|
187
|
+
f"Instance `{instance_id}` is not yet running. "
|
|
188
|
+
f"Current status: {instance.status}. "
|
|
189
|
+
f"Retrying in {sleep_time} seconds..."
|
|
190
|
+
)
|
|
191
|
+
waited_time += sleep_time
|
|
192
|
+
time.sleep(sleep_time)
|
|
193
|
+
|
|
194
|
+
return f"Instance is not running after waiting for {waited_time} seconds"
|
|
195
|
+
|
|
196
|
+
def _get_tenant_id(self) -> str:
|
|
197
|
+
response = req.get(
|
|
198
|
+
f"{self._base_uri}/v1/tenants",
|
|
199
|
+
headers=self._build_header(),
|
|
200
|
+
)
|
|
201
|
+
response.raise_for_status()
|
|
202
|
+
|
|
203
|
+
raw_data = response.json()["data"]
|
|
204
|
+
|
|
205
|
+
if len(raw_data) != 1:
|
|
206
|
+
raise RuntimeError(
|
|
207
|
+
f"This account has access to multiple tenants `{raw_data}`. Please specify which one to use."
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
return raw_data[0]["id"] # type: ignore
|
|
211
|
+
|
|
212
|
+
def _build_header(self) -> dict[str, str]:
|
|
213
|
+
return {"Authorization": f"Bearer {self._auth_token()}", "User-agent": f"neo4j-graphdatascience-v{__version__}"}
|
|
214
|
+
|
|
215
|
+
def _auth_token(self) -> str:
|
|
216
|
+
if self._token is None or self._token.is_expired():
|
|
217
|
+
self._token = self._update_token()
|
|
218
|
+
return self._token.access_token
|
|
219
|
+
|
|
220
|
+
def _update_token(self) -> AuraAuthToken:
|
|
221
|
+
data = {
|
|
222
|
+
"grant_type": "client_credentials",
|
|
223
|
+
}
|
|
224
|
+
|
|
225
|
+
self._logger.debug("Updating oauth token")
|
|
226
|
+
|
|
227
|
+
response = req.post(
|
|
228
|
+
f"{self._base_uri}/oauth/token", data=data, auth=(self._credentials[0], self._credentials[1])
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
response.raise_for_status()
|
|
232
|
+
|
|
233
|
+
return AuraApi.AuraAuthToken(response.json())
|
|
234
|
+
|
|
235
|
+
def _instance_type(self) -> str:
|
|
236
|
+
return "enterprise-ds" if not self._dev_env else "professional-ds"
|
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
from typing import Any, Callable, Dict, Optional
|
|
2
|
+
|
|
3
|
+
from pandas import DataFrame
|
|
4
|
+
|
|
5
|
+
from graphdatascience.call_builder import IndirectCallBuilder
|
|
6
|
+
from graphdatascience.endpoints import AlphaEndpoints, BetaEndpoints, DirectEndpoints
|
|
7
|
+
from graphdatascience.error.uncallable_namespace import UncallableNamespace
|
|
8
|
+
from graphdatascience.gds_session.dbms_connection_info import DbmsConnectionInfo
|
|
9
|
+
from graphdatascience.graph.graph_proc_runner import GraphRemoteProcRunner
|
|
10
|
+
from graphdatascience.query_runner.arrow_query_runner import ArrowQueryRunner
|
|
11
|
+
from graphdatascience.query_runner.aura_db_arrow_query_runner import (
|
|
12
|
+
AuraDbArrowQueryRunner,
|
|
13
|
+
)
|
|
14
|
+
from graphdatascience.query_runner.neo4j_query_runner import Neo4jQueryRunner
|
|
15
|
+
from graphdatascience.server_version.server_version import ServerVersion
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class AuraGraphDataScience(DirectEndpoints, UncallableNamespace):
|
|
19
|
+
"""
|
|
20
|
+
Primary API class for interacting with Neo4j AuraDB + Graph Data Science.
|
|
21
|
+
Always bind this object to a variable called `gds`.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
gds_session_connection_info: DbmsConnectionInfo,
|
|
27
|
+
aura_db_connection_info: DbmsConnectionInfo,
|
|
28
|
+
delete_fn: Callable[[], bool],
|
|
29
|
+
arrow_disable_server_verification: bool = True,
|
|
30
|
+
arrow_tls_root_certs: Optional[bytes] = None,
|
|
31
|
+
bookmarks: Optional[Any] = None,
|
|
32
|
+
):
|
|
33
|
+
gds_neo4j_query_runner = Neo4jQueryRunner.create(
|
|
34
|
+
gds_session_connection_info.uri, gds_session_connection_info.auth(), aura_ds=True
|
|
35
|
+
)
|
|
36
|
+
gds_query_runner = ArrowQueryRunner.create(
|
|
37
|
+
gds_neo4j_query_runner,
|
|
38
|
+
gds_session_connection_info.auth(),
|
|
39
|
+
gds_neo4j_query_runner.encrypted(),
|
|
40
|
+
arrow_disable_server_verification,
|
|
41
|
+
arrow_tls_root_certs,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
self._server_version = gds_query_runner.server_version()
|
|
45
|
+
|
|
46
|
+
if self._server_version < ServerVersion(2, 6, 0):
|
|
47
|
+
raise RuntimeError(
|
|
48
|
+
f"AuraDB connection info was provided but GDS version {self._server_version} \
|
|
49
|
+
does not support connecting to AuraDB"
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
self._db_query_runner = Neo4jQueryRunner.create(
|
|
53
|
+
aura_db_connection_info.uri,
|
|
54
|
+
aura_db_connection_info.auth(),
|
|
55
|
+
aura_ds=True,
|
|
56
|
+
server_version=self._server_version,
|
|
57
|
+
)
|
|
58
|
+
self._db_query_runner.set_bookmarks(bookmarks)
|
|
59
|
+
|
|
60
|
+
# we need to explicitly set these as the default value is None
|
|
61
|
+
# which signals the driver to use the default configured database
|
|
62
|
+
# from the dbms.
|
|
63
|
+
gds_query_runner.set_database("neo4j")
|
|
64
|
+
self._db_query_runner.set_database("neo4j")
|
|
65
|
+
|
|
66
|
+
self._query_runner = AuraDbArrowQueryRunner(
|
|
67
|
+
gds_query_runner, self._db_query_runner, self._db_query_runner.encrypted(), aura_db_connection_info
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
self._delete_fn = delete_fn
|
|
71
|
+
|
|
72
|
+
super().__init__(self._query_runner, "gds", self._server_version)
|
|
73
|
+
|
|
74
|
+
def run_cypher(
|
|
75
|
+
self, query: str, params: Optional[Dict[str, Any]] = None, database: Optional[str] = None
|
|
76
|
+
) -> DataFrame:
|
|
77
|
+
"""
|
|
78
|
+
Run a Cypher query against the AuraDB instance.
|
|
79
|
+
|
|
80
|
+
Parameters
|
|
81
|
+
----------
|
|
82
|
+
query: str
|
|
83
|
+
the Cypher query
|
|
84
|
+
params: Dict[str, Any]
|
|
85
|
+
parameters to the query
|
|
86
|
+
database: str
|
|
87
|
+
the database on which to run the query
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
The query result as a DataFrame
|
|
91
|
+
"""
|
|
92
|
+
# This will avoid calling valid gds procedures through a raw string
|
|
93
|
+
return self._db_query_runner.run_cypher(query, params, database, False)
|
|
94
|
+
|
|
95
|
+
@property
|
|
96
|
+
def graph(self) -> GraphRemoteProcRunner:
|
|
97
|
+
return GraphRemoteProcRunner(self._query_runner, f"{self._namespace}.graph", self._server_version)
|
|
98
|
+
|
|
99
|
+
@property
|
|
100
|
+
def alpha(self) -> AlphaEndpoints:
|
|
101
|
+
return AlphaEndpoints(self._query_runner, "gds.alpha", self._server_version)
|
|
102
|
+
|
|
103
|
+
@property
|
|
104
|
+
def beta(self) -> BetaEndpoints:
|
|
105
|
+
return BetaEndpoints(self._query_runner, "gds.beta", self._server_version)
|
|
106
|
+
|
|
107
|
+
def __getattr__(self, attr: str) -> IndirectCallBuilder:
|
|
108
|
+
return IndirectCallBuilder(self._query_runner, f"gds.{attr}", self._server_version)
|
|
109
|
+
|
|
110
|
+
def set_database(self, database: str) -> None:
|
|
111
|
+
"""
|
|
112
|
+
Set the database which queries are run against.
|
|
113
|
+
|
|
114
|
+
Parameters
|
|
115
|
+
-------
|
|
116
|
+
database: str
|
|
117
|
+
The name of the database to run queries against.
|
|
118
|
+
"""
|
|
119
|
+
self._db_query_runner.set_database(database)
|
|
120
|
+
|
|
121
|
+
def set_bookmarks(self, bookmarks: Any) -> None:
|
|
122
|
+
"""
|
|
123
|
+
Set Neo4j bookmarks to require a certain state before the next query gets executed
|
|
124
|
+
|
|
125
|
+
Parameters
|
|
126
|
+
----------
|
|
127
|
+
bookmarks: Bookmark(s)
|
|
128
|
+
The Neo4j bookmarks defining the required state
|
|
129
|
+
"""
|
|
130
|
+
self._db_query_runner.set_bookmarks(bookmarks)
|
|
131
|
+
|
|
132
|
+
def database(self) -> Optional[str]:
|
|
133
|
+
"""
|
|
134
|
+
Get the database which queries are run against.
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
The name of the database.
|
|
138
|
+
"""
|
|
139
|
+
return self._db_query_runner.database()
|
|
140
|
+
|
|
141
|
+
def bookmarks(self) -> Optional[Any]:
|
|
142
|
+
"""
|
|
143
|
+
Get the Neo4j bookmarks defining the currently required states for queries to execute
|
|
144
|
+
|
|
145
|
+
Returns
|
|
146
|
+
-------
|
|
147
|
+
The (possibly None) Neo4j bookmarks defining the currently required state
|
|
148
|
+
"""
|
|
149
|
+
return self._db_query_runner.bookmarks()
|
|
150
|
+
|
|
151
|
+
def last_bookmarks(self) -> Optional[Any]:
|
|
152
|
+
"""
|
|
153
|
+
Get the Neo4j bookmarks defining the state following the most recently called query
|
|
154
|
+
|
|
155
|
+
Returns
|
|
156
|
+
-------
|
|
157
|
+
The (possibly None) Neo4j bookmarks defining the state following the most recently called query
|
|
158
|
+
"""
|
|
159
|
+
return self._db_query_runner.last_bookmarks()
|
|
160
|
+
|
|
161
|
+
def driver_config(self) -> Dict[str, Any]:
|
|
162
|
+
"""
|
|
163
|
+
Get the configuration used to create the underlying driver used to make queries to Neo4j.
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
The configuration as a dictionary.
|
|
167
|
+
"""
|
|
168
|
+
return self._query_runner.driver_config()
|
|
169
|
+
|
|
170
|
+
def delete(self) -> bool:
|
|
171
|
+
"""
|
|
172
|
+
Delete a GDS session.
|
|
173
|
+
"""
|
|
174
|
+
self.close()
|
|
175
|
+
return self._delete_fn()
|
|
176
|
+
|
|
177
|
+
def close(self) -> None:
|
|
178
|
+
"""
|
|
179
|
+
Close the GraphDataScience object and release any resources held by it.
|
|
180
|
+
"""
|
|
181
|
+
self._query_runner.close()
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Tuple
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass
|
|
8
|
+
class DbmsConnectionInfo:
|
|
9
|
+
uri: str
|
|
10
|
+
username: str
|
|
11
|
+
password: str
|
|
12
|
+
|
|
13
|
+
def auth(self) -> Tuple[str, str]:
|
|
14
|
+
return self.username, self.password
|