graphdatascience 1.10a1__tar.gz → 1.11a2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. {graphdatascience-1.10a1/graphdatascience.egg-info → graphdatascience-1.11a2}/PKG-INFO +2 -2
  2. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/graph/base_graph_proc_runner.py +4 -4
  3. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/graph/graph_entity_ops_runner.py +43 -6
  4. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/graph/graph_remote_proc_runner.py +0 -1
  5. graphdatascience-1.11a2/graphdatascience/graph/graph_remote_project_runner.py +41 -0
  6. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/graph_data_science.py +13 -10
  7. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/query_runner/arrow_graph_constructor.py +13 -40
  8. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/query_runner/arrow_query_runner.py +28 -166
  9. graphdatascience-1.11a2/graphdatascience/query_runner/aura_db_query_runner.py +224 -0
  10. graphdatascience-1.11a2/graphdatascience/query_runner/gds_arrow_client.py +242 -0
  11. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/query_runner/neo4j_query_runner.py +36 -15
  12. graphdatascience-1.11a2/graphdatascience/session/__init__.py +16 -0
  13. graphdatascience-1.11a2/graphdatascience/session/algorithm_category.py +14 -0
  14. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/session/aura_api.py +112 -110
  15. graphdatascience-1.11a2/graphdatascience/session/aura_api_responses.py +174 -0
  16. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/session/aura_graph_data_science.py +11 -5
  17. graphdatascience-1.11a2/graphdatascience/session/aurads_sessions.py +202 -0
  18. graphdatascience-1.11a2/graphdatascience/session/dedicated_sessions.py +147 -0
  19. graphdatascience-1.11a2/graphdatascience/session/gds_sessions.py +107 -0
  20. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/session/schema.py +0 -3
  21. graphdatascience-1.11a2/graphdatascience/session/session_info.py +40 -0
  22. graphdatascience-1.11a2/graphdatascience/session/session_sizes.py +31 -0
  23. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/system/system_endpoints.py +2 -2
  24. graphdatascience-1.11a2/graphdatascience/version.py +1 -0
  25. {graphdatascience-1.10a1 → graphdatascience-1.11a2/graphdatascience.egg-info}/PKG-INFO +2 -2
  26. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience.egg-info/SOURCES.txt +7 -1
  27. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience.egg-info/requires.txt +1 -1
  28. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/requirements/base/base.txt +1 -1
  29. graphdatascience-1.10a1/graphdatascience/graph/graph_remote_project_runner.py +0 -40
  30. graphdatascience-1.10a1/graphdatascience/query_runner/aura_db_arrow_query_runner.py +0 -184
  31. graphdatascience-1.10a1/graphdatascience/session/__init__.py +0 -13
  32. graphdatascience-1.10a1/graphdatascience/session/gds_sessions.py +0 -240
  33. graphdatascience-1.10a1/graphdatascience/session/session_sizes.py +0 -33
  34. graphdatascience-1.10a1/graphdatascience/version.py +0 -1
  35. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/LICENSE +0 -0
  36. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/MANIFEST.in +0 -0
  37. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/README.md +0 -0
  38. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/__init__.py +0 -0
  39. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/algo/__init__.py +0 -0
  40. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/algo/algo_endpoints.py +0 -0
  41. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/algo/algo_proc_runner.py +0 -0
  42. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/algo/single_mode_algo_endpoints.py +0 -0
  43. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/call_builder.py +0 -0
  44. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/call_parameters.py +0 -0
  45. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/caller_base.py +0 -0
  46. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/endpoints.py +0 -0
  47. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/error/__init__.py +0 -0
  48. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/error/client_only_endpoint.py +0 -0
  49. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/error/cypher_warning_handler.py +0 -0
  50. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/error/endpoint_suggester.py +0 -0
  51. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/error/gds_not_installed.py +0 -0
  52. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/error/illegal_attr_checker.py +0 -0
  53. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/error/unable_to_connect.py +0 -0
  54. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/error/uncallable_namespace.py +0 -0
  55. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/graph/__init__.py +0 -0
  56. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/graph/graph_alpha_proc_runner.py +0 -0
  57. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/graph/graph_beta_proc_runner.py +0 -0
  58. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/graph/graph_create_result.py +0 -0
  59. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/graph/graph_cypher_runner.py +0 -0
  60. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/graph/graph_endpoints.py +0 -0
  61. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/graph/graph_export_runner.py +0 -0
  62. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/graph/graph_object.py +0 -0
  63. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/graph/graph_proc_runner.py +0 -0
  64. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/graph/graph_project_runner.py +0 -0
  65. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/graph/graph_sample_runner.py +0 -0
  66. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/graph/graph_type_check.py +0 -0
  67. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/graph/nx_loader.py +0 -0
  68. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/graph/ogb_loader.py +0 -0
  69. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/ignored_server_endpoints.py +0 -0
  70. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/model/__init__.py +0 -0
  71. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/model/graphsage_model.py +0 -0
  72. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/model/link_prediction_model.py +0 -0
  73. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/model/model.py +0 -0
  74. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/model/model_alpha_proc_runner.py +0 -0
  75. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/model/model_beta_proc_runner.py +0 -0
  76. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/model/model_endpoints.py +0 -0
  77. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/model/model_proc_runner.py +0 -0
  78. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/model/model_resolver.py +0 -0
  79. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/model/node_classification_model.py +0 -0
  80. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/model/node_regression_model.py +0 -0
  81. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/model/pipeline_model.py +0 -0
  82. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/model/simple_rel_embedding_model.py +0 -0
  83. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/pipeline/__init__.py +0 -0
  84. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/pipeline/classification_training_pipeline.py +0 -0
  85. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/pipeline/lp_pipeline_create_runner.py +0 -0
  86. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/pipeline/lp_training_pipeline.py +0 -0
  87. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/pipeline/nc_pipeline_create_runner.py +0 -0
  88. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/pipeline/nc_training_pipeline.py +0 -0
  89. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/pipeline/nr_pipeline_create_runner.py +0 -0
  90. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/pipeline/nr_training_pipeline.py +0 -0
  91. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/pipeline/pipeline_alpha_proc_runner.py +0 -0
  92. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/pipeline/pipeline_beta_proc_runner.py +0 -0
  93. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/pipeline/pipeline_endpoints.py +0 -0
  94. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/pipeline/pipeline_proc_runner.py +0 -0
  95. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/pipeline/training_pipeline.py +0 -0
  96. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/py.typed +0 -0
  97. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/query_runner/__init__.py +0 -0
  98. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/query_runner/arrow_endpoint_version.py +0 -0
  99. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/query_runner/cypher_graph_constructor.py +0 -0
  100. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/query_runner/graph_constructor.py +0 -0
  101. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/query_runner/query_runner.py +0 -0
  102. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/resources/__init__.py +0 -0
  103. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/resources/cora/__init__.py +0 -0
  104. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/resources/cora/cora_nodes.parquet.gzip +0 -0
  105. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/resources/cora/cora_rels.parquet.gzip +0 -0
  106. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/resources/cora/serialize_cora.py +0 -0
  107. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/resources/imdb/__init__.py +0 -0
  108. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/resources/imdb/imdb_acted_in.parquet.gzip +0 -0
  109. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/resources/imdb/imdb_actors.parquet.gzip +0 -0
  110. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/resources/imdb/imdb_directed_in.parquet.gzip +0 -0
  111. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/resources/imdb/imdb_directors.parquet.gzip +0 -0
  112. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/resources/imdb/imdb_movies_with_genre.parquet.gzip +0 -0
  113. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/resources/imdb/imdb_movies_without_genre.parquet.gzip +0 -0
  114. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/resources/imdb/serialize_imdb.py +0 -0
  115. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/resources/karate/__init__.py +0 -0
  116. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/resources/karate/karate_club.parquet.gzip +0 -0
  117. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/resources/lastfm/__init__.py +0 -0
  118. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/resources/lastfm/artist_nodes.parquet.gzip +0 -0
  119. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/resources/lastfm/serialize_lastfm.py +0 -0
  120. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/resources/lastfm/user_friend_df_directed.parquet.gzip +0 -0
  121. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/resources/lastfm/user_listen_artist_rels.parquet.gzip +0 -0
  122. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/resources/lastfm/user_nodes.parquet.gzip +0 -0
  123. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/resources/lastfm/user_tag_artist_rels.parquet.gzip +0 -0
  124. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/server_version/__init__.py +0 -0
  125. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/server_version/compatible_with.py +0 -0
  126. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/server_version/server_version.py +0 -0
  127. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/session/dbms_connection_info.py +0 -0
  128. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/session/region_suggester.py +0 -0
  129. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/system/__init__.py +0 -0
  130. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/system/config_endpoints.py +0 -0
  131. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/topological_lp/__init__.py +0 -0
  132. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/topological_lp/topological_lp_alpha_runner.py +0 -0
  133. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/topological_lp/topological_lp_endpoints.py +0 -0
  134. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/utils/__init__.py +0 -0
  135. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/utils/util_endpoints.py +0 -0
  136. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience/utils/util_proc_runner.py +0 -0
  137. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience.egg-info/dependency_links.txt +0 -0
  138. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience.egg-info/not-zip-safe +0 -0
  139. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/graphdatascience.egg-info/top_level.txt +0 -0
  140. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/pyproject.toml +0 -0
  141. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/requirements/base/networkx.txt +0 -0
  142. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/requirements/base/ogb.txt +0 -0
  143. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/setup.cfg +0 -0
  144. {graphdatascience-1.10a1 → graphdatascience-1.11a2}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: graphdatascience
3
- Version: 1.10a1
3
+ Version: 1.11a2
4
4
  Summary: A Python client for the Neo4j Graph Data Science (GDS) library
5
5
  Home-page: https://neo4j.com/product/graph-data-science/
6
6
  Author: Neo4j
@@ -31,7 +31,7 @@ License-File: LICENSE
31
31
  Requires-Dist: multimethod<2.0,>=1.0
32
32
  Requires-Dist: neo4j<6.0,>=4.4.2
33
33
  Requires-Dist: pandas<3.0,>=1.0
34
- Requires-Dist: pyarrow<15.0,>=10.0
34
+ Requires-Dist: pyarrow<16.0,>=11.0
35
35
  Requires-Dist: textdistance<5.0,>=4.0
36
36
  Requires-Dist: tqdm<5.0,>=4.0
37
37
  Requires-Dist: typing-extensions<5.0,>=4.0
@@ -18,6 +18,7 @@ from .graph_entity_ops_runner import (
18
18
  GraphElementPropertyRunner,
19
19
  GraphLabelRunner,
20
20
  GraphNodePropertiesRunner,
21
+ GraphNodePropertyRunner,
21
22
  GraphPropertyRunner,
22
23
  GraphRelationshipPropertiesRunner,
23
24
  GraphRelationshipRunner,
@@ -379,9 +380,9 @@ class BaseGraphProcRunner(UncallableNamespace, IllegalAttrChecker):
379
380
  )
380
381
 
381
382
  @property
382
- def nodeProperty(self) -> GraphElementPropertyRunner:
383
+ def nodeProperty(self) -> GraphNodePropertyRunner:
383
384
  self._namespace += ".nodeProperty"
384
- return GraphElementPropertyRunner(self._query_runner, self._namespace, self._server_version)
385
+ return GraphNodePropertyRunner(self._query_runner, self._namespace, self._server_version)
385
386
 
386
387
  @property
387
388
  def nodeProperties(self) -> GraphNodePropertiesRunner:
@@ -516,8 +517,7 @@ class BaseGraphProcRunner(UncallableNamespace, IllegalAttrChecker):
516
517
  ).squeeze()
517
518
 
518
519
  @multimethod
519
- def removeNodeProperties(self) -> None:
520
- ...
520
+ def removeNodeProperties(self) -> None: ...
521
521
 
522
522
  @removeNodeProperties.register
523
523
  @graph_type_check
@@ -77,6 +77,26 @@ class GraphElementPropertyRunner(GraphEntityOpsBaseRunner):
77
77
  return self._handle_properties(G, node_properties, node_labels, config)
78
78
 
79
79
 
80
+ class GraphNodePropertyRunner(GraphEntityOpsBaseRunner):
81
+ @compatible_with("stream", min_inclusive=ServerVersion(2, 2, 0))
82
+ @filter_id_func_deprecation_warning()
83
+ def stream(
84
+ self,
85
+ G: Graph,
86
+ node_property: str,
87
+ node_labels: Strings = ["*"],
88
+ db_node_properties: List[str] = [],
89
+ **config: Any,
90
+ ) -> DataFrame:
91
+ self._namespace += ".stream"
92
+
93
+ result = self._handle_properties(G, node_property, node_labels, config)
94
+
95
+ return GraphNodePropertiesRunner._process_result(
96
+ self._query_runner, list(node_property), False, db_node_properties, result, config
97
+ )
98
+
99
+
80
100
  class GraphNodePropertiesRunner(GraphEntityOpsBaseRunner):
81
101
  @compatible_with("stream", min_inclusive=ServerVersion(2, 2, 0))
82
102
  @filter_id_func_deprecation_warning()
@@ -93,6 +113,19 @@ class GraphNodePropertiesRunner(GraphEntityOpsBaseRunner):
93
113
 
94
114
  result = self._handle_properties(G, node_properties, node_labels, config)
95
115
 
116
+ return GraphNodePropertiesRunner._process_result(
117
+ self._query_runner, node_properties, separate_property_columns, db_node_properties, result, config
118
+ )
119
+
120
+ @staticmethod
121
+ def _process_result(
122
+ query_runner: QueryRunner,
123
+ node_properties: List[str],
124
+ separate_property_columns: bool,
125
+ db_node_properties: List[str],
126
+ result: DataFrame,
127
+ config: Dict[str, Any],
128
+ ) -> DataFrame:
96
129
  # new format was requested, but the query was run via Cypher
97
130
  if separate_property_columns and "propertyValue" in result.keys():
98
131
  wide_result = result.pivot(index=["nodeId"], columns=["nodeProperty"], values="propertyValue")
@@ -106,7 +139,7 @@ class GraphNodePropertiesRunner(GraphEntityOpsBaseRunner):
106
139
  # old format was requested but the query was run via Arrow
107
140
  elif not separate_property_columns and "propertyValue" not in result.keys():
108
141
  id_vars = ["nodeId", "nodeLabels"] if config.get("listNodeLabels", False) else ["nodeId"]
109
- result = result.melt(id_vars=id_vars).rename(columns={"variable": "nodeProperty", "value": "propertyValue"})
142
+ result = result.melt(id_vars=id_vars, var_name="nodeProperty", value_name="propertyValue")
110
143
 
111
144
  if db_node_properties:
112
145
  duplicate_properties = set(db_node_properties).intersection(set(node_properties))
@@ -116,16 +149,20 @@ class GraphNodePropertiesRunner(GraphEntityOpsBaseRunner):
116
149
  )
117
150
 
118
151
  unique_node_ids = result["nodeId"].drop_duplicates().tolist()
119
- db_properties_df = self._query_runner.run_cypher(
120
- self._build_query(db_node_properties), {"ids": unique_node_ids}
152
+ db_properties_df = query_runner.run_cypher(
153
+ GraphNodePropertiesRunner._build_query(db_node_properties), {"ids": unique_node_ids}
121
154
  )
122
155
 
123
156
  if "propertyValue" not in result.keys():
124
157
  result = result.join(db_properties_df.set_index("nodeId"), on="nodeId")
125
158
  else:
126
- db_properties_df = db_properties_df.melt(id_vars=["nodeId"]).rename(
127
- columns={"variable": "nodeProperty", "value": "propertyValue"}
159
+ db_properties_df = db_properties_df.melt(
160
+ id_vars=["nodeId"], var_name="nodeProperty", value_name="propertyValue"
128
161
  )
162
+
163
+ if "nodeProperty" not in result.keys():
164
+ result["nodeProperty"] = node_properties[0]
165
+
129
166
  result = pd.concat([result, db_properties_df])
130
167
 
131
168
  return result
@@ -140,7 +177,7 @@ class GraphNodePropertiesRunner(GraphEntityOpsBaseRunner):
140
177
  return reduce(add_property, db_node_properties, query_prefix)
141
178
 
142
179
  @compatible_with("write", min_inclusive=ServerVersion(2, 2, 0))
143
- def write(self, G: Graph, node_properties: List[str], node_labels: Strings = ["*"], **config: Any) -> "Series[Any]":
180
+ def write(self, G: Graph, node_properties: Strings, node_labels: Strings = ["*"], **config: Any) -> "Series[Any]":
144
181
  self._namespace += ".write"
145
182
  return self._handle_properties(G, node_properties, node_labels, config).squeeze() # type: ignore
146
183
 
@@ -5,5 +5,4 @@ from graphdatascience.graph.graph_remote_project_runner import GraphProjectRemot
5
5
  class GraphRemoteProcRunner(BaseGraphProcRunner):
6
6
  @property
7
7
  def project(self) -> GraphProjectRemoteRunner:
8
- self._namespace += ".project.remoteDb"
9
8
  return GraphProjectRemoteRunner(self._query_runner, self._namespace, self._server_version)
@@ -0,0 +1,41 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import List, Optional
4
+
5
+ from ..error.illegal_attr_checker import IllegalAttrChecker
6
+ from ..query_runner.aura_db_query_runner import AuraDbQueryRunner
7
+ from ..server_version.compatible_with import compatible_with
8
+ from .graph_object import Graph
9
+ from graphdatascience.call_parameters import CallParameters
10
+ from graphdatascience.graph.graph_create_result import GraphCreateResult
11
+ from graphdatascience.server_version.server_version import ServerVersion
12
+
13
+
14
+ class GraphProjectRemoteRunner(IllegalAttrChecker):
15
+ @compatible_with("project", min_inclusive=ServerVersion(2, 7, 0))
16
+ def __call__(
17
+ self,
18
+ graph_name: str,
19
+ query: str,
20
+ concurrency: int = 4,
21
+ undirected_relationship_types: Optional[List[str]] = None,
22
+ inverse_indexed_relationship_types: Optional[List[str]] = None,
23
+ ) -> GraphCreateResult:
24
+ if inverse_indexed_relationship_types is None:
25
+ inverse_indexed_relationship_types = []
26
+ if undirected_relationship_types is None:
27
+ undirected_relationship_types = []
28
+
29
+ params = CallParameters(
30
+ graph_name=graph_name,
31
+ query=query,
32
+ concurrency=concurrency,
33
+ undirected_relationship_types=undirected_relationship_types,
34
+ inverse_indexed_relationship_types=inverse_indexed_relationship_types,
35
+ )
36
+
37
+ result = self._query_runner.call_procedure(
38
+ endpoint=AuraDbQueryRunner.GDS_REMOTE_PROJECTION_PROC_NAME,
39
+ params=params,
40
+ ).squeeze()
41
+ return GraphCreateResult(Graph(graph_name, self._query_runner, self._server_version), result)
@@ -23,11 +23,12 @@ class GraphDataScience(DirectEndpoints, UncallableNamespace):
23
23
 
24
24
  def __init__(
25
25
  self,
26
+ /,
26
27
  endpoint: Union[str, Driver, QueryRunner],
27
28
  auth: Optional[Tuple[str, str]] = None,
28
29
  aura_ds: bool = False,
29
30
  database: Optional[str] = None,
30
- arrow: bool = True,
31
+ arrow: Union[str, bool] = True,
31
32
  arrow_disable_server_verification: bool = True,
32
33
  arrow_tls_root_certs: Optional[bytes] = None,
33
34
  bookmarks: Optional[Any] = None,
@@ -43,19 +44,20 @@ class GraphDataScience(DirectEndpoints, UncallableNamespace):
43
44
  A username, password pair for database authentication.
44
45
  aura_ds : bool, default False
45
46
  A flag that indicates that that the client is used to connect
46
- to a Neo4j Aura instance.
47
+ to a Neo4j AuraDS instance.
47
48
  database: Optional[str], default None
48
49
  The Neo4j database to query against.
49
- arrow : bool, default True
50
- A flag that indicates that the client should use Apache Arrow
51
- for data streaming if it is available on the server.
50
+ arrow : Union[str, bool], default True
51
+ Arrow connection information. This is either a bool or a string.
52
+ If it is a string, it will be interpreted as a connection URL to a GDS Arrow Server.
53
+ If it is a bool,
54
+ True will make the client discover the connection URI to the GDS Arrow server via the Neo4j endpoint,
55
+ while False will make the client use Bolt for all operations.
52
56
  arrow_disable_server_verification : bool, default True
53
- A flag that indicates that, if the flight client is connecting with
54
- TLS, that it skips server verification. If this is enabled, all
55
- other TLS settings are overridden.
57
+ A flag that overrides other TLS settings and disables server verification for TLS connections.
56
58
  arrow_tls_root_certs : Optional[bytes], default None
57
- PEM-encoded certificates that are used for the connecting to the
58
- Arrow Flight server.
59
+ PEM-encoded certificates that are used for the connection to the
60
+ GDS Arrow Flight server.
59
61
  bookmarks : Optional[Any], default None
60
62
  The Neo4j bookmarks to require a certain state before the next query gets executed.
61
63
  """
@@ -76,6 +78,7 @@ class GraphDataScience(DirectEndpoints, UncallableNamespace):
76
78
  self._query_runner.encrypted(),
77
79
  arrow_disable_server_verification,
78
80
  arrow_tls_root_certs,
81
+ None if arrow is True else arrow,
79
82
  )
80
83
 
81
84
  super().__init__(self._query_runner, "gds", self._server_version)
@@ -1,17 +1,17 @@
1
+ from __future__ import annotations
2
+
1
3
  import concurrent
2
- import json
3
4
  import math
4
5
  import warnings
5
6
  from concurrent.futures import ThreadPoolExecutor
6
- from typing import Any, Dict, List, Optional
7
+ from typing import Any, Dict, List, NoReturn, Optional
7
8
 
8
9
  import numpy
9
- import pyarrow.flight as flight
10
10
  from pandas import DataFrame
11
11
  from pyarrow import Table
12
12
  from tqdm.auto import tqdm
13
13
 
14
- from .arrow_endpoint_version import ArrowEndpointVersion
14
+ from .gds_arrow_client import GdsArrowClient
15
15
  from .graph_constructor import GraphConstructor
16
16
 
17
17
 
@@ -20,9 +20,8 @@ class ArrowGraphConstructor(GraphConstructor):
20
20
  self,
21
21
  database: str,
22
22
  graph_name: str,
23
- flight_client: flight.FlightClient,
23
+ flight_client: GdsArrowClient,
24
24
  concurrency: int,
25
- arrow_endpoint_version: ArrowEndpointVersion,
26
25
  undirected_relationship_types: Optional[List[str]],
27
26
  chunk_size: int = 10_000,
28
27
  ):
@@ -30,7 +29,6 @@ class ArrowGraphConstructor(GraphConstructor):
30
29
  self._concurrency = concurrency
31
30
  self._graph_name = graph_name
32
31
  self._client = flight_client
33
- self._arrow_endpoint_version = arrow_endpoint_version
34
32
  self._undirected_relationship_types = (
35
33
  [] if undirected_relationship_types is None else undirected_relationship_types
36
34
  )
@@ -47,20 +45,20 @@ class ArrowGraphConstructor(GraphConstructor):
47
45
  if self._undirected_relationship_types:
48
46
  config["undirected_relationship_types"] = self._undirected_relationship_types
49
47
 
50
- self._send_action(
48
+ self._client.send_action(
51
49
  "CREATE_GRAPH",
52
50
  config,
53
51
  )
54
52
 
55
53
  self._send_dfs(node_dfs, "node")
56
54
 
57
- self._send_action("NODE_LOAD_DONE", {"name": self._graph_name})
55
+ self._client.send_action("NODE_LOAD_DONE", {"name": self._graph_name})
58
56
 
59
57
  self._send_dfs(relationship_dfs, "relationship")
60
58
 
61
- self._send_action("RELATIONSHIP_LOAD_DONE", {"name": self._graph_name})
59
+ self._client.send_action("RELATIONSHIP_LOAD_DONE", {"name": self._graph_name})
62
60
  except (Exception, KeyboardInterrupt) as e:
63
- self._send_action("ABORT", {"name": self._graph_name})
61
+ self._client.send_action("ABORT", {"name": self._graph_name})
64
62
 
65
63
  raise e
66
64
 
@@ -83,31 +81,20 @@ class ArrowGraphConstructor(GraphConstructor):
83
81
 
84
82
  return partitioned_dfs
85
83
 
86
- def _send_action(self, action_type: str, meta_data: Dict[str, Any]) -> None:
87
- action_type = self._versioned_action_type(action_type)
88
- result = self._client.do_action(flight.Action(action_type, json.dumps(meta_data).encode("utf-8")))
89
-
90
- # Consume result fully to sanity check and avoid cancelled streams
91
- collected_result = list(result)
92
- assert len(collected_result) == 1
93
-
94
- json.loads(collected_result[0].body.to_pybytes().decode())
95
-
96
- def _send_df(self, df: DataFrame, entity_type: str, pbar: tqdm) -> None:
84
+ def _send_df(self, df: DataFrame, entity_type: str, pbar: tqdm[NoReturn]) -> None:
97
85
  table = Table.from_pandas(df)
98
86
  batches = table.to_batches(self._chunk_size)
99
87
  flight_descriptor = {"name": self._graph_name, "entity_type": entity_type}
100
- flight_descriptor = self._versioned_flight_desriptor(flight_descriptor)
101
88
 
102
- # Write schema
103
- upload_descriptor = flight.FlightDescriptor.for_command(json.dumps(flight_descriptor).encode("utf-8"))
104
- writer, _ = self._client.do_put(upload_descriptor, table.schema)
89
+ writer, _ = self._client.start_put(flight_descriptor, table.schema)
105
90
 
106
91
  with writer:
107
92
  # Write table in chunks
108
93
  for partition in batches:
109
94
  writer.write_batch(partition)
110
95
  pbar.update(partition.num_rows)
96
+ # Force a refresh to avoid the progress bar getting stuck at 0%
97
+ pbar.refresh()
111
98
 
112
99
  def _send_dfs(self, dfs: List[DataFrame], entity_type: str) -> None:
113
100
  desc = "Uploading Nodes" if entity_type == "node" else "Uploading Relationships"
@@ -122,17 +109,3 @@ class ArrowGraphConstructor(GraphConstructor):
122
109
  if not future.exception():
123
110
  continue
124
111
  raise future.exception() # type: ignore
125
-
126
- def _versioned_action_type(self, action_type: str) -> str:
127
- return self._arrow_endpoint_version.prefix() + action_type
128
-
129
- def _versioned_flight_desriptor(self, flight_descriptor: Dict[str, Any]) -> Dict[str, Any]:
130
- return (
131
- flight_descriptor
132
- if self._arrow_endpoint_version == ArrowEndpointVersion.ALPHA
133
- else {
134
- "name": "PUT_MESSAGE",
135
- "version": ArrowEndpointVersion.V1.version(),
136
- "body": flight_descriptor,
137
- }
138
- )
@@ -1,21 +1,14 @@
1
1
  from __future__ import annotations
2
2
 
3
- import base64
4
- import json
5
- import time
6
3
  import warnings
7
4
  from typing import Any, Dict, List, Optional, Tuple
8
5
 
9
- import pyarrow.flight as flight
10
6
  from pandas import DataFrame
11
- from pyarrow import ChunkedArray, Table, chunked_array
12
- from pyarrow.flight import ClientMiddleware, ClientMiddlewareFactory
13
- from pyarrow.types import is_dictionary # type: ignore
14
7
 
15
8
  from ..call_parameters import CallParameters
16
9
  from ..server_version.server_version import ServerVersion
17
- from .arrow_endpoint_version import ArrowEndpointVersion
18
10
  from .arrow_graph_constructor import ArrowGraphConstructor
11
+ from .gds_arrow_client import GdsArrowClient
19
12
  from .graph_constructor import GraphConstructor
20
13
  from .query_runner import QueryRunner
21
14
  from graphdatascience.server_version.compatible_with import (
@@ -31,58 +24,31 @@ class ArrowQueryRunner(QueryRunner):
31
24
  encrypted: bool = False,
32
25
  disable_server_verification: bool = False,
33
26
  tls_root_certs: Optional[bytes] = None,
27
+ connection_string_override: Optional[str] = None,
34
28
  ) -> QueryRunner:
35
- arrow_info = (
36
- fallback_query_runner.call_procedure(endpoint="gds.debug.arrow", custom_error=False).squeeze().to_dict()
37
- )
38
- server_version = fallback_query_runner.server_version()
39
- listen_address: str = arrow_info.get("advertisedListenAddress", arrow_info["listenAddress"])
40
- arrow_endpoint_version = ArrowEndpointVersion.from_arrow_info(arrow_info.get("versions", []))
41
-
42
- if arrow_info["running"]:
43
- return ArrowQueryRunner(
44
- listen_address,
45
- fallback_query_runner,
46
- server_version,
47
- auth,
48
- encrypted,
49
- disable_server_verification,
50
- tls_root_certs,
51
- arrow_endpoint_version,
52
- )
53
- else:
29
+ if not GdsArrowClient.is_arrow_enabled(fallback_query_runner):
54
30
  return fallback_query_runner
55
31
 
32
+ gds_arrow_client = GdsArrowClient.create(
33
+ fallback_query_runner,
34
+ auth,
35
+ encrypted,
36
+ disable_server_verification,
37
+ tls_root_certs,
38
+ connection_string_override,
39
+ )
40
+
41
+ return ArrowQueryRunner(gds_arrow_client, fallback_query_runner, fallback_query_runner.server_version())
42
+
56
43
  def __init__(
57
44
  self,
58
- uri: str,
45
+ gds_arrow_client: GdsArrowClient,
59
46
  fallback_query_runner: QueryRunner,
60
47
  server_version: ServerVersion,
61
- auth: Optional[Tuple[str, str]] = None,
62
- encrypted: bool = False,
63
- disable_server_verification: bool = False,
64
- tls_root_certs: Optional[bytes] = None,
65
- arrow_endpoint_version: ArrowEndpointVersion = ArrowEndpointVersion.ALPHA,
66
48
  ):
67
49
  self._fallback_query_runner = fallback_query_runner
50
+ self._gds_arrow_client = gds_arrow_client
68
51
  self._server_version = server_version
69
- self._arrow_endpoint_version = arrow_endpoint_version
70
-
71
- host, port_string = uri.split(":")
72
-
73
- location = (
74
- flight.Location.for_grpc_tls(host, int(port_string))
75
- if encrypted
76
- else flight.Location.for_grpc_tcp(host, int(port_string))
77
- )
78
-
79
- client_options: Dict[str, Any] = {"disable_server_verification": disable_server_verification}
80
- if auth:
81
- client_options["middleware"] = [AuthFactory(auth)]
82
- if tls_root_certs:
83
- client_options["tls_root_certs"] = tls_root_certs
84
-
85
- self._flight_client = flight.FlightClient(location, **client_options)
86
52
 
87
53
  def warn_about_deprecation(self, old_endpoint: str, new_endpoint: str) -> None:
88
54
  warnings.warn(
@@ -135,7 +101,7 @@ class ArrowQueryRunner(QueryRunner):
135
101
  old_endpoint="gds.graph.streamNodeProperty", new_endpoint="gds.graph.nodeProperty.stream"
136
102
  )
137
103
 
138
- return self._run_arrow_property_get(graph_name, endpoint, config)
104
+ return self._gds_arrow_client.get_property(self.database(), graph_name, endpoint, config)
139
105
  elif (
140
106
  old_endpoint := ("gds.graph.streamNodeProperties" == endpoint)
141
107
  ) or "gds.graph.nodeProperties.stream" == endpoint:
@@ -154,7 +120,8 @@ class ArrowQueryRunner(QueryRunner):
154
120
  self.warn_about_deprecation(
155
121
  old_endpoint="gds.graph.streamNodeProperties", new_endpoint="gds.graph.nodeProperties.stream"
156
122
  )
157
- return self._run_arrow_property_get(
123
+ return self._gds_arrow_client.get_property(
124
+ self.database(),
158
125
  graph_name,
159
126
  endpoint,
160
127
  config,
@@ -175,7 +142,8 @@ class ArrowQueryRunner(QueryRunner):
175
142
  old_endpoint="gds.graph.streamRelationshipProperty",
176
143
  new_endpoint="gds.graph.relationshipProperty.stream",
177
144
  )
178
- return self._run_arrow_property_get(
145
+ return self._gds_arrow_client.get_property(
146
+ self.database(),
179
147
  graph_name,
180
148
  endpoint,
181
149
  {"relationship_property": property_name, "relationship_types": relationship_types},
@@ -197,7 +165,8 @@ class ArrowQueryRunner(QueryRunner):
197
165
  new_endpoint="gds.graph.relationshipProperties.stream",
198
166
  )
199
167
 
200
- return self._run_arrow_property_get(
168
+ return self._gds_arrow_client.get_property(
169
+ self.database(),
201
170
  graph_name,
202
171
  endpoint,
203
172
  {"relationship_properties": property_names, "relationship_types": relationship_types},
@@ -224,7 +193,9 @@ class ArrowQueryRunner(QueryRunner):
224
193
  new_endpoint="gds.graph.relationships.stream",
225
194
  )
226
195
 
227
- return self._run_arrow_property_get(graph_name, endpoint, {"relationship_types": relationship_types})
196
+ return self._gds_arrow_client.get_property(
197
+ self.database(), graph_name, endpoint, {"relationship_types": relationship_types}
198
+ )
228
199
 
229
200
  return self._fallback_query_runner.call_procedure(endpoint, params, yields, database, logging, custom_error)
230
201
 
@@ -254,52 +225,11 @@ class ArrowQueryRunner(QueryRunner):
254
225
 
255
226
  def close(self) -> None:
256
227
  self._fallback_query_runner.close()
257
- # PyArrow 7 did not expose a close method yet
258
- if hasattr(self._flight_client, "close"):
259
- self._flight_client.close()
228
+ self._gds_arrow_client.close()
260
229
 
261
230
  def fallback_query_runner(self) -> QueryRunner:
262
231
  return self._fallback_query_runner
263
232
 
264
- def _run_arrow_property_get(self, graph_name: str, procedure_name: str, configuration: Dict[str, Any]) -> DataFrame:
265
- if not self.database():
266
- raise ValueError(
267
- "For this call you must have explicitly specified a valid Neo4j database to execute on, "
268
- "using `GraphDataScience.set_database`."
269
- )
270
-
271
- payload = {
272
- "database_name": self.database(),
273
- "graph_name": graph_name,
274
- "procedure_name": procedure_name,
275
- "configuration": configuration,
276
- }
277
-
278
- if self._arrow_endpoint_version == ArrowEndpointVersion.V1:
279
- payload = {
280
- "name": "GET_MESSAGE",
281
- "version": ArrowEndpointVersion.V1.version(),
282
- "body": payload,
283
- }
284
-
285
- ticket = flight.Ticket(json.dumps(payload).encode("utf-8"))
286
- get = self._flight_client.do_get(ticket)
287
- arrow_table = get.read_all()
288
-
289
- if configuration.get("list_node_labels", False):
290
- # GDS 2.5 had an inconsistent naming of the node labels column
291
- new_colum_names = ["nodeLabels" if i == "labels" else i for i in arrow_table.column_names]
292
- arrow_table = arrow_table.rename_columns(new_colum_names)
293
-
294
- # Pandas 2.2.0 deprecated an API used by ArrowTable.to_pandas() (< pyarrow 15.0)
295
- warnings.filterwarnings(
296
- "ignore",
297
- category=DeprecationWarning,
298
- message=r"Passing a BlockManager to DataFrame is deprecated",
299
- )
300
-
301
- return self._sanitize_arrow_table(arrow_table).to_pandas() # type: ignore
302
-
303
233
  def create_graph_constructor(
304
234
  self, graph_name: str, concurrency: int, undirected_relationship_types: Optional[List[str]]
305
235
  ) -> GraphConstructor:
@@ -313,75 +243,7 @@ class ArrowQueryRunner(QueryRunner):
313
243
  return ArrowGraphConstructor(
314
244
  database,
315
245
  graph_name,
316
- self._flight_client,
246
+ self._gds_arrow_client,
317
247
  concurrency,
318
- self._arrow_endpoint_version,
319
248
  undirected_relationship_types,
320
249
  )
321
-
322
- def _sanitize_arrow_table(self, arrow_table: Table) -> Table:
323
- dict_encoded_fields = [
324
- (idx, field) for idx, field in enumerate(arrow_table.schema) if is_dictionary(field.type)
325
- ]
326
- for idx, field in dict_encoded_fields:
327
- try:
328
- field.type.to_pandas_dtype()
329
- except NotImplementedError:
330
- # we need to decode the dictionary column before transforming to pandas
331
- if isinstance(arrow_table[field.name], ChunkedArray):
332
- decoded_col = chunked_array([chunk.dictionary_decode() for chunk in arrow_table[field.name].chunks])
333
- else:
334
- decoded_col = arrow_table[field.name].dictionary_decode()
335
- arrow_table = arrow_table.set_column(idx, field.name, decoded_col)
336
- return arrow_table
337
-
338
-
339
- class AuthFactory(ClientMiddlewareFactory): # type: ignore
340
- def __init__(self, auth: Tuple[str, str], *args: Any, **kwargs: Any) -> None:
341
- super().__init__(*args, **kwargs)
342
- self._auth = auth
343
- self._token: Optional[str] = None
344
- self._token_timestamp = 0
345
-
346
- def start_call(self, info: Any) -> "AuthMiddleware":
347
- return AuthMiddleware(self)
348
-
349
- def token(self) -> Optional[str]:
350
- # check whether the token is older than 10 minutes. If so, reset it.
351
- if self._token and int(time.time()) - self._token_timestamp > 600:
352
- self._token = None
353
-
354
- return self._token
355
-
356
- def set_token(self, token: str) -> None:
357
- self._token = token
358
- self._token_timestamp = int(time.time())
359
-
360
- @property
361
- def auth(self) -> Tuple[str, str]:
362
- return self._auth
363
-
364
-
365
- class AuthMiddleware(ClientMiddleware): # type: ignore
366
- def __init__(self, factory: AuthFactory, *args: Any, **kwargs: Any) -> None:
367
- super().__init__(*args, **kwargs)
368
- self._factory = factory
369
-
370
- def received_headers(self, headers: Dict[str, Any]) -> None:
371
- auth_header: str = headers.get("Authorization", None)
372
- if not auth_header:
373
- return
374
- [auth_type, token] = auth_header.split(" ", 1)
375
- if auth_type == "Bearer":
376
- self._factory.set_token(token)
377
-
378
- def sending_headers(self) -> Dict[str, str]:
379
- token = self._factory.token()
380
- if not token:
381
- username, password = self._factory.auth
382
- auth_token = f"{username}:{password}"
383
- auth_token = "Basic " + base64.b64encode(auth_token.encode("utf-8")).decode("ASCII")
384
- # There seems to be a bug, `authorization` must be lower key
385
- return {"authorization": auth_token}
386
- else:
387
- return {"authorization": "Bearer " + token}