robosystems-client 0.1.17__py3-none-any.whl → 0.1.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of robosystems-client might be problematic. Click here for more details.
- robosystems_client/__init__.py +15 -4
- robosystems_client/api/agent/auto_select_agent.py +25 -0
- robosystems_client/api/agent/batch_process_queries.py +25 -0
- robosystems_client/api/agent/execute_specific_agent.py +25 -0
- robosystems_client/api/agent/get_agent_metadata.py +25 -0
- robosystems_client/api/agent/list_agents.py +20 -0
- robosystems_client/api/agent/recommend_agent.py +25 -0
- robosystems_client/api/backup/create_backup.py +25 -0
- robosystems_client/api/backup/export_backup.py +25 -0
- robosystems_client/api/backup/get_backup_download_url.py +20 -0
- robosystems_client/api/backup/get_backup_stats.py +25 -0
- robosystems_client/api/backup/list_backups.py +20 -0
- robosystems_client/api/backup/restore_backup.py +25 -0
- robosystems_client/api/connections/create_connection.py +25 -0
- robosystems_client/api/connections/create_link_token.py +25 -0
- robosystems_client/api/connections/delete_connection.py +25 -0
- robosystems_client/api/connections/exchange_link_token.py +25 -0
- robosystems_client/api/connections/get_connection.py +25 -0
- robosystems_client/api/connections/get_connection_options.py +25 -0
- robosystems_client/api/connections/init_o_auth.py +25 -0
- robosystems_client/api/connections/list_connections.py +20 -0
- robosystems_client/api/connections/oauth_callback.py +25 -0
- robosystems_client/api/connections/sync_connection.py +25 -0
- robosystems_client/api/copy/copy_data_to_graph.py +25 -0
- robosystems_client/api/create/create_graph.py +25 -0
- robosystems_client/api/graph_analytics/get_graph_metrics.py +25 -0
- robosystems_client/api/graph_analytics/get_graph_usage_stats.py +20 -0
- robosystems_client/api/graph_billing/get_current_graph_bill.py +25 -0
- robosystems_client/api/graph_billing/get_graph_billing_history.py +20 -0
- robosystems_client/api/graph_billing/get_graph_monthly_bill.py +25 -0
- robosystems_client/api/graph_billing/get_graph_usage_details.py +20 -0
- robosystems_client/api/graph_credits/check_credit_balance.py +20 -0
- robosystems_client/api/graph_credits/check_storage_limits.py +25 -0
- robosystems_client/api/graph_credits/get_credit_summary.py +25 -0
- robosystems_client/api/graph_credits/get_storage_usage.py +20 -0
- robosystems_client/api/graph_credits/list_credit_transactions.py +20 -0
- robosystems_client/api/graph_health/get_database_health.py +25 -0
- robosystems_client/api/graph_info/get_database_info.py +25 -0
- robosystems_client/api/graph_limits/get_graph_limits.py +25 -0
- robosystems_client/api/mcp/call_mcp_tool.py +20 -0
- robosystems_client/api/mcp/list_mcp_tools.py +25 -0
- robosystems_client/api/operations/cancel_operation.py +25 -0
- robosystems_client/api/operations/get_operation_status.py +25 -0
- robosystems_client/api/operations/stream_operation_events.py +20 -0
- robosystems_client/api/query/execute_cypher_query.py +20 -0
- robosystems_client/api/schema/export_graph_schema.py +20 -0
- robosystems_client/api/schema/get_graph_schema_info.py +25 -0
- robosystems_client/api/schema/list_schema_extensions.py +25 -0
- robosystems_client/api/schema/validate_schema.py +25 -0
- robosystems_client/api/subgraphs/create_subgraph.py +25 -0
- robosystems_client/api/subgraphs/delete_subgraph.py +25 -0
- robosystems_client/api/subgraphs/get_subgraph_info.py +25 -0
- robosystems_client/api/subgraphs/get_subgraph_quota.py +25 -0
- robosystems_client/api/subgraphs/list_subgraphs.py +25 -0
- robosystems_client/api/user/create_user_api_key.py +25 -0
- robosystems_client/api/user/get_all_credit_summaries.py +25 -0
- robosystems_client/api/user/get_current_user.py +25 -0
- robosystems_client/api/user/get_user_graphs.py +25 -0
- robosystems_client/api/user/list_user_api_keys.py +25 -0
- robosystems_client/api/user/revoke_user_api_key.py +25 -0
- robosystems_client/api/user/select_user_graph.py +25 -0
- robosystems_client/api/user/update_user.py +25 -0
- robosystems_client/api/user/update_user_api_key.py +25 -0
- robosystems_client/api/user/update_user_password.py +25 -0
- robosystems_client/api/user_analytics/get_detailed_user_analytics.py +20 -0
- robosystems_client/api/user_analytics/get_user_usage_overview.py +25 -0
- robosystems_client/api/user_limits/get_all_shared_repository_limits.py +25 -0
- robosystems_client/api/user_limits/get_shared_repository_limits.py +25 -0
- robosystems_client/api/user_limits/get_user_limits.py +25 -0
- robosystems_client/api/user_limits/get_user_usage.py +25 -0
- robosystems_client/api/user_subscriptions/cancel_shared_repository_subscription.py +25 -0
- robosystems_client/api/user_subscriptions/get_repository_credits.py +25 -0
- robosystems_client/api/user_subscriptions/get_shared_repository_credits.py +25 -0
- robosystems_client/api/user_subscriptions/get_user_shared_subscriptions.py +20 -0
- robosystems_client/api/user_subscriptions/subscribe_to_shared_repository.py +25 -0
- robosystems_client/api/user_subscriptions/upgrade_shared_repository_subscription.py +25 -0
- robosystems_client/extensions/__init__.py +70 -0
- robosystems_client/extensions/auth_integration.py +14 -1
- robosystems_client/extensions/copy_client.py +32 -22
- robosystems_client/extensions/dataframe_utils.py +455 -0
- robosystems_client/extensions/extensions.py +16 -0
- robosystems_client/extensions/operation_client.py +43 -21
- robosystems_client/extensions/query_client.py +109 -12
- robosystems_client/extensions/tests/test_dataframe_utils.py +334 -0
- robosystems_client/extensions/tests/test_integration.py +1 -1
- robosystems_client/extensions/tests/test_token_utils.py +274 -0
- robosystems_client/extensions/token_utils.py +417 -0
- robosystems_client/extensions/utils.py +32 -2
- {robosystems_client-0.1.17.dist-info → robosystems_client-0.1.18.dist-info}/METADATA +1 -1
- {robosystems_client-0.1.17.dist-info → robosystems_client-0.1.18.dist-info}/RECORD +92 -88
- {robosystems_client-0.1.17.dist-info → robosystems_client-0.1.18.dist-info}/WHEEL +0 -0
- {robosystems_client-0.1.17.dist-info → robosystems_client-0.1.18.dist-info}/licenses/LICENSE +0 -0
|
@@ -71,7 +71,14 @@ class OperationClient:
|
|
|
71
71
|
def __init__(self, config: Dict[str, Any]):
|
|
72
72
|
self.config = config
|
|
73
73
|
self.base_url = config["base_url"]
|
|
74
|
+
self.headers = config.get("headers", {})
|
|
75
|
+
# Get token from config if passed by parent
|
|
76
|
+
self.token = config.get("token")
|
|
74
77
|
self.active_operations: Dict[str, SSEClient] = {}
|
|
78
|
+
# Thread safety for operations tracking
|
|
79
|
+
import threading
|
|
80
|
+
|
|
81
|
+
self._lock = threading.Lock()
|
|
75
82
|
|
|
76
83
|
def monitor_operation(
|
|
77
84
|
self, operation_id: str, options: MonitorOptions = None
|
|
@@ -144,7 +151,8 @@ class OperationClient:
|
|
|
144
151
|
# Connect and monitor
|
|
145
152
|
try:
|
|
146
153
|
sse_client.connect(operation_id)
|
|
147
|
-
self.
|
|
154
|
+
with self._lock:
|
|
155
|
+
self.active_operations[operation_id] = sse_client
|
|
148
156
|
|
|
149
157
|
# Wait for completion
|
|
150
158
|
import time
|
|
@@ -166,10 +174,11 @@ class OperationClient:
|
|
|
166
174
|
time.sleep(options.poll_interval or 0.1)
|
|
167
175
|
|
|
168
176
|
finally:
|
|
169
|
-
# Clean up
|
|
170
|
-
|
|
171
|
-
self.active_operations
|
|
172
|
-
|
|
177
|
+
# Clean up with thread safety
|
|
178
|
+
with self._lock:
|
|
179
|
+
if operation_id in self.active_operations:
|
|
180
|
+
self.active_operations[operation_id].close()
|
|
181
|
+
del self.active_operations[operation_id]
|
|
173
182
|
|
|
174
183
|
return result
|
|
175
184
|
|
|
@@ -179,11 +188,16 @@ class OperationClient:
|
|
|
179
188
|
from ..api.operations.get_operation_status import (
|
|
180
189
|
sync_detailed as get_operation_status,
|
|
181
190
|
)
|
|
182
|
-
from ..client import
|
|
191
|
+
from ..client import Client
|
|
183
192
|
|
|
184
|
-
|
|
193
|
+
# Use regular Client with headers instead of AuthenticatedClient
|
|
194
|
+
client = Client(base_url=self.base_url, headers=self.headers)
|
|
185
195
|
try:
|
|
186
|
-
|
|
196
|
+
kwargs = {"operation_id": operation_id, "client": client}
|
|
197
|
+
# Only add token if it's a valid string
|
|
198
|
+
if self.token and isinstance(self.token, str) and self.token.strip():
|
|
199
|
+
kwargs["token"] = self.token
|
|
200
|
+
response = get_operation_status(**kwargs)
|
|
187
201
|
if response.parsed:
|
|
188
202
|
return {
|
|
189
203
|
"operation_id": operation_id,
|
|
@@ -201,21 +215,27 @@ class OperationClient:
|
|
|
201
215
|
"""Cancel an operation"""
|
|
202
216
|
# This would use the generated SDK to call /v1/operations/{operation_id}/cancel
|
|
203
217
|
from ..api.operations.cancel_operation import sync_detailed as cancel_operation
|
|
204
|
-
from ..client import
|
|
218
|
+
from ..client import Client
|
|
205
219
|
|
|
206
|
-
|
|
220
|
+
# Use regular Client with headers instead of AuthenticatedClient
|
|
221
|
+
client = Client(base_url=self.base_url, headers=self.headers)
|
|
207
222
|
try:
|
|
208
|
-
|
|
223
|
+
kwargs = {"operation_id": operation_id, "client": client}
|
|
224
|
+
# Only add token if it's a valid string
|
|
225
|
+
if self.token and isinstance(self.token, str) and self.token.strip():
|
|
226
|
+
kwargs["token"] = self.token
|
|
227
|
+
response = cancel_operation(**kwargs)
|
|
209
228
|
if response.parsed:
|
|
210
229
|
return response.parsed.cancelled or False
|
|
211
230
|
except Exception as e:
|
|
212
231
|
print(f"Failed to cancel operation {operation_id}: {e}")
|
|
213
232
|
return False
|
|
214
233
|
|
|
215
|
-
# Also close any active SSE connection
|
|
216
|
-
|
|
217
|
-
self.active_operations
|
|
218
|
-
|
|
234
|
+
# Also close any active SSE connection with thread safety
|
|
235
|
+
with self._lock:
|
|
236
|
+
if operation_id in self.active_operations:
|
|
237
|
+
self.active_operations[operation_id].close()
|
|
238
|
+
del self.active_operations[operation_id]
|
|
219
239
|
|
|
220
240
|
return False
|
|
221
241
|
|
|
@@ -226,15 +246,17 @@ class OperationClient:
|
|
|
226
246
|
|
|
227
247
|
def close_all(self):
|
|
228
248
|
"""Close all active operation monitors"""
|
|
229
|
-
|
|
230
|
-
sse_client.
|
|
231
|
-
|
|
249
|
+
with self._lock:
|
|
250
|
+
for sse_client in self.active_operations.values():
|
|
251
|
+
sse_client.close()
|
|
252
|
+
self.active_operations.clear()
|
|
232
253
|
|
|
233
254
|
def close_operation(self, operation_id: str):
|
|
234
255
|
"""Close monitoring for a specific operation"""
|
|
235
|
-
|
|
236
|
-
self.active_operations
|
|
237
|
-
|
|
256
|
+
with self._lock:
|
|
257
|
+
if operation_id in self.active_operations:
|
|
258
|
+
self.active_operations[operation_id].close()
|
|
259
|
+
del self.active_operations[operation_id]
|
|
238
260
|
|
|
239
261
|
|
|
240
262
|
class AsyncOperationClient:
|
|
@@ -4,7 +4,17 @@ Provides intelligent query execution with automatic strategy selection.
|
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
6
|
from dataclasses import dataclass
|
|
7
|
-
from typing import
|
|
7
|
+
from typing import (
|
|
8
|
+
Dict,
|
|
9
|
+
Any,
|
|
10
|
+
Optional,
|
|
11
|
+
Callable,
|
|
12
|
+
AsyncIterator,
|
|
13
|
+
Iterator,
|
|
14
|
+
Union,
|
|
15
|
+
Generator,
|
|
16
|
+
List,
|
|
17
|
+
)
|
|
8
18
|
from datetime import datetime
|
|
9
19
|
|
|
10
20
|
from ..api.query.execute_cypher_query import sync_detailed as execute_cypher_query
|
|
@@ -70,6 +80,9 @@ class QueryClient:
|
|
|
70
80
|
def __init__(self, config: Dict[str, Any]):
|
|
71
81
|
self.config = config
|
|
72
82
|
self.base_url = config["base_url"]
|
|
83
|
+
self.headers = config.get("headers", {})
|
|
84
|
+
# Get token from config if passed by parent
|
|
85
|
+
self.token = config.get("token")
|
|
73
86
|
self.sse_client: Optional[SSEClient] = None
|
|
74
87
|
|
|
75
88
|
def execute_query(
|
|
@@ -85,15 +98,17 @@ class QueryClient:
|
|
|
85
98
|
)
|
|
86
99
|
|
|
87
100
|
# Execute the query through the generated client
|
|
88
|
-
from ..client import
|
|
101
|
+
from ..client import Client
|
|
89
102
|
|
|
90
|
-
#
|
|
91
|
-
client =
|
|
103
|
+
# Create client with headers
|
|
104
|
+
client = Client(base_url=self.base_url, headers=self.headers)
|
|
92
105
|
|
|
93
106
|
try:
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
)
|
|
107
|
+
kwargs = {"graph_id": graph_id, "client": client, "body": query_request}
|
|
108
|
+
# Only add token if it's a valid string
|
|
109
|
+
if self.token and isinstance(self.token, str) and self.token.strip():
|
|
110
|
+
kwargs["token"] = self.token
|
|
111
|
+
response = execute_cypher_query(**kwargs)
|
|
97
112
|
|
|
98
113
|
# Check response type and handle accordingly
|
|
99
114
|
if hasattr(response, "parsed") and response.parsed:
|
|
@@ -145,7 +160,15 @@ class QueryClient:
|
|
|
145
160
|
except Exception as e:
|
|
146
161
|
if isinstance(e, QueuedQueryError):
|
|
147
162
|
raise
|
|
148
|
-
|
|
163
|
+
|
|
164
|
+
error_msg = str(e)
|
|
165
|
+
# Check for authentication errors
|
|
166
|
+
if (
|
|
167
|
+
"401" in error_msg or "403" in error_msg or "unauthorized" in error_msg.lower()
|
|
168
|
+
):
|
|
169
|
+
raise Exception(f"Authentication failed during query execution: {error_msg}")
|
|
170
|
+
else:
|
|
171
|
+
raise Exception(f"Query execution failed: {error_msg}")
|
|
149
172
|
|
|
150
173
|
# Unexpected response format
|
|
151
174
|
raise Exception("Unexpected response format from query endpoint")
|
|
@@ -316,18 +339,92 @@ class QueryClient:
|
|
|
316
339
|
cypher: str,
|
|
317
340
|
parameters: Dict[str, Any] = None,
|
|
318
341
|
chunk_size: int = 1000,
|
|
319
|
-
|
|
320
|
-
|
|
342
|
+
on_progress: Optional[Callable[[int, int], None]] = None,
|
|
343
|
+
) -> Generator[Any, None, None]:
|
|
344
|
+
"""Stream query results for large datasets with progress tracking
|
|
345
|
+
|
|
346
|
+
Args:
|
|
347
|
+
graph_id: Graph ID to query
|
|
348
|
+
cypher: Cypher query string
|
|
349
|
+
parameters: Query parameters
|
|
350
|
+
chunk_size: Number of records per chunk
|
|
351
|
+
on_progress: Callback for progress updates (current, total)
|
|
352
|
+
|
|
353
|
+
Yields:
|
|
354
|
+
Individual records from query results
|
|
355
|
+
|
|
356
|
+
Example:
|
|
357
|
+
>>> def progress(current, total):
|
|
358
|
+
... print(f"Processed {current}/{total} records")
|
|
359
|
+
>>> for record in query_client.stream_query(
|
|
360
|
+
... 'graph_id',
|
|
361
|
+
... 'MATCH (n) RETURN n',
|
|
362
|
+
... chunk_size=100,
|
|
363
|
+
... on_progress=progress
|
|
364
|
+
... ):
|
|
365
|
+
... process_record(record)
|
|
366
|
+
"""
|
|
321
367
|
request = QueryRequest(query=cypher, parameters=parameters)
|
|
322
368
|
result = self.execute_query(
|
|
323
369
|
graph_id, request, QueryOptions(mode="stream", chunk_size=chunk_size)
|
|
324
370
|
)
|
|
325
371
|
|
|
372
|
+
count = 0
|
|
326
373
|
if isinstance(result, Iterator):
|
|
327
|
-
|
|
374
|
+
for item in result:
|
|
375
|
+
count += 1
|
|
376
|
+
if on_progress and count % chunk_size == 0:
|
|
377
|
+
on_progress(count, None) # Total unknown in streaming
|
|
378
|
+
yield item
|
|
328
379
|
else:
|
|
329
380
|
# If not streaming, yield all results at once
|
|
330
|
-
|
|
381
|
+
total = len(result.data)
|
|
382
|
+
for item in result.data:
|
|
383
|
+
count += 1
|
|
384
|
+
if on_progress:
|
|
385
|
+
on_progress(count, total)
|
|
386
|
+
yield item
|
|
387
|
+
|
|
388
|
+
def query_batch(
|
|
389
|
+
self,
|
|
390
|
+
graph_id: str,
|
|
391
|
+
queries: List[str],
|
|
392
|
+
parameters_list: Optional[List[Dict[str, Any]]] = None,
|
|
393
|
+
parallel: bool = False,
|
|
394
|
+
) -> List[Union[QueryResult, Dict[str, Any]]]:
|
|
395
|
+
"""Execute multiple queries in batch
|
|
396
|
+
|
|
397
|
+
Args:
|
|
398
|
+
graph_id: Graph ID to query
|
|
399
|
+
queries: List of Cypher query strings
|
|
400
|
+
parameters_list: List of parameter dicts (one per query)
|
|
401
|
+
parallel: Execute queries in parallel (experimental)
|
|
402
|
+
|
|
403
|
+
Returns:
|
|
404
|
+
List of QueryResult objects or error dicts
|
|
405
|
+
|
|
406
|
+
Example:
|
|
407
|
+
>>> results = query_client.query_batch('graph_id', [
|
|
408
|
+
... 'MATCH (n:Person) RETURN count(n)',
|
|
409
|
+
... 'MATCH (c:Company) RETURN count(c)'
|
|
410
|
+
... ])
|
|
411
|
+
"""
|
|
412
|
+
if parameters_list is None:
|
|
413
|
+
parameters_list = [None] * len(queries)
|
|
414
|
+
|
|
415
|
+
if len(queries) != len(parameters_list):
|
|
416
|
+
raise ValueError("queries and parameters_list must have same length")
|
|
417
|
+
|
|
418
|
+
results = []
|
|
419
|
+
for query, params in zip(queries, parameters_list):
|
|
420
|
+
try:
|
|
421
|
+
result = self.query(graph_id, query, params)
|
|
422
|
+
results.append(result)
|
|
423
|
+
except Exception as e:
|
|
424
|
+
# Store error as result
|
|
425
|
+
results.append({"error": str(e), "query": query})
|
|
426
|
+
|
|
427
|
+
return results
|
|
331
428
|
|
|
332
429
|
def close(self):
|
|
333
430
|
"""Cancel any active SSE connections"""
|
|
@@ -0,0 +1,334 @@
|
|
|
1
|
+
"""Tests for DataFrame utilities"""
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
from unittest.mock import Mock, patch
|
|
5
|
+
import tempfile
|
|
6
|
+
import os
|
|
7
|
+
|
|
8
|
+
# Make pandas optional for tests
|
|
9
|
+
try:
|
|
10
|
+
import pandas as pd
|
|
11
|
+
|
|
12
|
+
HAS_PANDAS = True
|
|
13
|
+
except ImportError:
|
|
14
|
+
HAS_PANDAS = False
|
|
15
|
+
pd = None
|
|
16
|
+
|
|
17
|
+
# Only run tests if pandas is available
|
|
18
|
+
pytestmark = pytest.mark.skipif(not HAS_PANDAS, reason="pandas not installed")
|
|
19
|
+
|
|
20
|
+
if HAS_PANDAS:
|
|
21
|
+
from robosystems_client.extensions.dataframe_utils import (
|
|
22
|
+
query_result_to_dataframe,
|
|
23
|
+
parse_datetime_columns,
|
|
24
|
+
stream_to_dataframe,
|
|
25
|
+
dataframe_to_cypher_params,
|
|
26
|
+
export_query_to_csv,
|
|
27
|
+
compare_dataframes,
|
|
28
|
+
DataFrameQueryClient,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class TestQueryResultToDataFrame:
|
|
33
|
+
"""Test converting query results to DataFrames"""
|
|
34
|
+
|
|
35
|
+
def test_query_result_to_dataframe_basic(self):
|
|
36
|
+
"""Test basic conversion from query result to DataFrame"""
|
|
37
|
+
result = {
|
|
38
|
+
"data": [
|
|
39
|
+
{"name": "Alice", "age": 30},
|
|
40
|
+
{"name": "Bob", "age": 25},
|
|
41
|
+
{"name": "Charlie", "age": 35},
|
|
42
|
+
],
|
|
43
|
+
"columns": ["name", "age"],
|
|
44
|
+
"row_count": 3,
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
df = query_result_to_dataframe(result)
|
|
48
|
+
|
|
49
|
+
assert len(df) == 3
|
|
50
|
+
assert list(df.columns) == ["name", "age"]
|
|
51
|
+
assert df.iloc[0]["name"] == "Alice"
|
|
52
|
+
assert df.iloc[1]["age"] == 25
|
|
53
|
+
|
|
54
|
+
def test_query_result_to_dataframe_nested(self):
|
|
55
|
+
"""Test conversion with nested data"""
|
|
56
|
+
result = {
|
|
57
|
+
"data": [
|
|
58
|
+
{"name": "Alice", "company": {"name": "TechCorp", "revenue": 1000000}},
|
|
59
|
+
{"name": "Bob", "company": {"name": "StartupInc", "revenue": 500000}},
|
|
60
|
+
],
|
|
61
|
+
"columns": ["name", "company"],
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
df = query_result_to_dataframe(result, normalize_nested=True)
|
|
65
|
+
|
|
66
|
+
assert "name" in df.columns
|
|
67
|
+
assert "company.name" in df.columns
|
|
68
|
+
assert "company.revenue" in df.columns
|
|
69
|
+
assert df.iloc[0]["company.name"] == "TechCorp"
|
|
70
|
+
|
|
71
|
+
def test_query_result_to_dataframe_empty(self):
|
|
72
|
+
"""Test conversion of empty result"""
|
|
73
|
+
result = {"data": [], "columns": ["name", "age"], "row_count": 0}
|
|
74
|
+
|
|
75
|
+
df = query_result_to_dataframe(result)
|
|
76
|
+
|
|
77
|
+
assert len(df) == 0
|
|
78
|
+
assert list(df.columns) == ["name", "age"]
|
|
79
|
+
|
|
80
|
+
def test_query_result_to_dataframe_with_dates(self):
|
|
81
|
+
"""Test conversion with date parsing"""
|
|
82
|
+
result = {
|
|
83
|
+
"data": [
|
|
84
|
+
{"name": "Alice", "created_at": "2023-01-15T10:30:00"},
|
|
85
|
+
{"name": "Bob", "created_at": "2023-02-20T14:45:00"},
|
|
86
|
+
],
|
|
87
|
+
"columns": ["name", "created_at"],
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
df = query_result_to_dataframe(result, parse_dates=True)
|
|
91
|
+
|
|
92
|
+
assert pd.api.types.is_datetime64_any_dtype(df["created_at"])
|
|
93
|
+
assert df.iloc[0]["created_at"].year == 2023
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class TestParseDateTimeColumns:
|
|
97
|
+
"""Test datetime parsing functionality"""
|
|
98
|
+
|
|
99
|
+
def test_parse_datetime_columns_specific(self):
|
|
100
|
+
"""Test parsing specific datetime columns"""
|
|
101
|
+
df = pd.DataFrame(
|
|
102
|
+
{
|
|
103
|
+
"name": ["Alice", "Bob"],
|
|
104
|
+
"created_at": ["2023-01-15", "2023-02-20"],
|
|
105
|
+
"updated_at": ["2023-01-16T10:30:00", "2023-02-21T14:45:00"],
|
|
106
|
+
"count": [1, 2],
|
|
107
|
+
}
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
df = parse_datetime_columns(df, date_columns=["created_at", "updated_at"])
|
|
111
|
+
|
|
112
|
+
assert pd.api.types.is_datetime64_any_dtype(df["created_at"])
|
|
113
|
+
assert pd.api.types.is_datetime64_any_dtype(df["updated_at"])
|
|
114
|
+
assert not pd.api.types.is_datetime64_any_dtype(df["count"])
|
|
115
|
+
|
|
116
|
+
def test_parse_datetime_columns_infer(self):
|
|
117
|
+
"""Test automatic datetime column inference"""
|
|
118
|
+
df = pd.DataFrame(
|
|
119
|
+
{
|
|
120
|
+
"name": ["Alice", "Bob"],
|
|
121
|
+
"timestamp": ["2023-01-15T10:30:00", "2023-02-20T14:45:00"],
|
|
122
|
+
"not_a_date": ["abc", "def"],
|
|
123
|
+
}
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
df = parse_datetime_columns(df, infer=True)
|
|
127
|
+
|
|
128
|
+
assert pd.api.types.is_datetime64_any_dtype(df["timestamp"])
|
|
129
|
+
assert df["not_a_date"].dtype == "object"
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
class TestStreamToDataFrame:
|
|
133
|
+
"""Test streaming results to DataFrame"""
|
|
134
|
+
|
|
135
|
+
def test_stream_to_dataframe_basic(self):
|
|
136
|
+
"""Test converting stream to DataFrame"""
|
|
137
|
+
|
|
138
|
+
def mock_stream():
|
|
139
|
+
for i in range(10):
|
|
140
|
+
yield {"id": i, "value": i * 2}
|
|
141
|
+
|
|
142
|
+
df = stream_to_dataframe(mock_stream(), chunk_size=3)
|
|
143
|
+
|
|
144
|
+
assert len(df) == 10
|
|
145
|
+
assert df.iloc[5]["value"] == 10
|
|
146
|
+
|
|
147
|
+
def test_stream_to_dataframe_with_callback(self):
|
|
148
|
+
"""Test stream with chunk callback"""
|
|
149
|
+
chunk_counts = []
|
|
150
|
+
|
|
151
|
+
def on_chunk(chunk_df, total):
|
|
152
|
+
chunk_counts.append(len(chunk_df))
|
|
153
|
+
|
|
154
|
+
def mock_stream():
|
|
155
|
+
for i in range(10):
|
|
156
|
+
yield {"id": i, "value": i * 2}
|
|
157
|
+
|
|
158
|
+
df = stream_to_dataframe(mock_stream(), chunk_size=3, on_chunk=on_chunk)
|
|
159
|
+
|
|
160
|
+
assert len(df) == 10
|
|
161
|
+
assert chunk_counts == [3, 3, 3, 1] # 3 chunks of 3, 1 chunk of 1
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
class TestDataFrameToCypherParams:
|
|
165
|
+
"""Test DataFrame to Cypher parameter conversion"""
|
|
166
|
+
|
|
167
|
+
def test_dataframe_to_cypher_params(self):
|
|
168
|
+
"""Test converting DataFrame to Cypher parameters"""
|
|
169
|
+
df = pd.DataFrame(
|
|
170
|
+
{
|
|
171
|
+
"name": ["Alice", "Bob", "Charlie"],
|
|
172
|
+
"age": [30, 25, 35],
|
|
173
|
+
"active": [True, False, True],
|
|
174
|
+
}
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
params = dataframe_to_cypher_params(df)
|
|
178
|
+
|
|
179
|
+
assert "data" in params
|
|
180
|
+
assert len(params["data"]) == 3
|
|
181
|
+
assert params["data"][0]["name"] == "Alice"
|
|
182
|
+
assert params["data"][1]["age"] == 25
|
|
183
|
+
|
|
184
|
+
def test_dataframe_to_cypher_params_with_nan(self):
|
|
185
|
+
"""Test handling NaN values"""
|
|
186
|
+
df = pd.DataFrame(
|
|
187
|
+
{"name": ["Alice", "Bob"], "age": [30, pd.NA], "score": [95.5, None]}
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
params = dataframe_to_cypher_params(df, param_name="records")
|
|
191
|
+
|
|
192
|
+
assert "records" in params
|
|
193
|
+
assert params["records"][1]["age"] is None
|
|
194
|
+
assert params["records"][1]["score"] is None
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
class TestExportQueryToCSV:
|
|
198
|
+
"""Test CSV export functionality"""
|
|
199
|
+
|
|
200
|
+
@patch("robosystems_client.extensions.dataframe_utils.logger")
|
|
201
|
+
def test_export_query_to_csv(self, mock_logger):
|
|
202
|
+
"""Test exporting query results to CSV"""
|
|
203
|
+
mock_client = Mock()
|
|
204
|
+
|
|
205
|
+
def mock_stream(*args, **kwargs):
|
|
206
|
+
for i in range(5):
|
|
207
|
+
yield {"id": i, "name": f"Item {i}"}
|
|
208
|
+
|
|
209
|
+
mock_client.stream_query = Mock(side_effect=mock_stream)
|
|
210
|
+
|
|
211
|
+
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".csv") as f:
|
|
212
|
+
temp_file = f.name
|
|
213
|
+
|
|
214
|
+
try:
|
|
215
|
+
count = export_query_to_csv(
|
|
216
|
+
mock_client, "graph_id", "MATCH (n) RETURN n", temp_file, chunk_size=2
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
assert count == 5
|
|
220
|
+
mock_logger.info.assert_called()
|
|
221
|
+
|
|
222
|
+
# Verify CSV content
|
|
223
|
+
df = pd.read_csv(temp_file)
|
|
224
|
+
assert len(df) == 5
|
|
225
|
+
assert df.iloc[0]["name"] == "Item 0"
|
|
226
|
+
|
|
227
|
+
finally:
|
|
228
|
+
if os.path.exists(temp_file):
|
|
229
|
+
os.unlink(temp_file)
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
class TestCompareDataFrames:
|
|
233
|
+
"""Test DataFrame comparison"""
|
|
234
|
+
|
|
235
|
+
def test_compare_dataframes_with_keys(self):
|
|
236
|
+
"""Test comparing DataFrames with key columns"""
|
|
237
|
+
df1 = pd.DataFrame(
|
|
238
|
+
{"id": [1, 2, 3], "name": ["Alice", "Bob", "Charlie"], "age": [30, 25, 35]}
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
df2 = pd.DataFrame(
|
|
242
|
+
{"id": [1, 2, 4], "name": ["Alice", "Robert", "David"], "age": [31, 25, 40]}
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
diff = compare_dataframes(df1, df2, key_columns=["id"])
|
|
246
|
+
|
|
247
|
+
assert "_merge" in diff.columns
|
|
248
|
+
assert "name_old" in diff.columns
|
|
249
|
+
assert "name_new" in diff.columns
|
|
250
|
+
|
|
251
|
+
def test_compare_dataframes_without_keys(self):
|
|
252
|
+
"""Test comparing DataFrames without keys"""
|
|
253
|
+
df1 = pd.DataFrame({"name": ["Alice", "Bob"], "age": [30, 25]})
|
|
254
|
+
|
|
255
|
+
df2 = pd.DataFrame({"name": ["Alice", "Charlie"], "age": [30, 35]})
|
|
256
|
+
|
|
257
|
+
diff = compare_dataframes(df1, df2)
|
|
258
|
+
|
|
259
|
+
assert len(diff) == 2 # Bob and Charlie rows
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
class TestDataFrameQueryClient:
|
|
263
|
+
"""Test DataFrameQueryClient class"""
|
|
264
|
+
|
|
265
|
+
def test_query_df(self):
|
|
266
|
+
"""Test query_df method"""
|
|
267
|
+
mock_client = Mock()
|
|
268
|
+
mock_client.query.return_value = {
|
|
269
|
+
"data": [{"name": "Alice"}, {"name": "Bob"}],
|
|
270
|
+
"columns": ["name"],
|
|
271
|
+
"row_count": 2,
|
|
272
|
+
}
|
|
273
|
+
|
|
274
|
+
df_client = DataFrameQueryClient(mock_client)
|
|
275
|
+
df = df_client.query_df("graph_id", "MATCH (n) RETURN n")
|
|
276
|
+
|
|
277
|
+
assert len(df) == 2
|
|
278
|
+
assert df.iloc[0]["name"] == "Alice"
|
|
279
|
+
mock_client.query.assert_called_once()
|
|
280
|
+
|
|
281
|
+
def test_stream_df(self):
|
|
282
|
+
"""Test stream_df method"""
|
|
283
|
+
mock_client = Mock()
|
|
284
|
+
|
|
285
|
+
def mock_stream(*args, **kwargs):
|
|
286
|
+
for i in range(3):
|
|
287
|
+
yield {"id": i, "value": i * 10}
|
|
288
|
+
|
|
289
|
+
mock_client.stream_query.return_value = mock_stream()
|
|
290
|
+
|
|
291
|
+
df_client = DataFrameQueryClient(mock_client)
|
|
292
|
+
df = df_client.stream_df("graph_id", "MATCH (n) RETURN n")
|
|
293
|
+
|
|
294
|
+
assert len(df) == 3
|
|
295
|
+
assert df.iloc[1]["value"] == 10
|
|
296
|
+
|
|
297
|
+
def test_query_batch_df(self):
|
|
298
|
+
"""Test query_batch_df method"""
|
|
299
|
+
mock_client = Mock()
|
|
300
|
+
mock_client.query_batch.return_value = [
|
|
301
|
+
{"data": [{"count": 10}], "columns": ["count"]},
|
|
302
|
+
{"data": [{"count": 20}], "columns": ["count"]},
|
|
303
|
+
{"error": "Query failed", "query": "INVALID"},
|
|
304
|
+
]
|
|
305
|
+
|
|
306
|
+
df_client = DataFrameQueryClient(mock_client)
|
|
307
|
+
dfs = df_client.query_batch_df(
|
|
308
|
+
"graph_id",
|
|
309
|
+
[
|
|
310
|
+
"MATCH (p:Person) RETURN count(p)",
|
|
311
|
+
"MATCH (c:Company) RETURN count(c)",
|
|
312
|
+
"INVALID",
|
|
313
|
+
],
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
assert len(dfs) == 3
|
|
317
|
+
assert dfs[0].iloc[0]["count"] == 10
|
|
318
|
+
assert dfs[1].iloc[0]["count"] == 20
|
|
319
|
+
assert "error" in dfs[2].columns
|
|
320
|
+
|
|
321
|
+
def test_export_to_csv(self):
|
|
322
|
+
"""Test export_to_csv method"""
|
|
323
|
+
mock_client = Mock()
|
|
324
|
+
|
|
325
|
+
with patch(
|
|
326
|
+
"robosystems_client.extensions.dataframe_utils.export_query_to_csv"
|
|
327
|
+
) as mock_export:
|
|
328
|
+
mock_export.return_value = 100
|
|
329
|
+
|
|
330
|
+
df_client = DataFrameQueryClient(mock_client)
|
|
331
|
+
count = df_client.export_to_csv("graph_id", "MATCH (n) RETURN n", "output.csv")
|
|
332
|
+
|
|
333
|
+
assert count == 100
|
|
334
|
+
mock_export.assert_called_once()
|
|
@@ -71,7 +71,7 @@ class TestAuthenticatedIntegration:
|
|
|
71
71
|
)
|
|
72
72
|
assert ext.config["headers"]["Authorization"] == "Bearer jwt_token_here"
|
|
73
73
|
|
|
74
|
-
@patch("robosystems_client.
|
|
74
|
+
@patch("robosystems_client.api.query.execute_cypher_query.sync_detailed")
|
|
75
75
|
def test_cypher_query_execution(self, mock_sync_detailed, extensions):
|
|
76
76
|
"""Test executing Cypher queries through authenticated client"""
|
|
77
77
|
# Mock the response
|