graphdatascience 1.8__tar.gz → 1.9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. {graphdatascience-1.8/graphdatascience.egg-info → graphdatascience-1.9}/PKG-INFO +4 -2
  2. {graphdatascience-1.8 → graphdatascience-1.9}/README.md +1 -0
  3. graphdatascience-1.9/graphdatascience/__init__.py +36 -0
  4. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/algo/algo_proc_runner.py +3 -9
  5. graphdatascience-1.9/graphdatascience/call_parameters.py +8 -0
  6. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/caller_base.py +5 -1
  7. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/endpoints.py +1 -6
  8. graphdatascience-1.9/graphdatascience/gds_session/aura_api.py +236 -0
  9. graphdatascience-1.9/graphdatascience/gds_session/aura_graph_data_science.py +181 -0
  10. graphdatascience-1.9/graphdatascience/gds_session/dbms_connection_info.py +14 -0
  11. graphdatascience-1.9/graphdatascience/gds_session/gds_sessions.py +178 -0
  12. graphdatascience-1.9/graphdatascience/gds_session/schema.py +12 -0
  13. graphdatascience-1.9/graphdatascience/gds_session/session_sizes.py +23 -0
  14. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/graph/graph_alpha_proc_runner.py +6 -8
  15. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/graph/graph_beta_proc_runner.py +8 -8
  16. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/graph/graph_create_result.py +4 -4
  17. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/graph/graph_cypher_runner.py +1 -1
  18. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/graph/graph_endpoints.py +0 -7
  19. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/graph/graph_entity_ops_runner.py +77 -79
  20. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/graph/graph_export_runner.py +5 -8
  21. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/graph/graph_object.py +16 -13
  22. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/graph/graph_proc_runner.py +110 -83
  23. graphdatascience-1.9/graphdatascience/graph/graph_project_runner.py +101 -0
  24. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/graph/graph_sample_runner.py +19 -21
  25. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/graph_data_science.py +36 -96
  26. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/model/graphsage_model.py +8 -5
  27. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/model/link_prediction_model.py +2 -2
  28. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/model/model.py +26 -16
  29. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/model/model_alpha_proc_runner.py +15 -20
  30. graphdatascience-1.9/graphdatascience/model/model_beta_proc_runner.py +31 -0
  31. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/model/model_proc_runner.py +23 -33
  32. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/model/node_classification_model.py +8 -5
  33. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/model/node_regression_model.py +2 -2
  34. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/model/simple_rel_embedding_model.py +44 -96
  35. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/pipeline/classification_training_pipeline.py +17 -12
  36. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/pipeline/lp_pipeline_create_runner.py +3 -3
  37. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/pipeline/lp_training_pipeline.py +8 -9
  38. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/pipeline/nc_pipeline_create_runner.py +3 -3
  39. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/pipeline/nc_training_pipeline.py +6 -5
  40. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/pipeline/nr_pipeline_create_runner.py +3 -3
  41. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/pipeline/nr_training_pipeline.py +12 -11
  42. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/pipeline/pipeline_proc_runner.py +10 -14
  43. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/pipeline/training_pipeline.py +44 -38
  44. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/query_runner/arrow_query_runner.py +104 -26
  45. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/query_runner/aura_db_arrow_query_runner.py +67 -41
  46. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/query_runner/cypher_graph_constructor.py +3 -3
  47. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/query_runner/neo4j_query_runner.py +103 -9
  48. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/query_runner/query_runner.py +25 -5
  49. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/server_version/server_version.py +5 -6
  50. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/system/config_endpoints.py +5 -8
  51. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/system/system_endpoints.py +23 -37
  52. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/topological_lp/topological_lp_alpha_runner.py +2 -2
  53. graphdatascience-1.9/graphdatascience/utils/__init__.py +0 -0
  54. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/utils/util_endpoints.py +11 -13
  55. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/utils/util_proc_runner.py +6 -6
  56. graphdatascience-1.9/graphdatascience/version.py +1 -0
  57. {graphdatascience-1.8 → graphdatascience-1.9/graphdatascience.egg-info}/PKG-INFO +4 -2
  58. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience.egg-info/SOURCES.txt +8 -1
  59. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience.egg-info/requires.txt +2 -1
  60. {graphdatascience-1.8 → graphdatascience-1.9}/requirements/base/base.txt +2 -1
  61. graphdatascience-1.8/graphdatascience/__init__.py +0 -5
  62. graphdatascience-1.8/graphdatascience/graph/graph_alpha_project_runner.py +0 -17
  63. graphdatascience-1.8/graphdatascience/graph/graph_project_runner.py +0 -65
  64. graphdatascience-1.8/graphdatascience/model/model_beta_proc_runner.py +0 -37
  65. graphdatascience-1.8/graphdatascience/version.py +0 -1
  66. {graphdatascience-1.8 → graphdatascience-1.9}/LICENSE +0 -0
  67. {graphdatascience-1.8 → graphdatascience-1.9}/MANIFEST.in +0 -0
  68. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/algo/__init__.py +0 -0
  69. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/algo/algo_endpoints.py +0 -0
  70. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/algo/single_mode_algo_endpoints.py +0 -0
  71. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/call_builder.py +0 -0
  72. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/error/__init__.py +0 -0
  73. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/error/client_only_endpoint.py +0 -0
  74. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/error/cypher_warning_handler.py +0 -0
  75. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/error/endpoint_suggester.py +0 -0
  76. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/error/gds_not_installed.py +0 -0
  77. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/error/illegal_attr_checker.py +0 -0
  78. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/error/unable_to_connect.py +0 -0
  79. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/error/uncallable_namespace.py +0 -0
  80. {graphdatascience-1.8/graphdatascience/graph → graphdatascience-1.9/graphdatascience/gds_session}/__init__.py +0 -0
  81. {graphdatascience-1.8/graphdatascience/model → graphdatascience-1.9/graphdatascience/graph}/__init__.py +0 -0
  82. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/graph/graph_type_check.py +0 -0
  83. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/graph/nx_loader.py +0 -0
  84. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/graph/ogb_loader.py +0 -0
  85. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/ignored_server_endpoints.py +0 -0
  86. {graphdatascience-1.8/graphdatascience/pipeline → graphdatascience-1.9/graphdatascience/model}/__init__.py +0 -0
  87. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/model/model_endpoints.py +0 -0
  88. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/model/model_resolver.py +0 -0
  89. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/model/pipeline_model.py +0 -0
  90. {graphdatascience-1.8/graphdatascience/query_runner → graphdatascience-1.9/graphdatascience/pipeline}/__init__.py +0 -0
  91. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/pipeline/pipeline_alpha_proc_runner.py +0 -0
  92. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/pipeline/pipeline_beta_proc_runner.py +0 -0
  93. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/pipeline/pipeline_endpoints.py +0 -0
  94. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/py.typed +0 -0
  95. {graphdatascience-1.8/graphdatascience/resources → graphdatascience-1.9/graphdatascience/query_runner}/__init__.py +0 -0
  96. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/query_runner/arrow_graph_constructor.py +0 -0
  97. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/query_runner/graph_constructor.py +0 -0
  98. {graphdatascience-1.8/graphdatascience/resources/cora → graphdatascience-1.9/graphdatascience/resources}/__init__.py +0 -0
  99. {graphdatascience-1.8/graphdatascience/resources/imdb → graphdatascience-1.9/graphdatascience/resources/cora}/__init__.py +0 -0
  100. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/resources/cora/cora_nodes.parquet.gzip +0 -0
  101. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/resources/cora/cora_rels.parquet.gzip +0 -0
  102. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/resources/cora/serialize_cora.py +0 -0
  103. {graphdatascience-1.8/graphdatascience/resources/karate → graphdatascience-1.9/graphdatascience/resources/imdb}/__init__.py +0 -0
  104. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/resources/imdb/imdb_acted_in.parquet.gzip +0 -0
  105. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/resources/imdb/imdb_actors.parquet.gzip +0 -0
  106. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/resources/imdb/imdb_directed_in.parquet.gzip +0 -0
  107. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/resources/imdb/imdb_directors.parquet.gzip +0 -0
  108. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/resources/imdb/imdb_movies_with_genre.parquet.gzip +0 -0
  109. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/resources/imdb/imdb_movies_without_genre.parquet.gzip +0 -0
  110. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/resources/imdb/serialize_imdb.py +0 -0
  111. {graphdatascience-1.8/graphdatascience/resources/lastfm → graphdatascience-1.9/graphdatascience/resources/karate}/__init__.py +0 -0
  112. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/resources/karate/karate_club.parquet.gzip +0 -0
  113. {graphdatascience-1.8/graphdatascience/server_version → graphdatascience-1.9/graphdatascience/resources/lastfm}/__init__.py +0 -0
  114. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/resources/lastfm/artist_nodes.parquet.gzip +0 -0
  115. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/resources/lastfm/serialize_lastfm.py +0 -0
  116. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/resources/lastfm/user_friend_df_directed.parquet.gzip +0 -0
  117. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/resources/lastfm/user_listen_artist_rels.parquet.gzip +0 -0
  118. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/resources/lastfm/user_nodes.parquet.gzip +0 -0
  119. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/resources/lastfm/user_tag_artist_rels.parquet.gzip +0 -0
  120. {graphdatascience-1.8/graphdatascience/system → graphdatascience-1.9/graphdatascience/server_version}/__init__.py +0 -0
  121. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/server_version/compatible_with.py +0 -0
  122. {graphdatascience-1.8/graphdatascience/topological_lp → graphdatascience-1.9/graphdatascience/system}/__init__.py +0 -0
  123. {graphdatascience-1.8/graphdatascience/utils → graphdatascience-1.9/graphdatascience/topological_lp}/__init__.py +0 -0
  124. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience/topological_lp/topological_lp_endpoints.py +0 -0
  125. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience.egg-info/dependency_links.txt +0 -0
  126. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience.egg-info/not-zip-safe +0 -0
  127. {graphdatascience-1.8 → graphdatascience-1.9}/graphdatascience.egg-info/top_level.txt +0 -0
  128. {graphdatascience-1.8 → graphdatascience-1.9}/pyproject.toml +0 -0
  129. {graphdatascience-1.8 → graphdatascience-1.9}/requirements/base/networkx.txt +0 -0
  130. {graphdatascience-1.8 → graphdatascience-1.9}/requirements/base/ogb.txt +0 -0
  131. {graphdatascience-1.8 → graphdatascience-1.9}/setup.cfg +0 -0
  132. {graphdatascience-1.8 → graphdatascience-1.9}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: graphdatascience
3
- Version: 1.8
3
+ Version: 1.9
4
4
  Summary: A Python client for the Neo4j Graph Data Science (GDS) library
5
5
  Home-page: https://neo4j.com/product/graph-data-science/
6
6
  Author: Neo4j
@@ -30,10 +30,11 @@ License-File: LICENSE
30
30
  Requires-Dist: multimethod<2.0,>=1.0
31
31
  Requires-Dist: neo4j<6.0,>=4.4.2
32
32
  Requires-Dist: pandas<3.0,>=1.0
33
- Requires-Dist: pyarrow<14.0,>=4.0
33
+ Requires-Dist: pyarrow<15.0,>=10.0
34
34
  Requires-Dist: textdistance<5.0,>=4.0
35
35
  Requires-Dist: tqdm<5.0,>=4.0
36
36
  Requires-Dist: typing-extensions<5.0,>=4.0
37
+ Requires-Dist: requests
37
38
  Provides-Extra: ogb
38
39
  Requires-Dist: ogb<2.0,>=1.0; extra == "ogb"
39
40
  Provides-Extra: networkx
@@ -125,6 +126,7 @@ Full end-to-end examples in Jupyter ready-to-run notebooks can be found in the [
125
126
  * [Sampling, Export and Integration with PyG example](examples/import-sample-export-gnn.ipynb)
126
127
  * [Load data to a projected graph via graph construction](examples/load-data-via-graph-construction.ipynb)
127
128
  * [Heterogeneous Node Classification with HashGNN and Autotuning](https://github.com/neo4j/graph-data-science-client/tree/main/examples/heterogeneous-node-classification-with-hashgnn.ipynb)
129
+ * [Perform inference using pre-trained KGE models](examples/kge-predict-transe-pyg-train.ipynb)
128
130
 
129
131
 
130
132
  ## Documentation
@@ -84,6 +84,7 @@ Full end-to-end examples in Jupyter ready-to-run notebooks can be found in the [
84
84
  * [Sampling, Export and Integration with PyG example](examples/import-sample-export-gnn.ipynb)
85
85
  * [Load data to a projected graph via graph construction](examples/load-data-via-graph-construction.ipynb)
86
86
  * [Heterogeneous Node Classification with HashGNN and Autotuning](https://github.com/neo4j/graph-data-science-client/tree/main/examples/heterogeneous-node-classification-with-hashgnn.ipynb)
87
+ * [Perform inference using pre-trained KGE models](examples/kge-predict-transe-pyg-train.ipynb)
87
88
 
88
89
 
89
90
  ## Documentation
@@ -0,0 +1,36 @@
1
+ from .gds_session.gds_sessions import GdsSessions
2
+ from .graph.graph_create_result import GraphCreateResult
3
+ from .graph.graph_object import Graph
4
+ from .graph_data_science import GraphDataScience
5
+ from .model.graphsage_model import GraphSageModel
6
+ from .model.link_prediction_model import LinkFeature, LPModel
7
+ from .model.node_classification_model import NCModel
8
+ from .model.node_regression_model import NRModel
9
+ from .model.pipeline_model import NodePropertyStep
10
+ from .model.simple_rel_embedding_model import SimpleRelEmbeddingModel
11
+ from .pipeline.lp_training_pipeline import LPTrainingPipeline
12
+ from .pipeline.nc_training_pipeline import NCTrainingPipeline
13
+ from .pipeline.nr_training_pipeline import NRTrainingPipeline
14
+ from .query_runner.query_runner import QueryRunner
15
+ from .server_version.server_version import ServerVersion
16
+ from .version import __version__
17
+
18
+ __all__ = [
19
+ "GraphDataScience",
20
+ "GdsSessions",
21
+ "QueryRunner",
22
+ "__version__",
23
+ "ServerVersion",
24
+ "Graph",
25
+ "GraphCreateResult",
26
+ "LPTrainingPipeline",
27
+ "NCTrainingPipeline",
28
+ "NRTrainingPipeline",
29
+ "NodePropertyStep",
30
+ "LinkFeature",
31
+ "LPModel",
32
+ "NCModel",
33
+ "NRModel",
34
+ "GraphSageModel",
35
+ "SimpleRelEmbeddingModel",
36
+ ]
@@ -7,21 +7,15 @@ from ..error.illegal_attr_checker import IllegalAttrChecker
7
7
  from ..graph.graph_object import Graph
8
8
  from ..graph.graph_type_check import graph_type_check
9
9
  from ..model.graphsage_model import GraphSageModel
10
+ from graphdatascience.call_parameters import CallParameters
10
11
 
11
12
 
12
13
  class AlgoProcRunner(IllegalAttrChecker, ABC):
13
14
  @graph_type_check
14
15
  def _run_procedure(self, G: Graph, config: Dict[str, Any], with_logging: bool = True) -> DataFrame:
15
- query = f"CALL {self._namespace}($graph_name, $config)"
16
+ params = CallParameters(graph_name=G.name(), config=config)
16
17
 
17
- params: Dict[str, Any] = {}
18
- params["graph_name"] = G.name()
19
- params["config"] = config
20
-
21
- if with_logging:
22
- return self._query_runner.run_query_with_logging(query, params)
23
- else:
24
- return self._query_runner.run_query(query, params)
18
+ return self._query_runner.call_procedure(endpoint=self._namespace, params=params, logging=with_logging)
25
19
 
26
20
  @graph_type_check
27
21
  def estimate(self, G: Graph, **config: Any) -> "Series[Any]":
@@ -0,0 +1,8 @@
1
+ from typing import Any, OrderedDict
2
+
3
+
4
+ class CallParameters(OrderedDict[str, Any]):
5
+ # since Python 3.6 also initializing through CallParameters(**kwargs) is order preserving
6
+
7
+ def placeholder_str(self) -> str:
8
+ return ", ".join([f"${k}" for k in self.keys()])
@@ -13,7 +13,11 @@ class CallerBase(ABC):
13
13
  self._server_version = server_version
14
14
 
15
15
  def _raise_suggestive_error_message(self, requested_endpoint: str) -> NoReturn:
16
- list_result = self._query_runner.run_query("CALL gds.list() YIELD name", custom_error=False)
16
+ list_result = self._query_runner.call_procedure(
17
+ endpoint="gds.list",
18
+ yields=["name"],
19
+ custom_error=False,
20
+ )
17
21
  all_endpoints = list_result["name"].tolist()
18
22
 
19
23
  raise SyntaxError(generate_suggestive_error_message(requested_endpoint, all_endpoints))
@@ -3,11 +3,7 @@ from .algo.single_mode_algo_endpoints import (
3
3
  SingleModeAlphaAlgoEndpoints,
4
4
  )
5
5
  from .call_builder import IndirectAlphaCallBuilder, IndirectBetaCallBuilder
6
- from .graph.graph_endpoints import (
7
- GraphAlphaEndpoints,
8
- GraphBetaEndpoints,
9
- GraphEndpoints,
10
- )
6
+ from .graph.graph_endpoints import GraphAlphaEndpoints, GraphBetaEndpoints
11
7
  from .model.model_endpoints import (
12
8
  ModelAlphaEndpoints,
13
9
  ModelBetaEndpoints,
@@ -39,7 +35,6 @@ class DirectEndpoints(
39
35
  SingleModeAlgoEndpoints,
40
36
  DirectSystemEndpoints,
41
37
  DirectUtilEndpoints,
42
- GraphEndpoints,
43
38
  PipelineEndpoints,
44
39
  ModelEndpoints,
45
40
  ConfigEndpoints,
@@ -0,0 +1,236 @@
1
+ from __future__ import annotations
2
+
3
+ import dataclasses
4
+ import logging
5
+ import os
6
+ import time
7
+ from dataclasses import dataclass
8
+ from typing import Any, List, Optional
9
+ from urllib.parse import urlparse
10
+
11
+ import requests as req
12
+ from requests import HTTPError
13
+
14
+ from graphdatascience.version import __version__
15
+
16
+
17
+ @dataclass(repr=True)
18
+ class InstanceDetails:
19
+ id: str
20
+ name: str
21
+ tenant_id: str
22
+ cloud_provider: str
23
+
24
+ @classmethod
25
+ def fromJson(cls, json: dict[str, Any]) -> InstanceDetails:
26
+ return cls(
27
+ id=json["id"],
28
+ name=json["name"],
29
+ tenant_id=json["tenant_id"],
30
+ cloud_provider=json["cloud_provider"],
31
+ )
32
+
33
+
34
+ @dataclass(repr=True)
35
+ class InstanceSpecificDetails(InstanceDetails):
36
+ status: str
37
+ connection_url: str
38
+ memory: str
39
+ type: str
40
+ region: str
41
+
42
+ @classmethod
43
+ def fromJson(cls, json: dict[str, Any]) -> InstanceSpecificDetails:
44
+ return cls(
45
+ id=json["id"],
46
+ name=json["name"],
47
+ tenant_id=json["tenant_id"],
48
+ cloud_provider=json["cloud_provider"],
49
+ status=json["status"],
50
+ connection_url=json.get("connection_url", ""),
51
+ memory=json.get("memory", ""),
52
+ type=json["type"],
53
+ region=json["region"],
54
+ )
55
+
56
+
57
+ @dataclass(repr=True)
58
+ class InstanceCreateDetails:
59
+ id: str
60
+ username: str
61
+ password: str
62
+ connection_url: str
63
+
64
+ @classmethod
65
+ def from_json(cls, json: dict[str, Any]) -> InstanceCreateDetails:
66
+ fields = dataclasses.fields(cls)
67
+ if any(f.name not in json for f in fields):
68
+ raise RuntimeError(f"Missing required field. Expected `{[f.name for f in fields]}` but got `{json}`")
69
+
70
+ return cls(**{f.name: json[f.name] for f in fields})
71
+
72
+
73
+ class AuraApi:
74
+ class AuraAuthToken:
75
+ access_token: str
76
+ expires_in: int
77
+ token_type: str
78
+
79
+ def __init__(self, json: dict[str, Any]) -> None:
80
+ self.access_token = json["access_token"]
81
+ expires_in: int = json["expires_in"]
82
+ self.expires_at = int(time.time()) + expires_in
83
+ self.token_type = json["token_type"]
84
+
85
+ def is_expired(self) -> bool:
86
+ return self.expires_at >= int(time.time())
87
+
88
+ def __init__(self, client_id: str, client_secret: str, tenant_id: Optional[str] = None) -> None:
89
+ self._dev_env = os.environ.get("AURA_ENV")
90
+ self._base_uri = "https://api.neo4j.io" if not self._dev_env else f"https://api-{self._dev_env}.neo4j-dev.io"
91
+ self._credentials = (client_id, client_secret)
92
+ self._token: Optional[AuraApi.AuraAuthToken] = None
93
+ self._logger = logging.getLogger()
94
+ self._tenant_id = tenant_id if tenant_id else self._get_tenant_id()
95
+
96
+ @staticmethod
97
+ def extract_id(uri: str) -> str:
98
+ host = urlparse(uri).hostname
99
+
100
+ if not host:
101
+ raise RuntimeError(f"Could not parse the uri `{uri}`.")
102
+
103
+ return host.split(".")[0].split("-")[0]
104
+
105
+ def create_instance(self, name: str, memory: str, cloud_provider: str, region: str) -> InstanceCreateDetails:
106
+ # TODO should give more control here
107
+ data = {
108
+ "name": name,
109
+ "memory": memory,
110
+ "version": "5",
111
+ "region": region,
112
+ # TODO should be figured out from the tenant details in the future
113
+ "type": self._instance_type(),
114
+ "tenant_id": self._tenant_id,
115
+ "cloud_provider": cloud_provider,
116
+ }
117
+
118
+ response = req.post(
119
+ f"{self._base_uri}/v1/instances",
120
+ json=data,
121
+ headers=self._build_header(),
122
+ )
123
+
124
+ try:
125
+ response.raise_for_status()
126
+ except HTTPError as e:
127
+ print(response.json())
128
+ raise e
129
+
130
+ return InstanceCreateDetails.from_json(response.json()["data"])
131
+
132
+ def delete_instance(self, instance_id: str) -> Optional[InstanceSpecificDetails]:
133
+ response = req.delete(
134
+ f"{self._base_uri}/v1/instances/{instance_id}",
135
+ headers=self._build_header(),
136
+ )
137
+
138
+ if response.status_code == 404:
139
+ return None
140
+
141
+ response.raise_for_status()
142
+
143
+ return InstanceSpecificDetails.fromJson(response.json()["data"])
144
+
145
+ def list_instances(self) -> List[InstanceDetails]:
146
+ response = req.get(
147
+ f"{self._base_uri}/v1/instances",
148
+ headers=self._build_header(),
149
+ params={"tenantId": self._tenant_id},
150
+ )
151
+
152
+ response.raise_for_status()
153
+
154
+ raw_data = response.json()["data"]
155
+
156
+ return [InstanceDetails.fromJson(i) for i in raw_data]
157
+
158
+ def list_instance(self, instance_id: str) -> Optional[InstanceSpecificDetails]:
159
+ response = req.get(
160
+ f"{self._base_uri}/v1/instances/{instance_id}",
161
+ headers=self._build_header(),
162
+ )
163
+
164
+ if response.status_code == 404:
165
+ return None
166
+
167
+ response.raise_for_status()
168
+
169
+ raw_data = response.json()["data"]
170
+
171
+ return InstanceSpecificDetails.fromJson(raw_data)
172
+
173
+ def wait_for_instance_running(
174
+ self, instance_id: str, sleep_time: float = 0.2, max_sleep_time: float = 300
175
+ ) -> Optional[str]:
176
+ waited_time = 0.0
177
+ while waited_time <= max_sleep_time:
178
+ instance = self.list_instance(instance_id)
179
+ if instance is None:
180
+ return "Instance is not found -- please retry"
181
+ elif instance.status in ["deleting", "destroying"]:
182
+ return "Instance is being deleted"
183
+ elif instance.status == "running":
184
+ return None
185
+ else:
186
+ self._logger.debug(
187
+ f"Instance `{instance_id}` is not yet running. "
188
+ f"Current status: {instance.status}. "
189
+ f"Retrying in {sleep_time} seconds..."
190
+ )
191
+ waited_time += sleep_time
192
+ time.sleep(sleep_time)
193
+
194
+ return f"Instance is not running after waiting for {waited_time} seconds"
195
+
196
+ def _get_tenant_id(self) -> str:
197
+ response = req.get(
198
+ f"{self._base_uri}/v1/tenants",
199
+ headers=self._build_header(),
200
+ )
201
+ response.raise_for_status()
202
+
203
+ raw_data = response.json()["data"]
204
+
205
+ if len(raw_data) != 1:
206
+ raise RuntimeError(
207
+ f"This account has access to multiple tenants `{raw_data}`. Please specify which one to use."
208
+ )
209
+
210
+ return raw_data[0]["id"] # type: ignore
211
+
212
+ def _build_header(self) -> dict[str, str]:
213
+ return {"Authorization": f"Bearer {self._auth_token()}", "User-agent": f"neo4j-graphdatascience-v{__version__}"}
214
+
215
+ def _auth_token(self) -> str:
216
+ if self._token is None or self._token.is_expired():
217
+ self._token = self._update_token()
218
+ return self._token.access_token
219
+
220
+ def _update_token(self) -> AuraAuthToken:
221
+ data = {
222
+ "grant_type": "client_credentials",
223
+ }
224
+
225
+ self._logger.debug("Updating oauth token")
226
+
227
+ response = req.post(
228
+ f"{self._base_uri}/oauth/token", data=data, auth=(self._credentials[0], self._credentials[1])
229
+ )
230
+
231
+ response.raise_for_status()
232
+
233
+ return AuraApi.AuraAuthToken(response.json())
234
+
235
+ def _instance_type(self) -> str:
236
+ return "enterprise-ds" if not self._dev_env else "professional-ds"
@@ -0,0 +1,181 @@
1
+ from typing import Any, Callable, Dict, Optional
2
+
3
+ from pandas import DataFrame
4
+
5
+ from graphdatascience.call_builder import IndirectCallBuilder
6
+ from graphdatascience.endpoints import AlphaEndpoints, BetaEndpoints, DirectEndpoints
7
+ from graphdatascience.error.uncallable_namespace import UncallableNamespace
8
+ from graphdatascience.gds_session.dbms_connection_info import DbmsConnectionInfo
9
+ from graphdatascience.graph.graph_proc_runner import GraphRemoteProcRunner
10
+ from graphdatascience.query_runner.arrow_query_runner import ArrowQueryRunner
11
+ from graphdatascience.query_runner.aura_db_arrow_query_runner import (
12
+ AuraDbArrowQueryRunner,
13
+ )
14
+ from graphdatascience.query_runner.neo4j_query_runner import Neo4jQueryRunner
15
+ from graphdatascience.server_version.server_version import ServerVersion
16
+
17
+
18
+ class AuraGraphDataScience(DirectEndpoints, UncallableNamespace):
19
+ """
20
+ Primary API class for interacting with Neo4j AuraDB + Graph Data Science.
21
+ Always bind this object to a variable called `gds`.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ gds_session_connection_info: DbmsConnectionInfo,
27
+ aura_db_connection_info: DbmsConnectionInfo,
28
+ delete_fn: Callable[[], bool],
29
+ arrow_disable_server_verification: bool = True,
30
+ arrow_tls_root_certs: Optional[bytes] = None,
31
+ bookmarks: Optional[Any] = None,
32
+ ):
33
+ gds_neo4j_query_runner = Neo4jQueryRunner.create(
34
+ gds_session_connection_info.uri, gds_session_connection_info.auth(), aura_ds=True
35
+ )
36
+ gds_query_runner = ArrowQueryRunner.create(
37
+ gds_neo4j_query_runner,
38
+ gds_session_connection_info.auth(),
39
+ gds_neo4j_query_runner.encrypted(),
40
+ arrow_disable_server_verification,
41
+ arrow_tls_root_certs,
42
+ )
43
+
44
+ self._server_version = gds_query_runner.server_version()
45
+
46
+ if self._server_version < ServerVersion(2, 6, 0):
47
+ raise RuntimeError(
48
+ f"AuraDB connection info was provided but GDS version {self._server_version} \
49
+ does not support connecting to AuraDB"
50
+ )
51
+
52
+ self._db_query_runner = Neo4jQueryRunner.create(
53
+ aura_db_connection_info.uri,
54
+ aura_db_connection_info.auth(),
55
+ aura_ds=True,
56
+ server_version=self._server_version,
57
+ )
58
+ self._db_query_runner.set_bookmarks(bookmarks)
59
+
60
+ # we need to explicitly set these as the default value is None
61
+ # which signals the driver to use the default configured database
62
+ # from the dbms.
63
+ gds_query_runner.set_database("neo4j")
64
+ self._db_query_runner.set_database("neo4j")
65
+
66
+ self._query_runner = AuraDbArrowQueryRunner(
67
+ gds_query_runner, self._db_query_runner, self._db_query_runner.encrypted(), aura_db_connection_info
68
+ )
69
+
70
+ self._delete_fn = delete_fn
71
+
72
+ super().__init__(self._query_runner, "gds", self._server_version)
73
+
74
+ def run_cypher(
75
+ self, query: str, params: Optional[Dict[str, Any]] = None, database: Optional[str] = None
76
+ ) -> DataFrame:
77
+ """
78
+ Run a Cypher query against the AuraDB instance.
79
+
80
+ Parameters
81
+ ----------
82
+ query: str
83
+ the Cypher query
84
+ params: Dict[str, Any]
85
+ parameters to the query
86
+ database: str
87
+ the database on which to run the query
88
+
89
+ Returns:
90
+ The query result as a DataFrame
91
+ """
92
+ # This will avoid calling valid gds procedures through a raw string
93
+ return self._db_query_runner.run_cypher(query, params, database, False)
94
+
95
+ @property
96
+ def graph(self) -> GraphRemoteProcRunner:
97
+ return GraphRemoteProcRunner(self._query_runner, f"{self._namespace}.graph", self._server_version)
98
+
99
+ @property
100
+ def alpha(self) -> AlphaEndpoints:
101
+ return AlphaEndpoints(self._query_runner, "gds.alpha", self._server_version)
102
+
103
+ @property
104
+ def beta(self) -> BetaEndpoints:
105
+ return BetaEndpoints(self._query_runner, "gds.beta", self._server_version)
106
+
107
+ def __getattr__(self, attr: str) -> IndirectCallBuilder:
108
+ return IndirectCallBuilder(self._query_runner, f"gds.{attr}", self._server_version)
109
+
110
+ def set_database(self, database: str) -> None:
111
+ """
112
+ Set the database which queries are run against.
113
+
114
+ Parameters
115
+ -------
116
+ database: str
117
+ The name of the database to run queries against.
118
+ """
119
+ self._db_query_runner.set_database(database)
120
+
121
+ def set_bookmarks(self, bookmarks: Any) -> None:
122
+ """
123
+ Set Neo4j bookmarks to require a certain state before the next query gets executed
124
+
125
+ Parameters
126
+ ----------
127
+ bookmarks: Bookmark(s)
128
+ The Neo4j bookmarks defining the required state
129
+ """
130
+ self._db_query_runner.set_bookmarks(bookmarks)
131
+
132
+ def database(self) -> Optional[str]:
133
+ """
134
+ Get the database which queries are run against.
135
+
136
+ Returns:
137
+ The name of the database.
138
+ """
139
+ return self._db_query_runner.database()
140
+
141
+ def bookmarks(self) -> Optional[Any]:
142
+ """
143
+ Get the Neo4j bookmarks defining the currently required states for queries to execute
144
+
145
+ Returns
146
+ -------
147
+ The (possibly None) Neo4j bookmarks defining the currently required state
148
+ """
149
+ return self._db_query_runner.bookmarks()
150
+
151
+ def last_bookmarks(self) -> Optional[Any]:
152
+ """
153
+ Get the Neo4j bookmarks defining the state following the most recently called query
154
+
155
+ Returns
156
+ -------
157
+ The (possibly None) Neo4j bookmarks defining the state following the most recently called query
158
+ """
159
+ return self._db_query_runner.last_bookmarks()
160
+
161
+ def driver_config(self) -> Dict[str, Any]:
162
+ """
163
+ Get the configuration used to create the underlying driver used to make queries to Neo4j.
164
+
165
+ Returns:
166
+ The configuration as a dictionary.
167
+ """
168
+ return self._query_runner.driver_config()
169
+
170
+ def delete(self) -> bool:
171
+ """
172
+ Delete a GDS session.
173
+ """
174
+ self.close()
175
+ return self._delete_fn()
176
+
177
+ def close(self) -> None:
178
+ """
179
+ Close the GraphDataScience object and release any resources held by it.
180
+ """
181
+ self._query_runner.close()
@@ -0,0 +1,14 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Tuple
5
+
6
+
7
+ @dataclass
8
+ class DbmsConnectionInfo:
9
+ uri: str
10
+ username: str
11
+ password: str
12
+
13
+ def auth(self) -> Tuple[str, str]:
14
+ return self.username, self.password