opteryx-sqlalchemy 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opteryx_sqlalchemy-0.0.5.dist-info/METADATA +13 -0
- opteryx_sqlalchemy-0.0.5.dist-info/RECORD +9 -0
- opteryx_sqlalchemy-0.0.5.dist-info/WHEEL +5 -0
- opteryx_sqlalchemy-0.0.5.dist-info/entry_points.txt +2 -0
- opteryx_sqlalchemy-0.0.5.dist-info/licenses/LICENSE +201 -0
- opteryx_sqlalchemy-0.0.5.dist-info/top_level.txt +1 -0
- sqlalchemy_dialect/__init__.py +19 -0
- sqlalchemy_dialect/dbapi.py +825 -0
- sqlalchemy_dialect/dialect.py +370 -0
|
@@ -0,0 +1,825 @@
|
|
|
1
|
+
"""
|
|
2
|
+
DBAPI 2.0 (PEP 249) compliant interface for Opteryx (opteryx.app).
|
|
3
|
+
|
|
4
|
+
This module implements a minimal DBAPI 2.0 interface that communicates
|
|
5
|
+
with the Opteryx data service via HTTP. It provides Connection and Cursor
|
|
6
|
+
classes that translate SQL queries into HTTP requests.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import json
|
|
12
|
+
import logging
|
|
13
|
+
import time
|
|
14
|
+
from importlib.metadata import version
|
|
15
|
+
from typing import Any
|
|
16
|
+
from typing import Dict
|
|
17
|
+
from typing import List
|
|
18
|
+
from typing import Optional
|
|
19
|
+
from typing import Sequence
|
|
20
|
+
from typing import Tuple
|
|
21
|
+
from typing import Union
|
|
22
|
+
from urllib.parse import urljoin
|
|
23
|
+
|
|
24
|
+
import requests
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger("sqlalchemy.dialects.opteryx")
|
|
27
|
+
|
|
28
|
+
try:
|
|
29
|
+
__version__ = version("opteryx-sqlalchemy")
|
|
30
|
+
except Exception:
|
|
31
|
+
__version__ = "unknown"
|
|
32
|
+
|
|
33
|
+
# Module globals required by PEP 249
|
|
34
|
+
apilevel = "2.0"
|
|
35
|
+
threadsafety = 1 # Threads may share the module, but not connections
|
|
36
|
+
paramstyle = "named" # Named style: WHERE name=:name
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class Error(Exception):
|
|
40
|
+
"""Base exception for DBAPI errors."""
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class Warning(Exception): # noqa: A001
|
|
44
|
+
"""Warning exception."""
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class InterfaceError(Error):
|
|
48
|
+
"""Exception for interface errors."""
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class DatabaseError(Error):
|
|
52
|
+
"""Exception for database errors."""
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class DataError(DatabaseError):
|
|
56
|
+
"""Exception for data errors."""
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class OperationalError(DatabaseError):
|
|
60
|
+
"""Exception for operational errors."""
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class IntegrityError(DatabaseError):
|
|
64
|
+
"""Exception for integrity constraint errors."""
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class InternalError(DatabaseError):
|
|
68
|
+
"""Exception for internal errors."""
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class ProgrammingError(DatabaseError):
|
|
72
|
+
"""Exception for programming errors."""
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class NotSupportedError(DatabaseError):
|
|
76
|
+
"""Exception for not supported operations."""
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
# Type constructors (required by PEP 249)
|
|
80
|
+
def Date(year: int, month: int, day: int) -> str:
|
|
81
|
+
"""Construct a date value."""
|
|
82
|
+
return f"{year:04d}-{month:02d}-{day:02d}"
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def Time(hour: int, minute: int, second: int) -> str:
|
|
86
|
+
"""Construct a time value."""
|
|
87
|
+
return f"{hour:02d}:{minute:02d}:{second:02d}"
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def Timestamp(year: int, month: int, day: int, hour: int, minute: int, second: int) -> str:
|
|
91
|
+
"""Construct a timestamp value."""
|
|
92
|
+
return f"{year:04d}-{month:02d}-{day:02d} {hour:02d}:{minute:02d}:{second:02d}"
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def Binary(string: bytes) -> bytes:
|
|
96
|
+
"""Construct a binary value."""
|
|
97
|
+
return string
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
STRING = str
|
|
101
|
+
BINARY = bytes
|
|
102
|
+
NUMBER = float
|
|
103
|
+
DATETIME = str
|
|
104
|
+
ROWID = str
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class Cursor:
|
|
108
|
+
"""DBAPI 2.0 Cursor implementation for Opteryx."""
|
|
109
|
+
|
|
110
|
+
def __init__(self, connection: "Connection") -> None:
|
|
111
|
+
self._connection = connection
|
|
112
|
+
self._jwt_token: Optional[str] = None
|
|
113
|
+
self._description: Optional[
|
|
114
|
+
List[Tuple[str, Any, None, None, None, None, Optional[bool]]]
|
|
115
|
+
] = None
|
|
116
|
+
self._rowcount = -1
|
|
117
|
+
self._rows: List[Tuple[Any, ...]] = []
|
|
118
|
+
self._row_index = 0
|
|
119
|
+
self._arraysize = 1
|
|
120
|
+
self._closed = False
|
|
121
|
+
self._statement_handle: Optional[str] = None
|
|
122
|
+
|
|
123
|
+
# Execution option placeholders (may be set by dialect.do_execute)
|
|
124
|
+
self._opteryx_execution_options: dict = {}
|
|
125
|
+
self._opteryx_stream_results_requested: bool = False
|
|
126
|
+
self._opteryx_max_row_buffer: Optional[int] = None
|
|
127
|
+
|
|
128
|
+
# Try to authenticate using client credentials (client credentials flow)
|
|
129
|
+
# client_id is connection._username and client_secret is connection._token
|
|
130
|
+
try:
|
|
131
|
+
username = getattr(self._connection, "_username", None)
|
|
132
|
+
secret = getattr(self._connection, "_token", None)
|
|
133
|
+
if username and secret:
|
|
134
|
+
logger.debug("Attempting client credentials authentication for user: %s", username)
|
|
135
|
+
host = getattr(self._connection, "_host", "localhost")
|
|
136
|
+
# Normalize domain and build auth host (auth.domain)
|
|
137
|
+
try:
|
|
138
|
+
domain = self._connection._normalize_domain(host)
|
|
139
|
+
except Exception as e:
|
|
140
|
+
logger.debug("Failed to normalize domain '%s': %s", host, e)
|
|
141
|
+
domain = host
|
|
142
|
+
# Only add auth. prefix when domain looks like a DNS name (not 'localhost')
|
|
143
|
+
if "." in domain and not domain.startswith("localhost"):
|
|
144
|
+
auth_host = f"authenticate.{domain}"
|
|
145
|
+
else:
|
|
146
|
+
auth_host = domain
|
|
147
|
+
scheme = "https" if getattr(self._connection, "_ssl", False) else "http"
|
|
148
|
+
auth_url = f"{scheme}://{auth_host}/token"
|
|
149
|
+
logger.debug("Authentication URL: %s", auth_url)
|
|
150
|
+
|
|
151
|
+
# Build form-encoded payload
|
|
152
|
+
payload = {
|
|
153
|
+
"grant_type": "client_credentials",
|
|
154
|
+
"client_id": username,
|
|
155
|
+
"client_secret": secret,
|
|
156
|
+
}
|
|
157
|
+
# Use the connection session for auth so auth header set for all subsequent calls
|
|
158
|
+
sess = getattr(self._connection, "_session", requests.Session())
|
|
159
|
+
headers = {
|
|
160
|
+
"accept": "application/json",
|
|
161
|
+
"Content-Type": "application/x-www-form-urlencoded",
|
|
162
|
+
}
|
|
163
|
+
resp = sess.post(
|
|
164
|
+
auth_url,
|
|
165
|
+
data=payload,
|
|
166
|
+
headers=headers,
|
|
167
|
+
timeout=getattr(self._connection, "_timeout", 30),
|
|
168
|
+
)
|
|
169
|
+
resp.raise_for_status()
|
|
170
|
+
body = resp.json() if resp.text else {}
|
|
171
|
+
token = body.get("access_token") or body.get("token") or body.get("jwt")
|
|
172
|
+
if token:
|
|
173
|
+
self._jwt_token = token
|
|
174
|
+
logger.info("Authentication successful for user: %s", username)
|
|
175
|
+
# Set Authorization header for subsequent requests via the connection session
|
|
176
|
+
try:
|
|
177
|
+
self._connection._session.headers["Authorization"] = f"Bearer {token}"
|
|
178
|
+
except Exception as e:
|
|
179
|
+
logger.warning("Failed to set Authorization header: %s", e)
|
|
180
|
+
else:
|
|
181
|
+
logger.warning("Authentication response missing token for user: %s", username)
|
|
182
|
+
except requests.exceptions.RequestException as e:
|
|
183
|
+
# Authentication failed — don't raise here; we will attempt queries without the JWT
|
|
184
|
+
logger.warning("Authentication failed for user %s: %s", username, e)
|
|
185
|
+
self._jwt_token = None
|
|
186
|
+
except Exception as e:
|
|
187
|
+
# Any unexpected failure in auth should not crash cursor creation
|
|
188
|
+
logger.error("Unexpected error during authentication: %s", e, exc_info=True)
|
|
189
|
+
self._jwt_token = None
|
|
190
|
+
|
|
191
|
+
@property
|
|
192
|
+
def description(
|
|
193
|
+
self,
|
|
194
|
+
) -> Optional[List[Tuple[str, Any, None, None, None, None, Optional[bool]]]]:
|
|
195
|
+
"""Column description as required by PEP 249."""
|
|
196
|
+
return self._description
|
|
197
|
+
|
|
198
|
+
@property
|
|
199
|
+
def rowcount(self) -> int:
|
|
200
|
+
"""Number of rows affected by the last operation."""
|
|
201
|
+
return self._rowcount
|
|
202
|
+
|
|
203
|
+
@property
|
|
204
|
+
def arraysize(self) -> int:
|
|
205
|
+
"""Number of rows to fetch at a time."""
|
|
206
|
+
return self._arraysize
|
|
207
|
+
|
|
208
|
+
@arraysize.setter
|
|
209
|
+
def arraysize(self, value: int) -> None:
|
|
210
|
+
self._arraysize = value
|
|
211
|
+
|
|
212
|
+
def close(self) -> None:
|
|
213
|
+
"""Close the cursor."""
|
|
214
|
+
self._closed = True
|
|
215
|
+
self._rows = []
|
|
216
|
+
self._description = None
|
|
217
|
+
|
|
218
|
+
def _check_closed(self) -> None:
|
|
219
|
+
"""Raise exception if cursor is closed."""
|
|
220
|
+
if self._closed:
|
|
221
|
+
raise ProgrammingError("Cursor is closed")
|
|
222
|
+
|
|
223
|
+
def execute(
|
|
224
|
+
self,
|
|
225
|
+
operation: str,
|
|
226
|
+
parameters: Optional[Union[Dict[str, Any], Sequence[Any]]] = None,
|
|
227
|
+
) -> "Cursor":
|
|
228
|
+
"""Execute a SQL statement.
|
|
229
|
+
|
|
230
|
+
Args:
|
|
231
|
+
operation: SQL statement to execute
|
|
232
|
+
parameters: Optional parameters for the statement
|
|
233
|
+
|
|
234
|
+
Returns:
|
|
235
|
+
Self for method chaining
|
|
236
|
+
"""
|
|
237
|
+
self._check_closed()
|
|
238
|
+
self._rows = []
|
|
239
|
+
self._row_index = 0
|
|
240
|
+
self._description = None
|
|
241
|
+
self._rowcount = -1
|
|
242
|
+
|
|
243
|
+
# Log query (truncate if too long)
|
|
244
|
+
query_preview = operation[:200] + "..." if len(operation) > 200 else operation
|
|
245
|
+
logger.debug("Executing query: %s", query_preview)
|
|
246
|
+
if parameters:
|
|
247
|
+
logger.debug("Query parameters: %s", parameters)
|
|
248
|
+
|
|
249
|
+
# Convert sequence parameters to dict if needed
|
|
250
|
+
params_dict: Optional[Dict[str, Any]] = None
|
|
251
|
+
if parameters is not None:
|
|
252
|
+
if isinstance(parameters, dict):
|
|
253
|
+
params_dict = parameters
|
|
254
|
+
else:
|
|
255
|
+
# Convert positional to named parameters
|
|
256
|
+
params_dict = {f"p{i}": v for i, v in enumerate(parameters)}
|
|
257
|
+
# Replace ? placeholders with :p0, :p1, etc.
|
|
258
|
+
for i in range(len(parameters)):
|
|
259
|
+
operation = operation.replace("?", f":p{i}", 1)
|
|
260
|
+
|
|
261
|
+
# Submit the statement
|
|
262
|
+
start_time = time.time()
|
|
263
|
+
response = self._connection._submit_statement(operation, params_dict)
|
|
264
|
+
self._statement_handle = response.get("execution_id")
|
|
265
|
+
|
|
266
|
+
if not self._statement_handle:
|
|
267
|
+
logger.error("No execution ID in response: %s", response)
|
|
268
|
+
raise DatabaseError("No statement handle returned from server")
|
|
269
|
+
|
|
270
|
+
logger.debug("Statement submitted with execution_id: %s", self._statement_handle)
|
|
271
|
+
|
|
272
|
+
# Poll for completion
|
|
273
|
+
self._poll_for_results()
|
|
274
|
+
|
|
275
|
+
elapsed = time.time() - start_time
|
|
276
|
+
logger.info("Query completed in %.2fs, returned %d rows", elapsed, self._rowcount)
|
|
277
|
+
|
|
278
|
+
# Ensure description is not None so SQLAlchemy treats this as a rows-capable result.
|
|
279
|
+
# Some Opteryx responses may delay column metadata; setting an empty description here
|
|
280
|
+
# prevents SQLAlchemy from closing the result object immediately.
|
|
281
|
+
if self._description is None:
|
|
282
|
+
self._description = []
|
|
283
|
+
|
|
284
|
+
return self
|
|
285
|
+
|
|
286
|
+
def _poll_for_results(self) -> None:
|
|
287
|
+
"""Poll the server until statement execution completes."""
|
|
288
|
+
if not self._statement_handle:
|
|
289
|
+
return
|
|
290
|
+
|
|
291
|
+
max_wait = 300 # Maximum wait time in seconds
|
|
292
|
+
poll_interval = 0.5 # Initial poll interval in seconds
|
|
293
|
+
elapsed = 0.0
|
|
294
|
+
last_log_time = 0.0 # Track when we last logged progress
|
|
295
|
+
log_interval = 5.0 # Log progress every 5 seconds
|
|
296
|
+
|
|
297
|
+
logger.debug("Polling for execution_id: %s", self._statement_handle)
|
|
298
|
+
|
|
299
|
+
while elapsed < max_wait:
|
|
300
|
+
status = self._connection._get_statement_status(self._statement_handle)
|
|
301
|
+
raw_state = status.get("status")
|
|
302
|
+
|
|
303
|
+
if isinstance(raw_state, dict):
|
|
304
|
+
state_value = raw_state.get("state")
|
|
305
|
+
status_details = raw_state
|
|
306
|
+
else:
|
|
307
|
+
state_value = raw_state
|
|
308
|
+
status_details = {}
|
|
309
|
+
|
|
310
|
+
if not state_value:
|
|
311
|
+
state_value = status.get("state")
|
|
312
|
+
|
|
313
|
+
normalized_state = (state_value or "UNKNOWN").upper()
|
|
314
|
+
|
|
315
|
+
# Log progress periodically for long-running queries
|
|
316
|
+
if elapsed - last_log_time >= log_interval:
|
|
317
|
+
logger.info(
|
|
318
|
+
"Query still executing (state=%s, elapsed=%.1fs)", normalized_state, elapsed
|
|
319
|
+
)
|
|
320
|
+
last_log_time = elapsed
|
|
321
|
+
|
|
322
|
+
if normalized_state in ("COMPLETED", "SUCCEEDED", "INCHOATE"):
|
|
323
|
+
logger.debug("Query execution completed with state: %s", normalized_state)
|
|
324
|
+
self._fetch_results()
|
|
325
|
+
return
|
|
326
|
+
if normalized_state in ("FAILED", "CANCELLED"):
|
|
327
|
+
error_message = (
|
|
328
|
+
status.get("error_message")
|
|
329
|
+
or status_details.get("description")
|
|
330
|
+
or status.get("description")
|
|
331
|
+
or status.get("detail")
|
|
332
|
+
or "Unknown error"
|
|
333
|
+
)
|
|
334
|
+
logger.error("Query failed with state %s: %s", normalized_state, error_message)
|
|
335
|
+
raise ProgrammingError(error_message)
|
|
336
|
+
if normalized_state in ("UNKNOWN", "SUBMITTED", "EXECUTING", "RUNNING"):
|
|
337
|
+
logger.debug("Query state: %s (elapsed: %.1fs)", normalized_state, elapsed)
|
|
338
|
+
time.sleep(poll_interval)
|
|
339
|
+
elapsed += poll_interval
|
|
340
|
+
poll_interval = min(poll_interval * 1.5, 2.5)
|
|
341
|
+
continue
|
|
342
|
+
|
|
343
|
+
logger.error("Unexpected statement state: %s", state_value)
|
|
344
|
+
raise DatabaseError(f"Unknown statement state: {state_value}")
|
|
345
|
+
|
|
346
|
+
logger.error("Query execution timed out after %.1fs", elapsed)
|
|
347
|
+
raise OperationalError("Statement execution timed out")
|
|
348
|
+
|
|
349
|
+
@staticmethod
|
|
350
|
+
def _rows_from_columnar_data(column_data: Sequence[Dict[str, Any]]) -> List[Tuple[Any, ...]]:
|
|
351
|
+
"""Convert column-oriented payloads into row tuples."""
|
|
352
|
+
column_values: List[List[Any]] = []
|
|
353
|
+
for column in column_data:
|
|
354
|
+
if not isinstance(column, dict):
|
|
355
|
+
continue
|
|
356
|
+
values = column.get("values") or []
|
|
357
|
+
column_values.append(list(values))
|
|
358
|
+
if not column_values:
|
|
359
|
+
return []
|
|
360
|
+
max_rows = max((len(values) for values in column_values), default=0)
|
|
361
|
+
return [
|
|
362
|
+
tuple(
|
|
363
|
+
column_values[col_index][row_index]
|
|
364
|
+
if row_index < len(column_values[col_index])
|
|
365
|
+
else None
|
|
366
|
+
for col_index in range(len(column_values))
|
|
367
|
+
)
|
|
368
|
+
for row_index in range(max_rows)
|
|
369
|
+
]
|
|
370
|
+
|
|
371
|
+
def _fetch_results(self) -> None:
|
|
372
|
+
"""Fetch results from a completed statement."""
|
|
373
|
+
if not self._statement_handle:
|
|
374
|
+
return
|
|
375
|
+
|
|
376
|
+
page_size = max(self._opteryx_max_row_buffer or 10, self._arraysize)
|
|
377
|
+
offset = 0
|
|
378
|
+
has_description = False
|
|
379
|
+
rows: List[Tuple[Any, ...]] = []
|
|
380
|
+
total_rows: Optional[int] = None
|
|
381
|
+
|
|
382
|
+
def process_result_page(result: Dict[str, Any]) -> int:
|
|
383
|
+
nonlocal has_description, total_rows
|
|
384
|
+
new_rows = 0
|
|
385
|
+
if total_rows is None and "total_rows" in result:
|
|
386
|
+
try:
|
|
387
|
+
total_rows = int(result.get("total_rows", 0))
|
|
388
|
+
except (TypeError, ValueError):
|
|
389
|
+
total_rows = None
|
|
390
|
+
|
|
391
|
+
columns_meta = result.get("columns", [])
|
|
392
|
+
if columns_meta and not has_description:
|
|
393
|
+
self._description = [
|
|
394
|
+
(col.get("name", f"col{i}"), None, None, None, None, None, None)
|
|
395
|
+
for i, col in enumerate(columns_meta)
|
|
396
|
+
]
|
|
397
|
+
has_description = True
|
|
398
|
+
|
|
399
|
+
data = result.get("data", [])
|
|
400
|
+
if data:
|
|
401
|
+
if isinstance(data[0], dict) and "values" in data[0]:
|
|
402
|
+
# Columnar format: each dict has {name: ..., values: [...]}
|
|
403
|
+
if not has_description:
|
|
404
|
+
self._description = [
|
|
405
|
+
(col.get("name", f"col{i}"), None, None, None, None, None, None)
|
|
406
|
+
for i, col in enumerate(data)
|
|
407
|
+
]
|
|
408
|
+
has_description = True
|
|
409
|
+
column_rows = self._rows_from_columnar_data(data)
|
|
410
|
+
rows.extend(column_rows)
|
|
411
|
+
new_rows = len(column_rows)
|
|
412
|
+
elif isinstance(data[0], dict):
|
|
413
|
+
# Row format: each dict is a row with {col1: val1, col2: val2, ...}
|
|
414
|
+
if not has_description and data:
|
|
415
|
+
# Extract column names from first row
|
|
416
|
+
col_names = list(data[0].keys())
|
|
417
|
+
self._description = [
|
|
418
|
+
(col_name, None, None, None, None, None, None) for col_name in col_names
|
|
419
|
+
]
|
|
420
|
+
has_description = True
|
|
421
|
+
# Convert each row dict to a tuple in the correct column order
|
|
422
|
+
col_order = [col[0] for col in (self._description or [])]
|
|
423
|
+
for row_dict in data:
|
|
424
|
+
row_tuple = tuple(row_dict.get(col) for col in col_order)
|
|
425
|
+
rows.append(row_tuple)
|
|
426
|
+
new_rows = len(data)
|
|
427
|
+
else:
|
|
428
|
+
# List/tuple format
|
|
429
|
+
for row in data:
|
|
430
|
+
rows.append(tuple(row))
|
|
431
|
+
new_rows = len(data)
|
|
432
|
+
|
|
433
|
+
return new_rows
|
|
434
|
+
|
|
435
|
+
status_result = self._connection._get_statement_status(self._statement_handle) # pylint: disable=protected-access
|
|
436
|
+
process_result_page(status_result)
|
|
437
|
+
offset = len(rows)
|
|
438
|
+
|
|
439
|
+
while True:
|
|
440
|
+
if total_rows is not None and offset >= total_rows:
|
|
441
|
+
break
|
|
442
|
+
|
|
443
|
+
result = self._connection._get_statement_results( # pylint: disable=protected-access
|
|
444
|
+
self._statement_handle, num_rows=page_size, offset=offset
|
|
445
|
+
)
|
|
446
|
+
|
|
447
|
+
fetched_this_page = process_result_page(result)
|
|
448
|
+
if fetched_this_page <= 0:
|
|
449
|
+
break
|
|
450
|
+
|
|
451
|
+
offset += fetched_this_page
|
|
452
|
+
if fetched_this_page < page_size:
|
|
453
|
+
break
|
|
454
|
+
|
|
455
|
+
# Finalize rows and counts
|
|
456
|
+
self._rows = rows
|
|
457
|
+
self._rowcount = len(self._rows)
|
|
458
|
+
|
|
459
|
+
def executemany(
|
|
460
|
+
self,
|
|
461
|
+
operation: str,
|
|
462
|
+
seq_of_parameters: Sequence[Union[Dict[str, Any], Sequence[Any]]],
|
|
463
|
+
) -> "Cursor":
|
|
464
|
+
"""Execute a SQL statement multiple times with different parameters."""
|
|
465
|
+
self._check_closed()
|
|
466
|
+
for parameters in seq_of_parameters:
|
|
467
|
+
self.execute(operation, parameters)
|
|
468
|
+
return self
|
|
469
|
+
|
|
470
|
+
def fetchone(self) -> Optional[Tuple[Any, ...]]:
|
|
471
|
+
"""Fetch the next row of a query result set."""
|
|
472
|
+
self._check_closed()
|
|
473
|
+
if self._row_index >= len(self._rows):
|
|
474
|
+
return None
|
|
475
|
+
row = self._rows[self._row_index]
|
|
476
|
+
self._row_index += 1
|
|
477
|
+
return row
|
|
478
|
+
|
|
479
|
+
def fetchmany(self, size: Optional[int] = None) -> List[Tuple[Any, ...]]:
|
|
480
|
+
"""Fetch the next set of rows."""
|
|
481
|
+
self._check_closed()
|
|
482
|
+
if size is None:
|
|
483
|
+
size = self._arraysize
|
|
484
|
+
rows = self._rows[self._row_index : self._row_index + size]
|
|
485
|
+
self._row_index += len(rows)
|
|
486
|
+
return rows
|
|
487
|
+
|
|
488
|
+
def fetchall(self) -> List[Tuple[Any, ...]]:
|
|
489
|
+
"""Fetch all remaining rows."""
|
|
490
|
+
self._check_closed()
|
|
491
|
+
rows = self._rows[self._row_index :]
|
|
492
|
+
self._row_index = len(self._rows)
|
|
493
|
+
return rows
|
|
494
|
+
|
|
495
|
+
def setinputsizes(self, sizes: Sequence[Any]) -> None:
|
|
496
|
+
"""Set input sizes (no-op, but required by PEP 249)."""
|
|
497
|
+
_ = sizes
|
|
498
|
+
return None
|
|
499
|
+
|
|
500
|
+
def setoutputsize(self, size: int, column: Optional[int] = None) -> None:
|
|
501
|
+
"""Set output size (no-op, but required by PEP 249)."""
|
|
502
|
+
_ = size
|
|
503
|
+
_ = column
|
|
504
|
+
return None
|
|
505
|
+
|
|
506
|
+
def __iter__(self) -> "Cursor":
|
|
507
|
+
"""Make cursor iterable."""
|
|
508
|
+
return self
|
|
509
|
+
|
|
510
|
+
def __next__(self) -> Tuple[Any, ...]:
|
|
511
|
+
"""Get next row."""
|
|
512
|
+
row = self.fetchone()
|
|
513
|
+
if row is None:
|
|
514
|
+
raise StopIteration
|
|
515
|
+
return row
|
|
516
|
+
|
|
517
|
+
|
|
518
|
+
class Connection:
|
|
519
|
+
"""DBAPI 2.0 Connection implementation for Opteryx.
|
|
520
|
+
|
|
521
|
+
Manages HTTP connections to the Opteryx data service.
|
|
522
|
+
"""
|
|
523
|
+
|
|
524
|
+
def __init__(
|
|
525
|
+
self,
|
|
526
|
+
host: str = "jobs.opteryx.app",
|
|
527
|
+
port: int = 8000,
|
|
528
|
+
username: Optional[str] = None,
|
|
529
|
+
token: Optional[str] = None,
|
|
530
|
+
database: Optional[str] = None,
|
|
531
|
+
ssl: bool = False,
|
|
532
|
+
timeout: float = 30.0,
|
|
533
|
+
) -> None:
|
|
534
|
+
"""Initialize connection to Opteryx data service.
|
|
535
|
+
|
|
536
|
+
Args:
|
|
537
|
+
host: Hostname of the Opteryx data service
|
|
538
|
+
port: Port number
|
|
539
|
+
username: Username for authentication (optional)
|
|
540
|
+
token: Bearer token for authentication
|
|
541
|
+
database: Database/schema name (optional)
|
|
542
|
+
ssl: Whether to use HTTPS
|
|
543
|
+
timeout: Request timeout in seconds
|
|
544
|
+
"""
|
|
545
|
+
self._host = host
|
|
546
|
+
self._port = port
|
|
547
|
+
self._username = username
|
|
548
|
+
self._token = token
|
|
549
|
+
self._database = database
|
|
550
|
+
self._ssl = ssl
|
|
551
|
+
self._timeout = timeout
|
|
552
|
+
self._closed = False
|
|
553
|
+
|
|
554
|
+
# Build base URL
|
|
555
|
+
scheme = "https" if ssl else "http"
|
|
556
|
+
if (ssl and port == 443) or (not ssl and port == 80):
|
|
557
|
+
self._base_url = f"{scheme}://{host}"
|
|
558
|
+
else:
|
|
559
|
+
self._base_url = f"{scheme}://{host}:{port}"
|
|
560
|
+
|
|
561
|
+
logger.debug(
|
|
562
|
+
"Creating connection to %s (ssl=%s, timeout=%.1fs)", self._base_url, ssl, timeout
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
# Create session for connection pooling
|
|
566
|
+
self._session = requests.Session()
|
|
567
|
+
if token:
|
|
568
|
+
self._session.headers["Authorization"] = f"Bearer {token}"
|
|
569
|
+
logger.debug("Using pre-configured token for authentication")
|
|
570
|
+
self._session.headers["Content-Type"] = "application/json"
|
|
571
|
+
|
|
572
|
+
def _normalize_domain(self, host: str) -> str:
|
|
573
|
+
"""Return the base domain for the given host by stripping known subdomain prefixes.
|
|
574
|
+
|
|
575
|
+
Examples:
|
|
576
|
+
'jobs.opteryx.app' -> 'opteryx.app'
|
|
577
|
+
'authenticate.opteryx.app' -> 'opteryx.app'
|
|
578
|
+
'opteryx.app' -> 'opteryx.app'
|
|
579
|
+
'localhost' -> 'localhost'
|
|
580
|
+
"""
|
|
581
|
+
domain = host
|
|
582
|
+
for p in ("jobs.", "authenticate."):
|
|
583
|
+
if domain.startswith(p):
|
|
584
|
+
domain = domain[len(p) :]
|
|
585
|
+
return domain
|
|
586
|
+
|
|
587
|
+
def _data_base_url(self) -> str:
|
|
588
|
+
"""Construct a base URL that targets the 'data' subdomain for API requests."""
|
|
589
|
+
scheme = "https" if self._ssl else "http"
|
|
590
|
+
domain = self._normalize_domain(self._host)
|
|
591
|
+
# Only add subdomain prefix for DNS-style hosts (e.g. example.com), not for localhost or IPs
|
|
592
|
+
if "." in domain and not domain.startswith("localhost"):
|
|
593
|
+
data_host = f"jobs.{domain}"
|
|
594
|
+
else:
|
|
595
|
+
data_host = domain
|
|
596
|
+
if (self._ssl and self._port == 443) or (not self._ssl and self._port == 80):
|
|
597
|
+
return f"{scheme}://{data_host}"
|
|
598
|
+
return f"{scheme}://{data_host}:{self._port}"
|
|
599
|
+
|
|
600
|
+
def _check_closed(self) -> None:
|
|
601
|
+
"""Raise exception if connection is closed."""
|
|
602
|
+
if self._closed:
|
|
603
|
+
raise ProgrammingError("Connection is closed")
|
|
604
|
+
|
|
605
|
+
def _submit_statement(
|
|
606
|
+
self, sql: str, parameters: Optional[Dict[str, Any]] = None
|
|
607
|
+
) -> Dict[str, Any]:
|
|
608
|
+
"""Submit a SQL statement to the data service."""
|
|
609
|
+
self._check_closed()
|
|
610
|
+
|
|
611
|
+
url = urljoin(self._data_base_url() + "/", "api/v1/jobs")
|
|
612
|
+
payload: Dict[str, Any] = {
|
|
613
|
+
"sql_text": sql,
|
|
614
|
+
"client_info": {
|
|
615
|
+
"application_name": "opteryx-sqlalchemy",
|
|
616
|
+
"application_version": __version__,
|
|
617
|
+
},
|
|
618
|
+
}
|
|
619
|
+
if parameters:
|
|
620
|
+
payload["parameters"] = parameters
|
|
621
|
+
|
|
622
|
+
logger.debug("Submitting statement to %s", url)
|
|
623
|
+
|
|
624
|
+
try:
|
|
625
|
+
response = self._session.post(url, json=payload, timeout=self._timeout)
|
|
626
|
+
response.raise_for_status()
|
|
627
|
+
result = response.json()
|
|
628
|
+
logger.debug(
|
|
629
|
+
"Statement submitted successfully, execution_id: %s", result.get("execution_id")
|
|
630
|
+
)
|
|
631
|
+
return result
|
|
632
|
+
except requests.exceptions.HTTPError as e:
|
|
633
|
+
if e.response is not None:
|
|
634
|
+
status_code = e.response.status_code
|
|
635
|
+
# Authentication/authorization errors should raise OperationalError
|
|
636
|
+
if status_code in (401, 403):
|
|
637
|
+
try:
|
|
638
|
+
detail = e.response.json().get("detail", str(e))
|
|
639
|
+
except (ValueError, json.JSONDecodeError):
|
|
640
|
+
detail = e.response.text or str(e)
|
|
641
|
+
logger.error("Authentication error (HTTP %d): %s", status_code, detail)
|
|
642
|
+
raise OperationalError(f"Authentication error: {detail}") from e
|
|
643
|
+
try:
|
|
644
|
+
detail = e.response.json().get("detail", str(e))
|
|
645
|
+
except (ValueError, json.JSONDecodeError):
|
|
646
|
+
detail = e.response.text or str(e)
|
|
647
|
+
logger.error("HTTP error %d submitting statement: %s", status_code, detail)
|
|
648
|
+
raise DatabaseError(f"HTTP error: {detail}") from e
|
|
649
|
+
logger.error("HTTP error submitting statement: %s", e)
|
|
650
|
+
raise DatabaseError(f"HTTP error: {e}") from e
|
|
651
|
+
except requests.exceptions.RequestException as e:
|
|
652
|
+
logger.error("Connection error submitting statement: %s", e)
|
|
653
|
+
raise OperationalError(f"Connection error: {e}") from e
|
|
654
|
+
|
|
655
|
+
def _get_statement_status(self, statement_handle: str) -> Dict[str, Any]:
|
|
656
|
+
"""Get the status of a submitted statement."""
|
|
657
|
+
self._check_closed()
|
|
658
|
+
|
|
659
|
+
url = urljoin(self._data_base_url() + "/", f"api/v1/jobs/{statement_handle}/status")
|
|
660
|
+
|
|
661
|
+
logger.debug("Checking status for execution_id: %s", statement_handle)
|
|
662
|
+
|
|
663
|
+
try:
|
|
664
|
+
response = self._session.get(url, timeout=self._timeout)
|
|
665
|
+
response.raise_for_status()
|
|
666
|
+
result = response.json()
|
|
667
|
+
logger.debug("Status check response: %s", result.get("status") or result.get("state"))
|
|
668
|
+
return result
|
|
669
|
+
except requests.exceptions.HTTPError as e:
|
|
670
|
+
if e.response is not None:
|
|
671
|
+
status_code = e.response.status_code
|
|
672
|
+
# Authentication/authorization errors should raise OperationalError
|
|
673
|
+
if status_code in (401, 403):
|
|
674
|
+
try:
|
|
675
|
+
detail = e.response.json().get("detail", str(e))
|
|
676
|
+
except (ValueError, json.JSONDecodeError):
|
|
677
|
+
detail = e.response.text or str(e)
|
|
678
|
+
logger.error(
|
|
679
|
+
"Authentication error (HTTP %d) checking status: %s", status_code, detail
|
|
680
|
+
)
|
|
681
|
+
raise OperationalError(f"Authentication error: {detail}") from e
|
|
682
|
+
if status_code == 404:
|
|
683
|
+
logger.error("Statement not found: %s", statement_handle)
|
|
684
|
+
raise ProgrammingError("Statement not found") from e
|
|
685
|
+
try:
|
|
686
|
+
detail = e.response.json().get("detail", str(e))
|
|
687
|
+
except (ValueError, json.JSONDecodeError):
|
|
688
|
+
detail = e.response.text or str(e)
|
|
689
|
+
logger.error("HTTP error %d checking status: %s", status_code, detail)
|
|
690
|
+
raise DatabaseError(f"HTTP error: {detail}") from e
|
|
691
|
+
logger.error("HTTP error checking status: %s", e)
|
|
692
|
+
raise DatabaseError(f"HTTP error: {e}") from e
|
|
693
|
+
except requests.exceptions.RequestException as e:
|
|
694
|
+
logger.error("Connection error checking status: %s", e)
|
|
695
|
+
raise OperationalError(f"Connection error: {e}") from e
|
|
696
|
+
|
|
697
|
+
def _get_statement_results(
|
|
698
|
+
self, statement_handle: str, num_rows: Optional[int] = None, offset: Optional[int] = None
|
|
699
|
+
) -> Dict[str, Any]:
|
|
700
|
+
"""Get results for a completed statement using the /download endpoint.
|
|
701
|
+
|
|
702
|
+
Args:
|
|
703
|
+
statement_handle: The execution ID returned by submit
|
|
704
|
+
num_rows: Maximum number of rows to return (maps to 'limit' param)
|
|
705
|
+
offset: Row offset for pagination
|
|
706
|
+
|
|
707
|
+
Returns:
|
|
708
|
+
Dictionary containing the result data in a format compatible with process_result_page.
|
|
709
|
+
The download endpoint returns NDJSON (newline-delimited JSON), so we parse it and
|
|
710
|
+
convert to the expected format.
|
|
711
|
+
"""
|
|
712
|
+
url = urljoin(self._data_base_url() + "/", f"api/v1/jobs/{statement_handle}/download")
|
|
713
|
+
params: Dict[str, Any] = {"file_format": "json"}
|
|
714
|
+
if num_rows is not None:
|
|
715
|
+
params["limit"] = int(num_rows)
|
|
716
|
+
if offset is not None:
|
|
717
|
+
params["offset"] = int(offset)
|
|
718
|
+
|
|
719
|
+
logger.debug(
|
|
720
|
+
"Fetching results for execution_id: %s (limit=%s, offset=%s)",
|
|
721
|
+
statement_handle,
|
|
722
|
+
num_rows,
|
|
723
|
+
offset,
|
|
724
|
+
)
|
|
725
|
+
|
|
726
|
+
try:
|
|
727
|
+
response = self._session.get(url, params=params, timeout=self._timeout)
|
|
728
|
+
response.raise_for_status()
|
|
729
|
+
|
|
730
|
+
# The download endpoint returns NDJSON (newline-delimited JSON)
|
|
731
|
+
# Parse each line as a separate JSON object (row)
|
|
732
|
+
rows = []
|
|
733
|
+
columns = None
|
|
734
|
+
for line in response.text.strip().split("\n"):
|
|
735
|
+
if line:
|
|
736
|
+
row_dict = json.loads(line)
|
|
737
|
+
if columns is None:
|
|
738
|
+
# Extract column names from first row
|
|
739
|
+
columns = list(row_dict.keys())
|
|
740
|
+
rows.append(row_dict)
|
|
741
|
+
|
|
742
|
+
# Convert to the format expected by process_result_page
|
|
743
|
+
result = {"data": rows, "columns": [{"name": col} for col in (columns or [])]}
|
|
744
|
+
|
|
745
|
+
logger.debug("Fetched %d rows from download endpoint", len(rows))
|
|
746
|
+
return result
|
|
747
|
+
except requests.exceptions.HTTPError as e:
|
|
748
|
+
if e.response is not None:
|
|
749
|
+
# Authentication/authorization errors should raise OperationalError
|
|
750
|
+
if e.response.status_code in (401, 403):
|
|
751
|
+
try:
|
|
752
|
+
detail = e.response.json().get("detail", str(e))
|
|
753
|
+
except (ValueError, json.JSONDecodeError):
|
|
754
|
+
detail = e.response.text or str(e)
|
|
755
|
+
logger.error("Authentication error fetching results: %s", detail)
|
|
756
|
+
raise OperationalError(f"Authentication error: {detail}") from e
|
|
757
|
+
# For other HTTP errors, fall back to status endpoint
|
|
758
|
+
logger.debug("Download endpoint unavailable, falling back to status endpoint")
|
|
759
|
+
except requests.exceptions.RequestException as e:
|
|
760
|
+
logger.debug("Error fetching from download endpoint: %s, falling back", e)
|
|
761
|
+
|
|
762
|
+
# Fallback to status endpoint if dedicated download endpoint is unavailable
|
|
763
|
+
return self._get_statement_status(statement_handle)
|
|
764
|
+
|
|
765
|
+
def close(self) -> None:
|
|
766
|
+
"""Close the connection."""
|
|
767
|
+
if not self._closed:
|
|
768
|
+
logger.debug("Closing connection to %s", self._base_url)
|
|
769
|
+
self._session.close()
|
|
770
|
+
self._closed = True
|
|
771
|
+
|
|
772
|
+
def commit(self) -> None:
|
|
773
|
+
"""Commit transaction (no-op for Opteryx as it's read-only)."""
|
|
774
|
+
self._check_closed()
|
|
775
|
+
|
|
776
|
+
def rollback(self) -> None:
|
|
777
|
+
"""Rollback transaction (no-op for Opteryx as it's read-only)."""
|
|
778
|
+
self._check_closed()
|
|
779
|
+
|
|
780
|
+
def cursor(self) -> Cursor:
|
|
781
|
+
"""Create a new cursor object."""
|
|
782
|
+
self._check_closed()
|
|
783
|
+
return Cursor(self)
|
|
784
|
+
|
|
785
|
+
def __enter__(self) -> "Connection":
|
|
786
|
+
"""Context manager entry."""
|
|
787
|
+
return self
|
|
788
|
+
|
|
789
|
+
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
|
790
|
+
"""Context manager exit."""
|
|
791
|
+
self.close()
|
|
792
|
+
|
|
793
|
+
|
|
794
|
+
def connect(
|
|
795
|
+
host: str = "localhost",
|
|
796
|
+
port: int = 8000,
|
|
797
|
+
username: Optional[str] = None,
|
|
798
|
+
token: Optional[str] = None,
|
|
799
|
+
database: Optional[str] = None,
|
|
800
|
+
ssl: bool = False,
|
|
801
|
+
timeout: float = 30.0,
|
|
802
|
+
) -> Connection:
|
|
803
|
+
"""Create a new connection to the Opteryx data service.
|
|
804
|
+
|
|
805
|
+
Args:
|
|
806
|
+
host: Hostname of the Opteryx data service
|
|
807
|
+
port: Port number
|
|
808
|
+
username: Username for authentication (optional)
|
|
809
|
+
token: Bearer token for authentication
|
|
810
|
+
database: Database/schema name (optional)
|
|
811
|
+
ssl: Whether to use HTTPS
|
|
812
|
+
timeout: Request timeout in seconds
|
|
813
|
+
|
|
814
|
+
Returns:
|
|
815
|
+
A new Connection object
|
|
816
|
+
"""
|
|
817
|
+
return Connection(
|
|
818
|
+
host=host,
|
|
819
|
+
port=port,
|
|
820
|
+
username=username,
|
|
821
|
+
token=token,
|
|
822
|
+
database=database,
|
|
823
|
+
ssl=ssl,
|
|
824
|
+
timeout=timeout,
|
|
825
|
+
)
|