robosystems-client 0.1.17__py3-none-any.whl → 0.1.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of robosystems-client might be problematic. Click here for more details.

Files changed (92) hide show
  1. robosystems_client/__init__.py +15 -4
  2. robosystems_client/api/agent/auto_select_agent.py +25 -0
  3. robosystems_client/api/agent/batch_process_queries.py +25 -0
  4. robosystems_client/api/agent/execute_specific_agent.py +25 -0
  5. robosystems_client/api/agent/get_agent_metadata.py +25 -0
  6. robosystems_client/api/agent/list_agents.py +20 -0
  7. robosystems_client/api/agent/recommend_agent.py +25 -0
  8. robosystems_client/api/backup/create_backup.py +25 -0
  9. robosystems_client/api/backup/export_backup.py +25 -0
  10. robosystems_client/api/backup/get_backup_download_url.py +20 -0
  11. robosystems_client/api/backup/get_backup_stats.py +25 -0
  12. robosystems_client/api/backup/list_backups.py +20 -0
  13. robosystems_client/api/backup/restore_backup.py +25 -0
  14. robosystems_client/api/connections/create_connection.py +25 -0
  15. robosystems_client/api/connections/create_link_token.py +25 -0
  16. robosystems_client/api/connections/delete_connection.py +25 -0
  17. robosystems_client/api/connections/exchange_link_token.py +25 -0
  18. robosystems_client/api/connections/get_connection.py +25 -0
  19. robosystems_client/api/connections/get_connection_options.py +25 -0
  20. robosystems_client/api/connections/init_o_auth.py +25 -0
  21. robosystems_client/api/connections/list_connections.py +20 -0
  22. robosystems_client/api/connections/oauth_callback.py +25 -0
  23. robosystems_client/api/connections/sync_connection.py +25 -0
  24. robosystems_client/api/copy/copy_data_to_graph.py +25 -0
  25. robosystems_client/api/create/create_graph.py +25 -0
  26. robosystems_client/api/graph_analytics/get_graph_metrics.py +25 -0
  27. robosystems_client/api/graph_analytics/get_graph_usage_stats.py +20 -0
  28. robosystems_client/api/graph_billing/get_current_graph_bill.py +25 -0
  29. robosystems_client/api/graph_billing/get_graph_billing_history.py +20 -0
  30. robosystems_client/api/graph_billing/get_graph_monthly_bill.py +25 -0
  31. robosystems_client/api/graph_billing/get_graph_usage_details.py +20 -0
  32. robosystems_client/api/graph_credits/check_credit_balance.py +20 -0
  33. robosystems_client/api/graph_credits/check_storage_limits.py +25 -0
  34. robosystems_client/api/graph_credits/get_credit_summary.py +25 -0
  35. robosystems_client/api/graph_credits/get_storage_usage.py +20 -0
  36. robosystems_client/api/graph_credits/list_credit_transactions.py +20 -0
  37. robosystems_client/api/graph_health/get_database_health.py +25 -0
  38. robosystems_client/api/graph_info/get_database_info.py +25 -0
  39. robosystems_client/api/graph_limits/get_graph_limits.py +25 -0
  40. robosystems_client/api/mcp/call_mcp_tool.py +20 -0
  41. robosystems_client/api/mcp/list_mcp_tools.py +25 -0
  42. robosystems_client/api/operations/cancel_operation.py +25 -0
  43. robosystems_client/api/operations/get_operation_status.py +25 -0
  44. robosystems_client/api/operations/stream_operation_events.py +20 -0
  45. robosystems_client/api/query/execute_cypher_query.py +20 -0
  46. robosystems_client/api/schema/export_graph_schema.py +20 -0
  47. robosystems_client/api/schema/get_graph_schema_info.py +25 -0
  48. robosystems_client/api/schema/list_schema_extensions.py +25 -0
  49. robosystems_client/api/schema/validate_schema.py +25 -0
  50. robosystems_client/api/subgraphs/create_subgraph.py +25 -0
  51. robosystems_client/api/subgraphs/delete_subgraph.py +25 -0
  52. robosystems_client/api/subgraphs/get_subgraph_info.py +25 -0
  53. robosystems_client/api/subgraphs/get_subgraph_quota.py +25 -0
  54. robosystems_client/api/subgraphs/list_subgraphs.py +25 -0
  55. robosystems_client/api/user/create_user_api_key.py +25 -0
  56. robosystems_client/api/user/get_all_credit_summaries.py +25 -0
  57. robosystems_client/api/user/get_current_user.py +25 -0
  58. robosystems_client/api/user/get_user_graphs.py +25 -0
  59. robosystems_client/api/user/list_user_api_keys.py +25 -0
  60. robosystems_client/api/user/revoke_user_api_key.py +25 -0
  61. robosystems_client/api/user/select_user_graph.py +25 -0
  62. robosystems_client/api/user/update_user.py +25 -0
  63. robosystems_client/api/user/update_user_api_key.py +25 -0
  64. robosystems_client/api/user/update_user_password.py +25 -0
  65. robosystems_client/api/user_analytics/get_detailed_user_analytics.py +20 -0
  66. robosystems_client/api/user_analytics/get_user_usage_overview.py +25 -0
  67. robosystems_client/api/user_limits/get_all_shared_repository_limits.py +25 -0
  68. robosystems_client/api/user_limits/get_shared_repository_limits.py +25 -0
  69. robosystems_client/api/user_limits/get_user_limits.py +25 -0
  70. robosystems_client/api/user_limits/get_user_usage.py +25 -0
  71. robosystems_client/api/user_subscriptions/cancel_shared_repository_subscription.py +25 -0
  72. robosystems_client/api/user_subscriptions/get_repository_credits.py +25 -0
  73. robosystems_client/api/user_subscriptions/get_shared_repository_credits.py +25 -0
  74. robosystems_client/api/user_subscriptions/get_user_shared_subscriptions.py +20 -0
  75. robosystems_client/api/user_subscriptions/subscribe_to_shared_repository.py +25 -0
  76. robosystems_client/api/user_subscriptions/upgrade_shared_repository_subscription.py +25 -0
  77. robosystems_client/extensions/__init__.py +70 -0
  78. robosystems_client/extensions/auth_integration.py +14 -1
  79. robosystems_client/extensions/copy_client.py +32 -22
  80. robosystems_client/extensions/dataframe_utils.py +455 -0
  81. robosystems_client/extensions/extensions.py +16 -0
  82. robosystems_client/extensions/operation_client.py +43 -21
  83. robosystems_client/extensions/query_client.py +109 -12
  84. robosystems_client/extensions/tests/test_dataframe_utils.py +334 -0
  85. robosystems_client/extensions/tests/test_integration.py +1 -1
  86. robosystems_client/extensions/tests/test_token_utils.py +274 -0
  87. robosystems_client/extensions/token_utils.py +417 -0
  88. robosystems_client/extensions/utils.py +32 -2
  89. {robosystems_client-0.1.17.dist-info → robosystems_client-0.1.18.dist-info}/METADATA +1 -1
  90. {robosystems_client-0.1.17.dist-info → robosystems_client-0.1.18.dist-info}/RECORD +92 -88
  91. {robosystems_client-0.1.17.dist-info → robosystems_client-0.1.18.dist-info}/WHEEL +0 -0
  92. {robosystems_client-0.1.17.dist-info → robosystems_client-0.1.18.dist-info}/licenses/LICENSE +0 -0
@@ -71,7 +71,14 @@ class OperationClient:
71
71
  def __init__(self, config: Dict[str, Any]):
72
72
  self.config = config
73
73
  self.base_url = config["base_url"]
74
+ self.headers = config.get("headers", {})
75
+ # Get token from config if passed by parent
76
+ self.token = config.get("token")
74
77
  self.active_operations: Dict[str, SSEClient] = {}
78
+ # Thread safety for operations tracking
79
+ import threading
80
+
81
+ self._lock = threading.Lock()
75
82
 
76
83
  def monitor_operation(
77
84
  self, operation_id: str, options: MonitorOptions = None
@@ -144,7 +151,8 @@ class OperationClient:
144
151
  # Connect and monitor
145
152
  try:
146
153
  sse_client.connect(operation_id)
147
- self.active_operations[operation_id] = sse_client
154
+ with self._lock:
155
+ self.active_operations[operation_id] = sse_client
148
156
 
149
157
  # Wait for completion
150
158
  import time
@@ -166,10 +174,11 @@ class OperationClient:
166
174
  time.sleep(options.poll_interval or 0.1)
167
175
 
168
176
  finally:
169
- # Clean up
170
- if operation_id in self.active_operations:
171
- self.active_operations[operation_id].close()
172
- del self.active_operations[operation_id]
177
+ # Clean up with thread safety
178
+ with self._lock:
179
+ if operation_id in self.active_operations:
180
+ self.active_operations[operation_id].close()
181
+ del self.active_operations[operation_id]
173
182
 
174
183
  return result
175
184
 
@@ -179,11 +188,16 @@ class OperationClient:
179
188
  from ..api.operations.get_operation_status import (
180
189
  sync_detailed as get_operation_status,
181
190
  )
182
- from ..client import AuthenticatedClient
191
+ from ..client import Client
183
192
 
184
- client = AuthenticatedClient(base_url=self.base_url)
193
+ # Use regular Client with headers instead of AuthenticatedClient
194
+ client = Client(base_url=self.base_url, headers=self.headers)
185
195
  try:
186
- response = get_operation_status(operation_id=operation_id, client=client)
196
+ kwargs = {"operation_id": operation_id, "client": client}
197
+ # Only add token if it's a valid string
198
+ if self.token and isinstance(self.token, str) and self.token.strip():
199
+ kwargs["token"] = self.token
200
+ response = get_operation_status(**kwargs)
187
201
  if response.parsed:
188
202
  return {
189
203
  "operation_id": operation_id,
@@ -201,21 +215,27 @@ class OperationClient:
201
215
  """Cancel an operation"""
202
216
  # This would use the generated SDK to call /v1/operations/{operation_id}/cancel
203
217
  from ..api.operations.cancel_operation import sync_detailed as cancel_operation
204
- from ..client import AuthenticatedClient
218
+ from ..client import Client
205
219
 
206
- client = AuthenticatedClient(base_url=self.base_url)
220
+ # Use regular Client with headers instead of AuthenticatedClient
221
+ client = Client(base_url=self.base_url, headers=self.headers)
207
222
  try:
208
- response = cancel_operation(operation_id=operation_id, client=client)
223
+ kwargs = {"operation_id": operation_id, "client": client}
224
+ # Only add token if it's a valid string
225
+ if self.token and isinstance(self.token, str) and self.token.strip():
226
+ kwargs["token"] = self.token
227
+ response = cancel_operation(**kwargs)
209
228
  if response.parsed:
210
229
  return response.parsed.cancelled or False
211
230
  except Exception as e:
212
231
  print(f"Failed to cancel operation {operation_id}: {e}")
213
232
  return False
214
233
 
215
- # Also close any active SSE connection
216
- if operation_id in self.active_operations:
217
- self.active_operations[operation_id].close()
218
- del self.active_operations[operation_id]
234
+ # Also close any active SSE connection with thread safety
235
+ with self._lock:
236
+ if operation_id in self.active_operations:
237
+ self.active_operations[operation_id].close()
238
+ del self.active_operations[operation_id]
219
239
 
220
240
  return False
221
241
 
@@ -226,15 +246,17 @@ class OperationClient:
226
246
 
227
247
  def close_all(self):
228
248
  """Close all active operation monitors"""
229
- for sse_client in self.active_operations.values():
230
- sse_client.close()
231
- self.active_operations.clear()
249
+ with self._lock:
250
+ for sse_client in self.active_operations.values():
251
+ sse_client.close()
252
+ self.active_operations.clear()
232
253
 
233
254
  def close_operation(self, operation_id: str):
234
255
  """Close monitoring for a specific operation"""
235
- if operation_id in self.active_operations:
236
- self.active_operations[operation_id].close()
237
- del self.active_operations[operation_id]
256
+ with self._lock:
257
+ if operation_id in self.active_operations:
258
+ self.active_operations[operation_id].close()
259
+ del self.active_operations[operation_id]
238
260
 
239
261
 
240
262
  class AsyncOperationClient:
@@ -4,7 +4,17 @@ Provides intelligent query execution with automatic strategy selection.
4
4
  """
5
5
 
6
6
  from dataclasses import dataclass
7
- from typing import Dict, Any, Optional, Callable, AsyncIterator, Iterator, Union
7
+ from typing import (
8
+ Dict,
9
+ Any,
10
+ Optional,
11
+ Callable,
12
+ AsyncIterator,
13
+ Iterator,
14
+ Union,
15
+ Generator,
16
+ List,
17
+ )
8
18
  from datetime import datetime
9
19
 
10
20
  from ..api.query.execute_cypher_query import sync_detailed as execute_cypher_query
@@ -70,6 +80,9 @@ class QueryClient:
70
80
  def __init__(self, config: Dict[str, Any]):
71
81
  self.config = config
72
82
  self.base_url = config["base_url"]
83
+ self.headers = config.get("headers", {})
84
+ # Get token from config if passed by parent
85
+ self.token = config.get("token")
73
86
  self.sse_client: Optional[SSEClient] = None
74
87
 
75
88
  def execute_query(
@@ -85,15 +98,17 @@ class QueryClient:
85
98
  )
86
99
 
87
100
  # Execute the query through the generated client
88
- from ..client import AuthenticatedClient
101
+ from ..client import Client
89
102
 
90
- # Get client instance (you'd configure this based on your setup)
91
- client = AuthenticatedClient(base_url=self.base_url)
103
+ # Create client with headers
104
+ client = Client(base_url=self.base_url, headers=self.headers)
92
105
 
93
106
  try:
94
- response = execute_cypher_query(
95
- graph_id=graph_id, client=client, body=query_request
96
- )
107
+ kwargs = {"graph_id": graph_id, "client": client, "body": query_request}
108
+ # Only add token if it's a valid string
109
+ if self.token and isinstance(self.token, str) and self.token.strip():
110
+ kwargs["token"] = self.token
111
+ response = execute_cypher_query(**kwargs)
97
112
 
98
113
  # Check response type and handle accordingly
99
114
  if hasattr(response, "parsed") and response.parsed:
@@ -145,7 +160,15 @@ class QueryClient:
145
160
  except Exception as e:
146
161
  if isinstance(e, QueuedQueryError):
147
162
  raise
148
- raise Exception(f"Query execution failed: {str(e)}")
163
+
164
+ error_msg = str(e)
165
+ # Check for authentication errors
166
+ if (
167
+ "401" in error_msg or "403" in error_msg or "unauthorized" in error_msg.lower()
168
+ ):
169
+ raise Exception(f"Authentication failed during query execution: {error_msg}")
170
+ else:
171
+ raise Exception(f"Query execution failed: {error_msg}")
149
172
 
150
173
  # Unexpected response format
151
174
  raise Exception("Unexpected response format from query endpoint")
@@ -316,18 +339,92 @@ class QueryClient:
316
339
  cypher: str,
317
340
  parameters: Dict[str, Any] = None,
318
341
  chunk_size: int = 1000,
319
- ) -> Iterator[Any]:
320
- """Streaming query for large results"""
342
+ on_progress: Optional[Callable[[int, int], None]] = None,
343
+ ) -> Generator[Any, None, None]:
344
+ """Stream query results for large datasets with progress tracking
345
+
346
+ Args:
347
+ graph_id: Graph ID to query
348
+ cypher: Cypher query string
349
+ parameters: Query parameters
350
+ chunk_size: Number of records per chunk
351
+ on_progress: Callback for progress updates (current, total)
352
+
353
+ Yields:
354
+ Individual records from query results
355
+
356
+ Example:
357
+ >>> def progress(current, total):
358
+ ... print(f"Processed {current}/{total} records")
359
+ >>> for record in query_client.stream_query(
360
+ ... 'graph_id',
361
+ ... 'MATCH (n) RETURN n',
362
+ ... chunk_size=100,
363
+ ... on_progress=progress
364
+ ... ):
365
+ ... process_record(record)
366
+ """
321
367
  request = QueryRequest(query=cypher, parameters=parameters)
322
368
  result = self.execute_query(
323
369
  graph_id, request, QueryOptions(mode="stream", chunk_size=chunk_size)
324
370
  )
325
371
 
372
+ count = 0
326
373
  if isinstance(result, Iterator):
327
- yield from result
374
+ for item in result:
375
+ count += 1
376
+ if on_progress and count % chunk_size == 0:
377
+ on_progress(count, None) # Total unknown in streaming
378
+ yield item
328
379
  else:
329
380
  # If not streaming, yield all results at once
330
- yield from result.data
381
+ total = len(result.data)
382
+ for item in result.data:
383
+ count += 1
384
+ if on_progress:
385
+ on_progress(count, total)
386
+ yield item
387
+
388
+ def query_batch(
389
+ self,
390
+ graph_id: str,
391
+ queries: List[str],
392
+ parameters_list: Optional[List[Dict[str, Any]]] = None,
393
+ parallel: bool = False,
394
+ ) -> List[Union[QueryResult, Dict[str, Any]]]:
395
+ """Execute multiple queries in batch
396
+
397
+ Args:
398
+ graph_id: Graph ID to query
399
+ queries: List of Cypher query strings
400
+ parameters_list: List of parameter dicts (one per query)
401
+ parallel: Execute queries in parallel (experimental)
402
+
403
+ Returns:
404
+ List of QueryResult objects or error dicts
405
+
406
+ Example:
407
+ >>> results = query_client.query_batch('graph_id', [
408
+ ... 'MATCH (n:Person) RETURN count(n)',
409
+ ... 'MATCH (c:Company) RETURN count(c)'
410
+ ... ])
411
+ """
412
+ if parameters_list is None:
413
+ parameters_list = [None] * len(queries)
414
+
415
+ if len(queries) != len(parameters_list):
416
+ raise ValueError("queries and parameters_list must have same length")
417
+
418
+ results = []
419
+ for query, params in zip(queries, parameters_list):
420
+ try:
421
+ result = self.query(graph_id, query, params)
422
+ results.append(result)
423
+ except Exception as e:
424
+ # Store error as result
425
+ results.append({"error": str(e), "query": query})
426
+
427
+ return results
331
428
 
332
429
  def close(self):
333
430
  """Cancel any active SSE connections"""
@@ -0,0 +1,334 @@
1
+ """Tests for DataFrame utilities"""
2
+
3
+ import pytest
4
+ from unittest.mock import Mock, patch
5
+ import tempfile
6
+ import os
7
+
8
+ # Make pandas optional for tests
9
+ try:
10
+ import pandas as pd
11
+
12
+ HAS_PANDAS = True
13
+ except ImportError:
14
+ HAS_PANDAS = False
15
+ pd = None
16
+
17
+ # Only run tests if pandas is available
18
+ pytestmark = pytest.mark.skipif(not HAS_PANDAS, reason="pandas not installed")
19
+
20
+ if HAS_PANDAS:
21
+ from robosystems_client.extensions.dataframe_utils import (
22
+ query_result_to_dataframe,
23
+ parse_datetime_columns,
24
+ stream_to_dataframe,
25
+ dataframe_to_cypher_params,
26
+ export_query_to_csv,
27
+ compare_dataframes,
28
+ DataFrameQueryClient,
29
+ )
30
+
31
+
32
+ class TestQueryResultToDataFrame:
33
+ """Test converting query results to DataFrames"""
34
+
35
+ def test_query_result_to_dataframe_basic(self):
36
+ """Test basic conversion from query result to DataFrame"""
37
+ result = {
38
+ "data": [
39
+ {"name": "Alice", "age": 30},
40
+ {"name": "Bob", "age": 25},
41
+ {"name": "Charlie", "age": 35},
42
+ ],
43
+ "columns": ["name", "age"],
44
+ "row_count": 3,
45
+ }
46
+
47
+ df = query_result_to_dataframe(result)
48
+
49
+ assert len(df) == 3
50
+ assert list(df.columns) == ["name", "age"]
51
+ assert df.iloc[0]["name"] == "Alice"
52
+ assert df.iloc[1]["age"] == 25
53
+
54
+ def test_query_result_to_dataframe_nested(self):
55
+ """Test conversion with nested data"""
56
+ result = {
57
+ "data": [
58
+ {"name": "Alice", "company": {"name": "TechCorp", "revenue": 1000000}},
59
+ {"name": "Bob", "company": {"name": "StartupInc", "revenue": 500000}},
60
+ ],
61
+ "columns": ["name", "company"],
62
+ }
63
+
64
+ df = query_result_to_dataframe(result, normalize_nested=True)
65
+
66
+ assert "name" in df.columns
67
+ assert "company.name" in df.columns
68
+ assert "company.revenue" in df.columns
69
+ assert df.iloc[0]["company.name"] == "TechCorp"
70
+
71
+ def test_query_result_to_dataframe_empty(self):
72
+ """Test conversion of empty result"""
73
+ result = {"data": [], "columns": ["name", "age"], "row_count": 0}
74
+
75
+ df = query_result_to_dataframe(result)
76
+
77
+ assert len(df) == 0
78
+ assert list(df.columns) == ["name", "age"]
79
+
80
+ def test_query_result_to_dataframe_with_dates(self):
81
+ """Test conversion with date parsing"""
82
+ result = {
83
+ "data": [
84
+ {"name": "Alice", "created_at": "2023-01-15T10:30:00"},
85
+ {"name": "Bob", "created_at": "2023-02-20T14:45:00"},
86
+ ],
87
+ "columns": ["name", "created_at"],
88
+ }
89
+
90
+ df = query_result_to_dataframe(result, parse_dates=True)
91
+
92
+ assert pd.api.types.is_datetime64_any_dtype(df["created_at"])
93
+ assert df.iloc[0]["created_at"].year == 2023
94
+
95
+
96
+ class TestParseDateTimeColumns:
97
+ """Test datetime parsing functionality"""
98
+
99
+ def test_parse_datetime_columns_specific(self):
100
+ """Test parsing specific datetime columns"""
101
+ df = pd.DataFrame(
102
+ {
103
+ "name": ["Alice", "Bob"],
104
+ "created_at": ["2023-01-15", "2023-02-20"],
105
+ "updated_at": ["2023-01-16T10:30:00", "2023-02-21T14:45:00"],
106
+ "count": [1, 2],
107
+ }
108
+ )
109
+
110
+ df = parse_datetime_columns(df, date_columns=["created_at", "updated_at"])
111
+
112
+ assert pd.api.types.is_datetime64_any_dtype(df["created_at"])
113
+ assert pd.api.types.is_datetime64_any_dtype(df["updated_at"])
114
+ assert not pd.api.types.is_datetime64_any_dtype(df["count"])
115
+
116
+ def test_parse_datetime_columns_infer(self):
117
+ """Test automatic datetime column inference"""
118
+ df = pd.DataFrame(
119
+ {
120
+ "name": ["Alice", "Bob"],
121
+ "timestamp": ["2023-01-15T10:30:00", "2023-02-20T14:45:00"],
122
+ "not_a_date": ["abc", "def"],
123
+ }
124
+ )
125
+
126
+ df = parse_datetime_columns(df, infer=True)
127
+
128
+ assert pd.api.types.is_datetime64_any_dtype(df["timestamp"])
129
+ assert df["not_a_date"].dtype == "object"
130
+
131
+
132
+ class TestStreamToDataFrame:
133
+ """Test streaming results to DataFrame"""
134
+
135
+ def test_stream_to_dataframe_basic(self):
136
+ """Test converting stream to DataFrame"""
137
+
138
+ def mock_stream():
139
+ for i in range(10):
140
+ yield {"id": i, "value": i * 2}
141
+
142
+ df = stream_to_dataframe(mock_stream(), chunk_size=3)
143
+
144
+ assert len(df) == 10
145
+ assert df.iloc[5]["value"] == 10
146
+
147
+ def test_stream_to_dataframe_with_callback(self):
148
+ """Test stream with chunk callback"""
149
+ chunk_counts = []
150
+
151
+ def on_chunk(chunk_df, total):
152
+ chunk_counts.append(len(chunk_df))
153
+
154
+ def mock_stream():
155
+ for i in range(10):
156
+ yield {"id": i, "value": i * 2}
157
+
158
+ df = stream_to_dataframe(mock_stream(), chunk_size=3, on_chunk=on_chunk)
159
+
160
+ assert len(df) == 10
161
+ assert chunk_counts == [3, 3, 3, 1] # 3 chunks of 3, 1 chunk of 1
162
+
163
+
164
+ class TestDataFrameToCypherParams:
165
+ """Test DataFrame to Cypher parameter conversion"""
166
+
167
+ def test_dataframe_to_cypher_params(self):
168
+ """Test converting DataFrame to Cypher parameters"""
169
+ df = pd.DataFrame(
170
+ {
171
+ "name": ["Alice", "Bob", "Charlie"],
172
+ "age": [30, 25, 35],
173
+ "active": [True, False, True],
174
+ }
175
+ )
176
+
177
+ params = dataframe_to_cypher_params(df)
178
+
179
+ assert "data" in params
180
+ assert len(params["data"]) == 3
181
+ assert params["data"][0]["name"] == "Alice"
182
+ assert params["data"][1]["age"] == 25
183
+
184
+ def test_dataframe_to_cypher_params_with_nan(self):
185
+ """Test handling NaN values"""
186
+ df = pd.DataFrame(
187
+ {"name": ["Alice", "Bob"], "age": [30, pd.NA], "score": [95.5, None]}
188
+ )
189
+
190
+ params = dataframe_to_cypher_params(df, param_name="records")
191
+
192
+ assert "records" in params
193
+ assert params["records"][1]["age"] is None
194
+ assert params["records"][1]["score"] is None
195
+
196
+
197
+ class TestExportQueryToCSV:
198
+ """Test CSV export functionality"""
199
+
200
+ @patch("robosystems_client.extensions.dataframe_utils.logger")
201
+ def test_export_query_to_csv(self, mock_logger):
202
+ """Test exporting query results to CSV"""
203
+ mock_client = Mock()
204
+
205
+ def mock_stream(*args, **kwargs):
206
+ for i in range(5):
207
+ yield {"id": i, "name": f"Item {i}"}
208
+
209
+ mock_client.stream_query = Mock(side_effect=mock_stream)
210
+
211
+ with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".csv") as f:
212
+ temp_file = f.name
213
+
214
+ try:
215
+ count = export_query_to_csv(
216
+ mock_client, "graph_id", "MATCH (n) RETURN n", temp_file, chunk_size=2
217
+ )
218
+
219
+ assert count == 5
220
+ mock_logger.info.assert_called()
221
+
222
+ # Verify CSV content
223
+ df = pd.read_csv(temp_file)
224
+ assert len(df) == 5
225
+ assert df.iloc[0]["name"] == "Item 0"
226
+
227
+ finally:
228
+ if os.path.exists(temp_file):
229
+ os.unlink(temp_file)
230
+
231
+
232
+ class TestCompareDataFrames:
233
+ """Test DataFrame comparison"""
234
+
235
+ def test_compare_dataframes_with_keys(self):
236
+ """Test comparing DataFrames with key columns"""
237
+ df1 = pd.DataFrame(
238
+ {"id": [1, 2, 3], "name": ["Alice", "Bob", "Charlie"], "age": [30, 25, 35]}
239
+ )
240
+
241
+ df2 = pd.DataFrame(
242
+ {"id": [1, 2, 4], "name": ["Alice", "Robert", "David"], "age": [31, 25, 40]}
243
+ )
244
+
245
+ diff = compare_dataframes(df1, df2, key_columns=["id"])
246
+
247
+ assert "_merge" in diff.columns
248
+ assert "name_old" in diff.columns
249
+ assert "name_new" in diff.columns
250
+
251
+ def test_compare_dataframes_without_keys(self):
252
+ """Test comparing DataFrames without keys"""
253
+ df1 = pd.DataFrame({"name": ["Alice", "Bob"], "age": [30, 25]})
254
+
255
+ df2 = pd.DataFrame({"name": ["Alice", "Charlie"], "age": [30, 35]})
256
+
257
+ diff = compare_dataframes(df1, df2)
258
+
259
+ assert len(diff) == 2 # Bob and Charlie rows
260
+
261
+
262
+ class TestDataFrameQueryClient:
263
+ """Test DataFrameQueryClient class"""
264
+
265
+ def test_query_df(self):
266
+ """Test query_df method"""
267
+ mock_client = Mock()
268
+ mock_client.query.return_value = {
269
+ "data": [{"name": "Alice"}, {"name": "Bob"}],
270
+ "columns": ["name"],
271
+ "row_count": 2,
272
+ }
273
+
274
+ df_client = DataFrameQueryClient(mock_client)
275
+ df = df_client.query_df("graph_id", "MATCH (n) RETURN n")
276
+
277
+ assert len(df) == 2
278
+ assert df.iloc[0]["name"] == "Alice"
279
+ mock_client.query.assert_called_once()
280
+
281
+ def test_stream_df(self):
282
+ """Test stream_df method"""
283
+ mock_client = Mock()
284
+
285
+ def mock_stream(*args, **kwargs):
286
+ for i in range(3):
287
+ yield {"id": i, "value": i * 10}
288
+
289
+ mock_client.stream_query.return_value = mock_stream()
290
+
291
+ df_client = DataFrameQueryClient(mock_client)
292
+ df = df_client.stream_df("graph_id", "MATCH (n) RETURN n")
293
+
294
+ assert len(df) == 3
295
+ assert df.iloc[1]["value"] == 10
296
+
297
+ def test_query_batch_df(self):
298
+ """Test query_batch_df method"""
299
+ mock_client = Mock()
300
+ mock_client.query_batch.return_value = [
301
+ {"data": [{"count": 10}], "columns": ["count"]},
302
+ {"data": [{"count": 20}], "columns": ["count"]},
303
+ {"error": "Query failed", "query": "INVALID"},
304
+ ]
305
+
306
+ df_client = DataFrameQueryClient(mock_client)
307
+ dfs = df_client.query_batch_df(
308
+ "graph_id",
309
+ [
310
+ "MATCH (p:Person) RETURN count(p)",
311
+ "MATCH (c:Company) RETURN count(c)",
312
+ "INVALID",
313
+ ],
314
+ )
315
+
316
+ assert len(dfs) == 3
317
+ assert dfs[0].iloc[0]["count"] == 10
318
+ assert dfs[1].iloc[0]["count"] == 20
319
+ assert "error" in dfs[2].columns
320
+
321
+ def test_export_to_csv(self):
322
+ """Test export_to_csv method"""
323
+ mock_client = Mock()
324
+
325
+ with patch(
326
+ "robosystems_client.extensions.dataframe_utils.export_query_to_csv"
327
+ ) as mock_export:
328
+ mock_export.return_value = 100
329
+
330
+ df_client = DataFrameQueryClient(mock_client)
331
+ count = df_client.export_to_csv("graph_id", "MATCH (n) RETURN n", "output.csv")
332
+
333
+ assert count == 100
334
+ mock_export.assert_called_once()
@@ -71,7 +71,7 @@ class TestAuthenticatedIntegration:
71
71
  )
72
72
  assert ext.config["headers"]["Authorization"] == "Bearer jwt_token_here"
73
73
 
74
- @patch("robosystems_client.extensions.auth_integration.sync_detailed")
74
+ @patch("robosystems_client.api.query.execute_cypher_query.sync_detailed")
75
75
  def test_cypher_query_execution(self, mock_sync_detailed, extensions):
76
76
  """Test executing Cypher queries through authenticated client"""
77
77
  # Mock the response