prismdb 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prismdb/__init__.py +43 -0
- prismdb/_codec.py +177 -0
- prismdb/client.py +289 -0
- prismdb/connection.py +116 -0
- prismdb/document.py +50 -0
- prismdb/errors.py +78 -0
- prismdb/messages.py +354 -0
- prismdb/py.typed +0 -0
- prismdb/query.py +136 -0
- prismdb/update.py +62 -0
- prismdb/value.py +196 -0
- prismdb-0.1.0.dist-info/METADATA +133 -0
- prismdb-0.1.0.dist-info/RECORD +15 -0
- prismdb-0.1.0.dist-info/WHEEL +5 -0
- prismdb-0.1.0.dist-info/top_level.txt +1 -0
prismdb/__init__.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
"""prismdb — a pure-Python client for PrismDB over the binary wire protocol
|
|
2
|
+
(``docs/specs/wire-protocol.md``). No native build, no C extensions.
|
|
3
|
+
|
|
4
|
+
from prismdb import Client, Q, U
|
|
5
|
+
|
|
6
|
+
with Client.connect(host="127.0.0.1", port=4444, username="admin", password="admin") as db:
|
|
7
|
+
db.sql("CREATE TABLE users (id BIGINT PRIMARY KEY, name TEXT)")
|
|
8
|
+
db.sql("INSERT INTO users VALUES (1, 'alice')")
|
|
9
|
+
print(db.sql("SELECT * FROM users").rows)
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from .client import Client, SqlResult
|
|
13
|
+
from .document import Document
|
|
14
|
+
from .errors import ErrorCode, ErrorInfo, PrismError, PrismServerError, ProtocolError
|
|
15
|
+
from .query import DocQuery, Q
|
|
16
|
+
from .update import DocUpdate, U
|
|
17
|
+
from .value import TAG, ObjectId, Typed, Value, float64, int32, int64, timestamp
|
|
18
|
+
|
|
19
|
+
__version__ = "0.1.0"
|
|
20
|
+
|
|
21
|
+
__all__ = [
|
|
22
|
+
"Client",
|
|
23
|
+
"SqlResult",
|
|
24
|
+
"Document",
|
|
25
|
+
"Q",
|
|
26
|
+
"DocQuery",
|
|
27
|
+
"U",
|
|
28
|
+
"DocUpdate",
|
|
29
|
+
"ObjectId",
|
|
30
|
+
"Typed",
|
|
31
|
+
"Value",
|
|
32
|
+
"TAG",
|
|
33
|
+
"int32",
|
|
34
|
+
"int64",
|
|
35
|
+
"float64",
|
|
36
|
+
"timestamp",
|
|
37
|
+
"PrismError",
|
|
38
|
+
"PrismServerError",
|
|
39
|
+
"ProtocolError",
|
|
40
|
+
"ErrorInfo",
|
|
41
|
+
"ErrorCode",
|
|
42
|
+
"__version__",
|
|
43
|
+
]
|
prismdb/_codec.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
"""Low-level binary codec: a growable little-endian ``Writer``, a bounds-checked
|
|
2
|
+
``Reader``, and the length-prefixed frame helpers. The byte layouts mirror
|
|
3
|
+
``crates/prism-protocol/src/codec.rs`` exactly (all multi-byte integers LE)."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import struct
|
|
8
|
+
|
|
9
|
+
from .errors import ProtocolError
|
|
10
|
+
|
|
11
|
+
_U64_MASK = (1 << 64) - 1
|
|
12
|
+
_U128_MASK = (1 << 128) - 1
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class Writer:
|
|
16
|
+
"""A growable little-endian writer over a bytearray."""
|
|
17
|
+
|
|
18
|
+
__slots__ = ("_buf",)
|
|
19
|
+
|
|
20
|
+
def __init__(self) -> None:
|
|
21
|
+
self._buf = bytearray()
|
|
22
|
+
|
|
23
|
+
def u8(self, v: int) -> None:
|
|
24
|
+
self._buf.append(v & 0xFF)
|
|
25
|
+
|
|
26
|
+
def u16(self, v: int) -> None:
|
|
27
|
+
self._buf += struct.pack("<H", v & 0xFFFF)
|
|
28
|
+
|
|
29
|
+
def u32(self, v: int) -> None:
|
|
30
|
+
self._buf += struct.pack("<I", v & 0xFFFFFFFF)
|
|
31
|
+
|
|
32
|
+
def i32(self, v: int) -> None:
|
|
33
|
+
self._buf += struct.pack("<i", _to_signed(v, 32))
|
|
34
|
+
|
|
35
|
+
def u64(self, v: int) -> None:
|
|
36
|
+
self._buf += struct.pack("<Q", v & _U64_MASK)
|
|
37
|
+
|
|
38
|
+
def i64(self, v: int) -> None:
|
|
39
|
+
self._buf += struct.pack("<q", _to_signed(v, 64))
|
|
40
|
+
|
|
41
|
+
def f64(self, v: float) -> None:
|
|
42
|
+
self._buf += struct.pack("<d", v)
|
|
43
|
+
|
|
44
|
+
def u128(self, v: int) -> None:
|
|
45
|
+
"""A 128-bit unsigned integer as 16 little-endian bytes."""
|
|
46
|
+
x = v & _U128_MASK
|
|
47
|
+
self._buf += struct.pack("<QQ", x & _U64_MASK, (x >> 64) & _U64_MASK)
|
|
48
|
+
|
|
49
|
+
def raw(self, b: bytes) -> None:
|
|
50
|
+
self._buf += b
|
|
51
|
+
|
|
52
|
+
def str_u16(self, s: str) -> None:
|
|
53
|
+
"""A UTF-8 string with a u16 length prefix."""
|
|
54
|
+
b = s.encode("utf-8")
|
|
55
|
+
self.u16(len(b))
|
|
56
|
+
self._buf += b
|
|
57
|
+
|
|
58
|
+
def str_u32(self, s: str) -> None:
|
|
59
|
+
"""A UTF-8 string with a u32 length prefix."""
|
|
60
|
+
b = s.encode("utf-8")
|
|
61
|
+
self.u32(len(b))
|
|
62
|
+
self._buf += b
|
|
63
|
+
|
|
64
|
+
def bytes_u16(self, b: bytes) -> None:
|
|
65
|
+
"""A byte string with a u16 length prefix."""
|
|
66
|
+
self.u16(len(b))
|
|
67
|
+
self._buf += b
|
|
68
|
+
|
|
69
|
+
def bytes_u32(self, b: bytes) -> None:
|
|
70
|
+
"""A byte string with a u32 length prefix."""
|
|
71
|
+
self.u32(len(b))
|
|
72
|
+
self._buf += b
|
|
73
|
+
|
|
74
|
+
def out(self) -> bytes:
|
|
75
|
+
return bytes(self._buf)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class Reader:
|
|
79
|
+
"""A bounds-checked little-endian reader over a bytes-like object."""
|
|
80
|
+
|
|
81
|
+
__slots__ = ("_buf", "_p")
|
|
82
|
+
|
|
83
|
+
def __init__(self, buf: bytes) -> None:
|
|
84
|
+
self._buf = buf
|
|
85
|
+
self._p = 0
|
|
86
|
+
|
|
87
|
+
def _need(self, n: int) -> None:
|
|
88
|
+
if self._p + n > len(self._buf):
|
|
89
|
+
raise ProtocolError(f"truncated: need {n} bytes at offset {self._p}")
|
|
90
|
+
|
|
91
|
+
def u8(self) -> int:
|
|
92
|
+
self._need(1)
|
|
93
|
+
v = self._buf[self._p]
|
|
94
|
+
self._p += 1
|
|
95
|
+
return v
|
|
96
|
+
|
|
97
|
+
def u16(self) -> int:
|
|
98
|
+
self._need(2)
|
|
99
|
+
v = struct.unpack_from("<H", self._buf, self._p)[0]
|
|
100
|
+
self._p += 2
|
|
101
|
+
return v
|
|
102
|
+
|
|
103
|
+
def u32(self) -> int:
|
|
104
|
+
self._need(4)
|
|
105
|
+
v = struct.unpack_from("<I", self._buf, self._p)[0]
|
|
106
|
+
self._p += 4
|
|
107
|
+
return v
|
|
108
|
+
|
|
109
|
+
def i32(self) -> int:
|
|
110
|
+
self._need(4)
|
|
111
|
+
v = struct.unpack_from("<i", self._buf, self._p)[0]
|
|
112
|
+
self._p += 4
|
|
113
|
+
return v
|
|
114
|
+
|
|
115
|
+
def u64(self) -> int:
|
|
116
|
+
self._need(8)
|
|
117
|
+
v = struct.unpack_from("<Q", self._buf, self._p)[0]
|
|
118
|
+
self._p += 8
|
|
119
|
+
return v
|
|
120
|
+
|
|
121
|
+
def i64(self) -> int:
|
|
122
|
+
self._need(8)
|
|
123
|
+
v = struct.unpack_from("<q", self._buf, self._p)[0]
|
|
124
|
+
self._p += 8
|
|
125
|
+
return v
|
|
126
|
+
|
|
127
|
+
def f64(self) -> float:
|
|
128
|
+
self._need(8)
|
|
129
|
+
v = struct.unpack_from("<d", self._buf, self._p)[0]
|
|
130
|
+
self._p += 8
|
|
131
|
+
return v
|
|
132
|
+
|
|
133
|
+
def u128(self) -> int:
|
|
134
|
+
self._need(16)
|
|
135
|
+
lo, hi = struct.unpack_from("<QQ", self._buf, self._p)
|
|
136
|
+
self._p += 16
|
|
137
|
+
return (hi << 64) | lo
|
|
138
|
+
|
|
139
|
+
def raw(self, n: int) -> bytes:
|
|
140
|
+
self._need(n)
|
|
141
|
+
s = self._buf[self._p : self._p + n]
|
|
142
|
+
self._p += n
|
|
143
|
+
return bytes(s)
|
|
144
|
+
|
|
145
|
+
def str_u16(self) -> str:
|
|
146
|
+
return self.raw(self.u16()).decode("utf-8")
|
|
147
|
+
|
|
148
|
+
def str_u32(self) -> str:
|
|
149
|
+
return self.raw(self.u32()).decode("utf-8")
|
|
150
|
+
|
|
151
|
+
def bytes_u16(self) -> bytes:
|
|
152
|
+
return self.raw(self.u16())
|
|
153
|
+
|
|
154
|
+
def bytes_u32(self) -> bytes:
|
|
155
|
+
return self.raw(self.u32())
|
|
156
|
+
|
|
157
|
+
def remaining(self) -> int:
|
|
158
|
+
return len(self._buf) - self._p
|
|
159
|
+
|
|
160
|
+
def expect_end(self) -> None:
|
|
161
|
+
"""Raise unless every byte has been consumed."""
|
|
162
|
+
if self.remaining() != 0:
|
|
163
|
+
raise ProtocolError(f"{self.remaining()} trailing byte(s) after message")
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def _to_signed(v: int, bits: int) -> int:
|
|
167
|
+
"""Wrap ``v`` into a signed ``bits``-wide integer (two's complement)."""
|
|
168
|
+
mask = (1 << bits) - 1
|
|
169
|
+
v &= mask
|
|
170
|
+
if v >= 1 << (bits - 1):
|
|
171
|
+
v -= 1 << bits
|
|
172
|
+
return v
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def frame_encode(payload: bytes) -> bytes:
|
|
176
|
+
"""Wrap a payload in a ``[len:u32][payload]`` frame."""
|
|
177
|
+
return struct.pack("<I", len(payload)) + payload
|
prismdb/client.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
1
|
+
"""The high-level client: connect + handshake, then SQL / KV / document calls
|
|
2
|
+
and transaction control. One client owns one connection = one server session,
|
|
3
|
+
so a ``begin()`` … ``commit()`` brackets the calls in between."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from typing import Dict, List, Optional, Sequence, Union
|
|
9
|
+
|
|
10
|
+
from .connection import Connection, NoticeHandler, TlsArg
|
|
11
|
+
from .document import Document, decode_document, encode_document
|
|
12
|
+
from .errors import ErrorInfo, PrismServerError, ProtocolError
|
|
13
|
+
from .messages import (
|
|
14
|
+
AUTH_PASSWORD,
|
|
15
|
+
FEATURE_CONNECT_DB,
|
|
16
|
+
TXN_READ_ONLY,
|
|
17
|
+
TXN_READ_WRITE,
|
|
18
|
+
AuthAck,
|
|
19
|
+
ColumnDesc,
|
|
20
|
+
DocResultMsg,
|
|
21
|
+
HelloAck,
|
|
22
|
+
KvResultMsg,
|
|
23
|
+
Pong,
|
|
24
|
+
SqlResultMsg,
|
|
25
|
+
TxnAck,
|
|
26
|
+
abort_body,
|
|
27
|
+
auth_body,
|
|
28
|
+
begin_body,
|
|
29
|
+
commit_body,
|
|
30
|
+
doc_body,
|
|
31
|
+
doc_insert_many_body,
|
|
32
|
+
hello_body,
|
|
33
|
+
kv_delete_body,
|
|
34
|
+
kv_get_body,
|
|
35
|
+
kv_put_body,
|
|
36
|
+
ping_body,
|
|
37
|
+
sql_body,
|
|
38
|
+
)
|
|
39
|
+
from .query import DocQuery, Q, encode_doc_query
|
|
40
|
+
from .update import DocUpdate, encode_doc_update
|
|
41
|
+
from .value import ObjectId, Value
|
|
42
|
+
|
|
43
|
+
_PROTOCOL_VERSION = 1
|
|
44
|
+
_EMPTY = b""
|
|
45
|
+
|
|
46
|
+
BytesLike = Union[str, bytes, bytearray]
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@dataclass
|
|
50
|
+
class SqlResult:
|
|
51
|
+
"""A SQL result set. ``rows`` are keyed by column name; ``raw`` keeps cell order."""
|
|
52
|
+
|
|
53
|
+
columns: List[ColumnDesc]
|
|
54
|
+
rows: List[Dict[str, Value]]
|
|
55
|
+
raw: List[List[Value]]
|
|
56
|
+
affected_rows: int
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _fail(error: Optional[ErrorInfo]) -> "None":
|
|
60
|
+
raise PrismServerError(
|
|
61
|
+
error or ErrorInfo(code=0, message="server error", sqlstate="XX000")
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _bytes(v: BytesLike) -> bytes:
|
|
66
|
+
return v.encode("utf-8") if isinstance(v, str) else bytes(v)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class KvSurface:
|
|
70
|
+
"""``client.kv`` — namespaced key/value operations."""
|
|
71
|
+
|
|
72
|
+
def __init__(self, client: "Client") -> None:
|
|
73
|
+
self._c = client
|
|
74
|
+
|
|
75
|
+
def get(self, namespace: str, key: BytesLike) -> Optional[bytes]:
|
|
76
|
+
reply = self._c._kv_reply(*kv_get_body(namespace, _bytes(key)))
|
|
77
|
+
if reply.op != 1:
|
|
78
|
+
raise ProtocolError("expected a KV get result")
|
|
79
|
+
return reply.value
|
|
80
|
+
|
|
81
|
+
def put(self, namespace: str, key: BytesLike, value: BytesLike) -> None:
|
|
82
|
+
self._c._kv_reply(*kv_put_body(namespace, _bytes(key), _bytes(value)))
|
|
83
|
+
|
|
84
|
+
def delete(self, namespace: str, key: BytesLike) -> None:
|
|
85
|
+
self._c._kv_reply(*kv_delete_body(namespace, _bytes(key)))
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class DocSurface:
|
|
89
|
+
"""``client.doc`` — document collection operations."""
|
|
90
|
+
|
|
91
|
+
def __init__(self, client: "Client") -> None:
|
|
92
|
+
self._c = client
|
|
93
|
+
|
|
94
|
+
def insert_one(self, collection: str, document: Document) -> ObjectId:
|
|
95
|
+
reply = self._c._doc_reply(*doc_body(1, collection, [encode_document(document)]))
|
|
96
|
+
if not reply.inserted_ids:
|
|
97
|
+
raise ProtocolError("insert returned no _id")
|
|
98
|
+
return reply.inserted_ids[0]
|
|
99
|
+
|
|
100
|
+
def insert_many(self, collection: str, documents: Sequence[Document]) -> List[ObjectId]:
|
|
101
|
+
blobs = [encode_document(d) for d in documents]
|
|
102
|
+
reply = self._c._doc_reply(*doc_insert_many_body(collection, blobs))
|
|
103
|
+
return reply.inserted_ids
|
|
104
|
+
|
|
105
|
+
def find(self, collection: str, query: Optional[DocQuery] = None) -> List[Document]:
|
|
106
|
+
reply = self._c._doc_reply(*doc_body(3, collection, [encode_doc_query(query or Q.all()), _EMPTY]))
|
|
107
|
+
return [decode_document(d) for d in reply.docs]
|
|
108
|
+
|
|
109
|
+
def find_one(self, collection: str, query: Optional[DocQuery] = None) -> Optional[Document]:
|
|
110
|
+
reply = self._c._doc_reply(*doc_body(4, collection, [encode_doc_query(query or Q.all()), _EMPTY]))
|
|
111
|
+
return decode_document(reply.docs[0]) if reply.docs else None
|
|
112
|
+
|
|
113
|
+
def count(self, collection: str, query: Optional[DocQuery] = None) -> int:
|
|
114
|
+
reply = self._c._doc_reply(*doc_body(9, collection, [encode_doc_query(query or Q.all()), _EMPTY]))
|
|
115
|
+
return reply.affected
|
|
116
|
+
|
|
117
|
+
def update_one(self, collection: str, query: DocQuery, update: List[DocUpdate]) -> int:
|
|
118
|
+
reply = self._c._doc_reply(
|
|
119
|
+
*doc_body(5, collection, [encode_doc_query(query), encode_doc_update(update), _EMPTY])
|
|
120
|
+
)
|
|
121
|
+
return reply.affected
|
|
122
|
+
|
|
123
|
+
def update_many(self, collection: str, query: DocQuery, update: List[DocUpdate]) -> int:
|
|
124
|
+
reply = self._c._doc_reply(
|
|
125
|
+
*doc_body(6, collection, [encode_doc_query(query), encode_doc_update(update), _EMPTY])
|
|
126
|
+
)
|
|
127
|
+
return reply.affected
|
|
128
|
+
|
|
129
|
+
def delete_one(self, collection: str, query: DocQuery) -> int:
|
|
130
|
+
reply = self._c._doc_reply(*doc_body(7, collection, [encode_doc_query(query), _EMPTY]))
|
|
131
|
+
return reply.affected
|
|
132
|
+
|
|
133
|
+
def delete_many(self, collection: str, query: DocQuery) -> int:
|
|
134
|
+
reply = self._c._doc_reply(*doc_body(8, collection, [encode_doc_query(query), _EMPTY]))
|
|
135
|
+
return reply.affected
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class Client:
|
|
139
|
+
"""A connected, authenticated Prism session."""
|
|
140
|
+
|
|
141
|
+
def __init__(self, conn: Connection) -> None:
|
|
142
|
+
self._conn = conn
|
|
143
|
+
self.kv = KvSurface(self)
|
|
144
|
+
self.doc = DocSurface(self)
|
|
145
|
+
|
|
146
|
+
@classmethod
|
|
147
|
+
def connect(
|
|
148
|
+
cls,
|
|
149
|
+
host: str = "127.0.0.1",
|
|
150
|
+
port: int = 4444,
|
|
151
|
+
*,
|
|
152
|
+
username: Optional[str] = None,
|
|
153
|
+
password: Optional[str] = None,
|
|
154
|
+
database: Optional[str] = None,
|
|
155
|
+
tls: TlsArg = None,
|
|
156
|
+
server_hostname: Optional[str] = None,
|
|
157
|
+
connect_timeout: float = 10.0,
|
|
158
|
+
client_name: str = "prismdb-python",
|
|
159
|
+
client_version: str = "0.1.0",
|
|
160
|
+
on_notice: Optional[NoticeHandler] = None,
|
|
161
|
+
) -> "Client":
|
|
162
|
+
"""Connect, perform the handshake, and (if ``username`` is set) authenticate."""
|
|
163
|
+
conn = Connection.connect(
|
|
164
|
+
host,
|
|
165
|
+
port,
|
|
166
|
+
tls=tls,
|
|
167
|
+
server_hostname=server_hostname,
|
|
168
|
+
connect_timeout=connect_timeout,
|
|
169
|
+
on_notice=on_notice,
|
|
170
|
+
)
|
|
171
|
+
client = cls(conn)
|
|
172
|
+
try:
|
|
173
|
+
connect_db_honored = client._handshake(
|
|
174
|
+
username, password, database or "", client_name, client_version
|
|
175
|
+
)
|
|
176
|
+
# Fall back to `USE` only when the server did not bind the database
|
|
177
|
+
# in the handshake (an older server without FEATURE_CONNECT_DB).
|
|
178
|
+
if database and not connect_db_honored:
|
|
179
|
+
client.sql(f"USE {database}", return_rows=False)
|
|
180
|
+
except Exception:
|
|
181
|
+
conn.close()
|
|
182
|
+
raise
|
|
183
|
+
return client
|
|
184
|
+
|
|
185
|
+
def _handshake(
|
|
186
|
+
self,
|
|
187
|
+
username: Optional[str],
|
|
188
|
+
password: Optional[str],
|
|
189
|
+
database: str,
|
|
190
|
+
client_name: str,
|
|
191
|
+
client_version: str,
|
|
192
|
+
) -> bool:
|
|
193
|
+
features = FEATURE_CONNECT_DB if database else 0
|
|
194
|
+
ack = self._conn.request(*hello_body(_PROTOCOL_VERSION, client_name, client_version, features, database))
|
|
195
|
+
if not isinstance(ack, HelloAck):
|
|
196
|
+
raise ProtocolError("expected HelloAck")
|
|
197
|
+
if ack.status != 0:
|
|
198
|
+
_fail(ack.error)
|
|
199
|
+
connect_db_honored = (ack.features & FEATURE_CONNECT_DB) != 0 and database != ""
|
|
200
|
+
|
|
201
|
+
if username is not None:
|
|
202
|
+
auth_ack = self._conn.request(*auth_body(AUTH_PASSWORD, username, password or ""))
|
|
203
|
+
if not isinstance(auth_ack, AuthAck):
|
|
204
|
+
raise ProtocolError("expected AuthAck")
|
|
205
|
+
if auth_ack.status != 0:
|
|
206
|
+
_fail(auth_ack.error)
|
|
207
|
+
return connect_db_honored
|
|
208
|
+
|
|
209
|
+
# ---- SQL --------------------------------------------------------------
|
|
210
|
+
|
|
211
|
+
def sql(
|
|
212
|
+
self,
|
|
213
|
+
text: str,
|
|
214
|
+
params: Optional[Sequence[Value]] = None,
|
|
215
|
+
*,
|
|
216
|
+
return_rows: bool = True,
|
|
217
|
+
) -> SqlResult:
|
|
218
|
+
"""Execute a SQL statement. Returns rows for ``SELECT``, counts otherwise."""
|
|
219
|
+
reply = self._conn.request(*sql_body(text, list(params or []), 1 if return_rows else 0))
|
|
220
|
+
if not isinstance(reply, SqlResultMsg):
|
|
221
|
+
raise ProtocolError("expected SqlResult")
|
|
222
|
+
if reply.status != 0:
|
|
223
|
+
_fail(reply.error)
|
|
224
|
+
if reply.more_frames:
|
|
225
|
+
raise ProtocolError("streamed SQL results are not yet supported")
|
|
226
|
+
names = [c.name for c in reply.columns]
|
|
227
|
+
rows = [{names[i]: cell for i, cell in enumerate(cells)} for cells in reply.rows]
|
|
228
|
+
return SqlResult(reply.columns, rows, reply.rows, reply.affected_rows)
|
|
229
|
+
|
|
230
|
+
# ---- transactions -----------------------------------------------------
|
|
231
|
+
|
|
232
|
+
def begin(self, mode: str = "read_write") -> int:
|
|
233
|
+
"""Begin a transaction; returns the assigned transaction id."""
|
|
234
|
+
ack = self._txn(*begin_body(TXN_READ_ONLY if mode == "read_only" else TXN_READ_WRITE))
|
|
235
|
+
return ack.txn_id
|
|
236
|
+
|
|
237
|
+
def commit(self, idempotency_key: int = 0) -> None:
|
|
238
|
+
"""Commit the current transaction (optionally idempotent)."""
|
|
239
|
+
self._txn(*commit_body(idempotency_key))
|
|
240
|
+
|
|
241
|
+
def abort(self) -> None:
|
|
242
|
+
"""Abort the current transaction."""
|
|
243
|
+
self._txn(*abort_body())
|
|
244
|
+
|
|
245
|
+
def _txn(self, type_code: int, body: bytes) -> TxnAck:
|
|
246
|
+
reply = self._conn.request(type_code, body)
|
|
247
|
+
if not isinstance(reply, TxnAck):
|
|
248
|
+
raise ProtocolError("expected TxnAck")
|
|
249
|
+
if reply.status != 0:
|
|
250
|
+
_fail(reply.error)
|
|
251
|
+
return reply
|
|
252
|
+
|
|
253
|
+
# ---- misc -------------------------------------------------------------
|
|
254
|
+
|
|
255
|
+
def ping(self) -> None:
|
|
256
|
+
"""Round-trip a keep-alive ping."""
|
|
257
|
+
reply = self._conn.request(*ping_body())
|
|
258
|
+
if not isinstance(reply, Pong):
|
|
259
|
+
raise ProtocolError("expected Pong")
|
|
260
|
+
|
|
261
|
+
def close(self) -> None:
|
|
262
|
+
"""Close the underlying connection."""
|
|
263
|
+
self._conn.close()
|
|
264
|
+
|
|
265
|
+
def __enter__(self) -> "Client":
|
|
266
|
+
return self
|
|
267
|
+
|
|
268
|
+
def __exit__(self, *exc: object) -> None:
|
|
269
|
+
self.close()
|
|
270
|
+
|
|
271
|
+
# ---- internal reply helpers -------------------------------------------
|
|
272
|
+
|
|
273
|
+
def _kv_reply(self, type_code: int, body: bytes) -> KvResultMsg:
|
|
274
|
+
reply = self._conn.request(type_code, body)
|
|
275
|
+
if not isinstance(reply, KvResultMsg):
|
|
276
|
+
raise ProtocolError("expected KvResult")
|
|
277
|
+
if reply.status != 0:
|
|
278
|
+
_fail(reply.error)
|
|
279
|
+
return reply
|
|
280
|
+
|
|
281
|
+
def _doc_reply(self, type_code: int, body: bytes) -> DocResultMsg:
|
|
282
|
+
reply = self._conn.request(type_code, body)
|
|
283
|
+
if not isinstance(reply, DocResultMsg):
|
|
284
|
+
raise ProtocolError("expected DocResult")
|
|
285
|
+
if reply.status != 0:
|
|
286
|
+
_fail(reply.error)
|
|
287
|
+
if reply.more_frames:
|
|
288
|
+
raise ProtocolError("streamed document results are not yet supported")
|
|
289
|
+
return reply
|
prismdb/connection.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
"""The transport: a TCP (optionally TLS) socket that frames outgoing messages,
|
|
2
|
+
reads back full frames, and matches each reply to its request by the echoed
|
|
3
|
+
``request_id``. Server-initiated notices (request_id 0) go to a handler.
|
|
4
|
+
|
|
5
|
+
The connection is synchronous: each :meth:`request` writes a frame and blocks
|
|
6
|
+
until the matching reply arrives, dispatching any notices seen in between."""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import socket
|
|
11
|
+
import ssl
|
|
12
|
+
from typing import Callable, Optional, Tuple, Union
|
|
13
|
+
|
|
14
|
+
from ._codec import frame_encode
|
|
15
|
+
from .errors import ProtocolError
|
|
16
|
+
from .messages import Notice, decode_packet, encode_packet
|
|
17
|
+
|
|
18
|
+
NoticeHandler = Callable[[Notice], None]
|
|
19
|
+
# ``tls``: False/None = plaintext; True = default client context; an
|
|
20
|
+
# ``ssl.SSLContext`` = use it as-is.
|
|
21
|
+
TlsArg = Union[bool, ssl.SSLContext, None]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class Connection:
|
|
25
|
+
"""A framed, request/reply socket connection to a Prism server."""
|
|
26
|
+
|
|
27
|
+
def __init__(self, sock: socket.socket, on_notice: Optional[NoticeHandler] = None) -> None:
|
|
28
|
+
self._sock = sock
|
|
29
|
+
self._on_notice = on_notice
|
|
30
|
+
self._next_id = 1
|
|
31
|
+
self._inbound = bytearray()
|
|
32
|
+
self._closed: Optional[Exception] = None
|
|
33
|
+
|
|
34
|
+
@classmethod
|
|
35
|
+
def connect(
|
|
36
|
+
cls,
|
|
37
|
+
host: str = "127.0.0.1",
|
|
38
|
+
port: int = 4444,
|
|
39
|
+
*,
|
|
40
|
+
tls: TlsArg = None,
|
|
41
|
+
server_hostname: Optional[str] = None,
|
|
42
|
+
connect_timeout: float = 10.0,
|
|
43
|
+
on_notice: Optional[NoticeHandler] = None,
|
|
44
|
+
) -> "Connection":
|
|
45
|
+
"""Open a connection (TCP, or TLS when ``tls`` is set)."""
|
|
46
|
+
raw = socket.create_connection((host, port), timeout=connect_timeout)
|
|
47
|
+
raw.settimeout(None)
|
|
48
|
+
raw.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
|
49
|
+
sock: socket.socket = raw
|
|
50
|
+
if tls:
|
|
51
|
+
ctx = tls if isinstance(tls, ssl.SSLContext) else ssl.create_default_context()
|
|
52
|
+
sock = ctx.wrap_socket(raw, server_hostname=server_hostname or host)
|
|
53
|
+
return cls(sock, on_notice)
|
|
54
|
+
|
|
55
|
+
def request(self, type_code: int, body: bytes) -> object:
|
|
56
|
+
"""Send a client message and return the matching reply message."""
|
|
57
|
+
if self._closed is not None:
|
|
58
|
+
raise self._closed
|
|
59
|
+
request_id = self._next_id
|
|
60
|
+
self._next_id = 1 if self._next_id >= 0xFFFFFFFF else self._next_id + 1
|
|
61
|
+
try:
|
|
62
|
+
self._sock.sendall(frame_encode(encode_packet(request_id, type_code, body)))
|
|
63
|
+
except OSError as e:
|
|
64
|
+
self._fail(ProtocolError(f"send failed: {e}"))
|
|
65
|
+
raise self._closed # type: ignore[misc]
|
|
66
|
+
|
|
67
|
+
while True:
|
|
68
|
+
payload = self._read_frame()
|
|
69
|
+
packet = decode_packet(payload)
|
|
70
|
+
if isinstance(packet.message, Notice):
|
|
71
|
+
if self._on_notice is not None:
|
|
72
|
+
self._on_notice(packet.message)
|
|
73
|
+
continue
|
|
74
|
+
if packet.request_id == request_id:
|
|
75
|
+
return packet.message
|
|
76
|
+
# An unmatched reply (e.g. a late response) is ignored.
|
|
77
|
+
|
|
78
|
+
def close(self) -> None:
|
|
79
|
+
"""Close the connection. Further use raises."""
|
|
80
|
+
self._fail(ProtocolError("connection closed by client"))
|
|
81
|
+
try:
|
|
82
|
+
self._sock.close()
|
|
83
|
+
except OSError:
|
|
84
|
+
pass
|
|
85
|
+
|
|
86
|
+
# -- internals ----------------------------------------------------------
|
|
87
|
+
|
|
88
|
+
def _read_frame(self) -> bytes:
|
|
89
|
+
header = self._read_exact(4)
|
|
90
|
+
length = int.from_bytes(header, "little")
|
|
91
|
+
return self._read_exact(length)
|
|
92
|
+
|
|
93
|
+
def _read_exact(self, n: int) -> bytes:
|
|
94
|
+
# Drain anything already buffered from a previous read first.
|
|
95
|
+
while len(self._inbound) < n:
|
|
96
|
+
try:
|
|
97
|
+
chunk = self._sock.recv(65536)
|
|
98
|
+
except OSError as e:
|
|
99
|
+
self._fail(ProtocolError(f"connection closed by server: {e}"))
|
|
100
|
+
raise self._closed # type: ignore[misc]
|
|
101
|
+
if not chunk:
|
|
102
|
+
self._fail(ProtocolError("connection closed by server"))
|
|
103
|
+
raise self._closed # type: ignore[misc]
|
|
104
|
+
self._inbound += chunk
|
|
105
|
+
out = bytes(self._inbound[:n])
|
|
106
|
+
del self._inbound[:n]
|
|
107
|
+
return out
|
|
108
|
+
|
|
109
|
+
def _fail(self, err: Exception) -> None:
|
|
110
|
+
if self._closed is None:
|
|
111
|
+
self._closed = err
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def split_tls_args(tls: TlsArg) -> Tuple[TlsArg, Optional[str]]: # pragma: no cover
|
|
115
|
+
"""Reserved hook for richer TLS option parsing; identity for now."""
|
|
116
|
+
return tls, None
|
prismdb/document.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
"""The document tagged-binary codec.
|
|
2
|
+
|
|
3
|
+
Mirrors ``crates/prism-doc/src/value.rs`` (``Document::encode``/``decode``). A
|
|
4
|
+
document is ``[total:u32][count:u16]`` followed by, per field,
|
|
5
|
+
``[tag:u8][nameLen:u16][name][value bytes]``. Field value bytes use the same
|
|
6
|
+
encoding as scalar values, except documents have no Binary type."""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from typing import Dict
|
|
11
|
+
|
|
12
|
+
from ._codec import Reader, Writer
|
|
13
|
+
from .errors import ProtocolError
|
|
14
|
+
from .value import TAG, Value, decode_untagged, encode_untagged, tag_of
|
|
15
|
+
|
|
16
|
+
# A document is a plain dict; field insertion order is preserved.
|
|
17
|
+
Document = Dict[str, Value]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def encode_document(doc: Document) -> bytes:
|
|
21
|
+
"""Encode a document to its tagged-binary payload."""
|
|
22
|
+
body = Writer()
|
|
23
|
+
if len(doc) > 0xFFFF:
|
|
24
|
+
raise ProtocolError("too many document fields")
|
|
25
|
+
body.u16(len(doc))
|
|
26
|
+
for name, value in doc.items():
|
|
27
|
+
tag = tag_of(value)
|
|
28
|
+
if tag == TAG.BINARY:
|
|
29
|
+
raise ProtocolError(f'field "{name}": binary values are not supported in documents')
|
|
30
|
+
body.u8(tag)
|
|
31
|
+
body.str_u16(name)
|
|
32
|
+
encode_untagged(body, tag, value)
|
|
33
|
+
inner = body.out()
|
|
34
|
+
out = Writer()
|
|
35
|
+
out.u32(4 + len(inner)) # total length, including this u32
|
|
36
|
+
out.raw(inner)
|
|
37
|
+
return out.out()
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def decode_document(raw: bytes) -> Document:
|
|
41
|
+
"""Decode a document from its tagged-binary payload."""
|
|
42
|
+
r = Reader(raw)
|
|
43
|
+
r.u32() # total length (redundant with the frame's blob length)
|
|
44
|
+
count = r.u16()
|
|
45
|
+
doc: Document = {}
|
|
46
|
+
for _ in range(count):
|
|
47
|
+
tag = r.u8()
|
|
48
|
+
name = r.str_u16()
|
|
49
|
+
doc[name] = decode_untagged(r, tag)
|
|
50
|
+
return doc
|