machbaseapi 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- machbaseAPI/__init__.py +32 -0
- machbaseAPI/connector.py +251 -0
- machbaseAPI/conntest.py +27 -0
- machbaseAPI/constants.py +189 -0
- machbaseAPI/errors.py +38 -0
- machbaseAPI/machbaseAPI.py +318 -0
- machbaseAPI/marshal.py +225 -0
- machbaseAPI/packet.py +87 -0
- machbaseAPI/protocol.py +661 -0
- machbaseAPI/sample/ConnTest.py +55 -0
- machbaseAPI/sample/MakeData.py +14 -0
- machbaseAPI/sample/Sample1Connect.py +27 -0
- machbaseAPI/sample/Sample2Simple.py +77 -0
- machbaseAPI/sample/Sample3Append.py +70 -0
- machbaseAPI/sample/Sample4Fetch.py +137 -0
- machbaseAPI/sample/Sample5Append2.py +81 -0
- machbaseAPI/sample/Sample5ConnectEx.py +57 -0
- machbaseAPI/sample/__init__.py +1 -0
- machbaseAPI/sample/data.txt +99 -0
- machbaseAPI/types.py +52 -0
- machbaseapi-2.0.0.dist-info/METADATA +121 -0
- machbaseapi-2.0.0.dist-info/RECORD +24 -0
- machbaseapi-2.0.0.dist-info/WHEEL +5 -0
- machbaseapi-2.0.0.dist-info/top_level.txt +1 -0
machbaseAPI/protocol.py
ADDED
|
@@ -0,0 +1,661 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import datetime
|
|
4
|
+
import ipaddress
|
|
5
|
+
import socket
|
|
6
|
+
import struct
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple
|
|
9
|
+
|
|
10
|
+
from .constants import *
|
|
11
|
+
from .errors import DatabaseError, OperationalError, ProgrammingError
|
|
12
|
+
from .marshal import MarshalReader, MarshalWriter
|
|
13
|
+
from .packet import PacketReader
|
|
14
|
+
from .types import spiner_type_from_legacy_sql, spiner_type_from_value
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class ColumnMeta:
|
|
19
|
+
name: str
|
|
20
|
+
cm_type: int
|
|
21
|
+
precision: int
|
|
22
|
+
scale: int
|
|
23
|
+
spiner_type: int
|
|
24
|
+
length: int
|
|
25
|
+
is_variable: bool
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class QueryMeta:
|
|
30
|
+
statement_id: int
|
|
31
|
+
columns: List[ColumnMeta]
|
|
32
|
+
rows: List[Dict[str, Any]]
|
|
33
|
+
rows_affected: int
|
|
34
|
+
message: Optional[str]
|
|
35
|
+
more: bool
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@dataclass
|
|
39
|
+
class _AppendSession:
|
|
40
|
+
statement_id: int
|
|
41
|
+
table: str
|
|
42
|
+
columns: List[ColumnMeta]
|
|
43
|
+
server_le: bool
|
|
44
|
+
user_types: Optional[List[int]] = None
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def extract_spiner_type(cm_type: int) -> int:
|
|
48
|
+
return (cm_type >> 56) & 0xFF
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def extract_precision(cm_type: int) -> int:
|
|
52
|
+
return (cm_type >> 28) & 0x0FFFFFFF
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def extract_scale(cm_type: int) -> int:
|
|
56
|
+
return cm_type & 0x0FFFFFFF
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _compute_column_length(spiner_type: int, precision: int) -> int:
|
|
60
|
+
if spiner_type in {CMD_INT16_TYPE, CMD_UINT16_TYPE}:
|
|
61
|
+
return 2
|
|
62
|
+
if spiner_type in {CMD_INT32_TYPE, CMD_UINT32_TYPE, CMD_FLT32_TYPE}:
|
|
63
|
+
return 4
|
|
64
|
+
if spiner_type in {CMD_INT64_TYPE, CMD_UINT64_TYPE, CMD_DATE_TYPE, CMD_FLT64_TYPE}:
|
|
65
|
+
return 8
|
|
66
|
+
if spiner_type == CMD_IPV4_TYPE:
|
|
67
|
+
return 5
|
|
68
|
+
if spiner_type == CMD_IPV6_TYPE:
|
|
69
|
+
return 17
|
|
70
|
+
if spiner_type == CMD_BOOL_TYPE:
|
|
71
|
+
return 1
|
|
72
|
+
if spiner_type == CMD_NUL_TYPE:
|
|
73
|
+
return 0
|
|
74
|
+
return precision
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def _decode_ip_field(raw: bytes) -> str:
|
|
78
|
+
if len(raw) == 5 and raw[0] == 4:
|
|
79
|
+
return ".".join(str(b) for b in raw[1:])
|
|
80
|
+
if len(raw) == 17 and raw[0] == 6:
|
|
81
|
+
return str(ipaddress.IPv6Address(raw[1:]))
|
|
82
|
+
return ""
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def _format_datetime(value: int) -> Optional[str]:
|
|
86
|
+
if value in (0xFFFFFFFFFFFFFFFF, -1):
|
|
87
|
+
return None
|
|
88
|
+
|
|
89
|
+
dt: Optional[datetime.datetime] = None
|
|
90
|
+
for divisor in (1_000_000_000, 1_000_000, 1_000):
|
|
91
|
+
try:
|
|
92
|
+
candidate = datetime.datetime.fromtimestamp(value / divisor)
|
|
93
|
+
except (OverflowError, OSError, ValueError):
|
|
94
|
+
continue
|
|
95
|
+
if 1980 <= candidate.year <= 2500:
|
|
96
|
+
dt = candidate
|
|
97
|
+
break
|
|
98
|
+
if dt is None:
|
|
99
|
+
dt = candidate
|
|
100
|
+
|
|
101
|
+
if dt is None:
|
|
102
|
+
dt = datetime.datetime.fromtimestamp(0)
|
|
103
|
+
|
|
104
|
+
if dt.microsecond:
|
|
105
|
+
return dt.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
|
|
106
|
+
return dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def _decode_column_value(spiner_type: int, data: bytes) -> Any:
|
|
110
|
+
if spiner_type == CMD_INT16_TYPE and data == b"\x80\x00":
|
|
111
|
+
return None
|
|
112
|
+
if spiner_type == CMD_UINT16_TYPE and data == b"\xff\xff":
|
|
113
|
+
return None
|
|
114
|
+
if spiner_type == CMD_INT32_TYPE and data == b"\x80\x00\x00\x00":
|
|
115
|
+
return None
|
|
116
|
+
if spiner_type == CMD_UINT32_TYPE and data == b"\xff\xff\xff\xff":
|
|
117
|
+
return None
|
|
118
|
+
if spiner_type == CMD_INT64_TYPE and data == b"\x80\x00\x00\x00\x00\x00\x00\x00":
|
|
119
|
+
return None
|
|
120
|
+
if spiner_type == CMD_UINT64_TYPE and data == b"\xff\xff\xff\xff\xff\xff\xff\xff":
|
|
121
|
+
return None
|
|
122
|
+
if spiner_type == CMD_BOOL_TYPE:
|
|
123
|
+
return data[0] != 0 if data else None
|
|
124
|
+
if spiner_type == CMD_INT16_TYPE:
|
|
125
|
+
return int.from_bytes(data[:2], "big", signed=True)
|
|
126
|
+
if spiner_type == CMD_UINT16_TYPE:
|
|
127
|
+
return int.from_bytes(data[:2], "big", signed=False)
|
|
128
|
+
if spiner_type == CMD_INT32_TYPE:
|
|
129
|
+
return int.from_bytes(data[:4], "big", signed=True)
|
|
130
|
+
if spiner_type == CMD_UINT32_TYPE:
|
|
131
|
+
return int.from_bytes(data[:4], "big", signed=False)
|
|
132
|
+
if spiner_type == CMD_INT64_TYPE:
|
|
133
|
+
return int.from_bytes(data[:8], "big", signed=True)
|
|
134
|
+
if spiner_type == CMD_UINT64_TYPE:
|
|
135
|
+
return int.from_bytes(data[:8], "big", signed=False)
|
|
136
|
+
if spiner_type == CMD_FLT32_TYPE:
|
|
137
|
+
return struct.unpack(">f", data[:4])[0]
|
|
138
|
+
if spiner_type == CMD_FLT64_TYPE:
|
|
139
|
+
return struct.unpack(">d", data[:8])[0]
|
|
140
|
+
if spiner_type == CMD_DATE_TYPE:
|
|
141
|
+
return _format_datetime(int.from_bytes(data[:8], "big", signed=True))
|
|
142
|
+
if spiner_type == CMD_IPV4_TYPE:
|
|
143
|
+
return _decode_ip_field(data[:5])
|
|
144
|
+
if spiner_type == CMD_IPV6_TYPE:
|
|
145
|
+
return _decode_ip_field(data[:17])
|
|
146
|
+
if spiner_type in (CMD_VARCHAR_TYPE, CMD_CHAR_TYPE, CMD_TEXT_TYPE, CMD_CLOB_TYPE, CMD_JSON_TYPE):
|
|
147
|
+
try:
|
|
148
|
+
return data.decode("utf-8")
|
|
149
|
+
except Exception:
|
|
150
|
+
return bytes(data)
|
|
151
|
+
if spiner_type in (CMD_BINARY_TYPE, CMD_BLOB_TYPE):
|
|
152
|
+
return bytes(data)
|
|
153
|
+
if spiner_type == CMD_NUL_TYPE:
|
|
154
|
+
return None
|
|
155
|
+
if not data:
|
|
156
|
+
return None
|
|
157
|
+
try:
|
|
158
|
+
return data.decode("utf-8")
|
|
159
|
+
except Exception:
|
|
160
|
+
return bytes(data)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def _consume_row_values(unit_data: bytes, columns: Sequence[ColumnMeta]) -> Dict[str, Any]:
|
|
164
|
+
row: Dict[str, Any] = {}
|
|
165
|
+
offset = 0
|
|
166
|
+
for column in columns:
|
|
167
|
+
if column.is_variable:
|
|
168
|
+
if offset + 4 > len(unit_data):
|
|
169
|
+
raise DatabaseError("Malformed variable-length field")
|
|
170
|
+
size = int.from_bytes(unit_data[offset:offset + 4], "big", signed=False)
|
|
171
|
+
offset += 4
|
|
172
|
+
data = unit_data[offset:offset + size]
|
|
173
|
+
offset += size
|
|
174
|
+
row[column.name] = _decode_column_value(column.spiner_type, data)
|
|
175
|
+
continue
|
|
176
|
+
data = unit_data[offset:offset + column.length]
|
|
177
|
+
offset += column.length
|
|
178
|
+
row[column.name] = _decode_column_value(column.spiner_type, data)
|
|
179
|
+
return row
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def collect_units(payload: bytes) -> Dict[int, List[Tuple[int, bytes]]]:
|
|
183
|
+
buckets: Dict[int, List[Tuple[int, bytes]]] = {}
|
|
184
|
+
for unit in MarshalReader(payload):
|
|
185
|
+
buckets.setdefault(unit.id, []).append((unit.type, unit.data))
|
|
186
|
+
return buckets
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def _first_unit(units: Mapping[int, List[Tuple[int, bytes]]], unit_id: int) -> Optional[Tuple[int, bytes]]:
|
|
190
|
+
values = units.get(unit_id)
|
|
191
|
+
return values[0] if values else None
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def _read_status(data: bytes) -> int:
|
|
195
|
+
return int.from_bytes(data[:8], "little", signed=False)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def _status_error_code(status: int) -> int:
|
|
199
|
+
return status & 0xFFFFFFFF
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def _format_error_message(units: Mapping[int, List[Tuple[int, bytes]]]) -> str:
|
|
203
|
+
parts: List[str] = []
|
|
204
|
+
for unit_id in (CMI_R_MESSAGE_ID, CMI_R_EMESSAGE_ID):
|
|
205
|
+
item = _first_unit(units, unit_id)
|
|
206
|
+
if item:
|
|
207
|
+
parts.append(item[1].decode("utf-8", errors="replace"))
|
|
208
|
+
message = "; ".join(part for part in parts if part)
|
|
209
|
+
status_unit = _first_unit(units, CMI_R_RESULT_ID)
|
|
210
|
+
if status_unit:
|
|
211
|
+
error_no = _status_error_code(_read_status(status_unit[1]))
|
|
212
|
+
if error_no:
|
|
213
|
+
if message:
|
|
214
|
+
return f"{error_no} - {message}"
|
|
215
|
+
return str(error_no)
|
|
216
|
+
return message
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def _make_protocol_version(major: int, minor: int, patch: int) -> int:
|
|
220
|
+
# JVM 기반 클라이언트에서 사용하는 makeVersion(major, minor, patch) 포맷을 그대로 사용한다.
|
|
221
|
+
return ((major & 0xFFFF) << 48) | ((minor & 0xFFFF) << 32) | (patch & 0xFFFFFFFF)
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def _status_ok(status: int) -> bool:
|
|
225
|
+
masked = status & 0xFFFFFFFF00000000
|
|
226
|
+
return masked in {CMI_OK_RESULT, CMI_LAST_RESULT}
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def _status_last(status: int) -> bool:
|
|
230
|
+
return (status & 0xFFFFFFFF00000000) == CMI_LAST_RESULT
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def _build_column_meta(units: Mapping[int, List[Tuple[int, bytes]]]) -> List[ColumnMeta]:
|
|
234
|
+
names = [item[1].decode("utf-8", errors="replace") for item in units.get(CMI_P_COLNAME_ID, [])]
|
|
235
|
+
types = units.get(CMI_P_COLTYPE_ID, [])
|
|
236
|
+
count = min(len(names), len(types))
|
|
237
|
+
columns: List[ColumnMeta] = []
|
|
238
|
+
for index in range(count):
|
|
239
|
+
cm_type = int.from_bytes(types[index][1][:8], "little", signed=False)
|
|
240
|
+
spiner_type = extract_spiner_type(cm_type)
|
|
241
|
+
precision = extract_precision(cm_type)
|
|
242
|
+
scale = extract_scale(cm_type)
|
|
243
|
+
columns.append(
|
|
244
|
+
ColumnMeta(
|
|
245
|
+
name=names[index],
|
|
246
|
+
cm_type=cm_type,
|
|
247
|
+
precision=precision,
|
|
248
|
+
scale=scale,
|
|
249
|
+
spiner_type=spiner_type,
|
|
250
|
+
length=_compute_column_length(spiner_type, precision),
|
|
251
|
+
is_variable=is_variable_type(spiner_type),
|
|
252
|
+
)
|
|
253
|
+
)
|
|
254
|
+
return columns
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def _rows_from_units(units: Mapping[int, List[Tuple[int, bytes]]], columns: Sequence[ColumnMeta]) -> List[Dict[str, Any]]:
|
|
258
|
+
rows: List[Dict[str, Any]] = []
|
|
259
|
+
for _, data in units.get(CMI_F_VALUE_ID, []):
|
|
260
|
+
rows.append(_consume_row_values(data, columns))
|
|
261
|
+
return rows
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def _int_to_microseconds(value: Any) -> int:
|
|
265
|
+
if value is None:
|
|
266
|
+
return -1
|
|
267
|
+
if isinstance(value, datetime.datetime):
|
|
268
|
+
return int(value.timestamp() * 1_000_000_000)
|
|
269
|
+
if isinstance(value, datetime.date):
|
|
270
|
+
return int(
|
|
271
|
+
datetime.datetime.combine(value, datetime.time.min).timestamp() * 1_000_000_000
|
|
272
|
+
)
|
|
273
|
+
if isinstance(value, (int, float)):
|
|
274
|
+
return int(value)
|
|
275
|
+
text = str(value).strip()
|
|
276
|
+
if not text:
|
|
277
|
+
return -1
|
|
278
|
+
for fmt in ("%Y-%m-%d %H:%M:%S.%f", "%Y-%m-%d %H:%M:%S", "%Y-%m-%d"):
|
|
279
|
+
try:
|
|
280
|
+
dt = datetime.datetime.strptime(text, fmt)
|
|
281
|
+
return int(dt.timestamp() * 1_000_000_000)
|
|
282
|
+
except ValueError:
|
|
283
|
+
continue
|
|
284
|
+
try:
|
|
285
|
+
return int(text)
|
|
286
|
+
except ValueError:
|
|
287
|
+
return -1
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def _encode_length_prefixed(data: bytes, server_le: bool) -> bytes:
|
|
291
|
+
return len(data).to_bytes(4, "little" if server_le else "big") + data
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
def _encode_fixed_int(value: int, size: int, signed: bool, server_le: bool) -> bytes:
|
|
295
|
+
return int(value).to_bytes(size, "little" if server_le else "big", signed=signed)
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
def _encode_value(spiner_type: int, value: Any, server_le: bool) -> bytes:
|
|
299
|
+
if value is None:
|
|
300
|
+
if is_variable_type(spiner_type):
|
|
301
|
+
return _encode_length_prefixed(b"", server_le)
|
|
302
|
+
if spiner_type == CMD_DATE_TYPE:
|
|
303
|
+
return _encode_fixed_int(-1, 8, True, server_le)
|
|
304
|
+
if spiner_type == CMD_IPV4_TYPE:
|
|
305
|
+
return bytes([4, 0, 0, 0, 0])
|
|
306
|
+
if spiner_type == CMD_IPV6_TYPE:
|
|
307
|
+
return bytes([6]) + bytes(16)
|
|
308
|
+
return b""
|
|
309
|
+
if spiner_type == CMD_BOOL_TYPE:
|
|
310
|
+
return b"\x01" if bool(value) else b"\x00"
|
|
311
|
+
if spiner_type == CMD_INT16_TYPE:
|
|
312
|
+
return _encode_fixed_int(int(value), 2, True, server_le)
|
|
313
|
+
if spiner_type == CMD_UINT16_TYPE:
|
|
314
|
+
return _encode_fixed_int(int(value), 2, False, server_le)
|
|
315
|
+
if spiner_type == CMD_INT32_TYPE:
|
|
316
|
+
return _encode_fixed_int(int(value), 4, True, server_le)
|
|
317
|
+
if spiner_type == CMD_UINT32_TYPE:
|
|
318
|
+
return _encode_fixed_int(int(value), 4, False, server_le)
|
|
319
|
+
if spiner_type == CMD_INT64_TYPE:
|
|
320
|
+
return _encode_fixed_int(int(value), 8, True, server_le)
|
|
321
|
+
if spiner_type == CMD_UINT64_TYPE:
|
|
322
|
+
return _encode_fixed_int(int(value), 8, False, server_le)
|
|
323
|
+
if spiner_type == CMD_FLT32_TYPE:
|
|
324
|
+
return struct.pack("<f" if server_le else ">f", float(value))
|
|
325
|
+
if spiner_type == CMD_FLT64_TYPE:
|
|
326
|
+
return struct.pack("<d" if server_le else ">d", float(value))
|
|
327
|
+
if spiner_type == CMD_DATE_TYPE:
|
|
328
|
+
return _encode_fixed_int(_int_to_microseconds(value), 8, True, server_le)
|
|
329
|
+
if spiner_type == CMD_IPV4_TYPE:
|
|
330
|
+
packed = ipaddress.ip_address(str(value)).packed
|
|
331
|
+
return bytes([4]) + packed
|
|
332
|
+
if spiner_type == CMD_IPV6_TYPE:
|
|
333
|
+
packed = ipaddress.ip_address(str(value)).packed
|
|
334
|
+
return bytes([6]) + packed
|
|
335
|
+
if isinstance(value, (bytes, bytearray, memoryview)):
|
|
336
|
+
return _encode_length_prefixed(bytes(value), server_le)
|
|
337
|
+
return _encode_length_prefixed(str(value).encode("utf-8"), server_le)
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
class MachbaseProtocolClient:
|
|
341
|
+
def __init__(
|
|
342
|
+
self,
|
|
343
|
+
*,
|
|
344
|
+
host: str = "127.0.0.1",
|
|
345
|
+
port: int = 5656,
|
|
346
|
+
user: str = "SYS",
|
|
347
|
+
password: str = "MANAGER",
|
|
348
|
+
database: str = "data",
|
|
349
|
+
timeout_ms: int = DEFAULT_CONNECT_TIMEOUT_MS,
|
|
350
|
+
query_timeout_ms: int = DEFAULT_QUERY_TIMEOUT_MS,
|
|
351
|
+
client_id: str = "PYMCB",
|
|
352
|
+
show_hidden_columns: bool = False,
|
|
353
|
+
timezone: str = "",
|
|
354
|
+
):
|
|
355
|
+
self.host = host
|
|
356
|
+
self.port = int(port)
|
|
357
|
+
self.user = user
|
|
358
|
+
self.password = password
|
|
359
|
+
self.database = database
|
|
360
|
+
self.connect_timeout_ms = int(timeout_ms)
|
|
361
|
+
self.query_timeout_ms = int(query_timeout_ms)
|
|
362
|
+
self.client_id = client_id
|
|
363
|
+
self.show_hidden_columns = bool(show_hidden_columns)
|
|
364
|
+
self.timezone = timezone
|
|
365
|
+
self.socket: Optional[socket.socket] = None
|
|
366
|
+
self.reader: Optional[PacketReader] = None
|
|
367
|
+
self.connected = False
|
|
368
|
+
self.next_stmt = 1
|
|
369
|
+
self.session_id = 0
|
|
370
|
+
self.server_le = True
|
|
371
|
+
|
|
372
|
+
def connect(self) -> None:
|
|
373
|
+
if self.connected:
|
|
374
|
+
return
|
|
375
|
+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
376
|
+
try:
|
|
377
|
+
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
|
|
378
|
+
sock.settimeout(self.connect_timeout_ms / 1000.0)
|
|
379
|
+
sock.connect((self.host, self.port))
|
|
380
|
+
sock.settimeout(None)
|
|
381
|
+
self._handshake(sock)
|
|
382
|
+
self.socket = sock
|
|
383
|
+
self.reader = PacketReader(sock)
|
|
384
|
+
self._send_connect()
|
|
385
|
+
self.connected = True
|
|
386
|
+
except Exception as exc:
|
|
387
|
+
try:
|
|
388
|
+
sock.close()
|
|
389
|
+
except Exception:
|
|
390
|
+
pass
|
|
391
|
+
self.socket = None
|
|
392
|
+
self.reader = None
|
|
393
|
+
raise OperationalError(str(exc)) from exc
|
|
394
|
+
|
|
395
|
+
def close(self) -> None:
|
|
396
|
+
if self.socket:
|
|
397
|
+
try:
|
|
398
|
+
self.socket.close()
|
|
399
|
+
finally:
|
|
400
|
+
self.socket = None
|
|
401
|
+
self.reader = None
|
|
402
|
+
self.connected = False
|
|
403
|
+
|
|
404
|
+
def _next_statement_id(self) -> int:
|
|
405
|
+
sid = self.next_stmt
|
|
406
|
+
self.next_stmt = (self.next_stmt + 1) & 0x7FFFFFFF
|
|
407
|
+
if self.next_stmt == 0:
|
|
408
|
+
self.next_stmt = 1
|
|
409
|
+
return sid
|
|
410
|
+
|
|
411
|
+
def _handshake(self, sock: socket.socket) -> None:
|
|
412
|
+
sock.sendall(f"{CMI_HANDSHAKE_PREFIX}{CMI_HANDSHAKE_ENDIAN_LITTLE}".encode("ascii"))
|
|
413
|
+
response = self._recv_exact(sock, CMI_PROTO_CNT)
|
|
414
|
+
if response.decode("ascii") != CMI_HANDSHAKE_READY:
|
|
415
|
+
raise OperationalError("Invalid handshake response")
|
|
416
|
+
|
|
417
|
+
def _send_connect(self) -> None:
|
|
418
|
+
version = _make_protocol_version(
|
|
419
|
+
CMI_PROTOCOL_MAJOR_VERSION,
|
|
420
|
+
CMI_PROTOCOL_MINOR_VERSION,
|
|
421
|
+
CMI_PROTOCOL_FIX_VERSION,
|
|
422
|
+
)
|
|
423
|
+
writer = MarshalWriter(CMI_CONNECT_PROTOCOL, 0)
|
|
424
|
+
writer.add_uint64(CMI_C_VERSION_ID, version)
|
|
425
|
+
writer.add_string(CMI_C_CLIENT_ID, self.client_id)
|
|
426
|
+
writer.add_string(CMI_C_DATABASE_ID, self.database)
|
|
427
|
+
writer.add_string(CMI_C_USER_ID, self.user)
|
|
428
|
+
writer.add_string(CMI_C_PASSWORD_ID, self.password)
|
|
429
|
+
writer.add_uint64(CMI_C_TIMEOUT_ID, max(1, self.query_timeout_ms // 1000))
|
|
430
|
+
writer.add_uint32(CMI_C_SHC_ID, 1 if self.show_hidden_columns else 0)
|
|
431
|
+
writer.add_string(CMI_C_IP_ID, self.host)
|
|
432
|
+
if self.timezone:
|
|
433
|
+
writer.add_string(CMI_C_TIMEZONE_ID, self.timezone)
|
|
434
|
+
self._send_packets(writer.finalize())
|
|
435
|
+
units = collect_units(self._read_protocol(CMI_CONNECT_PROTOCOL, self.query_timeout_ms))
|
|
436
|
+
status_unit = _first_unit(units, CMI_R_RESULT_ID)
|
|
437
|
+
if not status_unit:
|
|
438
|
+
raise OperationalError("CONNECT response missing status")
|
|
439
|
+
status = _read_status(status_unit[1])
|
|
440
|
+
if not _status_ok(status):
|
|
441
|
+
raise OperationalError(_format_error_message(units) or "CONNECT failed")
|
|
442
|
+
sid = _first_unit(units, CMI_C_SID_ID)
|
|
443
|
+
if sid:
|
|
444
|
+
self.session_id = int.from_bytes(sid[1][:8], "little", signed=False)
|
|
445
|
+
endian = _first_unit(units, CMI_C_ENDIAN_ID)
|
|
446
|
+
if endian:
|
|
447
|
+
self.server_le = int.from_bytes(endian[1][:4], "little", signed=False) == 0
|
|
448
|
+
|
|
449
|
+
def _send_packets(self, packets: Sequence[bytes]) -> None:
|
|
450
|
+
if not self.socket:
|
|
451
|
+
raise OperationalError("Not connected")
|
|
452
|
+
for packet in packets:
|
|
453
|
+
self.socket.sendall(packet)
|
|
454
|
+
|
|
455
|
+
def _read_protocol(self, protocol_id: int, timeout_ms: int) -> bytes:
|
|
456
|
+
if not self.reader:
|
|
457
|
+
raise OperationalError("Not connected")
|
|
458
|
+
chunks: List[bytes] = []
|
|
459
|
+
while True:
|
|
460
|
+
packet = self.reader.next(timeout_ms)
|
|
461
|
+
if packet.protocol != protocol_id:
|
|
462
|
+
raise OperationalError(f"Protocol mismatch: expected {protocol_id}, got {packet.protocol}")
|
|
463
|
+
chunks.append(packet.body)
|
|
464
|
+
if packet.flag in (0, 3):
|
|
465
|
+
return b"".join(chunks)
|
|
466
|
+
|
|
467
|
+
def exec_direct(self, sql: str) -> QueryMeta:
|
|
468
|
+
stmt_id = self._next_statement_id()
|
|
469
|
+
writer = MarshalWriter(CMI_EXECDIRECT_PROTOCOL, stmt_id)
|
|
470
|
+
writer.add_string(CMI_D_STATEMENT_ID, sql)
|
|
471
|
+
writer.add_uint64(CMI_P_ID_ID, stmt_id)
|
|
472
|
+
self._send_packets(writer.finalize())
|
|
473
|
+
units = collect_units(self._read_protocol(CMI_EXECDIRECT_PROTOCOL, self.query_timeout_ms))
|
|
474
|
+
status_unit = _first_unit(units, CMI_R_RESULT_ID)
|
|
475
|
+
if not status_unit:
|
|
476
|
+
raise DatabaseError("EXEC response missing status")
|
|
477
|
+
status = _read_status(status_unit[1])
|
|
478
|
+
if not _status_ok(status):
|
|
479
|
+
raise DatabaseError(_format_error_message(units) or "Execute failed")
|
|
480
|
+
rows = _first_unit(units, CMI_P_ROWS_ID)
|
|
481
|
+
rows_affected = int.from_bytes(rows[1][:8], "little", signed=False) if rows else 0
|
|
482
|
+
message = _format_error_message(units) or None
|
|
483
|
+
return QueryMeta(stmt_id, [], [], rows_affected, message, _status_last(status))
|
|
484
|
+
|
|
485
|
+
def _fetch_rows_loop(self, statement_id: int, columns: Sequence[ColumnMeta]) -> List[Dict[str, Any]]:
|
|
486
|
+
rows: List[Dict[str, Any]] = []
|
|
487
|
+
while True:
|
|
488
|
+
writer = MarshalWriter(CMI_FETCH_PROTOCOL, statement_id)
|
|
489
|
+
writer.add_uint32(CMI_F_ID_ID, statement_id)
|
|
490
|
+
writer.add_sint64(CMI_F_ROWS_ID, 1000)
|
|
491
|
+
self._send_packets(writer.finalize())
|
|
492
|
+
units = collect_units(self._read_protocol(CMI_FETCH_PROTOCOL, self.query_timeout_ms))
|
|
493
|
+
status_unit = _first_unit(units, CMI_R_RESULT_ID)
|
|
494
|
+
if not status_unit:
|
|
495
|
+
raise DatabaseError("FETCH response missing status")
|
|
496
|
+
status_units = units.get(CMI_R_RESULT_ID, [])
|
|
497
|
+
has_last = False
|
|
498
|
+
for _, status_bytes in status_units:
|
|
499
|
+
status = _read_status(status_bytes)
|
|
500
|
+
if not _status_ok(status):
|
|
501
|
+
raise DatabaseError(_format_error_message(units) or "Fetch failed")
|
|
502
|
+
if _status_last(status):
|
|
503
|
+
has_last = True
|
|
504
|
+
batch = _rows_from_units(units, columns)
|
|
505
|
+
rows.extend(batch)
|
|
506
|
+
row_count = _first_unit(units, CMI_F_ROWS_ID)
|
|
507
|
+
if has_last:
|
|
508
|
+
break
|
|
509
|
+
row_cnt = int.from_bytes(row_count[1][:8], "little", signed=False) if row_count else None
|
|
510
|
+
no_more_rows = (row_cnt is None or row_cnt <= 0) and not batch
|
|
511
|
+
if no_more_rows:
|
|
512
|
+
break
|
|
513
|
+
return rows
|
|
514
|
+
|
|
515
|
+
def query(self, sql: str) -> QueryMeta:
|
|
516
|
+
stmt_id = self._next_statement_id()
|
|
517
|
+
writer = MarshalWriter(CMI_EXECDIRECT_PROTOCOL, stmt_id)
|
|
518
|
+
writer.add_string(CMI_D_STATEMENT_ID, sql)
|
|
519
|
+
writer.add_uint64(CMI_P_ID_ID, stmt_id)
|
|
520
|
+
self._send_packets(writer.finalize())
|
|
521
|
+
units = collect_units(self._read_protocol(CMI_EXECDIRECT_PROTOCOL, self.query_timeout_ms))
|
|
522
|
+
status_unit = _first_unit(units, CMI_R_RESULT_ID)
|
|
523
|
+
if not status_unit:
|
|
524
|
+
raise DatabaseError("QUERY response missing status")
|
|
525
|
+
status = _read_status(status_unit[1])
|
|
526
|
+
if not _status_ok(status):
|
|
527
|
+
raise DatabaseError(_format_error_message(units) or "Query failed")
|
|
528
|
+
columns = _build_column_meta(units)
|
|
529
|
+
rows = _rows_from_units(units, columns)
|
|
530
|
+
if columns and not _status_last(status):
|
|
531
|
+
rows.extend(self._fetch_rows_loop(stmt_id, columns))
|
|
532
|
+
return QueryMeta(stmt_id, columns, rows, len(rows), _format_error_message(units) or None, _status_last(status))
|
|
533
|
+
|
|
534
|
+
def append_open(self, table: str) -> _AppendSession:
|
|
535
|
+
stmt_id = self._next_statement_id()
|
|
536
|
+
writer = MarshalWriter(CMI_APPEND_OPEN_PROTOCOL, stmt_id)
|
|
537
|
+
writer.add_uint64(CMI_P_ID_ID, stmt_id)
|
|
538
|
+
writer.add_string(CMI_P_TABLE_ID, table)
|
|
539
|
+
writer.add_uint64(CMI_E_ENDIAN_ID, 0)
|
|
540
|
+
self._send_packets(writer.finalize())
|
|
541
|
+
units = collect_units(self._read_protocol(CMI_APPEND_OPEN_PROTOCOL, self.query_timeout_ms))
|
|
542
|
+
status_unit = _first_unit(units, CMI_R_RESULT_ID)
|
|
543
|
+
if not status_unit:
|
|
544
|
+
raise DatabaseError("APPEND open response missing status")
|
|
545
|
+
status = _read_status(status_unit[1])
|
|
546
|
+
if not _status_ok(status):
|
|
547
|
+
raise DatabaseError(_format_error_message(units) or "Append open failed")
|
|
548
|
+
return _AppendSession(stmt_id, table, _build_column_meta(units), self.server_le)
|
|
549
|
+
|
|
550
|
+
def append_data(
|
|
551
|
+
self,
|
|
552
|
+
session: _AppendSession,
|
|
553
|
+
rows: Sequence[Sequence[Any]],
|
|
554
|
+
types: Optional[Sequence[int]] = None,
|
|
555
|
+
times: Optional[Sequence[Any]] = None,
|
|
556
|
+
on_ack: Optional[Callable[[Dict[int, List[Tuple[int, bytes]]]], None]] = None,
|
|
557
|
+
) -> int:
|
|
558
|
+
writer = MarshalWriter(CMI_APPEND_DATA_PROTOCOL, session.statement_id, adds=session.statement_id & 0xFFFF)
|
|
559
|
+
metadata = session.columns
|
|
560
|
+
has_arrival = bool(metadata) and metadata[0].name.upper() in {"", "_ARRIVAL_TIME"}
|
|
561
|
+
null_bytes = (len(metadata) // 8) + 1 if metadata else 0
|
|
562
|
+
for row_index, row in enumerate(rows):
|
|
563
|
+
values = list(row)
|
|
564
|
+
if metadata and has_arrival:
|
|
565
|
+
if len(values) == len(metadata) - 1:
|
|
566
|
+
if times and row_index < len(times):
|
|
567
|
+
values = [times[row_index]] + values
|
|
568
|
+
else:
|
|
569
|
+
values = [0] + values
|
|
570
|
+
elif len(values) == len(metadata):
|
|
571
|
+
if times and row_index < len(times) and values[0] is None:
|
|
572
|
+
values[0] = times[row_index]
|
|
573
|
+
else:
|
|
574
|
+
raise ProgrammingError("Append row length does not match append metadata")
|
|
575
|
+
elif metadata and len(values) != len(metadata):
|
|
576
|
+
raise ProgrammingError("Append row length does not match append metadata")
|
|
577
|
+
null_bits = bytearray(null_bytes)
|
|
578
|
+
payload = bytearray()
|
|
579
|
+
payload.append(0)
|
|
580
|
+
payload.extend(int(null_bytes).to_bytes(4, "little" if session.server_le else "big"))
|
|
581
|
+
payload.extend(null_bits)
|
|
582
|
+
for col_index, value in enumerate(values):
|
|
583
|
+
if has_arrival and col_index == 0:
|
|
584
|
+
payload.extend(_encode_value(CMD_DATE_TYPE, value, session.server_le))
|
|
585
|
+
continue
|
|
586
|
+
if metadata:
|
|
587
|
+
spiner_type = metadata[col_index].spiner_type
|
|
588
|
+
elif types and col_index < len(types):
|
|
589
|
+
spiner_type = spiner_type_from_legacy_sql(int(types[col_index]))
|
|
590
|
+
else:
|
|
591
|
+
spiner_type = spiner_type_from_value(value)
|
|
592
|
+
if value is None and null_bytes:
|
|
593
|
+
null_index = col_index
|
|
594
|
+
if null_index < 0:
|
|
595
|
+
payload.extend(_encode_value(spiner_type, value, session.server_le))
|
|
596
|
+
continue
|
|
597
|
+
null_offset = 5 + (null_index // 8)
|
|
598
|
+
payload[null_offset] |= 1 << (7 - (null_index % 8))
|
|
599
|
+
continue
|
|
600
|
+
payload.extend(_encode_value(spiner_type, value, session.server_le))
|
|
601
|
+
writer.add_binary(CMI_P_ROWS_ID, bytes(payload))
|
|
602
|
+
self._send_packets(writer.finalize())
|
|
603
|
+
try:
|
|
604
|
+
units = collect_units(self._read_protocol(CMI_APPEND_DATA_PROTOCOL, 0))
|
|
605
|
+
except TimeoutError:
|
|
606
|
+
return len(rows)
|
|
607
|
+
if on_ack is not None:
|
|
608
|
+
on_ack(units)
|
|
609
|
+
status_unit = _first_unit(units, CMI_R_RESULT_ID)
|
|
610
|
+
if status_unit:
|
|
611
|
+
status = _read_status(status_unit[1])
|
|
612
|
+
if not _status_ok(status):
|
|
613
|
+
raise DatabaseError(_format_error_message(units) or "Append data failed")
|
|
614
|
+
failures = _first_unit(units, CMI_X_APPEND_FAILURE_ID)
|
|
615
|
+
if failures and int.from_bytes(failures[1][:8], "little", signed=False) > 0:
|
|
616
|
+
raise DatabaseError(_format_error_message(units) or "Append data reported failures")
|
|
617
|
+
return len(rows)
|
|
618
|
+
|
|
619
|
+
def append_close(self, session: _AppendSession) -> int:
|
|
620
|
+
writer = MarshalWriter(CMI_APPEND_CLOSE_PROTOCOL, session.statement_id)
|
|
621
|
+
writer.add_uint64(CMI_P_ID_ID, session.statement_id)
|
|
622
|
+
self._send_packets(writer.finalize())
|
|
623
|
+
units = collect_units(self._read_protocol(CMI_APPEND_CLOSE_PROTOCOL, self.query_timeout_ms))
|
|
624
|
+
status_unit = _first_unit(units, CMI_R_RESULT_ID)
|
|
625
|
+
if not status_unit:
|
|
626
|
+
raise DatabaseError("APPEND close response missing status")
|
|
627
|
+
status = _read_status(status_unit[1])
|
|
628
|
+
if not _status_ok(status):
|
|
629
|
+
raise DatabaseError(_format_error_message(units) or "Append close failed")
|
|
630
|
+
success = _first_unit(units, CMI_X_APPEND_SUCCESS_ID)
|
|
631
|
+
return int.from_bytes(success[1][:8], "little", signed=False) if success else 0
|
|
632
|
+
|
|
633
|
+
def append(
|
|
634
|
+
self,
|
|
635
|
+
table: str,
|
|
636
|
+
rows: Sequence[Sequence[Any]],
|
|
637
|
+
types: Optional[Sequence[int]] = None,
|
|
638
|
+
times: Optional[Sequence[Any]] = None,
|
|
639
|
+
) -> int:
|
|
640
|
+
session = self.append_open(table)
|
|
641
|
+
try:
|
|
642
|
+
self.append_data(session, rows, types=types, times=times)
|
|
643
|
+
return self.append_close(session)
|
|
644
|
+
except Exception:
|
|
645
|
+
try:
|
|
646
|
+
self.append_close(session)
|
|
647
|
+
except Exception:
|
|
648
|
+
pass
|
|
649
|
+
raise
|
|
650
|
+
|
|
651
|
+
@staticmethod
|
|
652
|
+
def _recv_exact(sock: socket.socket, length: int) -> bytes:
|
|
653
|
+
chunks: List[bytes] = []
|
|
654
|
+
remaining = length
|
|
655
|
+
while remaining > 0:
|
|
656
|
+
chunk = sock.recv(remaining)
|
|
657
|
+
if not chunk:
|
|
658
|
+
raise OperationalError("Socket closed while reading protocol")
|
|
659
|
+
chunks.append(chunk)
|
|
660
|
+
remaining -= len(chunk)
|
|
661
|
+
return b"".join(chunks)
|