graphdatascience 1.9__tar.gz → 1.10a1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {graphdatascience-1.9/graphdatascience.egg-info → graphdatascience-1.10a1}/PKG-INFO +2 -1
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/__init__.py +1 -1
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/error/endpoint_suggester.py +1 -1
- graphdatascience-1.9/graphdatascience/graph/graph_proc_runner.py → graphdatascience-1.10a1/graphdatascience/graph/base_graph_proc_runner.py +10 -21
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/graph/graph_entity_ops_runner.py +10 -2
- graphdatascience-1.10a1/graphdatascience/graph/graph_proc_runner.py +15 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/graph/graph_project_runner.py +0 -29
- graphdatascience-1.10a1/graphdatascience/graph/graph_remote_proc_runner.py +9 -0
- graphdatascience-1.10a1/graphdatascience/graph/graph_remote_project_runner.py +40 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/graph/ogb_loader.py +2 -2
- graphdatascience-1.10a1/graphdatascience/query_runner/arrow_endpoint_version.py +35 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/query_runner/arrow_graph_constructor.py +19 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/query_runner/arrow_query_runner.py +34 -13
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/query_runner/aura_db_arrow_query_runner.py +1 -1
- graphdatascience-1.10a1/graphdatascience/session/__init__.py +13 -0
- {graphdatascience-1.9/graphdatascience/gds_session → graphdatascience-1.10a1/graphdatascience/session}/aura_api.py +77 -13
- {graphdatascience-1.9/graphdatascience/gds_session → graphdatascience-1.10a1/graphdatascience/session}/aura_graph_data_science.py +2 -2
- {graphdatascience-1.9/graphdatascience/gds_session → graphdatascience-1.10a1/graphdatascience/session}/dbms_connection_info.py +10 -0
- {graphdatascience-1.9/graphdatascience/gds_session → graphdatascience-1.10a1/graphdatascience/session}/gds_sessions.py +80 -18
- graphdatascience-1.10a1/graphdatascience/session/region_suggester.py +17 -0
- {graphdatascience-1.9/graphdatascience/gds_session → graphdatascience-1.10a1/graphdatascience/session}/schema.py +4 -0
- {graphdatascience-1.9/graphdatascience/gds_session → graphdatascience-1.10a1/graphdatascience/session}/session_sizes.py +10 -0
- graphdatascience-1.10a1/graphdatascience/version.py +1 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1/graphdatascience.egg-info}/PKG-INFO +2 -1
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience.egg-info/SOURCES.txt +12 -7
- {graphdatascience-1.9 → graphdatascience-1.10a1}/setup.py +1 -0
- graphdatascience-1.9/graphdatascience/utils/__init__.py +0 -0
- graphdatascience-1.9/graphdatascience/version.py +0 -1
- {graphdatascience-1.9 → graphdatascience-1.10a1}/LICENSE +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/MANIFEST.in +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/README.md +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/algo/__init__.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/algo/algo_endpoints.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/algo/algo_proc_runner.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/algo/single_mode_algo_endpoints.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/call_builder.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/call_parameters.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/caller_base.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/endpoints.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/error/__init__.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/error/client_only_endpoint.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/error/cypher_warning_handler.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/error/gds_not_installed.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/error/illegal_attr_checker.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/error/unable_to_connect.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/error/uncallable_namespace.py +0 -0
- {graphdatascience-1.9/graphdatascience/gds_session → graphdatascience-1.10a1/graphdatascience/graph}/__init__.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/graph/graph_alpha_proc_runner.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/graph/graph_beta_proc_runner.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/graph/graph_create_result.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/graph/graph_cypher_runner.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/graph/graph_endpoints.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/graph/graph_export_runner.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/graph/graph_object.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/graph/graph_sample_runner.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/graph/graph_type_check.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/graph/nx_loader.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/graph_data_science.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/ignored_server_endpoints.py +0 -0
- {graphdatascience-1.9/graphdatascience/graph → graphdatascience-1.10a1/graphdatascience/model}/__init__.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/model/graphsage_model.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/model/link_prediction_model.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/model/model.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/model/model_alpha_proc_runner.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/model/model_beta_proc_runner.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/model/model_endpoints.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/model/model_proc_runner.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/model/model_resolver.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/model/node_classification_model.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/model/node_regression_model.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/model/pipeline_model.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/model/simple_rel_embedding_model.py +0 -0
- {graphdatascience-1.9/graphdatascience/model → graphdatascience-1.10a1/graphdatascience/pipeline}/__init__.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/pipeline/classification_training_pipeline.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/pipeline/lp_pipeline_create_runner.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/pipeline/lp_training_pipeline.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/pipeline/nc_pipeline_create_runner.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/pipeline/nc_training_pipeline.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/pipeline/nr_pipeline_create_runner.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/pipeline/nr_training_pipeline.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/pipeline/pipeline_alpha_proc_runner.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/pipeline/pipeline_beta_proc_runner.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/pipeline/pipeline_endpoints.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/pipeline/pipeline_proc_runner.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/pipeline/training_pipeline.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/py.typed +0 -0
- {graphdatascience-1.9/graphdatascience/pipeline → graphdatascience-1.10a1/graphdatascience/query_runner}/__init__.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/query_runner/cypher_graph_constructor.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/query_runner/graph_constructor.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/query_runner/neo4j_query_runner.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/query_runner/query_runner.py +0 -0
- {graphdatascience-1.9/graphdatascience/query_runner → graphdatascience-1.10a1/graphdatascience/resources}/__init__.py +0 -0
- {graphdatascience-1.9/graphdatascience/resources → graphdatascience-1.10a1/graphdatascience/resources/cora}/__init__.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/resources/cora/cora_nodes.parquet.gzip +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/resources/cora/cora_rels.parquet.gzip +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/resources/cora/serialize_cora.py +0 -0
- {graphdatascience-1.9/graphdatascience/resources/cora → graphdatascience-1.10a1/graphdatascience/resources/imdb}/__init__.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/resources/imdb/imdb_acted_in.parquet.gzip +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/resources/imdb/imdb_actors.parquet.gzip +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/resources/imdb/imdb_directed_in.parquet.gzip +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/resources/imdb/imdb_directors.parquet.gzip +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/resources/imdb/imdb_movies_with_genre.parquet.gzip +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/resources/imdb/imdb_movies_without_genre.parquet.gzip +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/resources/imdb/serialize_imdb.py +0 -0
- {graphdatascience-1.9/graphdatascience/resources/imdb → graphdatascience-1.10a1/graphdatascience/resources/karate}/__init__.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/resources/karate/karate_club.parquet.gzip +0 -0
- {graphdatascience-1.9/graphdatascience/resources/karate → graphdatascience-1.10a1/graphdatascience/resources/lastfm}/__init__.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/resources/lastfm/artist_nodes.parquet.gzip +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/resources/lastfm/serialize_lastfm.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/resources/lastfm/user_friend_df_directed.parquet.gzip +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/resources/lastfm/user_listen_artist_rels.parquet.gzip +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/resources/lastfm/user_nodes.parquet.gzip +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/resources/lastfm/user_tag_artist_rels.parquet.gzip +0 -0
- {graphdatascience-1.9/graphdatascience/resources/lastfm → graphdatascience-1.10a1/graphdatascience/server_version}/__init__.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/server_version/compatible_with.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/server_version/server_version.py +0 -0
- {graphdatascience-1.9/graphdatascience/server_version → graphdatascience-1.10a1/graphdatascience/system}/__init__.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/system/config_endpoints.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/system/system_endpoints.py +0 -0
- {graphdatascience-1.9/graphdatascience/system → graphdatascience-1.10a1/graphdatascience/topological_lp}/__init__.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/topological_lp/topological_lp_alpha_runner.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/topological_lp/topological_lp_endpoints.py +0 -0
- {graphdatascience-1.9/graphdatascience/topological_lp → graphdatascience-1.10a1/graphdatascience/utils}/__init__.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/utils/util_endpoints.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/utils/util_proc_runner.py +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience.egg-info/dependency_links.txt +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience.egg-info/not-zip-safe +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience.egg-info/requires.txt +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience.egg-info/top_level.txt +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/pyproject.toml +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/requirements/base/base.txt +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/requirements/base/networkx.txt +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/requirements/base/ogb.txt +0 -0
- {graphdatascience-1.9 → graphdatascience-1.10a1}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: graphdatascience
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.10a1
|
|
4
4
|
Summary: A Python client for the Neo4j Graph Data Science (GDS) library
|
|
5
5
|
Home-page: https://neo4j.com/product/graph-data-science/
|
|
6
6
|
Author: Neo4j
|
|
@@ -20,6 +20,7 @@ Classifier: Programming Language :: Python :: 3.8
|
|
|
20
20
|
Classifier: Programming Language :: Python :: 3.9
|
|
21
21
|
Classifier: Programming Language :: Python :: 3.10
|
|
22
22
|
Classifier: Programming Language :: Python :: 3.11
|
|
23
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
23
24
|
Classifier: Topic :: Database
|
|
24
25
|
Classifier: Topic :: Scientific/Engineering
|
|
25
26
|
Classifier: Topic :: Software Development
|
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
from .gds_session.gds_sessions import GdsSessions
|
|
2
1
|
from .graph.graph_create_result import GraphCreateResult
|
|
3
2
|
from .graph.graph_object import Graph
|
|
4
3
|
from .graph_data_science import GraphDataScience
|
|
@@ -13,6 +12,7 @@ from .pipeline.nc_training_pipeline import NCTrainingPipeline
|
|
|
13
12
|
from .pipeline.nr_training_pipeline import NRTrainingPipeline
|
|
14
13
|
from .query_runner.query_runner import QueryRunner
|
|
15
14
|
from .server_version.server_version import ServerVersion
|
|
15
|
+
from .session.gds_sessions import GdsSessions
|
|
16
16
|
from .version import __version__
|
|
17
17
|
|
|
18
18
|
__all__ = [
|
{graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/error/endpoint_suggester.py
RENAMED
|
@@ -9,7 +9,7 @@ def generate_suggestive_error_message(requested_endpoint: str, all_endpoints: Li
|
|
|
9
9
|
MIN_SIMILARITY_FOR_SUGGESTION = 0.9
|
|
10
10
|
|
|
11
11
|
closest_endpoint = None
|
|
12
|
-
curr_max_similarity = 0
|
|
12
|
+
curr_max_similarity = 0.0
|
|
13
13
|
for ep in all_endpoints:
|
|
14
14
|
similarity = textdistance.jaro_winkler(requested_endpoint, ep)
|
|
15
15
|
if similarity >= MIN_SIMILARITY_FOR_SUGGESTION:
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import pathlib
|
|
3
3
|
import sys
|
|
4
|
+
import warnings
|
|
4
5
|
from typing import Any, Dict, List, Optional, Union
|
|
5
6
|
|
|
6
7
|
import pandas as pd
|
|
@@ -24,7 +25,6 @@ from .graph_entity_ops_runner import (
|
|
|
24
25
|
)
|
|
25
26
|
from .graph_export_runner import GraphExportRunner
|
|
26
27
|
from .graph_object import Graph
|
|
27
|
-
from .graph_project_runner import GraphProjectRemoteRunner, GraphProjectRunner
|
|
28
28
|
from .graph_sample_runner import GraphSampleRunner
|
|
29
29
|
from .graph_type_check import (
|
|
30
30
|
from_graph_type_check,
|
|
@@ -34,7 +34,6 @@ from .graph_type_check import (
|
|
|
34
34
|
from .ogb_loader import OGBLLoader, OGBNLoader
|
|
35
35
|
from graphdatascience.call_parameters import CallParameters
|
|
36
36
|
from graphdatascience.graph.graph_create_result import GraphCreateResult
|
|
37
|
-
from graphdatascience.graph.graph_cypher_runner import GraphCypherRunner
|
|
38
37
|
|
|
39
38
|
Strings = Union[str, List[str]]
|
|
40
39
|
|
|
@@ -42,6 +41,15 @@ is_neo4j_4_driver = ServerVersion.from_string(neo4j_driver_version) < ServerVers
|
|
|
42
41
|
|
|
43
42
|
|
|
44
43
|
class BaseGraphProcRunner(UncallableNamespace, IllegalAttrChecker):
|
|
44
|
+
def __init__(self, query_runner: Any, namespace: str, server_version: ServerVersion):
|
|
45
|
+
super().__init__(query_runner, namespace, server_version)
|
|
46
|
+
# Pandas 2.2.0 deprecated an API used by ArrowTable.to_pandas() (< pyarrow 14.0)
|
|
47
|
+
warnings.filterwarnings(
|
|
48
|
+
"ignore",
|
|
49
|
+
category=DeprecationWarning,
|
|
50
|
+
message=r"Passing a BlockManager to DataFrame is deprecated",
|
|
51
|
+
)
|
|
52
|
+
|
|
45
53
|
@staticmethod
|
|
46
54
|
def _path(package: str, resource: str) -> pathlib.Path:
|
|
47
55
|
if sys.version_info >= (3, 9):
|
|
@@ -558,22 +566,3 @@ class BaseGraphProcRunner(UncallableNamespace, IllegalAttrChecker):
|
|
|
558
566
|
endpoint=self._namespace,
|
|
559
567
|
params=params,
|
|
560
568
|
).squeeze()
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
class GraphProcRunner(BaseGraphProcRunner):
|
|
564
|
-
@property
|
|
565
|
-
def project(self) -> GraphProjectRunner:
|
|
566
|
-
self._namespace += ".project"
|
|
567
|
-
return GraphProjectRunner(self._query_runner, self._namespace, self._server_version)
|
|
568
|
-
|
|
569
|
-
@property
|
|
570
|
-
def cypher(self) -> GraphCypherRunner:
|
|
571
|
-
self._namespace += ".project"
|
|
572
|
-
return GraphCypherRunner(self._query_runner, self._namespace, self._server_version)
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
class GraphRemoteProcRunner(BaseGraphProcRunner):
|
|
576
|
-
@property
|
|
577
|
-
def project(self) -> GraphProjectRemoteRunner:
|
|
578
|
-
self._namespace += ".project.remoteDb"
|
|
579
|
-
return GraphProjectRemoteRunner(self._query_runner, self._namespace, self._server_version)
|
{graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/graph/graph_entity_ops_runner.py
RENAMED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from functools import reduce
|
|
2
2
|
from typing import Any, Dict, List, Type, Union
|
|
3
|
+
from warnings import filterwarnings
|
|
3
4
|
|
|
4
5
|
import pandas as pd
|
|
5
6
|
from pandas import DataFrame, Series
|
|
@@ -26,6 +27,13 @@ class TopologyDataFrame(DataFrame):
|
|
|
26
27
|
return TopologyDataFrame
|
|
27
28
|
|
|
28
29
|
def by_rel_type(self) -> Dict[str, List[List[int]]]:
|
|
30
|
+
# Pandas 2.2.0 deprecated an internal API used by DF.take(indices)
|
|
31
|
+
filterwarnings(
|
|
32
|
+
"ignore",
|
|
33
|
+
category=DeprecationWarning,
|
|
34
|
+
message=r"Passing a BlockManager to TopologyDataFrame is deprecated",
|
|
35
|
+
)
|
|
36
|
+
|
|
29
37
|
gb = self.groupby("relationshipType", observed=True)
|
|
30
38
|
|
|
31
39
|
output = {}
|
|
@@ -242,13 +250,13 @@ class ToUndirectedRunner(IllegalAttrChecker):
|
|
|
242
250
|
|
|
243
251
|
@graph_type_check
|
|
244
252
|
def __call__(self, G: Graph, relationship_type: str, mutate_relationship_type: str, **config: Any) -> "Series[Any]":
|
|
245
|
-
return self._run_procedure(G, relationship_type, mutate_relationship_type)
|
|
253
|
+
return self._run_procedure(G, relationship_type, mutate_relationship_type, **config)
|
|
246
254
|
|
|
247
255
|
@graph_type_check
|
|
248
256
|
@compatible_with("estimate", min_inclusive=ServerVersion(2, 3, 0))
|
|
249
257
|
def estimate(self, G: Graph, relationship_type: str, mutate_relationship_type: str, **config: Any) -> "Series[Any]":
|
|
250
258
|
self._namespace += ".estimate"
|
|
251
|
-
return self._run_procedure(G, relationship_type, mutate_relationship_type)
|
|
259
|
+
return self._run_procedure(G, relationship_type, mutate_relationship_type, **config)
|
|
252
260
|
|
|
253
261
|
|
|
254
262
|
class GraphRelationshipsRunner(GraphEntityOpsBaseRunner):
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from .graph_project_runner import GraphProjectRunner
|
|
2
|
+
from graphdatascience.graph.base_graph_proc_runner import BaseGraphProcRunner
|
|
3
|
+
from graphdatascience.graph.graph_cypher_runner import GraphCypherRunner
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class GraphProcRunner(BaseGraphProcRunner):
|
|
7
|
+
@property
|
|
8
|
+
def project(self) -> GraphProjectRunner:
|
|
9
|
+
self._namespace += ".project"
|
|
10
|
+
return GraphProjectRunner(self._query_runner, self._namespace, self._server_version)
|
|
11
|
+
|
|
12
|
+
@property
|
|
13
|
+
def cypher(self) -> GraphCypherRunner:
|
|
14
|
+
self._namespace += ".project"
|
|
15
|
+
return GraphCypherRunner(self._query_runner, self._namespace, self._server_version)
|
{graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/graph/graph_project_runner.py
RENAMED
|
@@ -5,13 +5,10 @@ from typing import Any
|
|
|
5
5
|
from pandas import Series
|
|
6
6
|
|
|
7
7
|
from ..error.illegal_attr_checker import IllegalAttrChecker
|
|
8
|
-
from ..gds_session.schema import NODE_PROPERTY_SCHEMA, RELATIONSHIP_PROPERTY_SCHEMA
|
|
9
8
|
from .graph_object import Graph
|
|
10
9
|
from .graph_type_check import from_graph_type_check
|
|
11
10
|
from graphdatascience.call_parameters import CallParameters
|
|
12
11
|
from graphdatascience.graph.graph_create_result import GraphCreateResult
|
|
13
|
-
from graphdatascience.server_version.compatible_with import compatible_with
|
|
14
|
-
from graphdatascience.server_version.server_version import ServerVersion
|
|
15
12
|
|
|
16
13
|
|
|
17
14
|
class GraphProjectRunner(IllegalAttrChecker):
|
|
@@ -73,29 +70,3 @@ class GraphProjectBetaRunner(IllegalAttrChecker):
|
|
|
73
70
|
).squeeze()
|
|
74
71
|
|
|
75
72
|
return GraphCreateResult(Graph(graph_name, self._query_runner, self._server_version), result)
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
class GraphProjectRemoteRunner(IllegalAttrChecker):
|
|
79
|
-
@compatible_with("project", min_inclusive=ServerVersion(2, 6, 0))
|
|
80
|
-
def __call__(self, graph_name: str, query: str, **config: Any) -> GraphCreateResult:
|
|
81
|
-
placeholder = "<>" # host and token will be added by query runner
|
|
82
|
-
self.map_property_types(config)
|
|
83
|
-
params = CallParameters(
|
|
84
|
-
graph_name=graph_name,
|
|
85
|
-
query=query,
|
|
86
|
-
token=placeholder,
|
|
87
|
-
host=placeholder,
|
|
88
|
-
remote_database=self._query_runner.database(),
|
|
89
|
-
config=config,
|
|
90
|
-
)
|
|
91
|
-
result = self._query_runner.call_procedure(
|
|
92
|
-
endpoint=self._namespace,
|
|
93
|
-
params=params,
|
|
94
|
-
).squeeze()
|
|
95
|
-
return GraphCreateResult(Graph(graph_name, self._query_runner, self._server_version), result)
|
|
96
|
-
|
|
97
|
-
@staticmethod
|
|
98
|
-
def map_property_types(config: dict[str, Any]) -> None:
|
|
99
|
-
for key in [NODE_PROPERTY_SCHEMA, RELATIONSHIP_PROPERTY_SCHEMA]:
|
|
100
|
-
if key in config:
|
|
101
|
-
config[key] = {k: v.value for k, v in config[key].items()}
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
from graphdatascience.graph.base_graph_proc_runner import BaseGraphProcRunner
|
|
2
|
+
from graphdatascience.graph.graph_remote_project_runner import GraphProjectRemoteRunner
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class GraphRemoteProcRunner(BaseGraphProcRunner):
|
|
6
|
+
@property
|
|
7
|
+
def project(self) -> GraphProjectRemoteRunner:
|
|
8
|
+
self._namespace += ".project.remoteDb"
|
|
9
|
+
return GraphProjectRemoteRunner(self._query_runner, self._namespace, self._server_version)
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from ..error.illegal_attr_checker import IllegalAttrChecker
|
|
6
|
+
from ..server_version.compatible_with import compatible_with
|
|
7
|
+
from .graph_object import Graph
|
|
8
|
+
from graphdatascience.call_parameters import CallParameters
|
|
9
|
+
from graphdatascience.graph.graph_create_result import GraphCreateResult
|
|
10
|
+
from graphdatascience.server_version.server_version import ServerVersion
|
|
11
|
+
from graphdatascience.session.schema import (
|
|
12
|
+
NODE_PROPERTY_SCHEMA,
|
|
13
|
+
RELATIONSHIP_PROPERTY_SCHEMA,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class GraphProjectRemoteRunner(IllegalAttrChecker):
|
|
18
|
+
@compatible_with("project", min_inclusive=ServerVersion(2, 6, 0))
|
|
19
|
+
def __call__(self, graph_name: str, query: str, **config: Any) -> GraphCreateResult:
|
|
20
|
+
placeholder = "<>" # host and token will be added by query runner
|
|
21
|
+
self.map_property_types(config)
|
|
22
|
+
params = CallParameters(
|
|
23
|
+
graph_name=graph_name,
|
|
24
|
+
query=query,
|
|
25
|
+
token=placeholder,
|
|
26
|
+
host=placeholder,
|
|
27
|
+
remote_database=self._query_runner.database(),
|
|
28
|
+
config=config,
|
|
29
|
+
)
|
|
30
|
+
result = self._query_runner.call_procedure(
|
|
31
|
+
endpoint=self._namespace,
|
|
32
|
+
params=params,
|
|
33
|
+
).squeeze()
|
|
34
|
+
return GraphCreateResult(Graph(graph_name, self._query_runner, self._server_version), result)
|
|
35
|
+
|
|
36
|
+
@staticmethod
|
|
37
|
+
def map_property_types(config: dict[str, Any]) -> None:
|
|
38
|
+
for key in [NODE_PROPERTY_SCHEMA, RELATIONSHIP_PROPERTY_SCHEMA]:
|
|
39
|
+
if key in config:
|
|
40
|
+
config[key] = {k: v.value for k, v in config[key].items()}
|
|
@@ -314,8 +314,8 @@ class OGBLLoader(OGBLoader):
|
|
|
314
314
|
assert source_labels[i] == source_label
|
|
315
315
|
assert target_labels[i] == target_label
|
|
316
316
|
|
|
317
|
-
source_ids[i] += node_id_offsets[edges["head_type"][i]]
|
|
318
|
-
target_ids[i] += node_id_offsets[edges["tail_type"][i]]
|
|
317
|
+
source_ids[i] += node_id_offsets[edges["head_type"][i]] + edges["head"][i]
|
|
318
|
+
target_ids[i] += node_id_offsets[edges["tail_type"][i]] + edges["tail"][i]
|
|
319
319
|
|
|
320
320
|
rel_types.append(f"{edge_type}_{set_type.upper()}")
|
|
321
321
|
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from enum import Enum
|
|
4
|
+
from typing import List
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ArrowEndpointVersion(Enum):
|
|
8
|
+
ALPHA = ""
|
|
9
|
+
V1 = "v1/"
|
|
10
|
+
|
|
11
|
+
def version(self) -> str:
|
|
12
|
+
return self._name_.lower()
|
|
13
|
+
|
|
14
|
+
def prefix(self) -> str:
|
|
15
|
+
return self._value_
|
|
16
|
+
|
|
17
|
+
@staticmethod
|
|
18
|
+
def from_arrow_info(supported_arrow_versions: List[str]) -> ArrowEndpointVersion:
|
|
19
|
+
# Fallback for pre 2.6.0 servers that do not support versions
|
|
20
|
+
if len(supported_arrow_versions) == 0:
|
|
21
|
+
return ArrowEndpointVersion.ALPHA
|
|
22
|
+
|
|
23
|
+
# If the server supports versioned endpoints, we try v1 first
|
|
24
|
+
if ArrowEndpointVersion.V1.version() in supported_arrow_versions:
|
|
25
|
+
return ArrowEndpointVersion.V1
|
|
26
|
+
|
|
27
|
+
if ArrowEndpointVersion.ALPHA.version() in supported_arrow_versions:
|
|
28
|
+
return ArrowEndpointVersion.ALPHA
|
|
29
|
+
|
|
30
|
+
raise UnsupportedArrowEndpointVersion(supported_arrow_versions)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class UnsupportedArrowEndpointVersion(Exception):
|
|
34
|
+
def __init__(self, server_version: List[str]) -> None:
|
|
35
|
+
super().__init__(self, f"Unsupported Arrow endpoint versions: {server_version}")
|
|
@@ -11,6 +11,7 @@ from pandas import DataFrame
|
|
|
11
11
|
from pyarrow import Table
|
|
12
12
|
from tqdm.auto import tqdm
|
|
13
13
|
|
|
14
|
+
from .arrow_endpoint_version import ArrowEndpointVersion
|
|
14
15
|
from .graph_constructor import GraphConstructor
|
|
15
16
|
|
|
16
17
|
|
|
@@ -21,6 +22,7 @@ class ArrowGraphConstructor(GraphConstructor):
|
|
|
21
22
|
graph_name: str,
|
|
22
23
|
flight_client: flight.FlightClient,
|
|
23
24
|
concurrency: int,
|
|
25
|
+
arrow_endpoint_version: ArrowEndpointVersion,
|
|
24
26
|
undirected_relationship_types: Optional[List[str]],
|
|
25
27
|
chunk_size: int = 10_000,
|
|
26
28
|
):
|
|
@@ -28,6 +30,7 @@ class ArrowGraphConstructor(GraphConstructor):
|
|
|
28
30
|
self._concurrency = concurrency
|
|
29
31
|
self._graph_name = graph_name
|
|
30
32
|
self._client = flight_client
|
|
33
|
+
self._arrow_endpoint_version = arrow_endpoint_version
|
|
31
34
|
self._undirected_relationship_types = (
|
|
32
35
|
[] if undirected_relationship_types is None else undirected_relationship_types
|
|
33
36
|
)
|
|
@@ -81,6 +84,7 @@ class ArrowGraphConstructor(GraphConstructor):
|
|
|
81
84
|
return partitioned_dfs
|
|
82
85
|
|
|
83
86
|
def _send_action(self, action_type: str, meta_data: Dict[str, Any]) -> None:
|
|
87
|
+
action_type = self._versioned_action_type(action_type)
|
|
84
88
|
result = self._client.do_action(flight.Action(action_type, json.dumps(meta_data).encode("utf-8")))
|
|
85
89
|
|
|
86
90
|
# Consume result fully to sanity check and avoid cancelled streams
|
|
@@ -93,6 +97,7 @@ class ArrowGraphConstructor(GraphConstructor):
|
|
|
93
97
|
table = Table.from_pandas(df)
|
|
94
98
|
batches = table.to_batches(self._chunk_size)
|
|
95
99
|
flight_descriptor = {"name": self._graph_name, "entity_type": entity_type}
|
|
100
|
+
flight_descriptor = self._versioned_flight_desriptor(flight_descriptor)
|
|
96
101
|
|
|
97
102
|
# Write schema
|
|
98
103
|
upload_descriptor = flight.FlightDescriptor.for_command(json.dumps(flight_descriptor).encode("utf-8"))
|
|
@@ -117,3 +122,17 @@ class ArrowGraphConstructor(GraphConstructor):
|
|
|
117
122
|
if not future.exception():
|
|
118
123
|
continue
|
|
119
124
|
raise future.exception() # type: ignore
|
|
125
|
+
|
|
126
|
+
def _versioned_action_type(self, action_type: str) -> str:
|
|
127
|
+
return self._arrow_endpoint_version.prefix() + action_type
|
|
128
|
+
|
|
129
|
+
def _versioned_flight_desriptor(self, flight_descriptor: Dict[str, Any]) -> Dict[str, Any]:
|
|
130
|
+
return (
|
|
131
|
+
flight_descriptor
|
|
132
|
+
if self._arrow_endpoint_version == ArrowEndpointVersion.ALPHA
|
|
133
|
+
else {
|
|
134
|
+
"name": "PUT_MESSAGE",
|
|
135
|
+
"version": ArrowEndpointVersion.V1.version(),
|
|
136
|
+
"body": flight_descriptor,
|
|
137
|
+
}
|
|
138
|
+
)
|
{graphdatascience-1.9 → graphdatascience-1.10a1}/graphdatascience/query_runner/arrow_query_runner.py
RENAMED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import base64
|
|
2
4
|
import json
|
|
3
5
|
import time
|
|
@@ -5,13 +7,14 @@ import warnings
|
|
|
5
7
|
from typing import Any, Dict, List, Optional, Tuple
|
|
6
8
|
|
|
7
9
|
import pyarrow.flight as flight
|
|
8
|
-
from pandas import DataFrame
|
|
10
|
+
from pandas import DataFrame
|
|
9
11
|
from pyarrow import ChunkedArray, Table, chunked_array
|
|
10
12
|
from pyarrow.flight import ClientMiddleware, ClientMiddlewareFactory
|
|
11
13
|
from pyarrow.types import is_dictionary # type: ignore
|
|
12
14
|
|
|
13
15
|
from ..call_parameters import CallParameters
|
|
14
16
|
from ..server_version.server_version import ServerVersion
|
|
17
|
+
from .arrow_endpoint_version import ArrowEndpointVersion
|
|
15
18
|
from .arrow_graph_constructor import ArrowGraphConstructor
|
|
16
19
|
from .graph_constructor import GraphConstructor
|
|
17
20
|
from .query_runner import QueryRunner
|
|
@@ -28,18 +31,14 @@ class ArrowQueryRunner(QueryRunner):
|
|
|
28
31
|
encrypted: bool = False,
|
|
29
32
|
disable_server_verification: bool = False,
|
|
30
33
|
tls_root_certs: Optional[bytes] = None,
|
|
31
|
-
) ->
|
|
34
|
+
) -> QueryRunner:
|
|
35
|
+
arrow_info = (
|
|
36
|
+
fallback_query_runner.call_procedure(endpoint="gds.debug.arrow", custom_error=False).squeeze().to_dict()
|
|
37
|
+
)
|
|
32
38
|
server_version = fallback_query_runner.server_version()
|
|
39
|
+
listen_address: str = arrow_info.get("advertisedListenAddress", arrow_info["listenAddress"])
|
|
40
|
+
arrow_endpoint_version = ArrowEndpointVersion.from_arrow_info(arrow_info.get("versions", []))
|
|
33
41
|
|
|
34
|
-
yield_fields = (
|
|
35
|
-
["running", "listenAddress"]
|
|
36
|
-
if server_version >= ServerVersion(2, 2, 1)
|
|
37
|
-
else ["running", "advertisedListenAddress"]
|
|
38
|
-
)
|
|
39
|
-
arrow_info: "Series[Any]" = fallback_query_runner.call_procedure(
|
|
40
|
-
endpoint="gds.debug.arrow", yields=yield_fields, custom_error=False
|
|
41
|
-
).squeeze()
|
|
42
|
-
listen_address: str = arrow_info.get("advertisedListenAddress", arrow_info["listenAddress"]) # type: ignore
|
|
43
42
|
if arrow_info["running"]:
|
|
44
43
|
return ArrowQueryRunner(
|
|
45
44
|
listen_address,
|
|
@@ -49,6 +48,7 @@ class ArrowQueryRunner(QueryRunner):
|
|
|
49
48
|
encrypted,
|
|
50
49
|
disable_server_verification,
|
|
51
50
|
tls_root_certs,
|
|
51
|
+
arrow_endpoint_version,
|
|
52
52
|
)
|
|
53
53
|
else:
|
|
54
54
|
return fallback_query_runner
|
|
@@ -62,9 +62,11 @@ class ArrowQueryRunner(QueryRunner):
|
|
|
62
62
|
encrypted: bool = False,
|
|
63
63
|
disable_server_verification: bool = False,
|
|
64
64
|
tls_root_certs: Optional[bytes] = None,
|
|
65
|
+
arrow_endpoint_version: ArrowEndpointVersion = ArrowEndpointVersion.ALPHA,
|
|
65
66
|
):
|
|
66
67
|
self._fallback_query_runner = fallback_query_runner
|
|
67
68
|
self._server_version = server_version
|
|
69
|
+
self._arrow_endpoint_version = arrow_endpoint_version
|
|
68
70
|
|
|
69
71
|
host, port_string = uri.split(":")
|
|
70
72
|
|
|
@@ -272,8 +274,15 @@ class ArrowQueryRunner(QueryRunner):
|
|
|
272
274
|
"procedure_name": procedure_name,
|
|
273
275
|
"configuration": configuration,
|
|
274
276
|
}
|
|
275
|
-
ticket = flight.Ticket(json.dumps(payload).encode("utf-8"))
|
|
276
277
|
|
|
278
|
+
if self._arrow_endpoint_version == ArrowEndpointVersion.V1:
|
|
279
|
+
payload = {
|
|
280
|
+
"name": "GET_MESSAGE",
|
|
281
|
+
"version": ArrowEndpointVersion.V1.version(),
|
|
282
|
+
"body": payload,
|
|
283
|
+
}
|
|
284
|
+
|
|
285
|
+
ticket = flight.Ticket(json.dumps(payload).encode("utf-8"))
|
|
277
286
|
get = self._flight_client.do_get(ticket)
|
|
278
287
|
arrow_table = get.read_all()
|
|
279
288
|
|
|
@@ -282,6 +291,13 @@ class ArrowQueryRunner(QueryRunner):
|
|
|
282
291
|
new_colum_names = ["nodeLabels" if i == "labels" else i for i in arrow_table.column_names]
|
|
283
292
|
arrow_table = arrow_table.rename_columns(new_colum_names)
|
|
284
293
|
|
|
294
|
+
# Pandas 2.2.0 deprecated an API used by ArrowTable.to_pandas() (< pyarrow 15.0)
|
|
295
|
+
warnings.filterwarnings(
|
|
296
|
+
"ignore",
|
|
297
|
+
category=DeprecationWarning,
|
|
298
|
+
message=r"Passing a BlockManager to DataFrame is deprecated",
|
|
299
|
+
)
|
|
300
|
+
|
|
285
301
|
return self._sanitize_arrow_table(arrow_table).to_pandas() # type: ignore
|
|
286
302
|
|
|
287
303
|
def create_graph_constructor(
|
|
@@ -295,7 +311,12 @@ class ArrowQueryRunner(QueryRunner):
|
|
|
295
311
|
)
|
|
296
312
|
|
|
297
313
|
return ArrowGraphConstructor(
|
|
298
|
-
database,
|
|
314
|
+
database,
|
|
315
|
+
graph_name,
|
|
316
|
+
self._flight_client,
|
|
317
|
+
concurrency,
|
|
318
|
+
self._arrow_endpoint_version,
|
|
319
|
+
undirected_relationship_types,
|
|
299
320
|
)
|
|
300
321
|
|
|
301
322
|
def _sanitize_arrow_table(self, arrow_table: Table) -> Table:
|
|
@@ -5,7 +5,7 @@ from pyarrow import flight
|
|
|
5
5
|
from pyarrow.flight import ClientMiddleware, ClientMiddlewareFactory
|
|
6
6
|
|
|
7
7
|
from ..call_parameters import CallParameters
|
|
8
|
-
from ..
|
|
8
|
+
from ..session.dbms_connection_info import DbmsConnectionInfo
|
|
9
9
|
from .query_runner import QueryRunner
|
|
10
10
|
from graphdatascience.query_runner.graph_constructor import GraphConstructor
|
|
11
11
|
from graphdatascience.server_version.server_version import ServerVersion
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from .dbms_connection_info import DbmsConnectionInfo
|
|
2
|
+
from .gds_sessions import AuraAPICredentials, GdsSessions, SessionInfo
|
|
3
|
+
from .schema import GdsPropertyTypes
|
|
4
|
+
from .session_sizes import SessionSizes
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
"GdsSessions",
|
|
8
|
+
"SessionInfo",
|
|
9
|
+
"DbmsConnectionInfo",
|
|
10
|
+
"AuraAPICredentials",
|
|
11
|
+
"SessionSizes",
|
|
12
|
+
"GdsPropertyTypes",
|
|
13
|
+
]
|
|
@@ -4,8 +4,9 @@ import dataclasses
|
|
|
4
4
|
import logging
|
|
5
5
|
import os
|
|
6
6
|
import time
|
|
7
|
+
from collections import defaultdict
|
|
7
8
|
from dataclasses import dataclass
|
|
8
|
-
from typing import Any, List, Optional
|
|
9
|
+
from typing import Any, List, NamedTuple, Optional, Set
|
|
9
10
|
from urllib.parse import urlparse
|
|
10
11
|
|
|
11
12
|
import requests as req
|
|
@@ -14,7 +15,7 @@ from requests import HTTPError
|
|
|
14
15
|
from graphdatascience.version import __version__
|
|
15
16
|
|
|
16
17
|
|
|
17
|
-
@dataclass(repr=True)
|
|
18
|
+
@dataclass(repr=True, frozen=True)
|
|
18
19
|
class InstanceDetails:
|
|
19
20
|
id: str
|
|
20
21
|
name: str
|
|
@@ -31,7 +32,7 @@ class InstanceDetails:
|
|
|
31
32
|
)
|
|
32
33
|
|
|
33
34
|
|
|
34
|
-
@dataclass(repr=True)
|
|
35
|
+
@dataclass(repr=True, frozen=True)
|
|
35
36
|
class InstanceSpecificDetails(InstanceDetails):
|
|
36
37
|
status: str
|
|
37
38
|
connection_url: str
|
|
@@ -54,7 +55,7 @@ class InstanceSpecificDetails(InstanceDetails):
|
|
|
54
55
|
)
|
|
55
56
|
|
|
56
57
|
|
|
57
|
-
@dataclass(repr=True)
|
|
58
|
+
@dataclass(repr=True, frozen=True)
|
|
58
59
|
class InstanceCreateDetails:
|
|
59
60
|
id: str
|
|
60
61
|
username: str
|
|
@@ -70,6 +71,51 @@ class InstanceCreateDetails:
|
|
|
70
71
|
return cls(**{f.name: json[f.name] for f in fields})
|
|
71
72
|
|
|
72
73
|
|
|
74
|
+
class WaitResult(NamedTuple):
|
|
75
|
+
connection_url: str
|
|
76
|
+
error: str
|
|
77
|
+
|
|
78
|
+
@classmethod
|
|
79
|
+
def from_error(cls, error: str) -> WaitResult:
|
|
80
|
+
return cls(connection_url="", error=error)
|
|
81
|
+
|
|
82
|
+
@classmethod
|
|
83
|
+
def from_connection_url(cls, connection_url: str) -> WaitResult:
|
|
84
|
+
return cls(connection_url=connection_url, error="")
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@dataclass(repr=True, frozen=True)
|
|
88
|
+
class TenantDetails:
|
|
89
|
+
id: str
|
|
90
|
+
ds_type: str
|
|
91
|
+
regions_per_provider: dict[str, Set[str]]
|
|
92
|
+
|
|
93
|
+
@classmethod
|
|
94
|
+
def from_json(cls, json: dict[str, Any]) -> TenantDetails:
|
|
95
|
+
regions_per_provider = defaultdict(set)
|
|
96
|
+
instance_types = set()
|
|
97
|
+
ds_type = None
|
|
98
|
+
|
|
99
|
+
for configs in json["instance_configurations"]:
|
|
100
|
+
type = configs["type"]
|
|
101
|
+
if type.split("-")[1] == "ds":
|
|
102
|
+
regions_per_provider[configs["cloud_provider"]].add(configs["region"])
|
|
103
|
+
ds_type = type
|
|
104
|
+
instance_types.add(configs["type"])
|
|
105
|
+
|
|
106
|
+
id = json["id"]
|
|
107
|
+
if not ds_type:
|
|
108
|
+
raise RuntimeError(
|
|
109
|
+
f"Tenant with id `{id}` cannot create DS instances. Available instances are `{instance_types}`."
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
return cls(
|
|
113
|
+
id=id,
|
|
114
|
+
ds_type=ds_type,
|
|
115
|
+
regions_per_provider=regions_per_provider,
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
|
|
73
119
|
class AuraApi:
|
|
74
120
|
class AuraAuthToken:
|
|
75
121
|
access_token: str
|
|
@@ -87,11 +133,19 @@ class AuraApi:
|
|
|
87
133
|
|
|
88
134
|
def __init__(self, client_id: str, client_secret: str, tenant_id: Optional[str] = None) -> None:
|
|
89
135
|
self._dev_env = os.environ.get("AURA_ENV")
|
|
90
|
-
|
|
136
|
+
|
|
137
|
+
if not self._dev_env:
|
|
138
|
+
self._base_uri = "https://api.neo4j.io"
|
|
139
|
+
elif self._dev_env == "staging":
|
|
140
|
+
self._base_uri = "https://api-staging.neo4j.io"
|
|
141
|
+
else:
|
|
142
|
+
self._base_uri = f"https://api-{self._dev_env}.neo4j-dev.io"
|
|
143
|
+
|
|
91
144
|
self._credentials = (client_id, client_secret)
|
|
92
145
|
self._token: Optional[AuraApi.AuraAuthToken] = None
|
|
93
146
|
self._logger = logging.getLogger()
|
|
94
147
|
self._tenant_id = tenant_id if tenant_id else self._get_tenant_id()
|
|
148
|
+
self._tenant_details: Optional[TenantDetails] = None
|
|
95
149
|
|
|
96
150
|
@staticmethod
|
|
97
151
|
def extract_id(uri: str) -> str:
|
|
@@ -103,14 +157,14 @@ class AuraApi:
|
|
|
103
157
|
return host.split(".")[0].split("-")[0]
|
|
104
158
|
|
|
105
159
|
def create_instance(self, name: str, memory: str, cloud_provider: str, region: str) -> InstanceCreateDetails:
|
|
106
|
-
|
|
160
|
+
tenant_details = self.tenant_details()
|
|
161
|
+
|
|
107
162
|
data = {
|
|
108
163
|
"name": name,
|
|
109
164
|
"memory": memory,
|
|
110
165
|
"version": "5",
|
|
111
166
|
"region": region,
|
|
112
|
-
|
|
113
|
-
"type": self._instance_type(),
|
|
167
|
+
"type": tenant_details.ds_type,
|
|
114
168
|
"tenant_id": self._tenant_id,
|
|
115
169
|
"cloud_provider": cloud_provider,
|
|
116
170
|
}
|
|
@@ -172,16 +226,16 @@ class AuraApi:
|
|
|
172
226
|
|
|
173
227
|
def wait_for_instance_running(
|
|
174
228
|
self, instance_id: str, sleep_time: float = 0.2, max_sleep_time: float = 300
|
|
175
|
-
) ->
|
|
229
|
+
) -> WaitResult:
|
|
176
230
|
waited_time = 0.0
|
|
177
231
|
while waited_time <= max_sleep_time:
|
|
178
232
|
instance = self.list_instance(instance_id)
|
|
179
233
|
if instance is None:
|
|
180
|
-
return "Instance is not found -- please retry"
|
|
234
|
+
return WaitResult.from_error("Instance is not found -- please retry")
|
|
181
235
|
elif instance.status in ["deleting", "destroying"]:
|
|
182
|
-
return "Instance is being deleted"
|
|
236
|
+
return WaitResult.from_error("Instance is being deleted")
|
|
183
237
|
elif instance.status == "running":
|
|
184
|
-
return
|
|
238
|
+
return WaitResult.from_connection_url(instance.connection_url)
|
|
185
239
|
else:
|
|
186
240
|
self._logger.debug(
|
|
187
241
|
f"Instance `{instance_id}` is not yet running. "
|
|
@@ -191,7 +245,7 @@ class AuraApi:
|
|
|
191
245
|
waited_time += sleep_time
|
|
192
246
|
time.sleep(sleep_time)
|
|
193
247
|
|
|
194
|
-
return f"Instance is not running after waiting for {waited_time} seconds"
|
|
248
|
+
return WaitResult.from_error(f"Instance is not running after waiting for {waited_time} seconds")
|
|
195
249
|
|
|
196
250
|
def _get_tenant_id(self) -> str:
|
|
197
251
|
response = req.get(
|
|
@@ -209,6 +263,16 @@ class AuraApi:
|
|
|
209
263
|
|
|
210
264
|
return raw_data[0]["id"] # type: ignore
|
|
211
265
|
|
|
266
|
+
def tenant_details(self) -> TenantDetails:
|
|
267
|
+
if not self._tenant_details:
|
|
268
|
+
response = req.get(
|
|
269
|
+
f"{self._base_uri}/v1/tenants/{self._tenant_id}",
|
|
270
|
+
headers=self._build_header(),
|
|
271
|
+
)
|
|
272
|
+
response.raise_for_status()
|
|
273
|
+
self._tenant_details = TenantDetails.from_json(response.json()["data"])
|
|
274
|
+
return self._tenant_details
|
|
275
|
+
|
|
212
276
|
def _build_header(self) -> dict[str, str]:
|
|
213
277
|
return {"Authorization": f"Bearer {self._auth_token()}", "User-agent": f"neo4j-graphdatascience-v{__version__}"}
|
|
214
278
|
|