machbaseapi 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,661 @@
1
+ from __future__ import annotations
2
+
3
+ import datetime
4
+ import ipaddress
5
+ import socket
6
+ import struct
7
+ from dataclasses import dataclass
8
+ from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple
9
+
10
+ from .constants import *
11
+ from .errors import DatabaseError, OperationalError, ProgrammingError
12
+ from .marshal import MarshalReader, MarshalWriter
13
+ from .packet import PacketReader
14
+ from .types import spiner_type_from_legacy_sql, spiner_type_from_value
15
+
16
+
17
+ @dataclass
18
+ class ColumnMeta:
19
+ name: str
20
+ cm_type: int
21
+ precision: int
22
+ scale: int
23
+ spiner_type: int
24
+ length: int
25
+ is_variable: bool
26
+
27
+
28
+ @dataclass
29
+ class QueryMeta:
30
+ statement_id: int
31
+ columns: List[ColumnMeta]
32
+ rows: List[Dict[str, Any]]
33
+ rows_affected: int
34
+ message: Optional[str]
35
+ more: bool
36
+
37
+
38
+ @dataclass
39
+ class _AppendSession:
40
+ statement_id: int
41
+ table: str
42
+ columns: List[ColumnMeta]
43
+ server_le: bool
44
+ user_types: Optional[List[int]] = None
45
+
46
+
47
+ def extract_spiner_type(cm_type: int) -> int:
48
+ return (cm_type >> 56) & 0xFF
49
+
50
+
51
+ def extract_precision(cm_type: int) -> int:
52
+ return (cm_type >> 28) & 0x0FFFFFFF
53
+
54
+
55
+ def extract_scale(cm_type: int) -> int:
56
+ return cm_type & 0x0FFFFFFF
57
+
58
+
59
+ def _compute_column_length(spiner_type: int, precision: int) -> int:
60
+ if spiner_type in {CMD_INT16_TYPE, CMD_UINT16_TYPE}:
61
+ return 2
62
+ if spiner_type in {CMD_INT32_TYPE, CMD_UINT32_TYPE, CMD_FLT32_TYPE}:
63
+ return 4
64
+ if spiner_type in {CMD_INT64_TYPE, CMD_UINT64_TYPE, CMD_DATE_TYPE, CMD_FLT64_TYPE}:
65
+ return 8
66
+ if spiner_type == CMD_IPV4_TYPE:
67
+ return 5
68
+ if spiner_type == CMD_IPV6_TYPE:
69
+ return 17
70
+ if spiner_type == CMD_BOOL_TYPE:
71
+ return 1
72
+ if spiner_type == CMD_NUL_TYPE:
73
+ return 0
74
+ return precision
75
+
76
+
77
+ def _decode_ip_field(raw: bytes) -> str:
78
+ if len(raw) == 5 and raw[0] == 4:
79
+ return ".".join(str(b) for b in raw[1:])
80
+ if len(raw) == 17 and raw[0] == 6:
81
+ return str(ipaddress.IPv6Address(raw[1:]))
82
+ return ""
83
+
84
+
85
+ def _format_datetime(value: int) -> Optional[str]:
86
+ if value in (0xFFFFFFFFFFFFFFFF, -1):
87
+ return None
88
+
89
+ dt: Optional[datetime.datetime] = None
90
+ for divisor in (1_000_000_000, 1_000_000, 1_000):
91
+ try:
92
+ candidate = datetime.datetime.fromtimestamp(value / divisor)
93
+ except (OverflowError, OSError, ValueError):
94
+ continue
95
+ if 1980 <= candidate.year <= 2500:
96
+ dt = candidate
97
+ break
98
+ if dt is None:
99
+ dt = candidate
100
+
101
+ if dt is None:
102
+ dt = datetime.datetime.fromtimestamp(0)
103
+
104
+ if dt.microsecond:
105
+ return dt.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
106
+ return dt.strftime("%Y-%m-%d %H:%M:%S")
107
+
108
+
109
+ def _decode_column_value(spiner_type: int, data: bytes) -> Any:
110
+ if spiner_type == CMD_INT16_TYPE and data == b"\x80\x00":
111
+ return None
112
+ if spiner_type == CMD_UINT16_TYPE and data == b"\xff\xff":
113
+ return None
114
+ if spiner_type == CMD_INT32_TYPE and data == b"\x80\x00\x00\x00":
115
+ return None
116
+ if spiner_type == CMD_UINT32_TYPE and data == b"\xff\xff\xff\xff":
117
+ return None
118
+ if spiner_type == CMD_INT64_TYPE and data == b"\x80\x00\x00\x00\x00\x00\x00\x00":
119
+ return None
120
+ if spiner_type == CMD_UINT64_TYPE and data == b"\xff\xff\xff\xff\xff\xff\xff\xff":
121
+ return None
122
+ if spiner_type == CMD_BOOL_TYPE:
123
+ return data[0] != 0 if data else None
124
+ if spiner_type == CMD_INT16_TYPE:
125
+ return int.from_bytes(data[:2], "big", signed=True)
126
+ if spiner_type == CMD_UINT16_TYPE:
127
+ return int.from_bytes(data[:2], "big", signed=False)
128
+ if spiner_type == CMD_INT32_TYPE:
129
+ return int.from_bytes(data[:4], "big", signed=True)
130
+ if spiner_type == CMD_UINT32_TYPE:
131
+ return int.from_bytes(data[:4], "big", signed=False)
132
+ if spiner_type == CMD_INT64_TYPE:
133
+ return int.from_bytes(data[:8], "big", signed=True)
134
+ if spiner_type == CMD_UINT64_TYPE:
135
+ return int.from_bytes(data[:8], "big", signed=False)
136
+ if spiner_type == CMD_FLT32_TYPE:
137
+ return struct.unpack(">f", data[:4])[0]
138
+ if spiner_type == CMD_FLT64_TYPE:
139
+ return struct.unpack(">d", data[:8])[0]
140
+ if spiner_type == CMD_DATE_TYPE:
141
+ return _format_datetime(int.from_bytes(data[:8], "big", signed=True))
142
+ if spiner_type == CMD_IPV4_TYPE:
143
+ return _decode_ip_field(data[:5])
144
+ if spiner_type == CMD_IPV6_TYPE:
145
+ return _decode_ip_field(data[:17])
146
+ if spiner_type in (CMD_VARCHAR_TYPE, CMD_CHAR_TYPE, CMD_TEXT_TYPE, CMD_CLOB_TYPE, CMD_JSON_TYPE):
147
+ try:
148
+ return data.decode("utf-8")
149
+ except Exception:
150
+ return bytes(data)
151
+ if spiner_type in (CMD_BINARY_TYPE, CMD_BLOB_TYPE):
152
+ return bytes(data)
153
+ if spiner_type == CMD_NUL_TYPE:
154
+ return None
155
+ if not data:
156
+ return None
157
+ try:
158
+ return data.decode("utf-8")
159
+ except Exception:
160
+ return bytes(data)
161
+
162
+
163
+ def _consume_row_values(unit_data: bytes, columns: Sequence[ColumnMeta]) -> Dict[str, Any]:
164
+ row: Dict[str, Any] = {}
165
+ offset = 0
166
+ for column in columns:
167
+ if column.is_variable:
168
+ if offset + 4 > len(unit_data):
169
+ raise DatabaseError("Malformed variable-length field")
170
+ size = int.from_bytes(unit_data[offset:offset + 4], "big", signed=False)
171
+ offset += 4
172
+ data = unit_data[offset:offset + size]
173
+ offset += size
174
+ row[column.name] = _decode_column_value(column.spiner_type, data)
175
+ continue
176
+ data = unit_data[offset:offset + column.length]
177
+ offset += column.length
178
+ row[column.name] = _decode_column_value(column.spiner_type, data)
179
+ return row
180
+
181
+
182
+ def collect_units(payload: bytes) -> Dict[int, List[Tuple[int, bytes]]]:
183
+ buckets: Dict[int, List[Tuple[int, bytes]]] = {}
184
+ for unit in MarshalReader(payload):
185
+ buckets.setdefault(unit.id, []).append((unit.type, unit.data))
186
+ return buckets
187
+
188
+
189
+ def _first_unit(units: Mapping[int, List[Tuple[int, bytes]]], unit_id: int) -> Optional[Tuple[int, bytes]]:
190
+ values = units.get(unit_id)
191
+ return values[0] if values else None
192
+
193
+
194
+ def _read_status(data: bytes) -> int:
195
+ return int.from_bytes(data[:8], "little", signed=False)
196
+
197
+
198
+ def _status_error_code(status: int) -> int:
199
+ return status & 0xFFFFFFFF
200
+
201
+
202
+ def _format_error_message(units: Mapping[int, List[Tuple[int, bytes]]]) -> str:
203
+ parts: List[str] = []
204
+ for unit_id in (CMI_R_MESSAGE_ID, CMI_R_EMESSAGE_ID):
205
+ item = _first_unit(units, unit_id)
206
+ if item:
207
+ parts.append(item[1].decode("utf-8", errors="replace"))
208
+ message = "; ".join(part for part in parts if part)
209
+ status_unit = _first_unit(units, CMI_R_RESULT_ID)
210
+ if status_unit:
211
+ error_no = _status_error_code(_read_status(status_unit[1]))
212
+ if error_no:
213
+ if message:
214
+ return f"{error_no} - {message}"
215
+ return str(error_no)
216
+ return message
217
+
218
+
219
+ def _make_protocol_version(major: int, minor: int, patch: int) -> int:
220
+ # JVM 기반 클라이언트에서 사용하는 makeVersion(major, minor, patch) 포맷을 그대로 사용한다.
221
+ return ((major & 0xFFFF) << 48) | ((minor & 0xFFFF) << 32) | (patch & 0xFFFFFFFF)
222
+
223
+
224
+ def _status_ok(status: int) -> bool:
225
+ masked = status & 0xFFFFFFFF00000000
226
+ return masked in {CMI_OK_RESULT, CMI_LAST_RESULT}
227
+
228
+
229
+ def _status_last(status: int) -> bool:
230
+ return (status & 0xFFFFFFFF00000000) == CMI_LAST_RESULT
231
+
232
+
233
+ def _build_column_meta(units: Mapping[int, List[Tuple[int, bytes]]]) -> List[ColumnMeta]:
234
+ names = [item[1].decode("utf-8", errors="replace") for item in units.get(CMI_P_COLNAME_ID, [])]
235
+ types = units.get(CMI_P_COLTYPE_ID, [])
236
+ count = min(len(names), len(types))
237
+ columns: List[ColumnMeta] = []
238
+ for index in range(count):
239
+ cm_type = int.from_bytes(types[index][1][:8], "little", signed=False)
240
+ spiner_type = extract_spiner_type(cm_type)
241
+ precision = extract_precision(cm_type)
242
+ scale = extract_scale(cm_type)
243
+ columns.append(
244
+ ColumnMeta(
245
+ name=names[index],
246
+ cm_type=cm_type,
247
+ precision=precision,
248
+ scale=scale,
249
+ spiner_type=spiner_type,
250
+ length=_compute_column_length(spiner_type, precision),
251
+ is_variable=is_variable_type(spiner_type),
252
+ )
253
+ )
254
+ return columns
255
+
256
+
257
+ def _rows_from_units(units: Mapping[int, List[Tuple[int, bytes]]], columns: Sequence[ColumnMeta]) -> List[Dict[str, Any]]:
258
+ rows: List[Dict[str, Any]] = []
259
+ for _, data in units.get(CMI_F_VALUE_ID, []):
260
+ rows.append(_consume_row_values(data, columns))
261
+ return rows
262
+
263
+
264
+ def _int_to_microseconds(value: Any) -> int:
265
+ if value is None:
266
+ return -1
267
+ if isinstance(value, datetime.datetime):
268
+ return int(value.timestamp() * 1_000_000_000)
269
+ if isinstance(value, datetime.date):
270
+ return int(
271
+ datetime.datetime.combine(value, datetime.time.min).timestamp() * 1_000_000_000
272
+ )
273
+ if isinstance(value, (int, float)):
274
+ return int(value)
275
+ text = str(value).strip()
276
+ if not text:
277
+ return -1
278
+ for fmt in ("%Y-%m-%d %H:%M:%S.%f", "%Y-%m-%d %H:%M:%S", "%Y-%m-%d"):
279
+ try:
280
+ dt = datetime.datetime.strptime(text, fmt)
281
+ return int(dt.timestamp() * 1_000_000_000)
282
+ except ValueError:
283
+ continue
284
+ try:
285
+ return int(text)
286
+ except ValueError:
287
+ return -1
288
+
289
+
290
+ def _encode_length_prefixed(data: bytes, server_le: bool) -> bytes:
291
+ return len(data).to_bytes(4, "little" if server_le else "big") + data
292
+
293
+
294
+ def _encode_fixed_int(value: int, size: int, signed: bool, server_le: bool) -> bytes:
295
+ return int(value).to_bytes(size, "little" if server_le else "big", signed=signed)
296
+
297
+
298
+ def _encode_value(spiner_type: int, value: Any, server_le: bool) -> bytes:
299
+ if value is None:
300
+ if is_variable_type(spiner_type):
301
+ return _encode_length_prefixed(b"", server_le)
302
+ if spiner_type == CMD_DATE_TYPE:
303
+ return _encode_fixed_int(-1, 8, True, server_le)
304
+ if spiner_type == CMD_IPV4_TYPE:
305
+ return bytes([4, 0, 0, 0, 0])
306
+ if spiner_type == CMD_IPV6_TYPE:
307
+ return bytes([6]) + bytes(16)
308
+ return b""
309
+ if spiner_type == CMD_BOOL_TYPE:
310
+ return b"\x01" if bool(value) else b"\x00"
311
+ if spiner_type == CMD_INT16_TYPE:
312
+ return _encode_fixed_int(int(value), 2, True, server_le)
313
+ if spiner_type == CMD_UINT16_TYPE:
314
+ return _encode_fixed_int(int(value), 2, False, server_le)
315
+ if spiner_type == CMD_INT32_TYPE:
316
+ return _encode_fixed_int(int(value), 4, True, server_le)
317
+ if spiner_type == CMD_UINT32_TYPE:
318
+ return _encode_fixed_int(int(value), 4, False, server_le)
319
+ if spiner_type == CMD_INT64_TYPE:
320
+ return _encode_fixed_int(int(value), 8, True, server_le)
321
+ if spiner_type == CMD_UINT64_TYPE:
322
+ return _encode_fixed_int(int(value), 8, False, server_le)
323
+ if spiner_type == CMD_FLT32_TYPE:
324
+ return struct.pack("<f" if server_le else ">f", float(value))
325
+ if spiner_type == CMD_FLT64_TYPE:
326
+ return struct.pack("<d" if server_le else ">d", float(value))
327
+ if spiner_type == CMD_DATE_TYPE:
328
+ return _encode_fixed_int(_int_to_microseconds(value), 8, True, server_le)
329
+ if spiner_type == CMD_IPV4_TYPE:
330
+ packed = ipaddress.ip_address(str(value)).packed
331
+ return bytes([4]) + packed
332
+ if spiner_type == CMD_IPV6_TYPE:
333
+ packed = ipaddress.ip_address(str(value)).packed
334
+ return bytes([6]) + packed
335
+ if isinstance(value, (bytes, bytearray, memoryview)):
336
+ return _encode_length_prefixed(bytes(value), server_le)
337
+ return _encode_length_prefixed(str(value).encode("utf-8"), server_le)
338
+
339
+
340
+ class MachbaseProtocolClient:
341
+ def __init__(
342
+ self,
343
+ *,
344
+ host: str = "127.0.0.1",
345
+ port: int = 5656,
346
+ user: str = "SYS",
347
+ password: str = "MANAGER",
348
+ database: str = "data",
349
+ timeout_ms: int = DEFAULT_CONNECT_TIMEOUT_MS,
350
+ query_timeout_ms: int = DEFAULT_QUERY_TIMEOUT_MS,
351
+ client_id: str = "PYMCB",
352
+ show_hidden_columns: bool = False,
353
+ timezone: str = "",
354
+ ):
355
+ self.host = host
356
+ self.port = int(port)
357
+ self.user = user
358
+ self.password = password
359
+ self.database = database
360
+ self.connect_timeout_ms = int(timeout_ms)
361
+ self.query_timeout_ms = int(query_timeout_ms)
362
+ self.client_id = client_id
363
+ self.show_hidden_columns = bool(show_hidden_columns)
364
+ self.timezone = timezone
365
+ self.socket: Optional[socket.socket] = None
366
+ self.reader: Optional[PacketReader] = None
367
+ self.connected = False
368
+ self.next_stmt = 1
369
+ self.session_id = 0
370
+ self.server_le = True
371
+
372
+ def connect(self) -> None:
373
+ if self.connected:
374
+ return
375
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
376
+ try:
377
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
378
+ sock.settimeout(self.connect_timeout_ms / 1000.0)
379
+ sock.connect((self.host, self.port))
380
+ sock.settimeout(None)
381
+ self._handshake(sock)
382
+ self.socket = sock
383
+ self.reader = PacketReader(sock)
384
+ self._send_connect()
385
+ self.connected = True
386
+ except Exception as exc:
387
+ try:
388
+ sock.close()
389
+ except Exception:
390
+ pass
391
+ self.socket = None
392
+ self.reader = None
393
+ raise OperationalError(str(exc)) from exc
394
+
395
+ def close(self) -> None:
396
+ if self.socket:
397
+ try:
398
+ self.socket.close()
399
+ finally:
400
+ self.socket = None
401
+ self.reader = None
402
+ self.connected = False
403
+
404
+ def _next_statement_id(self) -> int:
405
+ sid = self.next_stmt
406
+ self.next_stmt = (self.next_stmt + 1) & 0x7FFFFFFF
407
+ if self.next_stmt == 0:
408
+ self.next_stmt = 1
409
+ return sid
410
+
411
+ def _handshake(self, sock: socket.socket) -> None:
412
+ sock.sendall(f"{CMI_HANDSHAKE_PREFIX}{CMI_HANDSHAKE_ENDIAN_LITTLE}".encode("ascii"))
413
+ response = self._recv_exact(sock, CMI_PROTO_CNT)
414
+ if response.decode("ascii") != CMI_HANDSHAKE_READY:
415
+ raise OperationalError("Invalid handshake response")
416
+
417
+ def _send_connect(self) -> None:
418
+ version = _make_protocol_version(
419
+ CMI_PROTOCOL_MAJOR_VERSION,
420
+ CMI_PROTOCOL_MINOR_VERSION,
421
+ CMI_PROTOCOL_FIX_VERSION,
422
+ )
423
+ writer = MarshalWriter(CMI_CONNECT_PROTOCOL, 0)
424
+ writer.add_uint64(CMI_C_VERSION_ID, version)
425
+ writer.add_string(CMI_C_CLIENT_ID, self.client_id)
426
+ writer.add_string(CMI_C_DATABASE_ID, self.database)
427
+ writer.add_string(CMI_C_USER_ID, self.user)
428
+ writer.add_string(CMI_C_PASSWORD_ID, self.password)
429
+ writer.add_uint64(CMI_C_TIMEOUT_ID, max(1, self.query_timeout_ms // 1000))
430
+ writer.add_uint32(CMI_C_SHC_ID, 1 if self.show_hidden_columns else 0)
431
+ writer.add_string(CMI_C_IP_ID, self.host)
432
+ if self.timezone:
433
+ writer.add_string(CMI_C_TIMEZONE_ID, self.timezone)
434
+ self._send_packets(writer.finalize())
435
+ units = collect_units(self._read_protocol(CMI_CONNECT_PROTOCOL, self.query_timeout_ms))
436
+ status_unit = _first_unit(units, CMI_R_RESULT_ID)
437
+ if not status_unit:
438
+ raise OperationalError("CONNECT response missing status")
439
+ status = _read_status(status_unit[1])
440
+ if not _status_ok(status):
441
+ raise OperationalError(_format_error_message(units) or "CONNECT failed")
442
+ sid = _first_unit(units, CMI_C_SID_ID)
443
+ if sid:
444
+ self.session_id = int.from_bytes(sid[1][:8], "little", signed=False)
445
+ endian = _first_unit(units, CMI_C_ENDIAN_ID)
446
+ if endian:
447
+ self.server_le = int.from_bytes(endian[1][:4], "little", signed=False) == 0
448
+
449
+ def _send_packets(self, packets: Sequence[bytes]) -> None:
450
+ if not self.socket:
451
+ raise OperationalError("Not connected")
452
+ for packet in packets:
453
+ self.socket.sendall(packet)
454
+
455
+ def _read_protocol(self, protocol_id: int, timeout_ms: int) -> bytes:
456
+ if not self.reader:
457
+ raise OperationalError("Not connected")
458
+ chunks: List[bytes] = []
459
+ while True:
460
+ packet = self.reader.next(timeout_ms)
461
+ if packet.protocol != protocol_id:
462
+ raise OperationalError(f"Protocol mismatch: expected {protocol_id}, got {packet.protocol}")
463
+ chunks.append(packet.body)
464
+ if packet.flag in (0, 3):
465
+ return b"".join(chunks)
466
+
467
+ def exec_direct(self, sql: str) -> QueryMeta:
468
+ stmt_id = self._next_statement_id()
469
+ writer = MarshalWriter(CMI_EXECDIRECT_PROTOCOL, stmt_id)
470
+ writer.add_string(CMI_D_STATEMENT_ID, sql)
471
+ writer.add_uint64(CMI_P_ID_ID, stmt_id)
472
+ self._send_packets(writer.finalize())
473
+ units = collect_units(self._read_protocol(CMI_EXECDIRECT_PROTOCOL, self.query_timeout_ms))
474
+ status_unit = _first_unit(units, CMI_R_RESULT_ID)
475
+ if not status_unit:
476
+ raise DatabaseError("EXEC response missing status")
477
+ status = _read_status(status_unit[1])
478
+ if not _status_ok(status):
479
+ raise DatabaseError(_format_error_message(units) or "Execute failed")
480
+ rows = _first_unit(units, CMI_P_ROWS_ID)
481
+ rows_affected = int.from_bytes(rows[1][:8], "little", signed=False) if rows else 0
482
+ message = _format_error_message(units) or None
483
+ return QueryMeta(stmt_id, [], [], rows_affected, message, _status_last(status))
484
+
485
+ def _fetch_rows_loop(self, statement_id: int, columns: Sequence[ColumnMeta]) -> List[Dict[str, Any]]:
486
+ rows: List[Dict[str, Any]] = []
487
+ while True:
488
+ writer = MarshalWriter(CMI_FETCH_PROTOCOL, statement_id)
489
+ writer.add_uint32(CMI_F_ID_ID, statement_id)
490
+ writer.add_sint64(CMI_F_ROWS_ID, 1000)
491
+ self._send_packets(writer.finalize())
492
+ units = collect_units(self._read_protocol(CMI_FETCH_PROTOCOL, self.query_timeout_ms))
493
+ status_unit = _first_unit(units, CMI_R_RESULT_ID)
494
+ if not status_unit:
495
+ raise DatabaseError("FETCH response missing status")
496
+ status_units = units.get(CMI_R_RESULT_ID, [])
497
+ has_last = False
498
+ for _, status_bytes in status_units:
499
+ status = _read_status(status_bytes)
500
+ if not _status_ok(status):
501
+ raise DatabaseError(_format_error_message(units) or "Fetch failed")
502
+ if _status_last(status):
503
+ has_last = True
504
+ batch = _rows_from_units(units, columns)
505
+ rows.extend(batch)
506
+ row_count = _first_unit(units, CMI_F_ROWS_ID)
507
+ if has_last:
508
+ break
509
+ row_cnt = int.from_bytes(row_count[1][:8], "little", signed=False) if row_count else None
510
+ no_more_rows = (row_cnt is None or row_cnt <= 0) and not batch
511
+ if no_more_rows:
512
+ break
513
+ return rows
514
+
515
+ def query(self, sql: str) -> QueryMeta:
516
+ stmt_id = self._next_statement_id()
517
+ writer = MarshalWriter(CMI_EXECDIRECT_PROTOCOL, stmt_id)
518
+ writer.add_string(CMI_D_STATEMENT_ID, sql)
519
+ writer.add_uint64(CMI_P_ID_ID, stmt_id)
520
+ self._send_packets(writer.finalize())
521
+ units = collect_units(self._read_protocol(CMI_EXECDIRECT_PROTOCOL, self.query_timeout_ms))
522
+ status_unit = _first_unit(units, CMI_R_RESULT_ID)
523
+ if not status_unit:
524
+ raise DatabaseError("QUERY response missing status")
525
+ status = _read_status(status_unit[1])
526
+ if not _status_ok(status):
527
+ raise DatabaseError(_format_error_message(units) or "Query failed")
528
+ columns = _build_column_meta(units)
529
+ rows = _rows_from_units(units, columns)
530
+ if columns and not _status_last(status):
531
+ rows.extend(self._fetch_rows_loop(stmt_id, columns))
532
+ return QueryMeta(stmt_id, columns, rows, len(rows), _format_error_message(units) or None, _status_last(status))
533
+
534
+ def append_open(self, table: str) -> _AppendSession:
535
+ stmt_id = self._next_statement_id()
536
+ writer = MarshalWriter(CMI_APPEND_OPEN_PROTOCOL, stmt_id)
537
+ writer.add_uint64(CMI_P_ID_ID, stmt_id)
538
+ writer.add_string(CMI_P_TABLE_ID, table)
539
+ writer.add_uint64(CMI_E_ENDIAN_ID, 0)
540
+ self._send_packets(writer.finalize())
541
+ units = collect_units(self._read_protocol(CMI_APPEND_OPEN_PROTOCOL, self.query_timeout_ms))
542
+ status_unit = _first_unit(units, CMI_R_RESULT_ID)
543
+ if not status_unit:
544
+ raise DatabaseError("APPEND open response missing status")
545
+ status = _read_status(status_unit[1])
546
+ if not _status_ok(status):
547
+ raise DatabaseError(_format_error_message(units) or "Append open failed")
548
+ return _AppendSession(stmt_id, table, _build_column_meta(units), self.server_le)
549
+
550
+ def append_data(
551
+ self,
552
+ session: _AppendSession,
553
+ rows: Sequence[Sequence[Any]],
554
+ types: Optional[Sequence[int]] = None,
555
+ times: Optional[Sequence[Any]] = None,
556
+ on_ack: Optional[Callable[[Dict[int, List[Tuple[int, bytes]]]], None]] = None,
557
+ ) -> int:
558
+ writer = MarshalWriter(CMI_APPEND_DATA_PROTOCOL, session.statement_id, adds=session.statement_id & 0xFFFF)
559
+ metadata = session.columns
560
+ has_arrival = bool(metadata) and metadata[0].name.upper() in {"", "_ARRIVAL_TIME"}
561
+ null_bytes = (len(metadata) // 8) + 1 if metadata else 0
562
+ for row_index, row in enumerate(rows):
563
+ values = list(row)
564
+ if metadata and has_arrival:
565
+ if len(values) == len(metadata) - 1:
566
+ if times and row_index < len(times):
567
+ values = [times[row_index]] + values
568
+ else:
569
+ values = [0] + values
570
+ elif len(values) == len(metadata):
571
+ if times and row_index < len(times) and values[0] is None:
572
+ values[0] = times[row_index]
573
+ else:
574
+ raise ProgrammingError("Append row length does not match append metadata")
575
+ elif metadata and len(values) != len(metadata):
576
+ raise ProgrammingError("Append row length does not match append metadata")
577
+ null_bits = bytearray(null_bytes)
578
+ payload = bytearray()
579
+ payload.append(0)
580
+ payload.extend(int(null_bytes).to_bytes(4, "little" if session.server_le else "big"))
581
+ payload.extend(null_bits)
582
+ for col_index, value in enumerate(values):
583
+ if has_arrival and col_index == 0:
584
+ payload.extend(_encode_value(CMD_DATE_TYPE, value, session.server_le))
585
+ continue
586
+ if metadata:
587
+ spiner_type = metadata[col_index].spiner_type
588
+ elif types and col_index < len(types):
589
+ spiner_type = spiner_type_from_legacy_sql(int(types[col_index]))
590
+ else:
591
+ spiner_type = spiner_type_from_value(value)
592
+ if value is None and null_bytes:
593
+ null_index = col_index
594
+ if null_index < 0:
595
+ payload.extend(_encode_value(spiner_type, value, session.server_le))
596
+ continue
597
+ null_offset = 5 + (null_index // 8)
598
+ payload[null_offset] |= 1 << (7 - (null_index % 8))
599
+ continue
600
+ payload.extend(_encode_value(spiner_type, value, session.server_le))
601
+ writer.add_binary(CMI_P_ROWS_ID, bytes(payload))
602
+ self._send_packets(writer.finalize())
603
+ try:
604
+ units = collect_units(self._read_protocol(CMI_APPEND_DATA_PROTOCOL, 0))
605
+ except TimeoutError:
606
+ return len(rows)
607
+ if on_ack is not None:
608
+ on_ack(units)
609
+ status_unit = _first_unit(units, CMI_R_RESULT_ID)
610
+ if status_unit:
611
+ status = _read_status(status_unit[1])
612
+ if not _status_ok(status):
613
+ raise DatabaseError(_format_error_message(units) or "Append data failed")
614
+ failures = _first_unit(units, CMI_X_APPEND_FAILURE_ID)
615
+ if failures and int.from_bytes(failures[1][:8], "little", signed=False) > 0:
616
+ raise DatabaseError(_format_error_message(units) or "Append data reported failures")
617
+ return len(rows)
618
+
619
+ def append_close(self, session: _AppendSession) -> int:
620
+ writer = MarshalWriter(CMI_APPEND_CLOSE_PROTOCOL, session.statement_id)
621
+ writer.add_uint64(CMI_P_ID_ID, session.statement_id)
622
+ self._send_packets(writer.finalize())
623
+ units = collect_units(self._read_protocol(CMI_APPEND_CLOSE_PROTOCOL, self.query_timeout_ms))
624
+ status_unit = _first_unit(units, CMI_R_RESULT_ID)
625
+ if not status_unit:
626
+ raise DatabaseError("APPEND close response missing status")
627
+ status = _read_status(status_unit[1])
628
+ if not _status_ok(status):
629
+ raise DatabaseError(_format_error_message(units) or "Append close failed")
630
+ success = _first_unit(units, CMI_X_APPEND_SUCCESS_ID)
631
+ return int.from_bytes(success[1][:8], "little", signed=False) if success else 0
632
+
633
+ def append(
634
+ self,
635
+ table: str,
636
+ rows: Sequence[Sequence[Any]],
637
+ types: Optional[Sequence[int]] = None,
638
+ times: Optional[Sequence[Any]] = None,
639
+ ) -> int:
640
+ session = self.append_open(table)
641
+ try:
642
+ self.append_data(session, rows, types=types, times=times)
643
+ return self.append_close(session)
644
+ except Exception:
645
+ try:
646
+ self.append_close(session)
647
+ except Exception:
648
+ pass
649
+ raise
650
+
651
+ @staticmethod
652
+ def _recv_exact(sock: socket.socket, length: int) -> bytes:
653
+ chunks: List[bytes] = []
654
+ remaining = length
655
+ while remaining > 0:
656
+ chunk = sock.recv(remaining)
657
+ if not chunk:
658
+ raise OperationalError("Socket closed while reading protocol")
659
+ chunks.append(chunk)
660
+ remaining -= len(chunk)
661
+ return b"".join(chunks)