graphdatascience 1.11a1__tar.gz → 1.11a3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {graphdatascience-1.11a1/graphdatascience.egg-info → graphdatascience-1.11a3}/PKG-INFO +3 -2
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/graph/base_graph_proc_runner.py +3 -3
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/graph/graph_entity_ops_runner.py +13 -8
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/graph/graph_remote_proc_runner.py +0 -1
- graphdatascience-1.11a3/graphdatascience/graph/graph_remote_project_runner.py +47 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/query_runner/arrow_graph_constructor.py +7 -38
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/query_runner/arrow_query_runner.py +27 -181
- graphdatascience-1.11a3/graphdatascience/query_runner/aura_db_query_runner.py +170 -0
- graphdatascience-1.11a3/graphdatascience/query_runner/gds_arrow_client.py +254 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/session/__init__.py +2 -3
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/session/aura_api.py +21 -11
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/session/aura_api_responses.py +27 -17
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/session/aura_graph_data_science.py +11 -5
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/session/aurads_sessions.py +18 -8
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/session/dedicated_sessions.py +35 -7
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/session/session_info.py +2 -1
- graphdatascience-1.11a3/graphdatascience/session/session_sizes.py +70 -0
- graphdatascience-1.11a3/graphdatascience/version.py +1 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3/graphdatascience.egg-info}/PKG-INFO +3 -2
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience.egg-info/SOURCES.txt +2 -2
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience.egg-info/requires.txt +2 -1
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/requirements/base/base.txt +2 -1
- graphdatascience-1.11a1/graphdatascience/graph/graph_remote_project_runner.py +0 -38
- graphdatascience-1.11a1/graphdatascience/query_runner/aura_db_arrow_query_runner.py +0 -184
- graphdatascience-1.11a1/graphdatascience/session/schema.py +0 -13
- graphdatascience-1.11a1/graphdatascience/session/session_sizes.py +0 -31
- graphdatascience-1.11a1/graphdatascience/version.py +0 -1
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/LICENSE +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/MANIFEST.in +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/README.md +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/__init__.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/algo/__init__.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/algo/algo_endpoints.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/algo/algo_proc_runner.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/algo/single_mode_algo_endpoints.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/call_builder.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/call_parameters.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/caller_base.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/endpoints.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/error/__init__.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/error/client_only_endpoint.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/error/cypher_warning_handler.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/error/endpoint_suggester.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/error/gds_not_installed.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/error/illegal_attr_checker.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/error/unable_to_connect.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/error/uncallable_namespace.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/graph/__init__.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/graph/graph_alpha_proc_runner.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/graph/graph_beta_proc_runner.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/graph/graph_create_result.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/graph/graph_cypher_runner.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/graph/graph_endpoints.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/graph/graph_export_runner.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/graph/graph_object.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/graph/graph_proc_runner.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/graph/graph_project_runner.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/graph/graph_sample_runner.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/graph/graph_type_check.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/graph/nx_loader.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/graph/ogb_loader.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/graph_data_science.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/ignored_server_endpoints.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/model/__init__.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/model/graphsage_model.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/model/link_prediction_model.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/model/model.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/model/model_alpha_proc_runner.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/model/model_beta_proc_runner.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/model/model_endpoints.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/model/model_proc_runner.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/model/model_resolver.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/model/node_classification_model.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/model/node_regression_model.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/model/pipeline_model.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/model/simple_rel_embedding_model.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/pipeline/__init__.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/pipeline/classification_training_pipeline.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/pipeline/lp_pipeline_create_runner.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/pipeline/lp_training_pipeline.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/pipeline/nc_pipeline_create_runner.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/pipeline/nc_training_pipeline.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/pipeline/nr_pipeline_create_runner.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/pipeline/nr_training_pipeline.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/pipeline/pipeline_alpha_proc_runner.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/pipeline/pipeline_beta_proc_runner.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/pipeline/pipeline_endpoints.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/pipeline/pipeline_proc_runner.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/pipeline/training_pipeline.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/py.typed +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/query_runner/__init__.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/query_runner/arrow_endpoint_version.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/query_runner/cypher_graph_constructor.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/query_runner/graph_constructor.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/query_runner/neo4j_query_runner.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/query_runner/query_runner.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/resources/__init__.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/resources/cora/__init__.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/resources/cora/cora_nodes.parquet.gzip +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/resources/cora/cora_rels.parquet.gzip +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/resources/cora/serialize_cora.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/resources/imdb/__init__.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/resources/imdb/imdb_acted_in.parquet.gzip +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/resources/imdb/imdb_actors.parquet.gzip +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/resources/imdb/imdb_directed_in.parquet.gzip +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/resources/imdb/imdb_directors.parquet.gzip +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/resources/imdb/imdb_movies_with_genre.parquet.gzip +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/resources/imdb/imdb_movies_without_genre.parquet.gzip +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/resources/imdb/serialize_imdb.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/resources/karate/__init__.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/resources/karate/karate_club.parquet.gzip +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/resources/lastfm/__init__.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/resources/lastfm/artist_nodes.parquet.gzip +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/resources/lastfm/serialize_lastfm.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/resources/lastfm/user_friend_df_directed.parquet.gzip +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/resources/lastfm/user_listen_artist_rels.parquet.gzip +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/resources/lastfm/user_nodes.parquet.gzip +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/resources/lastfm/user_tag_artist_rels.parquet.gzip +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/server_version/__init__.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/server_version/compatible_with.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/server_version/server_version.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/session/algorithm_category.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/session/dbms_connection_info.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/session/gds_sessions.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/session/region_suggester.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/system/__init__.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/system/config_endpoints.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/system/system_endpoints.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/topological_lp/__init__.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/topological_lp/topological_lp_alpha_runner.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/topological_lp/topological_lp_endpoints.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/utils/__init__.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/utils/util_endpoints.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/utils/util_proc_runner.py +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience.egg-info/dependency_links.txt +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience.egg-info/not-zip-safe +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience.egg-info/top_level.txt +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/pyproject.toml +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/requirements/base/networkx.txt +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/requirements/base/ogb.txt +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/setup.cfg +0 -0
- {graphdatascience-1.11a1 → graphdatascience-1.11a3}/setup.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: graphdatascience
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.11a3
|
|
4
4
|
Summary: A Python client for the Neo4j Graph Data Science (GDS) library
|
|
5
5
|
Home-page: https://neo4j.com/product/graph-data-science/
|
|
6
6
|
Author: Neo4j
|
|
@@ -30,8 +30,9 @@ Description-Content-Type: text/markdown
|
|
|
30
30
|
License-File: LICENSE
|
|
31
31
|
Requires-Dist: multimethod<2.0,>=1.0
|
|
32
32
|
Requires-Dist: neo4j<6.0,>=4.4.2
|
|
33
|
+
Requires-Dist: numpy<2.0
|
|
33
34
|
Requires-Dist: pandas<3.0,>=1.0
|
|
34
|
-
Requires-Dist: pyarrow<
|
|
35
|
+
Requires-Dist: pyarrow<17.0,>=14.0.1
|
|
35
36
|
Requires-Dist: textdistance<5.0,>=4.0
|
|
36
37
|
Requires-Dist: tqdm<5.0,>=4.0
|
|
37
38
|
Requires-Dist: typing-extensions<5.0,>=4.0
|
{graphdatascience-1.11a1 → graphdatascience-1.11a3}/graphdatascience/graph/base_graph_proc_runner.py
RENAMED
|
@@ -15,12 +15,12 @@ from ..error.uncallable_namespace import UncallableNamespace
|
|
|
15
15
|
from ..server_version.compatible_with import compatible_with
|
|
16
16
|
from ..server_version.server_version import ServerVersion
|
|
17
17
|
from .graph_entity_ops_runner import (
|
|
18
|
-
GraphElementPropertyRunner,
|
|
19
18
|
GraphLabelRunner,
|
|
20
19
|
GraphNodePropertiesRunner,
|
|
21
20
|
GraphNodePropertyRunner,
|
|
22
21
|
GraphPropertyRunner,
|
|
23
22
|
GraphRelationshipPropertiesRunner,
|
|
23
|
+
GraphRelationshipPropertyRunner,
|
|
24
24
|
GraphRelationshipRunner,
|
|
25
25
|
GraphRelationshipsRunner,
|
|
26
26
|
)
|
|
@@ -390,9 +390,9 @@ class BaseGraphProcRunner(UncallableNamespace, IllegalAttrChecker):
|
|
|
390
390
|
return GraphNodePropertiesRunner(self._query_runner, self._namespace, self._server_version)
|
|
391
391
|
|
|
392
392
|
@property
|
|
393
|
-
def relationshipProperty(self) ->
|
|
393
|
+
def relationshipProperty(self) -> GraphRelationshipPropertyRunner:
|
|
394
394
|
self._namespace += ".relationshipProperty"
|
|
395
|
-
return
|
|
395
|
+
return GraphRelationshipPropertyRunner(self._query_runner, self._namespace, self._server_version)
|
|
396
396
|
|
|
397
397
|
@property
|
|
398
398
|
def relationshipProperties(self) -> GraphRelationshipPropertiesRunner:
|
|
@@ -70,13 +70,6 @@ class GraphEntityOpsBaseRunner(UncallableNamespace, IllegalAttrChecker):
|
|
|
70
70
|
)
|
|
71
71
|
|
|
72
72
|
|
|
73
|
-
class GraphElementPropertyRunner(GraphEntityOpsBaseRunner):
|
|
74
|
-
@compatible_with("stream", min_inclusive=ServerVersion(2, 2, 0))
|
|
75
|
-
def stream(self, G: Graph, node_properties: str, node_labels: Strings = ["*"], **config: Any) -> DataFrame:
|
|
76
|
-
self._namespace += ".stream"
|
|
77
|
-
return self._handle_properties(G, node_properties, node_labels, config)
|
|
78
|
-
|
|
79
|
-
|
|
80
73
|
class GraphNodePropertyRunner(GraphEntityOpsBaseRunner):
|
|
81
74
|
@compatible_with("stream", min_inclusive=ServerVersion(2, 2, 0))
|
|
82
75
|
@filter_id_func_deprecation_warning()
|
|
@@ -177,7 +170,7 @@ class GraphNodePropertiesRunner(GraphEntityOpsBaseRunner):
|
|
|
177
170
|
return reduce(add_property, db_node_properties, query_prefix)
|
|
178
171
|
|
|
179
172
|
@compatible_with("write", min_inclusive=ServerVersion(2, 2, 0))
|
|
180
|
-
def write(self, G: Graph, node_properties:
|
|
173
|
+
def write(self, G: Graph, node_properties: Strings, node_labels: Strings = ["*"], **config: Any) -> "Series[Any]":
|
|
181
174
|
self._namespace += ".write"
|
|
182
175
|
return self._handle_properties(G, node_properties, node_labels, config).squeeze() # type: ignore
|
|
183
176
|
|
|
@@ -197,6 +190,16 @@ class GraphNodePropertiesRunner(GraphEntityOpsBaseRunner):
|
|
|
197
190
|
).squeeze()
|
|
198
191
|
|
|
199
192
|
|
|
193
|
+
class GraphRelationshipPropertyRunner(GraphEntityOpsBaseRunner):
|
|
194
|
+
@compatible_with("stream", min_inclusive=ServerVersion(2, 2, 0))
|
|
195
|
+
def stream(
|
|
196
|
+
self, G: Graph, relationship_property: str, relationship_types: Strings = ["*"], **config: Any
|
|
197
|
+
) -> DataFrame:
|
|
198
|
+
self._namespace += ".stream"
|
|
199
|
+
relationship_types = [relationship_types] if isinstance(relationship_types, str) else relationship_types
|
|
200
|
+
return self._handle_properties(G, relationship_property, relationship_types, config)
|
|
201
|
+
|
|
202
|
+
|
|
200
203
|
class GraphRelationshipPropertiesRunner(GraphEntityOpsBaseRunner):
|
|
201
204
|
@compatible_with("stream", min_inclusive=ServerVersion(2, 2, 0))
|
|
202
205
|
def stream(
|
|
@@ -209,6 +212,8 @@ class GraphRelationshipPropertiesRunner(GraphEntityOpsBaseRunner):
|
|
|
209
212
|
) -> DataFrame:
|
|
210
213
|
self._namespace += ".stream"
|
|
211
214
|
|
|
215
|
+
relationship_types = [relationship_types] if isinstance(relationship_types, str) else relationship_types
|
|
216
|
+
|
|
212
217
|
result = self._handle_properties(G, relationship_properties, relationship_types, config)
|
|
213
218
|
|
|
214
219
|
# new format was requested, but the query was run via Cypher
|
|
@@ -5,5 +5,4 @@ from graphdatascience.graph.graph_remote_project_runner import GraphProjectRemot
|
|
|
5
5
|
class GraphRemoteProcRunner(BaseGraphProcRunner):
|
|
6
6
|
@property
|
|
7
7
|
def project(self) -> GraphProjectRemoteRunner:
|
|
8
|
-
self._namespace += ".project.remoteDb"
|
|
9
8
|
return GraphProjectRemoteRunner(self._query_runner, self._namespace, self._server_version)
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import List, Optional
|
|
4
|
+
|
|
5
|
+
from ..error.illegal_attr_checker import IllegalAttrChecker
|
|
6
|
+
from ..query_runner.aura_db_query_runner import AuraDbQueryRunner
|
|
7
|
+
from ..server_version.compatible_with import compatible_with
|
|
8
|
+
from .graph_object import Graph
|
|
9
|
+
from graphdatascience.call_parameters import CallParameters
|
|
10
|
+
from graphdatascience.graph.graph_create_result import GraphCreateResult
|
|
11
|
+
from graphdatascience.server_version.server_version import ServerVersion
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class GraphProjectRemoteRunner(IllegalAttrChecker):
|
|
15
|
+
@compatible_with("project", min_inclusive=ServerVersion(2, 7, 0))
|
|
16
|
+
def __call__(
|
|
17
|
+
self,
|
|
18
|
+
graph_name: str,
|
|
19
|
+
query: str,
|
|
20
|
+
concurrency: int = 4,
|
|
21
|
+
undirected_relationship_types: Optional[List[str]] = None,
|
|
22
|
+
inverse_indexed_relationship_types: Optional[List[str]] = None,
|
|
23
|
+
batch_size: Optional[int] = None,
|
|
24
|
+
) -> GraphCreateResult:
|
|
25
|
+
if inverse_indexed_relationship_types is None:
|
|
26
|
+
inverse_indexed_relationship_types = []
|
|
27
|
+
if undirected_relationship_types is None:
|
|
28
|
+
undirected_relationship_types = []
|
|
29
|
+
|
|
30
|
+
arrow_configuration = {}
|
|
31
|
+
if batch_size is not None:
|
|
32
|
+
arrow_configuration["batchSize"] = batch_size
|
|
33
|
+
|
|
34
|
+
params = CallParameters(
|
|
35
|
+
graph_name=graph_name,
|
|
36
|
+
query=query,
|
|
37
|
+
concurrency=concurrency,
|
|
38
|
+
undirected_relationship_types=undirected_relationship_types,
|
|
39
|
+
inverse_indexed_relationship_types=inverse_indexed_relationship_types,
|
|
40
|
+
arrow_configuration=arrow_configuration,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
result = self._query_runner.call_procedure(
|
|
44
|
+
endpoint=AuraDbQueryRunner.GDS_REMOTE_PROJECTION_PROC_NAME,
|
|
45
|
+
params=params,
|
|
46
|
+
).squeeze()
|
|
47
|
+
return GraphCreateResult(Graph(graph_name, self._query_runner, self._server_version), result)
|
|
@@ -1,19 +1,17 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import concurrent
|
|
4
|
-
import json
|
|
5
4
|
import math
|
|
6
5
|
import warnings
|
|
7
6
|
from concurrent.futures import ThreadPoolExecutor
|
|
8
7
|
from typing import Any, Dict, List, NoReturn, Optional
|
|
9
8
|
|
|
10
9
|
import numpy
|
|
11
|
-
import pyarrow.flight as flight
|
|
12
10
|
from pandas import DataFrame
|
|
13
11
|
from pyarrow import Table
|
|
14
12
|
from tqdm.auto import tqdm
|
|
15
13
|
|
|
16
|
-
from .
|
|
14
|
+
from .gds_arrow_client import GdsArrowClient
|
|
17
15
|
from .graph_constructor import GraphConstructor
|
|
18
16
|
|
|
19
17
|
|
|
@@ -22,9 +20,8 @@ class ArrowGraphConstructor(GraphConstructor):
|
|
|
22
20
|
self,
|
|
23
21
|
database: str,
|
|
24
22
|
graph_name: str,
|
|
25
|
-
flight_client:
|
|
23
|
+
flight_client: GdsArrowClient,
|
|
26
24
|
concurrency: int,
|
|
27
|
-
arrow_endpoint_version: ArrowEndpointVersion,
|
|
28
25
|
undirected_relationship_types: Optional[List[str]],
|
|
29
26
|
chunk_size: int = 10_000,
|
|
30
27
|
):
|
|
@@ -32,7 +29,6 @@ class ArrowGraphConstructor(GraphConstructor):
|
|
|
32
29
|
self._concurrency = concurrency
|
|
33
30
|
self._graph_name = graph_name
|
|
34
31
|
self._client = flight_client
|
|
35
|
-
self._arrow_endpoint_version = arrow_endpoint_version
|
|
36
32
|
self._undirected_relationship_types = (
|
|
37
33
|
[] if undirected_relationship_types is None else undirected_relationship_types
|
|
38
34
|
)
|
|
@@ -49,20 +45,20 @@ class ArrowGraphConstructor(GraphConstructor):
|
|
|
49
45
|
if self._undirected_relationship_types:
|
|
50
46
|
config["undirected_relationship_types"] = self._undirected_relationship_types
|
|
51
47
|
|
|
52
|
-
self.
|
|
48
|
+
self._client.send_action(
|
|
53
49
|
"CREATE_GRAPH",
|
|
54
50
|
config,
|
|
55
51
|
)
|
|
56
52
|
|
|
57
53
|
self._send_dfs(node_dfs, "node")
|
|
58
54
|
|
|
59
|
-
self.
|
|
55
|
+
self._client.send_action("NODE_LOAD_DONE", {"name": self._graph_name})
|
|
60
56
|
|
|
61
57
|
self._send_dfs(relationship_dfs, "relationship")
|
|
62
58
|
|
|
63
|
-
self.
|
|
59
|
+
self._client.send_action("RELATIONSHIP_LOAD_DONE", {"name": self._graph_name})
|
|
64
60
|
except (Exception, KeyboardInterrupt) as e:
|
|
65
|
-
self.
|
|
61
|
+
self._client.send_action("ABORT", {"name": self._graph_name})
|
|
66
62
|
|
|
67
63
|
raise e
|
|
68
64
|
|
|
@@ -85,25 +81,12 @@ class ArrowGraphConstructor(GraphConstructor):
|
|
|
85
81
|
|
|
86
82
|
return partitioned_dfs
|
|
87
83
|
|
|
88
|
-
def _send_action(self, action_type: str, meta_data: Dict[str, Any]) -> None:
|
|
89
|
-
action_type = self._versioned_action_type(action_type)
|
|
90
|
-
result = self._client.do_action(flight.Action(action_type, json.dumps(meta_data).encode("utf-8")))
|
|
91
|
-
|
|
92
|
-
# Consume result fully to sanity check and avoid cancelled streams
|
|
93
|
-
collected_result = list(result)
|
|
94
|
-
assert len(collected_result) == 1
|
|
95
|
-
|
|
96
|
-
json.loads(collected_result[0].body.to_pybytes().decode())
|
|
97
|
-
|
|
98
84
|
def _send_df(self, df: DataFrame, entity_type: str, pbar: tqdm[NoReturn]) -> None:
|
|
99
85
|
table = Table.from_pandas(df)
|
|
100
86
|
batches = table.to_batches(self._chunk_size)
|
|
101
87
|
flight_descriptor = {"name": self._graph_name, "entity_type": entity_type}
|
|
102
|
-
flight_descriptor = self._versioned_flight_desriptor(flight_descriptor)
|
|
103
88
|
|
|
104
|
-
|
|
105
|
-
upload_descriptor = flight.FlightDescriptor.for_command(json.dumps(flight_descriptor).encode("utf-8"))
|
|
106
|
-
writer, _ = self._client.do_put(upload_descriptor, table.schema)
|
|
89
|
+
writer, _ = self._client.start_put(flight_descriptor, table.schema)
|
|
107
90
|
|
|
108
91
|
with writer:
|
|
109
92
|
# Write table in chunks
|
|
@@ -126,17 +109,3 @@ class ArrowGraphConstructor(GraphConstructor):
|
|
|
126
109
|
if not future.exception():
|
|
127
110
|
continue
|
|
128
111
|
raise future.exception() # type: ignore
|
|
129
|
-
|
|
130
|
-
def _versioned_action_type(self, action_type: str) -> str:
|
|
131
|
-
return self._arrow_endpoint_version.prefix() + action_type
|
|
132
|
-
|
|
133
|
-
def _versioned_flight_desriptor(self, flight_descriptor: Dict[str, Any]) -> Dict[str, Any]:
|
|
134
|
-
return (
|
|
135
|
-
flight_descriptor
|
|
136
|
-
if self._arrow_endpoint_version == ArrowEndpointVersion.ALPHA
|
|
137
|
-
else {
|
|
138
|
-
"name": "PUT_MESSAGE",
|
|
139
|
-
"version": ArrowEndpointVersion.V1.version(),
|
|
140
|
-
"body": flight_descriptor,
|
|
141
|
-
}
|
|
142
|
-
)
|
|
@@ -1,21 +1,14 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
import base64
|
|
4
|
-
import json
|
|
5
|
-
import time
|
|
6
3
|
import warnings
|
|
7
4
|
from typing import Any, Dict, List, Optional, Tuple
|
|
8
5
|
|
|
9
|
-
import pyarrow.flight as flight
|
|
10
6
|
from pandas import DataFrame
|
|
11
|
-
from pyarrow import ChunkedArray, Table, chunked_array
|
|
12
|
-
from pyarrow.flight import ClientMiddleware, ClientMiddlewareFactory
|
|
13
|
-
from pyarrow.types import is_dictionary # type: ignore
|
|
14
7
|
|
|
15
8
|
from ..call_parameters import CallParameters
|
|
16
9
|
from ..server_version.server_version import ServerVersion
|
|
17
|
-
from .arrow_endpoint_version import ArrowEndpointVersion
|
|
18
10
|
from .arrow_graph_constructor import ArrowGraphConstructor
|
|
11
|
+
from .gds_arrow_client import GdsArrowClient
|
|
19
12
|
from .graph_constructor import GraphConstructor
|
|
20
13
|
from .query_runner import QueryRunner
|
|
21
14
|
from graphdatascience.server_version.compatible_with import (
|
|
@@ -33,61 +26,29 @@ class ArrowQueryRunner(QueryRunner):
|
|
|
33
26
|
tls_root_certs: Optional[bytes] = None,
|
|
34
27
|
connection_string_override: Optional[str] = None,
|
|
35
28
|
) -> QueryRunner:
|
|
36
|
-
|
|
37
|
-
fallback_query_runner.call_procedure(endpoint="gds.debug.arrow", custom_error=False).squeeze().to_dict()
|
|
38
|
-
)
|
|
39
|
-
server_version = fallback_query_runner.server_version()
|
|
40
|
-
connection_string: str
|
|
41
|
-
if connection_string_override is not None:
|
|
42
|
-
connection_string = connection_string_override
|
|
43
|
-
else:
|
|
44
|
-
connection_string = arrow_info.get("advertisedListenAddress", arrow_info["listenAddress"])
|
|
45
|
-
arrow_endpoint_version = ArrowEndpointVersion.from_arrow_info(arrow_info.get("versions", []))
|
|
46
|
-
|
|
47
|
-
if arrow_info["running"]:
|
|
48
|
-
return ArrowQueryRunner(
|
|
49
|
-
connection_string,
|
|
50
|
-
fallback_query_runner,
|
|
51
|
-
server_version,
|
|
52
|
-
auth,
|
|
53
|
-
encrypted,
|
|
54
|
-
disable_server_verification,
|
|
55
|
-
tls_root_certs,
|
|
56
|
-
arrow_endpoint_version,
|
|
57
|
-
)
|
|
58
|
-
else:
|
|
29
|
+
if not GdsArrowClient.is_arrow_enabled(fallback_query_runner):
|
|
59
30
|
return fallback_query_runner
|
|
60
31
|
|
|
32
|
+
gds_arrow_client = GdsArrowClient.create(
|
|
33
|
+
fallback_query_runner,
|
|
34
|
+
auth,
|
|
35
|
+
encrypted,
|
|
36
|
+
disable_server_verification,
|
|
37
|
+
tls_root_certs,
|
|
38
|
+
connection_string_override,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
return ArrowQueryRunner(gds_arrow_client, fallback_query_runner, fallback_query_runner.server_version())
|
|
42
|
+
|
|
61
43
|
def __init__(
|
|
62
44
|
self,
|
|
63
|
-
|
|
45
|
+
gds_arrow_client: GdsArrowClient,
|
|
64
46
|
fallback_query_runner: QueryRunner,
|
|
65
47
|
server_version: ServerVersion,
|
|
66
|
-
auth: Optional[Tuple[str, str]] = None,
|
|
67
|
-
encrypted: bool = False,
|
|
68
|
-
disable_server_verification: bool = False,
|
|
69
|
-
tls_root_certs: Optional[bytes] = None,
|
|
70
|
-
arrow_endpoint_version: ArrowEndpointVersion = ArrowEndpointVersion.ALPHA,
|
|
71
48
|
):
|
|
72
49
|
self._fallback_query_runner = fallback_query_runner
|
|
50
|
+
self._gds_arrow_client = gds_arrow_client
|
|
73
51
|
self._server_version = server_version
|
|
74
|
-
self._arrow_endpoint_version = arrow_endpoint_version
|
|
75
|
-
|
|
76
|
-
host, port_string = uri.split(":")
|
|
77
|
-
|
|
78
|
-
location = (
|
|
79
|
-
flight.Location.for_grpc_tls(host, int(port_string))
|
|
80
|
-
if encrypted
|
|
81
|
-
else flight.Location.for_grpc_tcp(host, int(port_string))
|
|
82
|
-
)
|
|
83
|
-
|
|
84
|
-
client_options: Dict[str, Any] = {"disable_server_verification": disable_server_verification}
|
|
85
|
-
if auth:
|
|
86
|
-
client_options["middleware"] = [AuthFactory(auth)]
|
|
87
|
-
if tls_root_certs:
|
|
88
|
-
client_options["tls_root_certs"] = tls_root_certs
|
|
89
|
-
|
|
90
|
-
self._flight_client = flight.FlightClient(location, **client_options)
|
|
91
52
|
|
|
92
53
|
def warn_about_deprecation(self, old_endpoint: str, new_endpoint: str) -> None:
|
|
93
54
|
warnings.warn(
|
|
@@ -140,7 +101,7 @@ class ArrowQueryRunner(QueryRunner):
|
|
|
140
101
|
old_endpoint="gds.graph.streamNodeProperty", new_endpoint="gds.graph.nodeProperty.stream"
|
|
141
102
|
)
|
|
142
103
|
|
|
143
|
-
return self.
|
|
104
|
+
return self._gds_arrow_client.get_property(self.database(), graph_name, endpoint, config)
|
|
144
105
|
elif (
|
|
145
106
|
old_endpoint := ("gds.graph.streamNodeProperties" == endpoint)
|
|
146
107
|
) or "gds.graph.nodeProperties.stream" == endpoint:
|
|
@@ -159,7 +120,8 @@ class ArrowQueryRunner(QueryRunner):
|
|
|
159
120
|
self.warn_about_deprecation(
|
|
160
121
|
old_endpoint="gds.graph.streamNodeProperties", new_endpoint="gds.graph.nodeProperties.stream"
|
|
161
122
|
)
|
|
162
|
-
return self.
|
|
123
|
+
return self._gds_arrow_client.get_property(
|
|
124
|
+
self.database(),
|
|
163
125
|
graph_name,
|
|
164
126
|
endpoint,
|
|
165
127
|
config,
|
|
@@ -180,7 +142,8 @@ class ArrowQueryRunner(QueryRunner):
|
|
|
180
142
|
old_endpoint="gds.graph.streamRelationshipProperty",
|
|
181
143
|
new_endpoint="gds.graph.relationshipProperty.stream",
|
|
182
144
|
)
|
|
183
|
-
return self.
|
|
145
|
+
return self._gds_arrow_client.get_property(
|
|
146
|
+
self.database(),
|
|
184
147
|
graph_name,
|
|
185
148
|
endpoint,
|
|
186
149
|
{"relationship_property": property_name, "relationship_types": relationship_types},
|
|
@@ -202,7 +165,8 @@ class ArrowQueryRunner(QueryRunner):
|
|
|
202
165
|
new_endpoint="gds.graph.relationshipProperties.stream",
|
|
203
166
|
)
|
|
204
167
|
|
|
205
|
-
return self.
|
|
168
|
+
return self._gds_arrow_client.get_property(
|
|
169
|
+
self.database(),
|
|
206
170
|
graph_name,
|
|
207
171
|
endpoint,
|
|
208
172
|
{"relationship_properties": property_names, "relationship_types": relationship_types},
|
|
@@ -229,7 +193,9 @@ class ArrowQueryRunner(QueryRunner):
|
|
|
229
193
|
new_endpoint="gds.graph.relationships.stream",
|
|
230
194
|
)
|
|
231
195
|
|
|
232
|
-
return self.
|
|
196
|
+
return self._gds_arrow_client.get_property(
|
|
197
|
+
self.database(), graph_name, endpoint, {"relationship_types": relationship_types}
|
|
198
|
+
)
|
|
233
199
|
|
|
234
200
|
return self._fallback_query_runner.call_procedure(endpoint, params, yields, database, logging, custom_error)
|
|
235
201
|
|
|
@@ -259,52 +225,11 @@ class ArrowQueryRunner(QueryRunner):
|
|
|
259
225
|
|
|
260
226
|
def close(self) -> None:
|
|
261
227
|
self._fallback_query_runner.close()
|
|
262
|
-
|
|
263
|
-
if hasattr(self._flight_client, "close"):
|
|
264
|
-
self._flight_client.close()
|
|
228
|
+
self._gds_arrow_client.close()
|
|
265
229
|
|
|
266
230
|
def fallback_query_runner(self) -> QueryRunner:
|
|
267
231
|
return self._fallback_query_runner
|
|
268
232
|
|
|
269
|
-
def _run_arrow_property_get(self, graph_name: str, procedure_name: str, configuration: Dict[str, Any]) -> DataFrame:
|
|
270
|
-
if not self.database():
|
|
271
|
-
raise ValueError(
|
|
272
|
-
"For this call you must have explicitly specified a valid Neo4j database to execute on, "
|
|
273
|
-
"using `GraphDataScience.set_database`."
|
|
274
|
-
)
|
|
275
|
-
|
|
276
|
-
payload = {
|
|
277
|
-
"database_name": self.database(),
|
|
278
|
-
"graph_name": graph_name,
|
|
279
|
-
"procedure_name": procedure_name,
|
|
280
|
-
"configuration": configuration,
|
|
281
|
-
}
|
|
282
|
-
|
|
283
|
-
if self._arrow_endpoint_version == ArrowEndpointVersion.V1:
|
|
284
|
-
payload = {
|
|
285
|
-
"name": "GET_COMMAND",
|
|
286
|
-
"version": ArrowEndpointVersion.V1.version(),
|
|
287
|
-
"body": payload,
|
|
288
|
-
}
|
|
289
|
-
|
|
290
|
-
ticket = flight.Ticket(json.dumps(payload).encode("utf-8"))
|
|
291
|
-
get = self._flight_client.do_get(ticket)
|
|
292
|
-
arrow_table = get.read_all()
|
|
293
|
-
|
|
294
|
-
if configuration.get("list_node_labels", False):
|
|
295
|
-
# GDS 2.5 had an inconsistent naming of the node labels column
|
|
296
|
-
new_colum_names = ["nodeLabels" if i == "labels" else i for i in arrow_table.column_names]
|
|
297
|
-
arrow_table = arrow_table.rename_columns(new_colum_names)
|
|
298
|
-
|
|
299
|
-
# Pandas 2.2.0 deprecated an API used by ArrowTable.to_pandas() (< pyarrow 15.0)
|
|
300
|
-
warnings.filterwarnings(
|
|
301
|
-
"ignore",
|
|
302
|
-
category=DeprecationWarning,
|
|
303
|
-
message=r"Passing a BlockManager to DataFrame is deprecated",
|
|
304
|
-
)
|
|
305
|
-
|
|
306
|
-
return self._sanitize_arrow_table(arrow_table).to_pandas() # type: ignore
|
|
307
|
-
|
|
308
233
|
def create_graph_constructor(
|
|
309
234
|
self, graph_name: str, concurrency: int, undirected_relationship_types: Optional[List[str]]
|
|
310
235
|
) -> GraphConstructor:
|
|
@@ -318,86 +243,7 @@ class ArrowQueryRunner(QueryRunner):
|
|
|
318
243
|
return ArrowGraphConstructor(
|
|
319
244
|
database,
|
|
320
245
|
graph_name,
|
|
321
|
-
self.
|
|
246
|
+
self._gds_arrow_client,
|
|
322
247
|
concurrency,
|
|
323
|
-
self._arrow_endpoint_version,
|
|
324
248
|
undirected_relationship_types,
|
|
325
249
|
)
|
|
326
|
-
|
|
327
|
-
def _sanitize_arrow_table(self, arrow_table: Table) -> Table:
|
|
328
|
-
# empty columns cannot be used to build a chunked_array in pyarrow
|
|
329
|
-
if len(arrow_table) == 0:
|
|
330
|
-
return arrow_table
|
|
331
|
-
|
|
332
|
-
dict_encoded_fields = [
|
|
333
|
-
(idx, field) for idx, field in enumerate(arrow_table.schema) if is_dictionary(field.type)
|
|
334
|
-
]
|
|
335
|
-
for idx, field in dict_encoded_fields:
|
|
336
|
-
try:
|
|
337
|
-
field.type.to_pandas_dtype()
|
|
338
|
-
except NotImplementedError:
|
|
339
|
-
# we need to decode the dictionary column before transforming to pandas
|
|
340
|
-
if isinstance(arrow_table[field.name], ChunkedArray):
|
|
341
|
-
decoded_col = chunked_array([chunk.dictionary_decode() for chunk in arrow_table[field.name].chunks])
|
|
342
|
-
else:
|
|
343
|
-
decoded_col = arrow_table[field.name].dictionary_decode()
|
|
344
|
-
arrow_table = arrow_table.set_column(idx, field.name, decoded_col)
|
|
345
|
-
return arrow_table
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
class AuthFactory(ClientMiddlewareFactory): # type: ignore
|
|
349
|
-
def __init__(self, auth: Tuple[str, str], *args: Any, **kwargs: Any) -> None:
|
|
350
|
-
super().__init__(*args, **kwargs)
|
|
351
|
-
self._auth = auth
|
|
352
|
-
self._token: Optional[str] = None
|
|
353
|
-
self._token_timestamp = 0
|
|
354
|
-
|
|
355
|
-
def start_call(self, info: Any) -> "AuthMiddleware":
|
|
356
|
-
return AuthMiddleware(self)
|
|
357
|
-
|
|
358
|
-
def token(self) -> Optional[str]:
|
|
359
|
-
# check whether the token is older than 10 minutes. If so, reset it.
|
|
360
|
-
if self._token and int(time.time()) - self._token_timestamp > 600:
|
|
361
|
-
self._token = None
|
|
362
|
-
|
|
363
|
-
return self._token
|
|
364
|
-
|
|
365
|
-
def set_token(self, token: str) -> None:
|
|
366
|
-
self._token = token
|
|
367
|
-
self._token_timestamp = int(time.time())
|
|
368
|
-
|
|
369
|
-
@property
|
|
370
|
-
def auth(self) -> Tuple[str, str]:
|
|
371
|
-
return self._auth
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
class AuthMiddleware(ClientMiddleware): # type: ignore
|
|
375
|
-
def __init__(self, factory: AuthFactory, *args: Any, **kwargs: Any) -> None:
|
|
376
|
-
super().__init__(*args, **kwargs)
|
|
377
|
-
self._factory = factory
|
|
378
|
-
|
|
379
|
-
def received_headers(self, headers: Dict[str, Any]) -> None:
|
|
380
|
-
auth_header = headers.get("authorization", None)
|
|
381
|
-
if not auth_header:
|
|
382
|
-
return
|
|
383
|
-
|
|
384
|
-
# the result is always a list
|
|
385
|
-
header_value = auth_header[0]
|
|
386
|
-
|
|
387
|
-
if not isinstance(header_value, str):
|
|
388
|
-
raise ValueError(f"Incompatible header value received from server: `{header_value}`")
|
|
389
|
-
|
|
390
|
-
auth_type, token = header_value.split(" ", 1)
|
|
391
|
-
if auth_type == "Bearer":
|
|
392
|
-
self._factory.set_token(token)
|
|
393
|
-
|
|
394
|
-
def sending_headers(self) -> Dict[str, str]:
|
|
395
|
-
token = self._factory.token()
|
|
396
|
-
if not token:
|
|
397
|
-
username, password = self._factory.auth
|
|
398
|
-
auth_token = f"{username}:{password}"
|
|
399
|
-
auth_token = "Basic " + base64.b64encode(auth_token.encode("utf-8")).decode("ASCII")
|
|
400
|
-
# There seems to be a bug, `authorization` must be lower key
|
|
401
|
-
return {"authorization": auth_token}
|
|
402
|
-
else:
|
|
403
|
-
return {"authorization": "Bearer " + token}
|