graphdatascience 1.9__tar.gz → 1.10a1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (134) hide show
  1. {graphdatascience-1.9/graphdatascience.egg-info → graphdatascience-1.10a1}/PKG-INFO +2 -1
  2. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/__init__.py +1 -1
  3. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/error/endpoint_suggester.py +1 -1
  4. graphdatascience-1.9/graphdatascience/graph/graph_proc_runner.py → graphdatascience-1.10a1/graphdatascience/graph/base_graph_proc_runner.py +10 -21
  5. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/graph/graph_entity_ops_runner.py +10 -2
  6. graphdatascience-1.10a1/graphdatascience/graph/graph_proc_runner.py +15 -0
  7. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/graph/graph_project_runner.py +0 -29
  8. graphdatascience-1.10a1/graphdatascience/graph/graph_remote_proc_runner.py +9 -0
  9. graphdatascience-1.10a1/graphdatascience/graph/graph_remote_project_runner.py +40 -0
  10. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/graph/ogb_loader.py +2 -2
  11. graphdatascience-1.10a1/graphdatascience/query_runner/arrow_endpoint_version.py +35 -0
  12. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/query_runner/arrow_graph_constructor.py +19 -0
  13. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/query_runner/arrow_query_runner.py +34 -13
  14. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/query_runner/aura_db_arrow_query_runner.py +1 -1
  15. graphdatascience-1.10a1/graphdatascience/session/__init__.py +13 -0
  16. {graphdatascience-1.9/graphdatascience/gds_session → graphdatascience-1.10a1/graphdatascience/session}/aura_api.py +77 -13
  17. {graphdatascience-1.9/graphdatascience/gds_session → graphdatascience-1.10a1/graphdatascience/session}/aura_graph_data_science.py +2 -2
  18. {graphdatascience-1.9/graphdatascience/gds_session → graphdatascience-1.10a1/graphdatascience/session}/dbms_connection_info.py +10 -0
  19. {graphdatascience-1.9/graphdatascience/gds_session → graphdatascience-1.10a1/graphdatascience/session}/gds_sessions.py +80 -18
  20. graphdatascience-1.10a1/graphdatascience/session/region_suggester.py +17 -0
  21. {graphdatascience-1.9/graphdatascience/gds_session → graphdatascience-1.10a1/graphdatascience/session}/schema.py +4 -0
  22. {graphdatascience-1.9/graphdatascience/gds_session → graphdatascience-1.10a1/graphdatascience/session}/session_sizes.py +10 -0
  23. graphdatascience-1.10a1/graphdatascience/version.py +1 -0
  24. {graphdatascience-1.9 → graphdatascience-1.10a1/graphdatascience.egg-info}/PKG-INFO +2 -1
  25. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience.egg-info/SOURCES.txt +12 -7
  26. {graphdatascience-1.9 → graphdatascience-1.10a1}/setup.py +1 -0
  27. graphdatascience-1.9/graphdatascience/utils/__init__.py +0 -0
  28. graphdatascience-1.9/graphdatascience/version.py +0 -1
  29. {graphdatascience-1.9 → graphdatascience-1.10a1}/LICENSE +0 -0
  30. {graphdatascience-1.9 → graphdatascience-1.10a1}/MANIFEST.in +0 -0
  31. {graphdatascience-1.9 → graphdatascience-1.10a1}/README.md +0 -0
  32. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/algo/__init__.py +0 -0
  33. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/algo/algo_endpoints.py +0 -0
  34. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/algo/algo_proc_runner.py +0 -0
  35. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/algo/single_mode_algo_endpoints.py +0 -0
  36. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/call_builder.py +0 -0
  37. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/call_parameters.py +0 -0
  38. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/caller_base.py +0 -0
  39. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/endpoints.py +0 -0
  40. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/error/__init__.py +0 -0
  41. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/error/client_only_endpoint.py +0 -0
  42. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/error/cypher_warning_handler.py +0 -0
  43. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/error/gds_not_installed.py +0 -0
  44. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/error/illegal_attr_checker.py +0 -0
  45. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/error/unable_to_connect.py +0 -0
  46. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/error/uncallable_namespace.py +0 -0
  47. {graphdatascience-1.9/graphdatascience/gds_session → graphdatascience-1.10a1/graphdatascience/graph}/__init__.py +0 -0
  48. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/graph/graph_alpha_proc_runner.py +0 -0
  49. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/graph/graph_beta_proc_runner.py +0 -0
  50. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/graph/graph_create_result.py +0 -0
  51. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/graph/graph_cypher_runner.py +0 -0
  52. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/graph/graph_endpoints.py +0 -0
  53. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/graph/graph_export_runner.py +0 -0
  54. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/graph/graph_object.py +0 -0
  55. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/graph/graph_sample_runner.py +0 -0
  56. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/graph/graph_type_check.py +0 -0
  57. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/graph/nx_loader.py +0 -0
  58. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/graph_data_science.py +0 -0
  59. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/ignored_server_endpoints.py +0 -0
  60. {graphdatascience-1.9/graphdatascience/graph → graphdatascience-1.10a1/graphdatascience/model}/__init__.py +0 -0
  61. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/model/graphsage_model.py +0 -0
  62. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/model/link_prediction_model.py +0 -0
  63. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/model/model.py +0 -0
  64. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/model/model_alpha_proc_runner.py +0 -0
  65. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/model/model_beta_proc_runner.py +0 -0
  66. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/model/model_endpoints.py +0 -0
  67. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/model/model_proc_runner.py +0 -0
  68. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/model/model_resolver.py +0 -0
  69. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/model/node_classification_model.py +0 -0
  70. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/model/node_regression_model.py +0 -0
  71. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/model/pipeline_model.py +0 -0
  72. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/model/simple_rel_embedding_model.py +0 -0
  73. {graphdatascience-1.9/graphdatascience/model → graphdatascience-1.10a1/graphdatascience/pipeline}/__init__.py +0 -0
  74. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/pipeline/classification_training_pipeline.py +0 -0
  75. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/pipeline/lp_pipeline_create_runner.py +0 -0
  76. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/pipeline/lp_training_pipeline.py +0 -0
  77. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/pipeline/nc_pipeline_create_runner.py +0 -0
  78. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/pipeline/nc_training_pipeline.py +0 -0
  79. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/pipeline/nr_pipeline_create_runner.py +0 -0
  80. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/pipeline/nr_training_pipeline.py +0 -0
  81. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/pipeline/pipeline_alpha_proc_runner.py +0 -0
  82. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/pipeline/pipeline_beta_proc_runner.py +0 -0
  83. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/pipeline/pipeline_endpoints.py +0 -0
  84. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/pipeline/pipeline_proc_runner.py +0 -0
  85. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/pipeline/training_pipeline.py +0 -0
  86. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/py.typed +0 -0
  87. {graphdatascience-1.9/graphdatascience/pipeline → graphdatascience-1.10a1/graphdatascience/query_runner}/__init__.py +0 -0
  88. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/query_runner/cypher_graph_constructor.py +0 -0
  89. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/query_runner/graph_constructor.py +0 -0
  90. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/query_runner/neo4j_query_runner.py +0 -0
  91. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/query_runner/query_runner.py +0 -0
  92. {graphdatascience-1.9/graphdatascience/query_runner → graphdatascience-1.10a1/graphdatascience/resources}/__init__.py +0 -0
  93. {graphdatascience-1.9/graphdatascience/resources → graphdatascience-1.10a1/graphdatascience/resources/cora}/__init__.py +0 -0
  94. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/resources/cora/cora_nodes.parquet.gzip +0 -0
  95. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/resources/cora/cora_rels.parquet.gzip +0 -0
  96. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/resources/cora/serialize_cora.py +0 -0
  97. {graphdatascience-1.9/graphdatascience/resources/cora → graphdatascience-1.10a1/graphdatascience/resources/imdb}/__init__.py +0 -0
  98. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/resources/imdb/imdb_acted_in.parquet.gzip +0 -0
  99. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/resources/imdb/imdb_actors.parquet.gzip +0 -0
  100. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/resources/imdb/imdb_directed_in.parquet.gzip +0 -0
  101. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/resources/imdb/imdb_directors.parquet.gzip +0 -0
  102. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/resources/imdb/imdb_movies_with_genre.parquet.gzip +0 -0
  103. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/resources/imdb/imdb_movies_without_genre.parquet.gzip +0 -0
  104. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/resources/imdb/serialize_imdb.py +0 -0
  105. {graphdatascience-1.9/graphdatascience/resources/imdb → graphdatascience-1.10a1/graphdatascience/resources/karate}/__init__.py +0 -0
  106. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/resources/karate/karate_club.parquet.gzip +0 -0
  107. {graphdatascience-1.9/graphdatascience/resources/karate → graphdatascience-1.10a1/graphdatascience/resources/lastfm}/__init__.py +0 -0
  108. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/resources/lastfm/artist_nodes.parquet.gzip +0 -0
  109. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/resources/lastfm/serialize_lastfm.py +0 -0
  110. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/resources/lastfm/user_friend_df_directed.parquet.gzip +0 -0
  111. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/resources/lastfm/user_listen_artist_rels.parquet.gzip +0 -0
  112. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/resources/lastfm/user_nodes.parquet.gzip +0 -0
  113. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/resources/lastfm/user_tag_artist_rels.parquet.gzip +0 -0
  114. {graphdatascience-1.9/graphdatascience/resources/lastfm → graphdatascience-1.10a1/graphdatascience/server_version}/__init__.py +0 -0
  115. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/server_version/compatible_with.py +0 -0
  116. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/server_version/server_version.py +0 -0
  117. {graphdatascience-1.9/graphdatascience/server_version → graphdatascience-1.10a1/graphdatascience/system}/__init__.py +0 -0
  118. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/system/config_endpoints.py +0 -0
  119. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/system/system_endpoints.py +0 -0
  120. {graphdatascience-1.9/graphdatascience/system → graphdatascience-1.10a1/graphdatascience/topological_lp}/__init__.py +0 -0
  121. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/topological_lp/topological_lp_alpha_runner.py +0 -0
  122. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/topological_lp/topological_lp_endpoints.py +0 -0
  123. {graphdatascience-1.9/graphdatascience/topological_lp → graphdatascience-1.10a1/graphdatascience/utils}/__init__.py +0 -0
  124. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/utils/util_endpoints.py +0 -0
  125. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/utils/util_proc_runner.py +0 -0
  126. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience.egg-info/dependency_links.txt +0 -0
  127. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience.egg-info/not-zip-safe +0 -0
  128. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience.egg-info/requires.txt +0 -0
  129. {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience.egg-info/top_level.txt +0 -0
  130. {graphdatascience-1.9 → graphdatascience-1.10a1}/pyproject.toml +0 -0
  131. {graphdatascience-1.9 → graphdatascience-1.10a1}/requirements/base/base.txt +0 -0
  132. {graphdatascience-1.9 → graphdatascience-1.10a1}/requirements/base/networkx.txt +0 -0
  133. {graphdatascience-1.9 → graphdatascience-1.10a1}/requirements/base/ogb.txt +0 -0
  134. {graphdatascience-1.9 → graphdatascience-1.10a1}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: graphdatascience
3
- Version: 1.9
3
+ Version: 1.10a1
4
4
  Summary: A Python client for the Neo4j Graph Data Science (GDS) library
5
5
  Home-page: https://neo4j.com/product/graph-data-science/
6
6
  Author: Neo4j
@@ -20,6 +20,7 @@ Classifier: Programming Language :: Python :: 3.8
20
20
  Classifier: Programming Language :: Python :: 3.9
21
21
  Classifier: Programming Language :: Python :: 3.10
22
22
  Classifier: Programming Language :: Python :: 3.11
23
+ Classifier: Programming Language :: Python :: 3.12
23
24
  Classifier: Topic :: Database
24
25
  Classifier: Topic :: Scientific/Engineering
25
26
  Classifier: Topic :: Software Development
@@ -1,4 +1,3 @@
1
- from .gds_session.gds_sessions import GdsSessions
2
1
  from .graph.graph_create_result import GraphCreateResult
3
2
  from .graph.graph_object import Graph
4
3
  from .graph_data_science import GraphDataScience
@@ -13,6 +12,7 @@ from .pipeline.nc_training_pipeline import NCTrainingPipeline
13
12
  from .pipeline.nr_training_pipeline import NRTrainingPipeline
14
13
  from .query_runner.query_runner import QueryRunner
15
14
  from .server_version.server_version import ServerVersion
15
+ from .session.gds_sessions import GdsSessions
16
16
  from .version import __version__
17
17
 
18
18
  __all__ = [
@@ -9,7 +9,7 @@ def generate_suggestive_error_message(requested_endpoint: str, all_endpoints: Li
9
9
  MIN_SIMILARITY_FOR_SUGGESTION = 0.9
10
10
 
11
11
  closest_endpoint = None
12
- curr_max_similarity = 0
12
+ curr_max_similarity = 0.0
13
13
  for ep in all_endpoints:
14
14
  similarity = textdistance.jaro_winkler(requested_endpoint, ep)
15
15
  if similarity >= MIN_SIMILARITY_FOR_SUGGESTION:
@@ -1,6 +1,7 @@
1
1
  import os
2
2
  import pathlib
3
3
  import sys
4
+ import warnings
4
5
  from typing import Any, Dict, List, Optional, Union
5
6
 
6
7
  import pandas as pd
@@ -24,7 +25,6 @@ from .graph_entity_ops_runner import (
24
25
  )
25
26
  from .graph_export_runner import GraphExportRunner
26
27
  from .graph_object import Graph
27
- from .graph_project_runner import GraphProjectRemoteRunner, GraphProjectRunner
28
28
  from .graph_sample_runner import GraphSampleRunner
29
29
  from .graph_type_check import (
30
30
  from_graph_type_check,
@@ -34,7 +34,6 @@ from .graph_type_check import (
34
34
  from .ogb_loader import OGBLLoader, OGBNLoader
35
35
  from graphdatascience.call_parameters import CallParameters
36
36
  from graphdatascience.graph.graph_create_result import GraphCreateResult
37
- from graphdatascience.graph.graph_cypher_runner import GraphCypherRunner
38
37
 
39
38
  Strings = Union[str, List[str]]
40
39
 
@@ -42,6 +41,15 @@ is_neo4j_4_driver = ServerVersion.from_string(neo4j_driver_version) < ServerVers
42
41
 
43
42
 
44
43
  class BaseGraphProcRunner(UncallableNamespace, IllegalAttrChecker):
44
+ def __init__(self, query_runner: Any, namespace: str, server_version: ServerVersion):
45
+ super().__init__(query_runner, namespace, server_version)
46
+ # Pandas 2.2.0 deprecated an API used by ArrowTable.to_pandas() (< pyarrow 14.0)
47
+ warnings.filterwarnings(
48
+ "ignore",
49
+ category=DeprecationWarning,
50
+ message=r"Passing a BlockManager to DataFrame is deprecated",
51
+ )
52
+
45
53
  @staticmethod
46
54
  def _path(package: str, resource: str) -> pathlib.Path:
47
55
  if sys.version_info >= (3, 9):
@@ -558,22 +566,3 @@ class BaseGraphProcRunner(UncallableNamespace, IllegalAttrChecker):
558
566
  endpoint=self._namespace,
559
567
  params=params,
560
568
  ).squeeze()
561
-
562
-
563
- class GraphProcRunner(BaseGraphProcRunner):
564
- @property
565
- def project(self) -> GraphProjectRunner:
566
- self._namespace += ".project"
567
- return GraphProjectRunner(self._query_runner, self._namespace, self._server_version)
568
-
569
- @property
570
- def cypher(self) -> GraphCypherRunner:
571
- self._namespace += ".project"
572
- return GraphCypherRunner(self._query_runner, self._namespace, self._server_version)
573
-
574
-
575
- class GraphRemoteProcRunner(BaseGraphProcRunner):
576
- @property
577
- def project(self) -> GraphProjectRemoteRunner:
578
- self._namespace += ".project.remoteDb"
579
- return GraphProjectRemoteRunner(self._query_runner, self._namespace, self._server_version)
@@ -1,5 +1,6 @@
1
1
  from functools import reduce
2
2
  from typing import Any, Dict, List, Type, Union
3
+ from warnings import filterwarnings
3
4
 
4
5
  import pandas as pd
5
6
  from pandas import DataFrame, Series
@@ -26,6 +27,13 @@ class TopologyDataFrame(DataFrame):
26
27
  return TopologyDataFrame
27
28
 
28
29
  def by_rel_type(self) -> Dict[str, List[List[int]]]:
30
+ # Pandas 2.2.0 deprecated an internal API used by DF.take(indices)
31
+ filterwarnings(
32
+ "ignore",
33
+ category=DeprecationWarning,
34
+ message=r"Passing a BlockManager to TopologyDataFrame is deprecated",
35
+ )
36
+
29
37
  gb = self.groupby("relationshipType", observed=True)
30
38
 
31
39
  output = {}
@@ -242,13 +250,13 @@ class ToUndirectedRunner(IllegalAttrChecker):
242
250
 
243
251
  @graph_type_check
244
252
  def __call__(self, G: Graph, relationship_type: str, mutate_relationship_type: str, **config: Any) -> "Series[Any]":
245
- return self._run_procedure(G, relationship_type, mutate_relationship_type)
253
+ return self._run_procedure(G, relationship_type, mutate_relationship_type, **config)
246
254
 
247
255
  @graph_type_check
248
256
  @compatible_with("estimate", min_inclusive=ServerVersion(2, 3, 0))
249
257
  def estimate(self, G: Graph, relationship_type: str, mutate_relationship_type: str, **config: Any) -> "Series[Any]":
250
258
  self._namespace += ".estimate"
251
- return self._run_procedure(G, relationship_type, mutate_relationship_type)
259
+ return self._run_procedure(G, relationship_type, mutate_relationship_type, **config)
252
260
 
253
261
 
254
262
  class GraphRelationshipsRunner(GraphEntityOpsBaseRunner):
@@ -0,0 +1,15 @@
1
+ from .graph_project_runner import GraphProjectRunner
2
+ from graphdatascience.graph.base_graph_proc_runner import BaseGraphProcRunner
3
+ from graphdatascience.graph.graph_cypher_runner import GraphCypherRunner
4
+
5
+
6
+ class GraphProcRunner(BaseGraphProcRunner):
7
+ @property
8
+ def project(self) -> GraphProjectRunner:
9
+ self._namespace += ".project"
10
+ return GraphProjectRunner(self._query_runner, self._namespace, self._server_version)
11
+
12
+ @property
13
+ def cypher(self) -> GraphCypherRunner:
14
+ self._namespace += ".project"
15
+ return GraphCypherRunner(self._query_runner, self._namespace, self._server_version)
@@ -5,13 +5,10 @@ from typing import Any
5
5
  from pandas import Series
6
6
 
7
7
  from ..error.illegal_attr_checker import IllegalAttrChecker
8
- from ..gds_session.schema import NODE_PROPERTY_SCHEMA, RELATIONSHIP_PROPERTY_SCHEMA
9
8
  from .graph_object import Graph
10
9
  from .graph_type_check import from_graph_type_check
11
10
  from graphdatascience.call_parameters import CallParameters
12
11
  from graphdatascience.graph.graph_create_result import GraphCreateResult
13
- from graphdatascience.server_version.compatible_with import compatible_with
14
- from graphdatascience.server_version.server_version import ServerVersion
15
12
 
16
13
 
17
14
  class GraphProjectRunner(IllegalAttrChecker):
@@ -73,29 +70,3 @@ class GraphProjectBetaRunner(IllegalAttrChecker):
73
70
  ).squeeze()
74
71
 
75
72
  return GraphCreateResult(Graph(graph_name, self._query_runner, self._server_version), result)
76
-
77
-
78
- class GraphProjectRemoteRunner(IllegalAttrChecker):
79
- @compatible_with("project", min_inclusive=ServerVersion(2, 6, 0))
80
- def __call__(self, graph_name: str, query: str, **config: Any) -> GraphCreateResult:
81
- placeholder = "<>" # host and token will be added by query runner
82
- self.map_property_types(config)
83
- params = CallParameters(
84
- graph_name=graph_name,
85
- query=query,
86
- token=placeholder,
87
- host=placeholder,
88
- remote_database=self._query_runner.database(),
89
- config=config,
90
- )
91
- result = self._query_runner.call_procedure(
92
- endpoint=self._namespace,
93
- params=params,
94
- ).squeeze()
95
- return GraphCreateResult(Graph(graph_name, self._query_runner, self._server_version), result)
96
-
97
- @staticmethod
98
- def map_property_types(config: dict[str, Any]) -> None:
99
- for key in [NODE_PROPERTY_SCHEMA, RELATIONSHIP_PROPERTY_SCHEMA]:
100
- if key in config:
101
- config[key] = {k: v.value for k, v in config[key].items()}
@@ -0,0 +1,9 @@
1
+ from graphdatascience.graph.base_graph_proc_runner import BaseGraphProcRunner
2
+ from graphdatascience.graph.graph_remote_project_runner import GraphProjectRemoteRunner
3
+
4
+
5
+ class GraphRemoteProcRunner(BaseGraphProcRunner):
6
+ @property
7
+ def project(self) -> GraphProjectRemoteRunner:
8
+ self._namespace += ".project.remoteDb"
9
+ return GraphProjectRemoteRunner(self._query_runner, self._namespace, self._server_version)
@@ -0,0 +1,40 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+ from ..error.illegal_attr_checker import IllegalAttrChecker
6
+ from ..server_version.compatible_with import compatible_with
7
+ from .graph_object import Graph
8
+ from graphdatascience.call_parameters import CallParameters
9
+ from graphdatascience.graph.graph_create_result import GraphCreateResult
10
+ from graphdatascience.server_version.server_version import ServerVersion
11
+ from graphdatascience.session.schema import (
12
+ NODE_PROPERTY_SCHEMA,
13
+ RELATIONSHIP_PROPERTY_SCHEMA,
14
+ )
15
+
16
+
17
+ class GraphProjectRemoteRunner(IllegalAttrChecker):
18
+ @compatible_with("project", min_inclusive=ServerVersion(2, 6, 0))
19
+ def __call__(self, graph_name: str, query: str, **config: Any) -> GraphCreateResult:
20
+ placeholder = "<>" # host and token will be added by query runner
21
+ self.map_property_types(config)
22
+ params = CallParameters(
23
+ graph_name=graph_name,
24
+ query=query,
25
+ token=placeholder,
26
+ host=placeholder,
27
+ remote_database=self._query_runner.database(),
28
+ config=config,
29
+ )
30
+ result = self._query_runner.call_procedure(
31
+ endpoint=self._namespace,
32
+ params=params,
33
+ ).squeeze()
34
+ return GraphCreateResult(Graph(graph_name, self._query_runner, self._server_version), result)
35
+
36
+ @staticmethod
37
+ def map_property_types(config: dict[str, Any]) -> None:
38
+ for key in [NODE_PROPERTY_SCHEMA, RELATIONSHIP_PROPERTY_SCHEMA]:
39
+ if key in config:
40
+ config[key] = {k: v.value for k, v in config[key].items()}
@@ -314,8 +314,8 @@ class OGBLLoader(OGBLoader):
314
314
  assert source_labels[i] == source_label
315
315
  assert target_labels[i] == target_label
316
316
 
317
- source_ids[i] += node_id_offsets[edges["head_type"][i]]
318
- target_ids[i] += node_id_offsets[edges["tail_type"][i]]
317
+ source_ids[i] += node_id_offsets[edges["head_type"][i]] + edges["head"][i]
318
+ target_ids[i] += node_id_offsets[edges["tail_type"][i]] + edges["tail"][i]
319
319
 
320
320
  rel_types.append(f"{edge_type}_{set_type.upper()}")
321
321
 
@@ -0,0 +1,35 @@
1
+ from __future__ import annotations
2
+
3
+ from enum import Enum
4
+ from typing import List
5
+
6
+
7
+ class ArrowEndpointVersion(Enum):
8
+ ALPHA = ""
9
+ V1 = "v1/"
10
+
11
+ def version(self) -> str:
12
+ return self._name_.lower()
13
+
14
+ def prefix(self) -> str:
15
+ return self._value_
16
+
17
+ @staticmethod
18
+ def from_arrow_info(supported_arrow_versions: List[str]) -> ArrowEndpointVersion:
19
+ # Fallback for pre 2.6.0 servers that do not support versions
20
+ if len(supported_arrow_versions) == 0:
21
+ return ArrowEndpointVersion.ALPHA
22
+
23
+ # If the server supports versioned endpoints, we try v1 first
24
+ if ArrowEndpointVersion.V1.version() in supported_arrow_versions:
25
+ return ArrowEndpointVersion.V1
26
+
27
+ if ArrowEndpointVersion.ALPHA.version() in supported_arrow_versions:
28
+ return ArrowEndpointVersion.ALPHA
29
+
30
+ raise UnsupportedArrowEndpointVersion(supported_arrow_versions)
31
+
32
+
33
+ class UnsupportedArrowEndpointVersion(Exception):
34
+ def __init__(self, server_version: List[str]) -> None:
35
+ super().__init__(self, f"Unsupported Arrow endpoint versions: {server_version}")
@@ -11,6 +11,7 @@ from pandas import DataFrame
11
11
  from pyarrow import Table
12
12
  from tqdm.auto import tqdm
13
13
 
14
+ from .arrow_endpoint_version import ArrowEndpointVersion
14
15
  from .graph_constructor import GraphConstructor
15
16
 
16
17
 
@@ -21,6 +22,7 @@ class ArrowGraphConstructor(GraphConstructor):
21
22
  graph_name: str,
22
23
  flight_client: flight.FlightClient,
23
24
  concurrency: int,
25
+ arrow_endpoint_version: ArrowEndpointVersion,
24
26
  undirected_relationship_types: Optional[List[str]],
25
27
  chunk_size: int = 10_000,
26
28
  ):
@@ -28,6 +30,7 @@ class ArrowGraphConstructor(GraphConstructor):
28
30
  self._concurrency = concurrency
29
31
  self._graph_name = graph_name
30
32
  self._client = flight_client
33
+ self._arrow_endpoint_version = arrow_endpoint_version
31
34
  self._undirected_relationship_types = (
32
35
  [] if undirected_relationship_types is None else undirected_relationship_types
33
36
  )
@@ -81,6 +84,7 @@ class ArrowGraphConstructor(GraphConstructor):
81
84
  return partitioned_dfs
82
85
 
83
86
  def _send_action(self, action_type: str, meta_data: Dict[str, Any]) -> None:
87
+ action_type = self._versioned_action_type(action_type)
84
88
  result = self._client.do_action(flight.Action(action_type, json.dumps(meta_data).encode("utf-8")))
85
89
 
86
90
  # Consume result fully to sanity check and avoid cancelled streams
@@ -93,6 +97,7 @@ class ArrowGraphConstructor(GraphConstructor):
93
97
  table = Table.from_pandas(df)
94
98
  batches = table.to_batches(self._chunk_size)
95
99
  flight_descriptor = {"name": self._graph_name, "entity_type": entity_type}
100
+ flight_descriptor = self._versioned_flight_desriptor(flight_descriptor)
96
101
 
97
102
  # Write schema
98
103
  upload_descriptor = flight.FlightDescriptor.for_command(json.dumps(flight_descriptor).encode("utf-8"))
@@ -117,3 +122,17 @@ class ArrowGraphConstructor(GraphConstructor):
117
122
  if not future.exception():
118
123
  continue
119
124
  raise future.exception() # type: ignore
125
+
126
+ def _versioned_action_type(self, action_type: str) -> str:
127
+ return self._arrow_endpoint_version.prefix() + action_type
128
+
129
+ def _versioned_flight_desriptor(self, flight_descriptor: Dict[str, Any]) -> Dict[str, Any]:
130
+ return (
131
+ flight_descriptor
132
+ if self._arrow_endpoint_version == ArrowEndpointVersion.ALPHA
133
+ else {
134
+ "name": "PUT_MESSAGE",
135
+ "version": ArrowEndpointVersion.V1.version(),
136
+ "body": flight_descriptor,
137
+ }
138
+ )
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import base64
2
4
  import json
3
5
  import time
@@ -5,13 +7,14 @@ import warnings
5
7
  from typing import Any, Dict, List, Optional, Tuple
6
8
 
7
9
  import pyarrow.flight as flight
8
- from pandas import DataFrame, Series
10
+ from pandas import DataFrame
9
11
  from pyarrow import ChunkedArray, Table, chunked_array
10
12
  from pyarrow.flight import ClientMiddleware, ClientMiddlewareFactory
11
13
  from pyarrow.types import is_dictionary # type: ignore
12
14
 
13
15
  from ..call_parameters import CallParameters
14
16
  from ..server_version.server_version import ServerVersion
17
+ from .arrow_endpoint_version import ArrowEndpointVersion
15
18
  from .arrow_graph_constructor import ArrowGraphConstructor
16
19
  from .graph_constructor import GraphConstructor
17
20
  from .query_runner import QueryRunner
@@ -28,18 +31,14 @@ class ArrowQueryRunner(QueryRunner):
28
31
  encrypted: bool = False,
29
32
  disable_server_verification: bool = False,
30
33
  tls_root_certs: Optional[bytes] = None,
31
- ) -> "QueryRunner":
34
+ ) -> QueryRunner:
35
+ arrow_info = (
36
+ fallback_query_runner.call_procedure(endpoint="gds.debug.arrow", custom_error=False).squeeze().to_dict()
37
+ )
32
38
  server_version = fallback_query_runner.server_version()
39
+ listen_address: str = arrow_info.get("advertisedListenAddress", arrow_info["listenAddress"])
40
+ arrow_endpoint_version = ArrowEndpointVersion.from_arrow_info(arrow_info.get("versions", []))
33
41
 
34
- yield_fields = (
35
- ["running", "listenAddress"]
36
- if server_version >= ServerVersion(2, 2, 1)
37
- else ["running", "advertisedListenAddress"]
38
- )
39
- arrow_info: "Series[Any]" = fallback_query_runner.call_procedure(
40
- endpoint="gds.debug.arrow", yields=yield_fields, custom_error=False
41
- ).squeeze()
42
- listen_address: str = arrow_info.get("advertisedListenAddress", arrow_info["listenAddress"]) # type: ignore
43
42
  if arrow_info["running"]:
44
43
  return ArrowQueryRunner(
45
44
  listen_address,
@@ -49,6 +48,7 @@ class ArrowQueryRunner(QueryRunner):
49
48
  encrypted,
50
49
  disable_server_verification,
51
50
  tls_root_certs,
51
+ arrow_endpoint_version,
52
52
  )
53
53
  else:
54
54
  return fallback_query_runner
@@ -62,9 +62,11 @@ class ArrowQueryRunner(QueryRunner):
62
62
  encrypted: bool = False,
63
63
  disable_server_verification: bool = False,
64
64
  tls_root_certs: Optional[bytes] = None,
65
+ arrow_endpoint_version: ArrowEndpointVersion = ArrowEndpointVersion.ALPHA,
65
66
  ):
66
67
  self._fallback_query_runner = fallback_query_runner
67
68
  self._server_version = server_version
69
+ self._arrow_endpoint_version = arrow_endpoint_version
68
70
 
69
71
  host, port_string = uri.split(":")
70
72
 
@@ -272,8 +274,15 @@ class ArrowQueryRunner(QueryRunner):
272
274
  "procedure_name": procedure_name,
273
275
  "configuration": configuration,
274
276
  }
275
- ticket = flight.Ticket(json.dumps(payload).encode("utf-8"))
276
277
 
278
+ if self._arrow_endpoint_version == ArrowEndpointVersion.V1:
279
+ payload = {
280
+ "name": "GET_MESSAGE",
281
+ "version": ArrowEndpointVersion.V1.version(),
282
+ "body": payload,
283
+ }
284
+
285
+ ticket = flight.Ticket(json.dumps(payload).encode("utf-8"))
277
286
  get = self._flight_client.do_get(ticket)
278
287
  arrow_table = get.read_all()
279
288
 
@@ -282,6 +291,13 @@ class ArrowQueryRunner(QueryRunner):
282
291
  new_colum_names = ["nodeLabels" if i == "labels" else i for i in arrow_table.column_names]
283
292
  arrow_table = arrow_table.rename_columns(new_colum_names)
284
293
 
294
+ # Pandas 2.2.0 deprecated an API used by ArrowTable.to_pandas() (< pyarrow 15.0)
295
+ warnings.filterwarnings(
296
+ "ignore",
297
+ category=DeprecationWarning,
298
+ message=r"Passing a BlockManager to DataFrame is deprecated",
299
+ )
300
+
285
301
  return self._sanitize_arrow_table(arrow_table).to_pandas() # type: ignore
286
302
 
287
303
  def create_graph_constructor(
@@ -295,7 +311,12 @@ class ArrowQueryRunner(QueryRunner):
295
311
  )
296
312
 
297
313
  return ArrowGraphConstructor(
298
- database, graph_name, self._flight_client, concurrency, undirected_relationship_types
314
+ database,
315
+ graph_name,
316
+ self._flight_client,
317
+ concurrency,
318
+ self._arrow_endpoint_version,
319
+ undirected_relationship_types,
299
320
  )
300
321
 
301
322
  def _sanitize_arrow_table(self, arrow_table: Table) -> Table:
@@ -5,7 +5,7 @@ from pyarrow import flight
5
5
  from pyarrow.flight import ClientMiddleware, ClientMiddlewareFactory
6
6
 
7
7
  from ..call_parameters import CallParameters
8
- from ..gds_session.dbms_connection_info import DbmsConnectionInfo
8
+ from ..session.dbms_connection_info import DbmsConnectionInfo
9
9
  from .query_runner import QueryRunner
10
10
  from graphdatascience.query_runner.graph_constructor import GraphConstructor
11
11
  from graphdatascience.server_version.server_version import ServerVersion
@@ -0,0 +1,13 @@
1
+ from .dbms_connection_info import DbmsConnectionInfo
2
+ from .gds_sessions import AuraAPICredentials, GdsSessions, SessionInfo
3
+ from .schema import GdsPropertyTypes
4
+ from .session_sizes import SessionSizes
5
+
6
+ __all__ = [
7
+ "GdsSessions",
8
+ "SessionInfo",
9
+ "DbmsConnectionInfo",
10
+ "AuraAPICredentials",
11
+ "SessionSizes",
12
+ "GdsPropertyTypes",
13
+ ]
@@ -4,8 +4,9 @@ import dataclasses
4
4
  import logging
5
5
  import os
6
6
  import time
7
+ from collections import defaultdict
7
8
  from dataclasses import dataclass
8
- from typing import Any, List, Optional
9
+ from typing import Any, List, NamedTuple, Optional, Set
9
10
  from urllib.parse import urlparse
10
11
 
11
12
  import requests as req
@@ -14,7 +15,7 @@ from requests import HTTPError
14
15
  from graphdatascience.version import __version__
15
16
 
16
17
 
17
- @dataclass(repr=True)
18
+ @dataclass(repr=True, frozen=True)
18
19
  class InstanceDetails:
19
20
  id: str
20
21
  name: str
@@ -31,7 +32,7 @@ class InstanceDetails:
31
32
  )
32
33
 
33
34
 
34
- @dataclass(repr=True)
35
+ @dataclass(repr=True, frozen=True)
35
36
  class InstanceSpecificDetails(InstanceDetails):
36
37
  status: str
37
38
  connection_url: str
@@ -54,7 +55,7 @@ class InstanceSpecificDetails(InstanceDetails):
54
55
  )
55
56
 
56
57
 
57
- @dataclass(repr=True)
58
+ @dataclass(repr=True, frozen=True)
58
59
  class InstanceCreateDetails:
59
60
  id: str
60
61
  username: str
@@ -70,6 +71,51 @@ class InstanceCreateDetails:
70
71
  return cls(**{f.name: json[f.name] for f in fields})
71
72
 
72
73
 
74
+ class WaitResult(NamedTuple):
75
+ connection_url: str
76
+ error: str
77
+
78
+ @classmethod
79
+ def from_error(cls, error: str) -> WaitResult:
80
+ return cls(connection_url="", error=error)
81
+
82
+ @classmethod
83
+ def from_connection_url(cls, connection_url: str) -> WaitResult:
84
+ return cls(connection_url=connection_url, error="")
85
+
86
+
87
+ @dataclass(repr=True, frozen=True)
88
+ class TenantDetails:
89
+ id: str
90
+ ds_type: str
91
+ regions_per_provider: dict[str, Set[str]]
92
+
93
+ @classmethod
94
+ def from_json(cls, json: dict[str, Any]) -> TenantDetails:
95
+ regions_per_provider = defaultdict(set)
96
+ instance_types = set()
97
+ ds_type = None
98
+
99
+ for configs in json["instance_configurations"]:
100
+ type = configs["type"]
101
+ if type.split("-")[1] == "ds":
102
+ regions_per_provider[configs["cloud_provider"]].add(configs["region"])
103
+ ds_type = type
104
+ instance_types.add(configs["type"])
105
+
106
+ id = json["id"]
107
+ if not ds_type:
108
+ raise RuntimeError(
109
+ f"Tenant with id `{id}` cannot create DS instances. Available instances are `{instance_types}`."
110
+ )
111
+
112
+ return cls(
113
+ id=id,
114
+ ds_type=ds_type,
115
+ regions_per_provider=regions_per_provider,
116
+ )
117
+
118
+
73
119
  class AuraApi:
74
120
  class AuraAuthToken:
75
121
  access_token: str
@@ -87,11 +133,19 @@ class AuraApi:
87
133
 
88
134
  def __init__(self, client_id: str, client_secret: str, tenant_id: Optional[str] = None) -> None:
89
135
  self._dev_env = os.environ.get("AURA_ENV")
90
- self._base_uri = "https://api.neo4j.io" if not self._dev_env else f"https://api-{self._dev_env}.neo4j-dev.io"
136
+
137
+ if not self._dev_env:
138
+ self._base_uri = "https://api.neo4j.io"
139
+ elif self._dev_env == "staging":
140
+ self._base_uri = "https://api-staging.neo4j.io"
141
+ else:
142
+ self._base_uri = f"https://api-{self._dev_env}.neo4j-dev.io"
143
+
91
144
  self._credentials = (client_id, client_secret)
92
145
  self._token: Optional[AuraApi.AuraAuthToken] = None
93
146
  self._logger = logging.getLogger()
94
147
  self._tenant_id = tenant_id if tenant_id else self._get_tenant_id()
148
+ self._tenant_details: Optional[TenantDetails] = None
95
149
 
96
150
  @staticmethod
97
151
  def extract_id(uri: str) -> str:
@@ -103,14 +157,14 @@ class AuraApi:
103
157
  return host.split(".")[0].split("-")[0]
104
158
 
105
159
  def create_instance(self, name: str, memory: str, cloud_provider: str, region: str) -> InstanceCreateDetails:
106
- # TODO should give more control here
160
+ tenant_details = self.tenant_details()
161
+
107
162
  data = {
108
163
  "name": name,
109
164
  "memory": memory,
110
165
  "version": "5",
111
166
  "region": region,
112
- # TODO should be figured out from the tenant details in the future
113
- "type": self._instance_type(),
167
+ "type": tenant_details.ds_type,
114
168
  "tenant_id": self._tenant_id,
115
169
  "cloud_provider": cloud_provider,
116
170
  }
@@ -172,16 +226,16 @@ class AuraApi:
172
226
 
173
227
  def wait_for_instance_running(
174
228
  self, instance_id: str, sleep_time: float = 0.2, max_sleep_time: float = 300
175
- ) -> Optional[str]:
229
+ ) -> WaitResult:
176
230
  waited_time = 0.0
177
231
  while waited_time <= max_sleep_time:
178
232
  instance = self.list_instance(instance_id)
179
233
  if instance is None:
180
- return "Instance is not found -- please retry"
234
+ return WaitResult.from_error("Instance is not found -- please retry")
181
235
  elif instance.status in ["deleting", "destroying"]:
182
- return "Instance is being deleted"
236
+ return WaitResult.from_error("Instance is being deleted")
183
237
  elif instance.status == "running":
184
- return None
238
+ return WaitResult.from_connection_url(instance.connection_url)
185
239
  else:
186
240
  self._logger.debug(
187
241
  f"Instance `{instance_id}` is not yet running. "
@@ -191,7 +245,7 @@ class AuraApi:
191
245
  waited_time += sleep_time
192
246
  time.sleep(sleep_time)
193
247
 
194
- return f"Instance is not running after waiting for {waited_time} seconds"
248
+ return WaitResult.from_error(f"Instance is not running after waiting for {waited_time} seconds")
195
249
 
196
250
  def _get_tenant_id(self) -> str:
197
251
  response = req.get(
@@ -209,6 +263,16 @@ class AuraApi:
209
263
 
210
264
  return raw_data[0]["id"] # type: ignore
211
265
 
266
+ def tenant_details(self) -> TenantDetails:
267
+ if not self._tenant_details:
268
+ response = req.get(
269
+ f"{self._base_uri}/v1/tenants/{self._tenant_id}",
270
+ headers=self._build_header(),
271
+ )
272
+ response.raise_for_status()
273
+ self._tenant_details = TenantDetails.from_json(response.json()["data"])
274
+ return self._tenant_details
275
+
212
276
  def _build_header(self) -> dict[str, str]:
213
277
  return {"Authorization": f"Bearer {self._auth_token()}", "User-agent": f"neo4j-graphdatascience-v{__version__}"}
214
278