wherobots-python-dbapi 0.8.0__tar.gz → 0.9.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: wherobots-python-dbapi
3
- Version: 0.8.0
3
+ Version: 0.9.1
4
4
  Summary: Python DB-API driver for Wherobots DB
5
5
  License: Apache 2.0
6
6
  Author: Maxime Petazzoni
@@ -12,6 +12,7 @@ Classifier: Programming Language :: Python :: 3.9
12
12
  Classifier: Programming Language :: Python :: 3.10
13
13
  Classifier: Programming Language :: Python :: 3.11
14
14
  Classifier: Programming Language :: Python :: 3.12
15
+ Classifier: Programming Language :: Python :: 3.13
15
16
  Requires-Dist: StrEnum (>=0.4.15,<0.5.0)
16
17
  Requires-Dist: cbor2 (>=5.6.3,<6.0.0)
17
18
  Requires-Dist: numpy (<2)
@@ -19,7 +20,7 @@ Requires-Dist: pandas (>=2.1.0,<3.0.0)
19
20
  Requires-Dist: pyarrow (>=14.0.2,<15.0.0)
20
21
  Requires-Dist: requests (>=2.31.0,<3.0.0)
21
22
  Requires-Dist: tenacity (>=8.2.3,<9.0.0)
22
- Requires-Dist: websockets (>=12.0,<13.0)
23
+ Requires-Dist: websockets (==13.0)
23
24
  Description-Content-Type: text/markdown
24
25
 
25
26
  # wherobots-python-dbapi
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "wherobots-python-dbapi"
3
- version = "0.8.0"
3
+ version = "0.9.1"
4
4
  description = "Python DB-API driver for Wherobots DB"
5
5
  authors = ["Maxime Petazzoni <max@wherobots.com>"]
6
6
  license = "Apache 2.0"
@@ -14,7 +14,7 @@ Tracker = "https://github.com/wherobots/wherobots-python-dbapi-driver/issues"
14
14
  [tool.poetry.dependencies]
15
15
  python = "^3.9"
16
16
  requests = "^2.31.0"
17
- websockets = "^12.0"
17
+ websockets = "13.0"
18
18
  tenacity = "^8.2.3"
19
19
  pyarrow = "^14.0.2"
20
20
  cbor2 = "^5.6.3"
@@ -29,6 +29,7 @@ pytest = "^8.0.2"
29
29
  black = "^24.2.0"
30
30
  pre-commit = "^3.6.2"
31
31
  conventional-pre-commit = "^3.1.0"
32
+ types-requests = "^2.32.0.20241016"
32
33
  rich = "^13.7.1"
33
34
 
34
35
  [build-system]
@@ -7,6 +7,7 @@ from dataclasses import dataclass
7
7
  from typing import Any, Callable, Union
8
8
 
9
9
  import cbor2
10
+ import pandas
10
11
  import pyarrow
11
12
  import websockets.exceptions
12
13
  import websockets.protocol
@@ -74,19 +75,19 @@ class Connection:
74
75
  def __exit__(self, exc_type, exc_val, exc_tb):
75
76
  self.close()
76
77
 
77
- def close(self):
78
+ def close(self) -> None:
78
79
  self.__ws.close()
79
80
 
80
- def commit(self):
81
+ def commit(self) -> None:
81
82
  raise NotSupportedError
82
83
 
83
- def rollback(self):
84
+ def rollback(self) -> None:
84
85
  raise NotSupportedError
85
86
 
86
87
  def cursor(self) -> Cursor:
87
88
  return Cursor(self.__execute_sql, self.__cancel_query)
88
89
 
89
- def __main_loop(self):
90
+ def __main_loop(self) -> None:
90
91
  """Main background loop listening for messages from the SQL session."""
91
92
  logging.info("Starting background connection handling loop...")
92
93
  while self.__ws.protocol.state < websockets.protocol.State.CLOSING:
@@ -101,7 +102,7 @@ class Connection:
101
102
  except Exception as e:
102
103
  logging.exception("Error handling message from SQL session", exc_info=e)
103
104
 
104
- def __listen(self):
105
+ def __listen(self) -> None:
105
106
  """Waits for the next message from the SQL session and processes it.
106
107
 
107
108
  The code in this method is purposefully defensive to avoid unexpected situations killing the thread.
@@ -120,7 +121,8 @@ class Connection:
120
121
  )
121
122
  return
122
123
 
123
- if kind == EventKind.STATE_UPDATED:
124
+ # Incoming state transitions are handled here.
125
+ if kind == EventKind.STATE_UPDATED or kind == EventKind.EXECUTION_RESULT:
124
126
  try:
125
127
  query.state = ExecutionState[message["state"].upper()]
126
128
  logging.info("Query %s is now %s.", execution_id, query.state)
@@ -128,43 +130,32 @@ class Connection:
128
130
  logging.warning("Invalid state update message for %s", execution_id)
129
131
  return
130
132
 
131
- # Incoming state transitions are handled here.
132
133
  if query.state == ExecutionState.SUCCEEDED:
133
- self.__request_results(execution_id)
134
+ # On a state_updated event telling us the query succeeded,
135
+ # ask for results.
136
+ if kind == EventKind.STATE_UPDATED:
137
+ self.__request_results(execution_id)
138
+ return
139
+
140
+ # Otherwise, process the results from the execution_result event.
141
+ results = message.get("results")
142
+ if not results or not isinstance(results, dict):
143
+ logging.warning("Got no results back from %s.", execution_id)
144
+ return
145
+
146
+ query.state = ExecutionState.COMPLETED
147
+ query.handler(self._handle_results(execution_id, results))
148
+ elif query.state == ExecutionState.CANCELLED:
149
+ logging.info(
150
+ "Query %s has been cancelled; returning empty results.",
151
+ execution_id,
152
+ )
153
+ query.handler(pandas.DataFrame())
154
+ self.__queries.pop(execution_id)
134
155
  elif query.state == ExecutionState.FAILED:
135
156
  # Don't do anything here; the ERROR event is coming with more
136
157
  # details.
137
158
  pass
138
-
139
- elif kind == EventKind.EXECUTION_RESULT:
140
- results = message.get("results")
141
- if not results or not isinstance(results, dict):
142
- logging.warning("Got no results back from %s.", execution_id)
143
- return
144
-
145
- result_bytes = results.get("result_bytes")
146
- result_format = results.get("format")
147
- result_compression = results.get("compression")
148
- logging.info(
149
- "Received %d bytes of %s-compressed %s results from %s.",
150
- len(result_bytes),
151
- result_compression,
152
- result_format,
153
- execution_id,
154
- )
155
-
156
- query.state = ExecutionState.COMPLETED
157
- if result_format == ResultsFormat.JSON:
158
- query.handler(json.loads(result_bytes.decode("utf-8")))
159
- elif result_format == ResultsFormat.ARROW:
160
- buffer = pyarrow.py_buffer(result_bytes)
161
- stream = pyarrow.input_stream(buffer, result_compression)
162
- with pyarrow.ipc.open_stream(stream) as reader:
163
- query.handler(reader.read_pandas())
164
- else:
165
- query.handler(
166
- OperationalError(f"Unsupported results format {result_format}")
167
- )
168
159
  elif kind == EventKind.ERROR:
169
160
  query.state = ExecutionState.FAILED
170
161
  error = message.get("message")
@@ -172,6 +163,28 @@ class Connection:
172
163
  else:
173
164
  logging.warning("Received unknown %s event!", kind)
174
165
 
166
+ def _handle_results(self, execution_id: str, results: dict[str, Any]) -> Any:
167
+ result_bytes = results.get("result_bytes")
168
+ result_format = results.get("format")
169
+ result_compression = results.get("compression")
170
+ logging.info(
171
+ "Received %d bytes of %s-compressed %s results from %s.",
172
+ len(result_bytes),
173
+ result_compression,
174
+ result_format,
175
+ execution_id,
176
+ )
177
+
178
+ if result_format == ResultsFormat.JSON:
179
+ return json.loads(result_bytes.decode("utf-8"))
180
+ elif result_format == ResultsFormat.ARROW:
181
+ buffer = pyarrow.py_buffer(result_bytes)
182
+ stream = pyarrow.input_stream(buffer, result_compression)
183
+ with pyarrow.ipc.open_stream(stream) as reader:
184
+ return reader.read_pandas()
185
+ else:
186
+ return OperationalError(f"Unsupported results format {result_format}")
187
+
175
188
  def __send(self, message: dict[str, Any]) -> None:
176
189
  request = json.dumps(message)
177
190
  logging.debug("Request: %s", request)
@@ -230,7 +243,14 @@ class Connection:
230
243
  self.__send(request)
231
244
 
232
245
  def __cancel_query(self, execution_id: str) -> None:
233
- query = self.__queries.pop(execution_id)
234
- if query:
235
- logging.info("Cancelled query %s.", execution_id)
236
- # TODO: when protocol supports it, send cancellation request.
246
+ """Cancels the query with the given execution ID."""
247
+ query = self.__queries.get(execution_id)
248
+ if not query:
249
+ return
250
+
251
+ request = {
252
+ "kind": RequestKind.CANCEL.value,
253
+ "execution_id": execution_id,
254
+ }
255
+ logging.info("Cancelling query %s...", execution_id)
256
+ self.__send(request)
@@ -1,4 +1,5 @@
1
1
  from enum import auto
2
+ from packaging.version import Version
2
3
  from strenum import LowercaseStrEnum, StrEnum
3
4
 
4
5
  from .region import Region
@@ -15,7 +16,7 @@ DEFAULT_SESSION_WAIT_TIMEOUT_SECONDS: float = 900
15
16
  DEFAULT_REUSE_SESSION: bool = True
16
17
 
17
18
  MAX_MESSAGE_SIZE: int = 100 * 2**20 # 100MiB
18
- PROTOCOL_VERSION: str = "1.0.0"
19
+ PROTOCOL_VERSION: Version = Version("1.0.0")
19
20
 
20
21
 
21
22
  class ExecutionState(LowercaseStrEnum):
@@ -31,6 +32,9 @@ class ExecutionState(LowercaseStrEnum):
31
32
  SUCCEEDED = auto()
32
33
  "The SQL session has reported the query has completed successfully."
33
34
 
35
+ CANCELLED = auto()
36
+ "The SQL session has reported the query has been cancelled."
37
+
34
38
  FAILED = auto()
35
39
  "The SQL session has reported the query has failed."
36
40
 
@@ -40,13 +44,18 @@ class ExecutionState(LowercaseStrEnum):
40
44
  COMPLETED = auto()
41
45
  "The driver has completed processing the query results."
42
46
 
43
- def is_terminal_state(self):
44
- return self in (ExecutionState.COMPLETED, ExecutionState.FAILED)
47
+ def is_terminal_state(self) -> bool:
48
+ return self in (
49
+ ExecutionState.COMPLETED,
50
+ ExecutionState.CANCELLED,
51
+ ExecutionState.FAILED,
52
+ )
45
53
 
46
54
 
47
55
  class RequestKind(LowercaseStrEnum):
48
56
  EXECUTE_SQL = auto()
49
57
  RETRIEVE_RESULTS = auto()
58
+ CANCEL = auto()
50
59
 
51
60
 
52
61
  class EventKind(LowercaseStrEnum):
@@ -88,7 +97,7 @@ class AppStatus(StrEnum):
88
97
  DESTROY_FAILED = auto()
89
98
  DESTROYED = auto()
90
99
 
91
- def is_starting(self):
100
+ def is_starting(self) -> bool:
92
101
  return self in (
93
102
  AppStatus.PENDING,
94
103
  AppStatus.PREPARING,
@@ -98,7 +107,7 @@ class AppStatus(StrEnum):
98
107
  AppStatus.INITIALIZING,
99
108
  )
100
109
 
101
- def is_terminal_state(self):
110
+ def is_terminal_state(self) -> bool:
102
111
  return self in (
103
112
  AppStatus.PREPARE_FAILED,
104
113
  AppStatus.DEPLOY_FAILED,
@@ -1,7 +1,7 @@
1
1
  import queue
2
2
  from typing import Any, Optional, List, Tuple
3
3
 
4
- from .errors import ProgrammingError, DatabaseError
4
+ from .errors import DatabaseError, ProgrammingError
5
5
 
6
6
  _TYPE_MAP = {
7
7
  "object": "STRING",
@@ -16,7 +16,7 @@ _TYPE_MAP = {
16
16
 
17
17
  class Cursor:
18
18
 
19
- def __init__(self, exec_fn, cancel_fn):
19
+ def __init__(self, exec_fn, cancel_fn) -> None:
20
20
  self.__exec_fn = exec_fn
21
21
  self.__cancel_fn = cancel_fn
22
22
 
@@ -72,7 +72,7 @@ class Cursor:
72
72
 
73
73
  return self.__results
74
74
 
75
- def execute(self, operation: str, parameters: dict[str, Any] = None):
75
+ def execute(self, operation: str, parameters: dict[str, Any] = None) -> None:
76
76
  if self.__current_execution_id:
77
77
  self.__cancel_fn(self.__current_execution_id)
78
78
 
@@ -84,37 +84,40 @@ class Cursor:
84
84
  sql = operation.format(**(parameters or {}))
85
85
  self.__current_execution_id = self.__exec_fn(sql, self.__on_execution_result)
86
86
 
87
- def executemany(self, operation: str, seq_of_parameters: list[dict[str, Any]]):
87
+ def executemany(
88
+ self, operation: str, seq_of_parameters: list[dict[str, Any]]
89
+ ) -> None:
88
90
  raise NotImplementedError
89
91
 
90
- def fetchone(self):
92
+ def fetchone(self) -> Any:
91
93
  results = self.__get_results()[self.__current_row :]
92
- if not results:
94
+ if len(results) == 0:
93
95
  return None
94
96
  self.__current_row += 1
95
97
  return results[0]
96
98
 
97
- def fetchmany(self, size: int = None):
99
+ def fetchmany(self, size: int = None) -> list[Any]:
98
100
  size = size or self.arraysize
99
101
  results = self.__get_results()[self.__current_row : self.__current_row + size]
100
102
  self.__current_row += size
101
103
  return results
102
104
 
103
- def fetchall(self):
105
+ def fetchall(self) -> list[Any]:
104
106
  return self.__get_results()[self.__current_row :]
105
107
 
106
- def close(self):
108
+ def close(self) -> None:
107
109
  """Close the cursor."""
108
- pass
110
+ if self.__results is None and self.__current_execution_id:
111
+ self.__cancel_fn(self.__current_execution_id)
109
112
 
110
113
  def __iter__(self):
111
114
  return self
112
115
 
113
- def __next__(self):
116
+ def __next__(self) -> None:
114
117
  raise StopIteration
115
118
 
116
119
  def __enter__(self):
117
120
  return self
118
121
 
119
- def __exit__(self, exc_type, exc_val, exc_tb):
122
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
120
123
  self.close()
@@ -3,19 +3,19 @@
3
3
  A PEP-0249 compatible driver for interfacing with Wherobots DB.
4
4
  """
5
5
 
6
+ from importlib import metadata
7
+ from importlib.metadata import PackageNotFoundError
6
8
  import logging
9
+ from packaging.version import Version
7
10
  import platform
8
- import urllib.parse
9
11
  import queue
10
- from importlib import metadata
11
- from importlib.metadata import PackageNotFoundError
12
-
13
12
  import requests
14
13
  import tenacity
15
- import threading
16
14
  from typing import Union
15
+ import urllib.parse
17
16
  import websockets.sync.client
18
17
 
18
+ from .connection import Connection
19
19
  from .constants import (
20
20
  DEFAULT_ENDPOINT,
21
21
  DEFAULT_REGION,
@@ -36,7 +36,6 @@ from .errors import (
36
36
  )
37
37
  from .region import Region
38
38
  from .runtime import Runtime
39
- from .connection import Connection
40
39
 
41
40
  apilevel = "2.0"
42
41
  threadsafety = 1
@@ -163,6 +162,7 @@ def http_to_ws(uri: str) -> str:
163
162
 
164
163
  def connect_direct(
165
164
  uri: str,
165
+ protocol: Version = PROTOCOL_VERSION,
166
166
  headers: dict[str, str] = None,
167
167
  read_timeout: float = DEFAULT_READ_TIMEOUT_SECONDS,
168
168
  results_format: Union[ResultsFormat, None] = None,
@@ -170,34 +170,20 @@ def connect_direct(
170
170
  geometry_representation: Union[GeometryRepresentation, None] = None,
171
171
  ) -> Connection:
172
172
  q = queue.SimpleQueue()
173
- uri_with_protocol = f"{uri}/{PROTOCOL_VERSION}"
174
-
175
- def create_ws_connection():
176
- try:
177
- logging.info("Connecting to SQL session at %s ...", uri_with_protocol)
178
- ws = websockets.sync.client.connect(
179
- uri=uri_with_protocol,
180
- additional_headers=headers,
181
- max_size=MAX_MESSAGE_SIZE,
182
- )
183
- q.put(ws)
184
- except Exception as e:
185
- q.put(e)
186
-
187
- dt = threading.Thread(
188
- name="wherobots-ws-connector",
189
- target=create_ws_connection,
190
- daemon=True,
191
- )
192
- dt.start()
193
- dt.join()
173
+ uri_with_protocol = f"{uri}/{protocol}"
194
174
 
195
- result = q.get()
196
- if isinstance(result, Exception):
197
- raise InterfaceError("Failed to connect to SQL session!") from result
175
+ try:
176
+ logging.info("Connecting to SQL session at %s ...", uri_with_protocol)
177
+ ws = websockets.sync.client.connect(
178
+ uri=uri_with_protocol,
179
+ additional_headers=headers,
180
+ max_size=MAX_MESSAGE_SIZE,
181
+ )
182
+ except Exception as e:
183
+ raise InterfaceError("Failed to connect to SQL session!") from e
198
184
 
199
185
  return Connection(
200
- result,
186
+ ws,
201
187
  read_timeout=read_timeout,
202
188
  results_format=results_format,
203
189
  data_compression=data_compression,