graphdatascience 1.16__tar.gz → 1.17a1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (167) hide show
  1. {graphdatascience-1.16/graphdatascience.egg-info → graphdatascience-1.17a1}/PKG-INFO +2 -1
  2. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/graph/graph_cypher_runner.py +3 -1
  3. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/graph/graph_entity_ops_runner.py +5 -1
  4. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/graph/graph_object.py +1 -1
  5. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/graph_data_science.py +78 -11
  6. graphdatascience-1.17a1/graphdatascience/procedure_surface/arrow/wcc_arrow_endpoints.py +191 -0
  7. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/query_runner/arrow_query_runner.py +12 -10
  8. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/query_runner/cypher_graph_constructor.py +5 -1
  9. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/query_runner/gds_arrow_client.py +41 -17
  10. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/query_runner/neo4j_query_runner.py +22 -8
  11. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/query_runner/query_mode.py +8 -0
  12. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/query_runner/query_runner.py +2 -0
  13. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/query_runner/session_query_runner.py +4 -2
  14. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/query_runner/standalone_session_query_runner.py +2 -0
  15. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/session/aura_api.py +10 -1
  16. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/session/aura_graph_data_science.py +14 -7
  17. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/session/dbms/protocol_resolver.py +3 -1
  18. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/session/dedicated_sessions.py +4 -0
  19. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/session/gds_sessions.py +4 -1
  20. graphdatascience-1.17a1/graphdatascience/topological_lp/__init__.py +0 -0
  21. graphdatascience-1.17a1/graphdatascience/utils/__init__.py +0 -0
  22. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/utils/direct_util_endpoints.py +3 -1
  23. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/utils/util_remote_proc_runner.py +4 -2
  24. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/version.py +1 -1
  25. {graphdatascience-1.16 → graphdatascience-1.17a1/graphdatascience.egg-info}/PKG-INFO +2 -1
  26. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience.egg-info/SOURCES.txt +3 -0
  27. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience.egg-info/requires.txt +1 -0
  28. {graphdatascience-1.16 → graphdatascience-1.17a1}/requirements/base/base.txt +1 -0
  29. {graphdatascience-1.16 → graphdatascience-1.17a1}/LICENSE +0 -0
  30. {graphdatascience-1.16 → graphdatascience-1.17a1}/MANIFEST.in +0 -0
  31. {graphdatascience-1.16 → graphdatascience-1.17a1}/README.md +0 -0
  32. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/__init__.py +0 -0
  33. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/algo/__init__.py +0 -0
  34. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/algo/algo_endpoints.py +0 -0
  35. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/algo/algo_proc_runner.py +0 -0
  36. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/algo/single_mode_algo_endpoints.py +0 -0
  37. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/call_builder.py +0 -0
  38. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/call_parameters.py +0 -0
  39. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/caller_base.py +0 -0
  40. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/endpoints.py +0 -0
  41. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/error/__init__.py +0 -0
  42. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/error/client_only_endpoint.py +0 -0
  43. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/error/cypher_warning_handler.py +0 -0
  44. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/error/endpoint_suggester.py +0 -0
  45. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/error/gds_not_installed.py +0 -0
  46. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/error/illegal_attr_checker.py +0 -0
  47. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/error/unable_to_connect.py +0 -0
  48. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/error/uncallable_namespace.py +0 -0
  49. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/graph/__init__.py +0 -0
  50. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/graph/base_graph_proc_runner.py +0 -0
  51. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/graph/graph_alpha_proc_runner.py +0 -0
  52. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/graph/graph_beta_proc_runner.py +0 -0
  53. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/graph/graph_create_result.py +0 -0
  54. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/graph/graph_endpoints.py +0 -0
  55. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/graph/graph_export_runner.py +0 -0
  56. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/graph/graph_proc_runner.py +0 -0
  57. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/graph/graph_project_runner.py +0 -0
  58. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/graph/graph_remote_proc_runner.py +0 -0
  59. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/graph/graph_remote_project_runner.py +0 -0
  60. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/graph/graph_sample_runner.py +0 -0
  61. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/graph/graph_type_check.py +0 -0
  62. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/graph/nx_loader.py +0 -0
  63. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/graph/ogb_loader.py +0 -0
  64. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/ignored_server_endpoints.py +0 -0
  65. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/model/__init__.py +0 -0
  66. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/model/graphsage_model.py +0 -0
  67. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/model/link_prediction_model.py +0 -0
  68. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/model/model.py +0 -0
  69. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/model/model_alpha_proc_runner.py +0 -0
  70. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/model/model_beta_proc_runner.py +0 -0
  71. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/model/model_endpoints.py +0 -0
  72. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/model/model_proc_runner.py +0 -0
  73. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/model/model_resolver.py +0 -0
  74. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/model/node_classification_model.py +0 -0
  75. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/model/node_regression_model.py +0 -0
  76. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/model/pipeline_model.py +0 -0
  77. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/model/simple_rel_embedding_model.py +0 -0
  78. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/pipeline/__init__.py +0 -0
  79. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/pipeline/classification_training_pipeline.py +0 -0
  80. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/pipeline/lp_pipeline_create_runner.py +0 -0
  81. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/pipeline/lp_training_pipeline.py +0 -0
  82. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/pipeline/nc_pipeline_create_runner.py +0 -0
  83. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/pipeline/nc_training_pipeline.py +0 -0
  84. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/pipeline/nr_pipeline_create_runner.py +0 -0
  85. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/pipeline/nr_training_pipeline.py +0 -0
  86. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/pipeline/pipeline_alpha_proc_runner.py +0 -0
  87. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/pipeline/pipeline_beta_proc_runner.py +0 -0
  88. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/pipeline/pipeline_endpoints.py +0 -0
  89. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/pipeline/pipeline_proc_runner.py +0 -0
  90. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/pipeline/training_pipeline.py +0 -0
  91. {graphdatascience-1.16/graphdatascience/query_runner → graphdatascience-1.17a1/graphdatascience/procedure_surface}/__init__.py +0 -0
  92. {graphdatascience-1.16/graphdatascience/query_runner/progress → graphdatascience-1.17a1/graphdatascience/procedure_surface/arrow}/__init__.py +0 -0
  93. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/py.typed +0 -0
  94. {graphdatascience-1.16/graphdatascience/query_runner/protocol → graphdatascience-1.17a1/graphdatascience/query_runner}/__init__.py +0 -0
  95. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/query_runner/arrow_authentication.py +0 -0
  96. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/query_runner/arrow_endpoint_version.py +0 -0
  97. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/query_runner/arrow_graph_constructor.py +0 -0
  98. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/query_runner/arrow_info.py +0 -0
  99. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/query_runner/graph_constructor.py +0 -0
  100. {graphdatascience-1.16/graphdatascience/resources → graphdatascience-1.17a1/graphdatascience/query_runner/progress}/__init__.py +0 -0
  101. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/query_runner/progress/progress_provider.py +0 -0
  102. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/query_runner/progress/query_progress_logger.py +0 -0
  103. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/query_runner/progress/query_progress_provider.py +0 -0
  104. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/query_runner/progress/static_progress_provider.py +0 -0
  105. {graphdatascience-1.16/graphdatascience/resources/cora → graphdatascience-1.17a1/graphdatascience/query_runner/protocol}/__init__.py +0 -0
  106. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/query_runner/protocol/project_protocols.py +0 -0
  107. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/query_runner/protocol/status.py +0 -0
  108. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/query_runner/protocol/write_protocols.py +0 -0
  109. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/query_runner/termination_flag.py +0 -0
  110. {graphdatascience-1.16/graphdatascience/resources/imdb → graphdatascience-1.17a1/graphdatascience/resources}/__init__.py +0 -0
  111. {graphdatascience-1.16/graphdatascience/resources/karate → graphdatascience-1.17a1/graphdatascience/resources/cora}/__init__.py +0 -0
  112. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/resources/cora/cora_nodes.parquet.gzip +0 -0
  113. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/resources/cora/cora_rels.parquet.gzip +0 -0
  114. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/resources/cora/serialize_cora.py +0 -0
  115. {graphdatascience-1.16/graphdatascience/resources/lastfm → graphdatascience-1.17a1/graphdatascience/resources/imdb}/__init__.py +0 -0
  116. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/resources/imdb/imdb_acted_in.parquet.gzip +0 -0
  117. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/resources/imdb/imdb_actors.parquet.gzip +0 -0
  118. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/resources/imdb/imdb_directed_in.parquet.gzip +0 -0
  119. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/resources/imdb/imdb_directors.parquet.gzip +0 -0
  120. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/resources/imdb/imdb_movies_with_genre.parquet.gzip +0 -0
  121. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/resources/imdb/imdb_movies_without_genre.parquet.gzip +0 -0
  122. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/resources/imdb/serialize_imdb.py +0 -0
  123. {graphdatascience-1.16/graphdatascience/retry_utils → graphdatascience-1.17a1/graphdatascience/resources/karate}/__init__.py +0 -0
  124. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/resources/karate/karate_club.parquet.gzip +0 -0
  125. {graphdatascience-1.16/graphdatascience/semantic_version → graphdatascience-1.17a1/graphdatascience/resources/lastfm}/__init__.py +0 -0
  126. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/resources/lastfm/artist_nodes.parquet.gzip +0 -0
  127. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/resources/lastfm/serialize_lastfm.py +0 -0
  128. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/resources/lastfm/user_friend_df_directed.parquet.gzip +0 -0
  129. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/resources/lastfm/user_listen_artist_rels.parquet.gzip +0 -0
  130. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/resources/lastfm/user_nodes.parquet.gzip +0 -0
  131. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/resources/lastfm/user_tag_artist_rels.parquet.gzip +0 -0
  132. {graphdatascience-1.16/graphdatascience/server_version → graphdatascience-1.17a1/graphdatascience/retry_utils}/__init__.py +0 -0
  133. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/retry_utils/retry_config.py +0 -0
  134. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/retry_utils/retry_utils.py +0 -0
  135. {graphdatascience-1.16/graphdatascience/session/dbms → graphdatascience-1.17a1/graphdatascience/semantic_version}/__init__.py +0 -0
  136. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/semantic_version/semantic_version.py +0 -0
  137. {graphdatascience-1.16/graphdatascience/system → graphdatascience-1.17a1/graphdatascience/server_version}/__init__.py +0 -0
  138. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/server_version/compatible_with.py +0 -0
  139. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/server_version/server_version.py +0 -0
  140. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/session/__init__.py +0 -0
  141. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/session/algorithm_category.py +0 -0
  142. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/session/aura_api_responses.py +0 -0
  143. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/session/aura_api_token_authentication.py +0 -0
  144. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/session/aurads_sessions.py +0 -0
  145. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/session/cloud_location.py +0 -0
  146. {graphdatascience-1.16/graphdatascience/topological_lp → graphdatascience-1.17a1/graphdatascience/session/dbms}/__init__.py +0 -0
  147. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/session/dbms/protocol_version.py +0 -0
  148. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/session/dbms_connection_info.py +0 -0
  149. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/session/region_suggester.py +0 -0
  150. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/session/session_info.py +0 -0
  151. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/session/session_sizes.py +0 -0
  152. {graphdatascience-1.16/graphdatascience/utils → graphdatascience-1.17a1/graphdatascience/system}/__init__.py +0 -0
  153. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/system/config_endpoints.py +0 -0
  154. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/system/system_endpoints.py +0 -0
  155. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/topological_lp/topological_lp_alpha_runner.py +0 -0
  156. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/topological_lp/topological_lp_endpoints.py +0 -0
  157. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/utils/util_node_property_func_runner.py +0 -0
  158. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience/utils/util_proc_runner.py +0 -0
  159. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience.egg-info/dependency_links.txt +0 -0
  160. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience.egg-info/not-zip-safe +0 -0
  161. {graphdatascience-1.16 → graphdatascience-1.17a1}/graphdatascience.egg-info/top_level.txt +0 -0
  162. {graphdatascience-1.16 → graphdatascience-1.17a1}/pyproject.toml +0 -0
  163. {graphdatascience-1.16 → graphdatascience-1.17a1}/requirements/base/networkx.txt +0 -0
  164. {graphdatascience-1.16 → graphdatascience-1.17a1}/requirements/base/ogb.txt +0 -0
  165. {graphdatascience-1.16 → graphdatascience-1.17a1}/requirements/base/rust-ext.txt +0 -0
  166. {graphdatascience-1.16 → graphdatascience-1.17a1}/setup.cfg +0 -0
  167. {graphdatascience-1.16 → graphdatascience-1.17a1}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: graphdatascience
3
- Version: 1.16
3
+ Version: 1.17a1
4
4
  Summary: A Python client for the Neo4j Graph Data Science (GDS) library
5
5
  Home-page: https://neo4j.com/product/graph-data-science/
6
6
  Author: Neo4j
@@ -38,6 +38,7 @@ Requires-Dist: tqdm<5.0,>=4.0
38
38
  Requires-Dist: typing-extensions<5.0,>=4.0
39
39
  Requires-Dist: requests
40
40
  Requires-Dist: tenacity>=9.0
41
+ Requires-Dist: pydantic>=2.11
41
42
  Provides-Extra: ogb
42
43
  Requires-Dist: ogb<2.0,>=1.0; extra == "ogb"
43
44
  Provides-Extra: networkx
@@ -6,6 +6,8 @@ from typing import Any, Optional
6
6
 
7
7
  from pandas import Series
8
8
 
9
+ from graphdatascience.query_runner.query_mode import QueryMode
10
+
9
11
  from ..caller_base import CallerBase
10
12
  from ..query_runner.query_runner import QueryRunner
11
13
  from ..server_version.server_version import ServerVersion
@@ -46,7 +48,7 @@ class GraphCypherRunner(CallerBase):
46
48
  GraphCypherRunner._verify_query_ends_with_return_clause(self._namespace, query)
47
49
 
48
50
  result: Optional[dict[str, Any]] = self._query_runner.run_retryable_cypher(
49
- query, params, database, custom_error=False
51
+ query, params, database, custom_error=False, mode=QueryMode.READ
50
52
  ).squeeze()
51
53
 
52
54
  if not result:
@@ -5,6 +5,8 @@ from warnings import filterwarnings
5
5
  import pandas as pd
6
6
  from pandas import DataFrame, Series
7
7
 
8
+ from graphdatascience.query_runner.query_mode import QueryMode
9
+
8
10
  from ..call_parameters import CallParameters
9
11
  from ..error.cypher_warning_handler import (
10
12
  filter_id_func_deprecation_warning,
@@ -164,7 +166,9 @@ class GraphNodePropertiesRunner(GraphEntityOpsBaseRunner):
164
166
  unique_node_ids = result["nodeId"].drop_duplicates().tolist()
165
167
 
166
168
  db_properties_df = query_runner.run_retryable_cypher(
167
- GraphNodePropertiesRunner._build_query(db_node_properties), params={"ids": unique_node_ids}
169
+ GraphNodePropertiesRunner._build_query(db_node_properties),
170
+ params={"ids": unique_node_ids},
171
+ mode=QueryMode.READ,
168
172
  )
169
173
 
170
174
  if "propertyValue" not in result.keys():
@@ -68,7 +68,7 @@ class Graph:
68
68
  """
69
69
  return self._graph_info(["database"]) # type: ignore
70
70
 
71
- def configuration(self) -> "Series[Any]":
71
+ def configuration(self) -> Series[Any]:
72
72
  """
73
73
  Returns:
74
74
  the configuration of the graph
@@ -9,6 +9,7 @@ from neo4j import Driver
9
9
  from pandas import DataFrame
10
10
 
11
11
  from graphdatascience.query_runner.arrow_authentication import UsernamePasswordAuthentication
12
+ from graphdatascience.query_runner.query_mode import QueryMode
12
13
 
13
14
  from .call_builder import IndirectCallBuilder
14
15
  from .endpoints import AlphaEndpoints, BetaEndpoints, DirectEndpoints
@@ -41,6 +42,7 @@ class GraphDataScience(DirectEndpoints, UncallableNamespace):
41
42
  arrow_tls_root_certs: Optional[bytes] = None,
42
43
  bookmarks: Optional[Any] = None,
43
44
  show_progress: bool = True,
45
+ arrow_client_options: Optional[dict[str, Any]] = None,
44
46
  ):
45
47
  """
46
48
  Construct a new GraphDataScience object.
@@ -64,14 +66,22 @@ class GraphDataScience(DirectEndpoints, UncallableNamespace):
64
66
  - True will make the client discover the connection URI to the GDS Arrow server via the Neo4j endpoint.
65
67
  - False will make the client use Bolt for all operations.
66
68
  arrow_disable_server_verification : bool, default True
69
+ .. deprecated:: 1.16
70
+ Use arrow_client_options instead
71
+
67
72
  A flag that overrides other TLS settings and disables server verification for TLS connections.
68
73
  arrow_tls_root_certs : Optional[bytes], default None
74
+ .. deprecated:: 1.16
75
+ Use arrow_client_options instead
76
+
69
77
  PEM-encoded certificates that are used for the connection to the
70
78
  GDS Arrow Flight server.
71
79
  bookmarks : Optional[Any], default None
72
80
  The Neo4j bookmarks to require a certain state before the next query gets executed.
73
81
  show_progress : bool, default True
74
82
  A flag to indicate whether to show progress bars for running procedures.
83
+ arrow_client_options : Optional[dict[str, Any]], default None
84
+ Additional options to be passed to the Arrow Flight client.
75
85
  """
76
86
  if aura_ds:
77
87
  GraphDataScience._validate_endpoint(endpoint)
@@ -104,14 +114,19 @@ class GraphDataScience(DirectEndpoints, UncallableNamespace):
104
114
  username, password = auth
105
115
  arrow_auth = UsernamePasswordAuthentication(username, password)
106
116
 
117
+ if arrow_client_options is None:
118
+ arrow_client_options = {}
119
+ if arrow_disable_server_verification:
120
+ arrow_client_options["disable_server_verification"] = True
121
+ if arrow_tls_root_certs is not None:
122
+ arrow_client_options["tls_root_certs"] = arrow_tls_root_certs
107
123
  self._query_runner = ArrowQueryRunner.create(
108
124
  self._query_runner,
109
- arrow_info,
110
- arrow_auth,
111
- self._query_runner.encrypted(),
112
- arrow_disable_server_verification,
113
- arrow_tls_root_certs,
114
- None if arrow is True else arrow,
125
+ arrow_info=arrow_info,
126
+ arrow_authentication=arrow_auth,
127
+ encrypted=self._query_runner.encrypted(),
128
+ arrow_client_options=arrow_client_options,
129
+ connection_string_override=None if arrow is True else arrow,
115
130
  )
116
131
 
117
132
  self._query_runner.set_show_progress(show_progress)
@@ -199,7 +214,12 @@ class GraphDataScience(DirectEndpoints, UncallableNamespace):
199
214
  return self._query_runner.last_bookmarks()
200
215
 
201
216
  def run_cypher(
202
- self, query: str, params: Optional[dict[str, Any]] = None, database: Optional[str] = None
217
+ self,
218
+ query: str,
219
+ params: Optional[dict[str, Any]] = None,
220
+ database: Optional[str] = None,
221
+ retryable: bool = False,
222
+ mode: QueryMode = QueryMode.WRITE,
203
223
  ) -> DataFrame:
204
224
  """
205
225
  Run a Cypher query
@@ -212,6 +232,10 @@ class GraphDataScience(DirectEndpoints, UncallableNamespace):
212
232
  parameters to the query
213
233
  database: str
214
234
  the database on which to run the query
235
+ retryable: bool
236
+ whether the query can be automatically retried. Make sure the query is idempotent if set to True.
237
+ mode: QueryMode
238
+ the query mode to use (READ or WRITE). Set based on the operation performed in the query.
215
239
 
216
240
  Returns:
217
241
  The query result as a DataFrame
@@ -222,8 +246,10 @@ class GraphDataScience(DirectEndpoints, UncallableNamespace):
222
246
  if isinstance(self._query_runner, ArrowQueryRunner):
223
247
  qr = self._query_runner.fallback_query_runner()
224
248
 
225
- # not using qr.run_retryable_cypher as we dont know if it can be retried
226
- return qr.run_cypher(query, params, database, False)
249
+ if retryable:
250
+ return qr.run_retryable_cypher(query, params, database, custom_error=False, mode=mode)
251
+ else:
252
+ return qr.run_cypher(query, params, database, custom_error=False, mode=mode)
227
253
 
228
254
  def driver_config(self) -> dict[str, Any]:
229
255
  """
@@ -240,11 +266,51 @@ class GraphDataScience(DirectEndpoints, UncallableNamespace):
240
266
  driver: Driver,
241
267
  auth: Optional[tuple[str, str]] = None,
242
268
  database: Optional[str] = None,
243
- arrow: bool = True,
269
+ arrow: Union[str, bool] = True,
244
270
  arrow_disable_server_verification: bool = True,
245
271
  arrow_tls_root_certs: Optional[bytes] = None,
246
272
  bookmarks: Optional[Any] = None,
247
- ) -> "GraphDataScience":
273
+ arrow_client_options: Optional[dict[str, Any]] = None,
274
+ ) -> GraphDataScience:
275
+ """
276
+ Construct a new GraphDataScience object from an existing Neo4j Driver.
277
+ This method is useful when you already have a Neo4j Driver instance and want to use it with the GDS client.
278
+
279
+ Parameters
280
+ ----------
281
+ driver: Driver
282
+ The Neo4j Driver instance to use.
283
+ auth : Optional[Tuple[str, str]], default None
284
+ A username, password pair for authentication.
285
+ database: Optional[str], default None
286
+ The Neo4j database to query against.
287
+ arrow : Union[str, bool], default True
288
+ Arrow connection information. This is either a string or a bool.
289
+
290
+ - If it is a string, it will be interpreted as a connection URL to a GDS Arrow Server.
291
+ - If it is a bool:
292
+ - True will make the client discover the connection URI to the GDS Arrow server via the Neo4j endpoint.
293
+ - False will make the client use Bolt for all operations.
294
+ arrow_disable_server_verification : bool, default True
295
+ .. deprecated:: 1.16
296
+ Use arrow_client_options instead
297
+
298
+ A flag that overrides other TLS settings and disables server verification for TLS connections.
299
+ arrow_tls_root_certs : Optional[bytes], default None
300
+ .. deprecated:: 1.16
301
+ Use arrow_client_options instead
302
+
303
+ PEM-encoded certificates that are used for the connection to the
304
+ GDS Arrow Flight server.
305
+ bookmarks : Optional[Any], default None
306
+ The Neo4j bookmarks to require a certain state before the next query gets executed.
307
+ show_progress : bool, default True
308
+ A flag to indicate whether to show progress bars for running procedures.
309
+ arrow_client_options : Optional[dict[str, Any]], default None
310
+ Additional options to be passed to the Arrow Flight client.
311
+ Returns:
312
+ A new GraphDataScience object. configured with the provided Neo4j Driver.
313
+ """
248
314
  return cls(
249
315
  driver,
250
316
  auth=auth,
@@ -253,6 +319,7 @@ class GraphDataScience(DirectEndpoints, UncallableNamespace):
253
319
  arrow_disable_server_verification=arrow_disable_server_verification,
254
320
  arrow_tls_root_certs=arrow_tls_root_certs,
255
321
  bookmarks=bookmarks,
322
+ arrow_client_options=arrow_client_options,
256
323
  )
257
324
 
258
325
  @staticmethod
@@ -0,0 +1,191 @@
1
+ import json
2
+ from typing import Any, List, Optional
3
+
4
+ from pandas import DataFrame
5
+
6
+ from ...arrow_client.authenticated_flight_client import AuthenticatedArrowClient
7
+ from ...arrow_client.v2.data_mapper_utils import deserialize_single
8
+ from ...arrow_client.v2.job_client import JobClient
9
+ from ...arrow_client.v2.mutation_client import MutationClient
10
+ from ...arrow_client.v2.write_back_client import WriteBackClient
11
+ from ...graph.graph_object import Graph
12
+ from ..api.estimation_result import EstimationResult
13
+ from ..api.wcc_endpoints import WccEndpoints, WccMutateResult, WccStatsResult, WccWriteResult
14
+ from ..utils.config_converter import ConfigConverter
15
+
16
+ WCC_ENDPOINT = "v2/community.wcc"
17
+
18
+
19
+ class WccArrowEndpoints(WccEndpoints):
20
+ def __init__(self, arrow_client: AuthenticatedArrowClient, write_back_client: Optional[WriteBackClient] = None):
21
+ self._arrow_client = arrow_client
22
+ self._write_back_client = write_back_client
23
+
24
+ def mutate(
25
+ self,
26
+ G: Graph,
27
+ mutate_property: str,
28
+ threshold: Optional[float] = None,
29
+ relationship_types: Optional[List[str]] = None,
30
+ node_labels: Optional[List[str]] = None,
31
+ sudo: Optional[bool] = None,
32
+ log_progress: Optional[bool] = None,
33
+ username: Optional[str] = None,
34
+ concurrency: Optional[int] = None,
35
+ job_id: Optional[str] = None,
36
+ seed_property: Optional[str] = None,
37
+ consecutive_ids: Optional[bool] = None,
38
+ relationship_weight_property: Optional[str] = None,
39
+ ) -> WccMutateResult:
40
+ config = ConfigConverter.convert_to_gds_config(
41
+ graph_name=G.name(),
42
+ concurrency=concurrency,
43
+ consecutive_ids=consecutive_ids,
44
+ job_id=job_id,
45
+ log_progress=log_progress,
46
+ node_labels=node_labels,
47
+ relationship_types=relationship_types,
48
+ relationship_weight_property=relationship_weight_property,
49
+ seed_property=seed_property,
50
+ sudo=sudo,
51
+ threshold=threshold,
52
+ )
53
+
54
+ job_id = JobClient.run_job_and_wait(self._arrow_client, WCC_ENDPOINT, config)
55
+
56
+ mutate_result = MutationClient.mutate_node_property(self._arrow_client, job_id, mutate_property)
57
+ computation_result = JobClient.get_summary(self._arrow_client, job_id)
58
+
59
+ computation_result["nodePropertiesWritten"] = mutate_result.node_properties_written
60
+ computation_result["mutateMillis"] = 0
61
+
62
+ return WccMutateResult(**computation_result)
63
+
64
+ def stats(
65
+ self,
66
+ G: Graph,
67
+ threshold: Optional[float] = None,
68
+ relationship_types: Optional[List[str]] = None,
69
+ node_labels: Optional[List[str]] = None,
70
+ sudo: Optional[bool] = None,
71
+ log_progress: Optional[bool] = None,
72
+ username: Optional[str] = None,
73
+ concurrency: Optional[int] = None,
74
+ job_id: Optional[str] = None,
75
+ seed_property: Optional[str] = None,
76
+ consecutive_ids: Optional[bool] = None,
77
+ relationship_weight_property: Optional[str] = None,
78
+ ) -> WccStatsResult:
79
+ config = ConfigConverter.convert_to_gds_config(
80
+ graph_name=G.name(),
81
+ concurrency=concurrency,
82
+ consecutive_ids=consecutive_ids,
83
+ job_id=job_id,
84
+ log_progress=log_progress,
85
+ node_labels=node_labels,
86
+ relationship_types=relationship_types,
87
+ relationship_weight_property=relationship_weight_property,
88
+ seed_property=seed_property,
89
+ sudo=sudo,
90
+ threshold=threshold,
91
+ )
92
+
93
+ job_id = JobClient.run_job_and_wait(self._arrow_client, WCC_ENDPOINT, config)
94
+ computation_result = JobClient.get_summary(self._arrow_client, job_id)
95
+
96
+ return WccStatsResult(**computation_result)
97
+
98
+ def stream(
99
+ self,
100
+ G: Graph,
101
+ min_component_size: Optional[int] = None,
102
+ threshold: Optional[float] = None,
103
+ relationship_types: Optional[List[str]] = None,
104
+ node_labels: Optional[List[str]] = None,
105
+ sudo: Optional[bool] = None,
106
+ log_progress: Optional[bool] = None,
107
+ username: Optional[str] = None,
108
+ concurrency: Optional[int] = None,
109
+ job_id: Optional[str] = None,
110
+ seed_property: Optional[str] = None,
111
+ consecutive_ids: Optional[bool] = None,
112
+ relationship_weight_property: Optional[str] = None,
113
+ ) -> DataFrame:
114
+ config = ConfigConverter.convert_to_gds_config(
115
+ graph_name=G.name(),
116
+ concurrency=concurrency,
117
+ consecutive_ids=consecutive_ids,
118
+ job_id=job_id,
119
+ log_progress=log_progress,
120
+ min_component_size=min_component_size,
121
+ node_labels=node_labels,
122
+ relationship_types=relationship_types,
123
+ relationship_weight_property=relationship_weight_property,
124
+ seed_property=seed_property,
125
+ sudo=sudo,
126
+ threshold=threshold,
127
+ )
128
+
129
+ job_id = JobClient.run_job_and_wait(self._arrow_client, WCC_ENDPOINT, config)
130
+ return JobClient.stream_results(self._arrow_client, G.name(), job_id)
131
+
132
+ def write(
133
+ self,
134
+ G: Graph,
135
+ write_property: str,
136
+ min_component_size: Optional[int] = None,
137
+ threshold: Optional[float] = None,
138
+ relationship_types: Optional[List[str]] = None,
139
+ node_labels: Optional[List[str]] = None,
140
+ sudo: Optional[bool] = None,
141
+ log_progress: Optional[bool] = None,
142
+ username: Optional[str] = None,
143
+ concurrency: Optional[int] = None,
144
+ job_id: Optional[str] = None,
145
+ seed_property: Optional[str] = None,
146
+ consecutive_ids: Optional[bool] = None,
147
+ relationship_weight_property: Optional[str] = None,
148
+ write_concurrency: Optional[int] = None,
149
+ ) -> WccWriteResult:
150
+ config = ConfigConverter.convert_to_gds_config(
151
+ graph_name=G.name(),
152
+ concurrency=concurrency,
153
+ consecutive_ids=consecutive_ids,
154
+ job_id=job_id,
155
+ log_progress=log_progress,
156
+ min_component_size=min_component_size,
157
+ node_labels=node_labels,
158
+ relationship_types=relationship_types,
159
+ relationship_weight_property=relationship_weight_property,
160
+ seed_property=seed_property,
161
+ sudo=sudo,
162
+ threshold=threshold,
163
+ )
164
+
165
+ job_id = JobClient.run_job_and_wait(self._arrow_client, WCC_ENDPOINT, config)
166
+ computation_result = JobClient.get_summary(self._arrow_client, job_id)
167
+
168
+ if self._write_back_client is None:
169
+ raise Exception("Write back client is not initialized")
170
+
171
+ write_millis = self._write_back_client.write(
172
+ G.name(), job_id, write_concurrency if write_concurrency is not None else concurrency
173
+ )
174
+
175
+ computation_result["writeMillis"] = write_millis
176
+
177
+ return WccWriteResult(**computation_result)
178
+
179
+ def estimate(
180
+ self, G: Optional[Graph] = None, projection_config: Optional[dict[str, Any]] = None
181
+ ) -> EstimationResult:
182
+ if G is not None:
183
+ payload = {"graphName": G.name()}
184
+ elif projection_config is not None:
185
+ payload = projection_config
186
+ else:
187
+ raise ValueError("Either graph_name or projection_config must be provided.")
188
+
189
+ res = self._arrow_client.do_action_with_retry("v2/community.wcc.estimate", json.dumps(payload).encode("utf-8"))
190
+
191
+ return EstimationResult(**deserialize_single(res))
@@ -25,8 +25,7 @@ class ArrowQueryRunner(QueryRunner):
25
25
  arrow_info: ArrowInfo,
26
26
  arrow_authentication: Optional[ArrowAuthentication] = None,
27
27
  encrypted: bool = False,
28
- disable_server_verification: bool = False,
29
- tls_root_certs: Optional[bytes] = None,
28
+ arrow_client_options: Optional[dict[str, Any]] = None,
30
29
  connection_string_override: Optional[str] = None,
31
30
  retry_config: Optional[RetryConfig] = None,
32
31
  ) -> ArrowQueryRunner:
@@ -34,13 +33,12 @@ class ArrowQueryRunner(QueryRunner):
34
33
  raise ValueError("Arrow is not enabled on the server")
35
34
 
36
35
  gds_arrow_client = GdsArrowClient.create(
37
- arrow_info,
38
- arrow_authentication,
39
- encrypted,
40
- disable_server_verification,
41
- tls_root_certs,
42
- connection_string_override,
36
+ arrow_info=arrow_info,
37
+ auth=arrow_authentication,
38
+ encrypted=encrypted,
39
+ connection_string_override=connection_string_override,
43
40
  retry_config=retry_config,
41
+ arrow_client_options=arrow_client_options,
44
42
  )
45
43
 
46
44
  return ArrowQueryRunner(gds_arrow_client, fallback_query_runner, fallback_query_runner.server_version())
@@ -65,18 +63,22 @@ class ArrowQueryRunner(QueryRunner):
65
63
  query: str,
66
64
  params: Optional[dict[str, Any]] = None,
67
65
  database: Optional[str] = None,
66
+ mode: Optional[QueryMode] = None,
68
67
  custom_error: bool = True,
69
68
  ) -> DataFrame:
70
- return self._fallback_query_runner.run_cypher(query, params, database, custom_error)
69
+ return self._fallback_query_runner.run_cypher(query, params, database, mode, custom_error=custom_error)
71
70
 
72
71
  def run_retryable_cypher(
73
72
  self,
74
73
  query: str,
75
74
  params: Optional[dict[str, Any]] = None,
76
75
  database: Optional[str] = None,
76
+ mode: Optional[QueryMode] = None,
77
77
  custom_error: bool = True,
78
78
  ) -> DataFrame:
79
- return self._fallback_query_runner.run_retryable_cypher(query, params, database, custom_error=custom_error)
79
+ return self._fallback_query_runner.run_retryable_cypher(
80
+ query, params, database, mode, custom_error=custom_error
81
+ )
80
82
 
81
83
  def call_function(self, endpoint: str, params: Optional[CallParameters] = None) -> Any:
82
84
  return self._fallback_query_runner.call_function(endpoint, params)
@@ -7,6 +7,8 @@ from uuid import uuid4
7
7
 
8
8
  from pandas import DataFrame, concat
9
9
 
10
+ from graphdatascience.query_runner.query_mode import QueryMode
11
+
10
12
  from ..server_version.server_version import ServerVersion
11
13
  from .graph_constructor import GraphConstructor
12
14
  from .query_runner import QueryRunner
@@ -105,7 +107,9 @@ class CypherGraphConstructor(GraphConstructor):
105
107
  def _should_warn_about_arrow_missing(self) -> bool:
106
108
  try:
107
109
  license: str = self._query_runner.run_retryable_cypher(
108
- "CALL gds.debug.sysInfo() YIELD key, value WHERE key = 'gdsEdition' RETURN value", custom_error=False
110
+ "CALL gds.debug.sysInfo() YIELD key, value WHERE key = 'gdsEdition' RETURN value",
111
+ custom_error=False,
112
+ mode=QueryMode.READ,
109
113
  ).squeeze()
110
114
  should_warn = license == "Licensed"
111
115
  except Exception as e:
@@ -55,6 +55,7 @@ class GdsArrowClient:
55
55
  tls_root_certs: Optional[bytes] = None,
56
56
  connection_string_override: Optional[str] = None,
57
57
  retry_config: Optional[RetryConfig] = None,
58
+ arrow_client_options: Optional[dict[str, Any]] = None,
58
59
  ) -> GdsArrowClient:
59
60
  connection_string: str
60
61
  if connection_string_override is not None:
@@ -78,14 +79,15 @@ class GdsArrowClient:
78
79
  )
79
80
 
80
81
  return GdsArrowClient(
81
- host,
82
- retry_config,
83
- int(port),
84
- auth,
85
- encrypted,
86
- disable_server_verification,
87
- tls_root_certs,
88
- arrow_endpoint_version,
82
+ host=host,
83
+ retry_config=retry_config,
84
+ port=int(port),
85
+ auth=auth,
86
+ encrypted=encrypted,
87
+ disable_server_verification=disable_server_verification,
88
+ tls_root_certs=tls_root_certs,
89
+ arrow_endpoint_version=arrow_endpoint_version,
90
+ arrow_client_options=arrow_client_options,
89
91
  )
90
92
 
91
93
  def __init__(
@@ -99,6 +101,7 @@ class GdsArrowClient:
99
101
  tls_root_certs: Optional[bytes] = None,
100
102
  arrow_endpoint_version: ArrowEndpointVersion = ArrowEndpointVersion.V1,
101
103
  user_agent: Optional[str] = None,
104
+ arrow_client_options: Optional[dict[str, Any]] = None,
102
105
  ):
103
106
  """Creates a new GdsArrowClient instance.
104
107
 
@@ -113,8 +116,12 @@ class GdsArrowClient:
113
116
  encrypted: bool
114
117
  A flag that indicates whether the connection should be encrypted (default is False)
115
118
  disable_server_verification: bool
119
+ .. deprecated:: 1.16
120
+ Use arrow_client_options instead
116
121
  A flag that disables server verification for TLS connections (default is False)
117
122
  tls_root_certs: Optional[bytes]
123
+ .. deprecated:: 1.16
124
+ Use arrow_client_options instead
118
125
  PEM-encoded certificates that are used for the connection to the GDS Arrow Flight server
119
126
  arrow_endpoint_version:
120
127
  The version of the Arrow endpoint to use (default is ArrowEndpointVersion.V1)
@@ -122,18 +129,26 @@ class GdsArrowClient:
122
129
  The user agent string to use for the connection. (default is `neo4j-graphdatascience-v[VERSION] pyarrow-v[PYARROW_VERSION])
123
130
  retry_config: Optional[RetryConfig]
124
131
  The retry configuration to use for the Arrow requests send by the client.
132
+ arrow_client_options: Optional[dict[str, Any]]
133
+ Additional configuration for the Arrow flight client.
134
+
125
135
  """
126
136
  self._arrow_endpoint_version = arrow_endpoint_version
127
137
  self._host = host
128
138
  self._port = port
129
139
  self._auth = None
130
140
  self._encrypted = encrypted
131
- self._disable_server_verification = disable_server_verification
132
- self._tls_root_certs = tls_root_certs
133
141
  self._user_agent = user_agent
134
142
  self._retry_config = retry_config
135
143
  self._logger = logging.getLogger("gds_arrow_client")
136
144
 
145
+ self._arrow_client_options = arrow_client_options if arrow_client_options is not None else {}
146
+
147
+ if disable_server_verification:
148
+ self._arrow_client_options["disable_server_verification"] = True
149
+ if tls_root_certs is not None:
150
+ self._arrow_client_options["tls_root_certs"] = tls_root_certs
151
+
137
152
  if auth:
138
153
  if not isinstance(auth, ArrowAuthentication):
139
154
  username, password = auth
@@ -149,18 +164,27 @@ class GdsArrowClient:
149
164
  if self._encrypted
150
165
  else flight.Location.for_grpc_tcp(self._host, self._port)
151
166
  )
152
- client_options: dict[str, Any] = {"disable_server_verification": self._disable_server_verification}
167
+
168
+ client_options = self._arrow_client_options.copy()
169
+
153
170
  if self._auth:
154
171
  user_agent = f"neo4j-graphdatascience-v{__version__} pyarrow-v{arrow_version}"
155
172
  if self._user_agent:
156
173
  user_agent = self._user_agent
157
174
 
158
- client_options["middleware"] = [
159
- AuthFactory(self._auth_middleware),
160
- UserAgentFactory(useragent=user_agent),
161
- ]
162
- if self._tls_root_certs:
163
- client_options["tls_root_certs"] = self._tls_root_certs
175
+ if "middleware" in client_options:
176
+ if not isinstance(client_options["middleware"], list):
177
+ raise TypeError("client_options['middleware'] must be a list")
178
+ else:
179
+ client_options["middleware"] = []
180
+
181
+ client_options["middleware"].extend(
182
+ [
183
+ AuthFactory(self._auth_middleware),
184
+ UserAgentFactory(useragent=user_agent),
185
+ ]
186
+ )
187
+
164
188
  return flight.FlightClient(location, **client_options)
165
189
 
166
190
  def connection_info(self) -> tuple[str, int]: