prismdb 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
prismdb/__init__.py ADDED
@@ -0,0 +1,43 @@
1
+ """prismdb — a pure-Python client for PrismDB over the binary wire protocol
2
+ (``docs/specs/wire-protocol.md``). No native build, no C extensions.
3
+
4
+ from prismdb import Client, Q, U
5
+
6
+ with Client.connect(host="127.0.0.1", port=4444, username="admin", password="admin") as db:
7
+ db.sql("CREATE TABLE users (id BIGINT PRIMARY KEY, name TEXT)")
8
+ db.sql("INSERT INTO users VALUES (1, 'alice')")
9
+ print(db.sql("SELECT * FROM users").rows)
10
+ """
11
+
12
+ from .client import Client, SqlResult
13
+ from .document import Document
14
+ from .errors import ErrorCode, ErrorInfo, PrismError, PrismServerError, ProtocolError
15
+ from .query import DocQuery, Q
16
+ from .update import DocUpdate, U
17
+ from .value import TAG, ObjectId, Typed, Value, float64, int32, int64, timestamp
18
+
19
+ __version__ = "0.1.0"
20
+
21
+ __all__ = [
22
+ "Client",
23
+ "SqlResult",
24
+ "Document",
25
+ "Q",
26
+ "DocQuery",
27
+ "U",
28
+ "DocUpdate",
29
+ "ObjectId",
30
+ "Typed",
31
+ "Value",
32
+ "TAG",
33
+ "int32",
34
+ "int64",
35
+ "float64",
36
+ "timestamp",
37
+ "PrismError",
38
+ "PrismServerError",
39
+ "ProtocolError",
40
+ "ErrorInfo",
41
+ "ErrorCode",
42
+ "__version__",
43
+ ]
prismdb/_codec.py ADDED
@@ -0,0 +1,177 @@
1
+ """Low-level binary codec: a growable little-endian ``Writer``, a bounds-checked
2
+ ``Reader``, and the length-prefixed frame helpers. The byte layouts mirror
3
+ ``crates/prism-protocol/src/codec.rs`` exactly (all multi-byte integers LE)."""
4
+
5
+ from __future__ import annotations
6
+
7
+ import struct
8
+
9
+ from .errors import ProtocolError
10
+
11
+ _U64_MASK = (1 << 64) - 1
12
+ _U128_MASK = (1 << 128) - 1
13
+
14
+
15
+ class Writer:
16
+ """A growable little-endian writer over a bytearray."""
17
+
18
+ __slots__ = ("_buf",)
19
+
20
+ def __init__(self) -> None:
21
+ self._buf = bytearray()
22
+
23
+ def u8(self, v: int) -> None:
24
+ self._buf.append(v & 0xFF)
25
+
26
+ def u16(self, v: int) -> None:
27
+ self._buf += struct.pack("<H", v & 0xFFFF)
28
+
29
+ def u32(self, v: int) -> None:
30
+ self._buf += struct.pack("<I", v & 0xFFFFFFFF)
31
+
32
+ def i32(self, v: int) -> None:
33
+ self._buf += struct.pack("<i", _to_signed(v, 32))
34
+
35
+ def u64(self, v: int) -> None:
36
+ self._buf += struct.pack("<Q", v & _U64_MASK)
37
+
38
+ def i64(self, v: int) -> None:
39
+ self._buf += struct.pack("<q", _to_signed(v, 64))
40
+
41
+ def f64(self, v: float) -> None:
42
+ self._buf += struct.pack("<d", v)
43
+
44
+ def u128(self, v: int) -> None:
45
+ """A 128-bit unsigned integer as 16 little-endian bytes."""
46
+ x = v & _U128_MASK
47
+ self._buf += struct.pack("<QQ", x & _U64_MASK, (x >> 64) & _U64_MASK)
48
+
49
+ def raw(self, b: bytes) -> None:
50
+ self._buf += b
51
+
52
+ def str_u16(self, s: str) -> None:
53
+ """A UTF-8 string with a u16 length prefix."""
54
+ b = s.encode("utf-8")
55
+ self.u16(len(b))
56
+ self._buf += b
57
+
58
+ def str_u32(self, s: str) -> None:
59
+ """A UTF-8 string with a u32 length prefix."""
60
+ b = s.encode("utf-8")
61
+ self.u32(len(b))
62
+ self._buf += b
63
+
64
+ def bytes_u16(self, b: bytes) -> None:
65
+ """A byte string with a u16 length prefix."""
66
+ self.u16(len(b))
67
+ self._buf += b
68
+
69
+ def bytes_u32(self, b: bytes) -> None:
70
+ """A byte string with a u32 length prefix."""
71
+ self.u32(len(b))
72
+ self._buf += b
73
+
74
+ def out(self) -> bytes:
75
+ return bytes(self._buf)
76
+
77
+
78
+ class Reader:
79
+ """A bounds-checked little-endian reader over a bytes-like object."""
80
+
81
+ __slots__ = ("_buf", "_p")
82
+
83
+ def __init__(self, buf: bytes) -> None:
84
+ self._buf = buf
85
+ self._p = 0
86
+
87
+ def _need(self, n: int) -> None:
88
+ if self._p + n > len(self._buf):
89
+ raise ProtocolError(f"truncated: need {n} bytes at offset {self._p}")
90
+
91
+ def u8(self) -> int:
92
+ self._need(1)
93
+ v = self._buf[self._p]
94
+ self._p += 1
95
+ return v
96
+
97
+ def u16(self) -> int:
98
+ self._need(2)
99
+ v = struct.unpack_from("<H", self._buf, self._p)[0]
100
+ self._p += 2
101
+ return v
102
+
103
+ def u32(self) -> int:
104
+ self._need(4)
105
+ v = struct.unpack_from("<I", self._buf, self._p)[0]
106
+ self._p += 4
107
+ return v
108
+
109
+ def i32(self) -> int:
110
+ self._need(4)
111
+ v = struct.unpack_from("<i", self._buf, self._p)[0]
112
+ self._p += 4
113
+ return v
114
+
115
+ def u64(self) -> int:
116
+ self._need(8)
117
+ v = struct.unpack_from("<Q", self._buf, self._p)[0]
118
+ self._p += 8
119
+ return v
120
+
121
+ def i64(self) -> int:
122
+ self._need(8)
123
+ v = struct.unpack_from("<q", self._buf, self._p)[0]
124
+ self._p += 8
125
+ return v
126
+
127
+ def f64(self) -> float:
128
+ self._need(8)
129
+ v = struct.unpack_from("<d", self._buf, self._p)[0]
130
+ self._p += 8
131
+ return v
132
+
133
+ def u128(self) -> int:
134
+ self._need(16)
135
+ lo, hi = struct.unpack_from("<QQ", self._buf, self._p)
136
+ self._p += 16
137
+ return (hi << 64) | lo
138
+
139
+ def raw(self, n: int) -> bytes:
140
+ self._need(n)
141
+ s = self._buf[self._p : self._p + n]
142
+ self._p += n
143
+ return bytes(s)
144
+
145
+ def str_u16(self) -> str:
146
+ return self.raw(self.u16()).decode("utf-8")
147
+
148
+ def str_u32(self) -> str:
149
+ return self.raw(self.u32()).decode("utf-8")
150
+
151
+ def bytes_u16(self) -> bytes:
152
+ return self.raw(self.u16())
153
+
154
+ def bytes_u32(self) -> bytes:
155
+ return self.raw(self.u32())
156
+
157
+ def remaining(self) -> int:
158
+ return len(self._buf) - self._p
159
+
160
+ def expect_end(self) -> None:
161
+ """Raise unless every byte has been consumed."""
162
+ if self.remaining() != 0:
163
+ raise ProtocolError(f"{self.remaining()} trailing byte(s) after message")
164
+
165
+
166
+ def _to_signed(v: int, bits: int) -> int:
167
+ """Wrap ``v`` into a signed ``bits``-wide integer (two's complement)."""
168
+ mask = (1 << bits) - 1
169
+ v &= mask
170
+ if v >= 1 << (bits - 1):
171
+ v -= 1 << bits
172
+ return v
173
+
174
+
175
+ def frame_encode(payload: bytes) -> bytes:
176
+ """Wrap a payload in a ``[len:u32][payload]`` frame."""
177
+ return struct.pack("<I", len(payload)) + payload
prismdb/client.py ADDED
@@ -0,0 +1,289 @@
1
+ """The high-level client: connect + handshake, then SQL / KV / document calls
2
+ and transaction control. One client owns one connection = one server session,
3
+ so a ``begin()`` … ``commit()`` brackets the calls in between."""
4
+
5
+ from __future__ import annotations
6
+
7
+ from dataclasses import dataclass
8
+ from typing import Dict, List, Optional, Sequence, Union
9
+
10
+ from .connection import Connection, NoticeHandler, TlsArg
11
+ from .document import Document, decode_document, encode_document
12
+ from .errors import ErrorInfo, PrismServerError, ProtocolError
13
+ from .messages import (
14
+ AUTH_PASSWORD,
15
+ FEATURE_CONNECT_DB,
16
+ TXN_READ_ONLY,
17
+ TXN_READ_WRITE,
18
+ AuthAck,
19
+ ColumnDesc,
20
+ DocResultMsg,
21
+ HelloAck,
22
+ KvResultMsg,
23
+ Pong,
24
+ SqlResultMsg,
25
+ TxnAck,
26
+ abort_body,
27
+ auth_body,
28
+ begin_body,
29
+ commit_body,
30
+ doc_body,
31
+ doc_insert_many_body,
32
+ hello_body,
33
+ kv_delete_body,
34
+ kv_get_body,
35
+ kv_put_body,
36
+ ping_body,
37
+ sql_body,
38
+ )
39
+ from .query import DocQuery, Q, encode_doc_query
40
+ from .update import DocUpdate, encode_doc_update
41
+ from .value import ObjectId, Value
42
+
43
+ _PROTOCOL_VERSION = 1
44
+ _EMPTY = b""
45
+
46
+ BytesLike = Union[str, bytes, bytearray]
47
+
48
+
49
+ @dataclass
50
+ class SqlResult:
51
+ """A SQL result set. ``rows`` are keyed by column name; ``raw`` keeps cell order."""
52
+
53
+ columns: List[ColumnDesc]
54
+ rows: List[Dict[str, Value]]
55
+ raw: List[List[Value]]
56
+ affected_rows: int
57
+
58
+
59
+ def _fail(error: Optional[ErrorInfo]) -> "None":
60
+ raise PrismServerError(
61
+ error or ErrorInfo(code=0, message="server error", sqlstate="XX000")
62
+ )
63
+
64
+
65
+ def _bytes(v: BytesLike) -> bytes:
66
+ return v.encode("utf-8") if isinstance(v, str) else bytes(v)
67
+
68
+
69
+ class KvSurface:
70
+ """``client.kv`` — namespaced key/value operations."""
71
+
72
+ def __init__(self, client: "Client") -> None:
73
+ self._c = client
74
+
75
+ def get(self, namespace: str, key: BytesLike) -> Optional[bytes]:
76
+ reply = self._c._kv_reply(*kv_get_body(namespace, _bytes(key)))
77
+ if reply.op != 1:
78
+ raise ProtocolError("expected a KV get result")
79
+ return reply.value
80
+
81
+ def put(self, namespace: str, key: BytesLike, value: BytesLike) -> None:
82
+ self._c._kv_reply(*kv_put_body(namespace, _bytes(key), _bytes(value)))
83
+
84
+ def delete(self, namespace: str, key: BytesLike) -> None:
85
+ self._c._kv_reply(*kv_delete_body(namespace, _bytes(key)))
86
+
87
+
88
+ class DocSurface:
89
+ """``client.doc`` — document collection operations."""
90
+
91
+ def __init__(self, client: "Client") -> None:
92
+ self._c = client
93
+
94
+ def insert_one(self, collection: str, document: Document) -> ObjectId:
95
+ reply = self._c._doc_reply(*doc_body(1, collection, [encode_document(document)]))
96
+ if not reply.inserted_ids:
97
+ raise ProtocolError("insert returned no _id")
98
+ return reply.inserted_ids[0]
99
+
100
+ def insert_many(self, collection: str, documents: Sequence[Document]) -> List[ObjectId]:
101
+ blobs = [encode_document(d) for d in documents]
102
+ reply = self._c._doc_reply(*doc_insert_many_body(collection, blobs))
103
+ return reply.inserted_ids
104
+
105
+ def find(self, collection: str, query: Optional[DocQuery] = None) -> List[Document]:
106
+ reply = self._c._doc_reply(*doc_body(3, collection, [encode_doc_query(query or Q.all()), _EMPTY]))
107
+ return [decode_document(d) for d in reply.docs]
108
+
109
+ def find_one(self, collection: str, query: Optional[DocQuery] = None) -> Optional[Document]:
110
+ reply = self._c._doc_reply(*doc_body(4, collection, [encode_doc_query(query or Q.all()), _EMPTY]))
111
+ return decode_document(reply.docs[0]) if reply.docs else None
112
+
113
+ def count(self, collection: str, query: Optional[DocQuery] = None) -> int:
114
+ reply = self._c._doc_reply(*doc_body(9, collection, [encode_doc_query(query or Q.all()), _EMPTY]))
115
+ return reply.affected
116
+
117
+ def update_one(self, collection: str, query: DocQuery, update: List[DocUpdate]) -> int:
118
+ reply = self._c._doc_reply(
119
+ *doc_body(5, collection, [encode_doc_query(query), encode_doc_update(update), _EMPTY])
120
+ )
121
+ return reply.affected
122
+
123
+ def update_many(self, collection: str, query: DocQuery, update: List[DocUpdate]) -> int:
124
+ reply = self._c._doc_reply(
125
+ *doc_body(6, collection, [encode_doc_query(query), encode_doc_update(update), _EMPTY])
126
+ )
127
+ return reply.affected
128
+
129
+ def delete_one(self, collection: str, query: DocQuery) -> int:
130
+ reply = self._c._doc_reply(*doc_body(7, collection, [encode_doc_query(query), _EMPTY]))
131
+ return reply.affected
132
+
133
+ def delete_many(self, collection: str, query: DocQuery) -> int:
134
+ reply = self._c._doc_reply(*doc_body(8, collection, [encode_doc_query(query), _EMPTY]))
135
+ return reply.affected
136
+
137
+
138
+ class Client:
139
+ """A connected, authenticated Prism session."""
140
+
141
+ def __init__(self, conn: Connection) -> None:
142
+ self._conn = conn
143
+ self.kv = KvSurface(self)
144
+ self.doc = DocSurface(self)
145
+
146
+ @classmethod
147
+ def connect(
148
+ cls,
149
+ host: str = "127.0.0.1",
150
+ port: int = 4444,
151
+ *,
152
+ username: Optional[str] = None,
153
+ password: Optional[str] = None,
154
+ database: Optional[str] = None,
155
+ tls: TlsArg = None,
156
+ server_hostname: Optional[str] = None,
157
+ connect_timeout: float = 10.0,
158
+ client_name: str = "prismdb-python",
159
+ client_version: str = "0.1.0",
160
+ on_notice: Optional[NoticeHandler] = None,
161
+ ) -> "Client":
162
+ """Connect, perform the handshake, and (if ``username`` is set) authenticate."""
163
+ conn = Connection.connect(
164
+ host,
165
+ port,
166
+ tls=tls,
167
+ server_hostname=server_hostname,
168
+ connect_timeout=connect_timeout,
169
+ on_notice=on_notice,
170
+ )
171
+ client = cls(conn)
172
+ try:
173
+ connect_db_honored = client._handshake(
174
+ username, password, database or "", client_name, client_version
175
+ )
176
+ # Fall back to `USE` only when the server did not bind the database
177
+ # in the handshake (an older server without FEATURE_CONNECT_DB).
178
+ if database and not connect_db_honored:
179
+ client.sql(f"USE {database}", return_rows=False)
180
+ except Exception:
181
+ conn.close()
182
+ raise
183
+ return client
184
+
185
+ def _handshake(
186
+ self,
187
+ username: Optional[str],
188
+ password: Optional[str],
189
+ database: str,
190
+ client_name: str,
191
+ client_version: str,
192
+ ) -> bool:
193
+ features = FEATURE_CONNECT_DB if database else 0
194
+ ack = self._conn.request(*hello_body(_PROTOCOL_VERSION, client_name, client_version, features, database))
195
+ if not isinstance(ack, HelloAck):
196
+ raise ProtocolError("expected HelloAck")
197
+ if ack.status != 0:
198
+ _fail(ack.error)
199
+ connect_db_honored = (ack.features & FEATURE_CONNECT_DB) != 0 and database != ""
200
+
201
+ if username is not None:
202
+ auth_ack = self._conn.request(*auth_body(AUTH_PASSWORD, username, password or ""))
203
+ if not isinstance(auth_ack, AuthAck):
204
+ raise ProtocolError("expected AuthAck")
205
+ if auth_ack.status != 0:
206
+ _fail(auth_ack.error)
207
+ return connect_db_honored
208
+
209
+ # ---- SQL --------------------------------------------------------------
210
+
211
+ def sql(
212
+ self,
213
+ text: str,
214
+ params: Optional[Sequence[Value]] = None,
215
+ *,
216
+ return_rows: bool = True,
217
+ ) -> SqlResult:
218
+ """Execute a SQL statement. Returns rows for ``SELECT``, counts otherwise."""
219
+ reply = self._conn.request(*sql_body(text, list(params or []), 1 if return_rows else 0))
220
+ if not isinstance(reply, SqlResultMsg):
221
+ raise ProtocolError("expected SqlResult")
222
+ if reply.status != 0:
223
+ _fail(reply.error)
224
+ if reply.more_frames:
225
+ raise ProtocolError("streamed SQL results are not yet supported")
226
+ names = [c.name for c in reply.columns]
227
+ rows = [{names[i]: cell for i, cell in enumerate(cells)} for cells in reply.rows]
228
+ return SqlResult(reply.columns, rows, reply.rows, reply.affected_rows)
229
+
230
+ # ---- transactions -----------------------------------------------------
231
+
232
+ def begin(self, mode: str = "read_write") -> int:
233
+ """Begin a transaction; returns the assigned transaction id."""
234
+ ack = self._txn(*begin_body(TXN_READ_ONLY if mode == "read_only" else TXN_READ_WRITE))
235
+ return ack.txn_id
236
+
237
+ def commit(self, idempotency_key: int = 0) -> None:
238
+ """Commit the current transaction (optionally idempotent)."""
239
+ self._txn(*commit_body(idempotency_key))
240
+
241
+ def abort(self) -> None:
242
+ """Abort the current transaction."""
243
+ self._txn(*abort_body())
244
+
245
+ def _txn(self, type_code: int, body: bytes) -> TxnAck:
246
+ reply = self._conn.request(type_code, body)
247
+ if not isinstance(reply, TxnAck):
248
+ raise ProtocolError("expected TxnAck")
249
+ if reply.status != 0:
250
+ _fail(reply.error)
251
+ return reply
252
+
253
+ # ---- misc -------------------------------------------------------------
254
+
255
+ def ping(self) -> None:
256
+ """Round-trip a keep-alive ping."""
257
+ reply = self._conn.request(*ping_body())
258
+ if not isinstance(reply, Pong):
259
+ raise ProtocolError("expected Pong")
260
+
261
+ def close(self) -> None:
262
+ """Close the underlying connection."""
263
+ self._conn.close()
264
+
265
+ def __enter__(self) -> "Client":
266
+ return self
267
+
268
+ def __exit__(self, *exc: object) -> None:
269
+ self.close()
270
+
271
+ # ---- internal reply helpers -------------------------------------------
272
+
273
+ def _kv_reply(self, type_code: int, body: bytes) -> KvResultMsg:
274
+ reply = self._conn.request(type_code, body)
275
+ if not isinstance(reply, KvResultMsg):
276
+ raise ProtocolError("expected KvResult")
277
+ if reply.status != 0:
278
+ _fail(reply.error)
279
+ return reply
280
+
281
+ def _doc_reply(self, type_code: int, body: bytes) -> DocResultMsg:
282
+ reply = self._conn.request(type_code, body)
283
+ if not isinstance(reply, DocResultMsg):
284
+ raise ProtocolError("expected DocResult")
285
+ if reply.status != 0:
286
+ _fail(reply.error)
287
+ if reply.more_frames:
288
+ raise ProtocolError("streamed document results are not yet supported")
289
+ return reply
prismdb/connection.py ADDED
@@ -0,0 +1,116 @@
1
+ """The transport: a TCP (optionally TLS) socket that frames outgoing messages,
2
+ reads back full frames, and matches each reply to its request by the echoed
3
+ ``request_id``. Server-initiated notices (request_id 0) go to a handler.
4
+
5
+ The connection is synchronous: each :meth:`request` writes a frame and blocks
6
+ until the matching reply arrives, dispatching any notices seen in between."""
7
+
8
+ from __future__ import annotations
9
+
10
+ import socket
11
+ import ssl
12
+ from typing import Callable, Optional, Tuple, Union
13
+
14
+ from ._codec import frame_encode
15
+ from .errors import ProtocolError
16
+ from .messages import Notice, decode_packet, encode_packet
17
+
18
+ NoticeHandler = Callable[[Notice], None]
19
+ # ``tls``: False/None = plaintext; True = default client context; an
20
+ # ``ssl.SSLContext`` = use it as-is.
21
+ TlsArg = Union[bool, ssl.SSLContext, None]
22
+
23
+
24
+ class Connection:
25
+ """A framed, request/reply socket connection to a Prism server."""
26
+
27
+ def __init__(self, sock: socket.socket, on_notice: Optional[NoticeHandler] = None) -> None:
28
+ self._sock = sock
29
+ self._on_notice = on_notice
30
+ self._next_id = 1
31
+ self._inbound = bytearray()
32
+ self._closed: Optional[Exception] = None
33
+
34
+ @classmethod
35
+ def connect(
36
+ cls,
37
+ host: str = "127.0.0.1",
38
+ port: int = 4444,
39
+ *,
40
+ tls: TlsArg = None,
41
+ server_hostname: Optional[str] = None,
42
+ connect_timeout: float = 10.0,
43
+ on_notice: Optional[NoticeHandler] = None,
44
+ ) -> "Connection":
45
+ """Open a connection (TCP, or TLS when ``tls`` is set)."""
46
+ raw = socket.create_connection((host, port), timeout=connect_timeout)
47
+ raw.settimeout(None)
48
+ raw.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
49
+ sock: socket.socket = raw
50
+ if tls:
51
+ ctx = tls if isinstance(tls, ssl.SSLContext) else ssl.create_default_context()
52
+ sock = ctx.wrap_socket(raw, server_hostname=server_hostname or host)
53
+ return cls(sock, on_notice)
54
+
55
+ def request(self, type_code: int, body: bytes) -> object:
56
+ """Send a client message and return the matching reply message."""
57
+ if self._closed is not None:
58
+ raise self._closed
59
+ request_id = self._next_id
60
+ self._next_id = 1 if self._next_id >= 0xFFFFFFFF else self._next_id + 1
61
+ try:
62
+ self._sock.sendall(frame_encode(encode_packet(request_id, type_code, body)))
63
+ except OSError as e:
64
+ self._fail(ProtocolError(f"send failed: {e}"))
65
+ raise self._closed # type: ignore[misc]
66
+
67
+ while True:
68
+ payload = self._read_frame()
69
+ packet = decode_packet(payload)
70
+ if isinstance(packet.message, Notice):
71
+ if self._on_notice is not None:
72
+ self._on_notice(packet.message)
73
+ continue
74
+ if packet.request_id == request_id:
75
+ return packet.message
76
+ # An unmatched reply (e.g. a late response) is ignored.
77
+
78
+ def close(self) -> None:
79
+ """Close the connection. Further use raises."""
80
+ self._fail(ProtocolError("connection closed by client"))
81
+ try:
82
+ self._sock.close()
83
+ except OSError:
84
+ pass
85
+
86
+ # -- internals ----------------------------------------------------------
87
+
88
+ def _read_frame(self) -> bytes:
89
+ header = self._read_exact(4)
90
+ length = int.from_bytes(header, "little")
91
+ return self._read_exact(length)
92
+
93
+ def _read_exact(self, n: int) -> bytes:
94
+ # Drain anything already buffered from a previous read first.
95
+ while len(self._inbound) < n:
96
+ try:
97
+ chunk = self._sock.recv(65536)
98
+ except OSError as e:
99
+ self._fail(ProtocolError(f"connection closed by server: {e}"))
100
+ raise self._closed # type: ignore[misc]
101
+ if not chunk:
102
+ self._fail(ProtocolError("connection closed by server"))
103
+ raise self._closed # type: ignore[misc]
104
+ self._inbound += chunk
105
+ out = bytes(self._inbound[:n])
106
+ del self._inbound[:n]
107
+ return out
108
+
109
+ def _fail(self, err: Exception) -> None:
110
+ if self._closed is None:
111
+ self._closed = err
112
+
113
+
114
+ def split_tls_args(tls: TlsArg) -> Tuple[TlsArg, Optional[str]]: # pragma: no cover
115
+ """Reserved hook for richer TLS option parsing; identity for now."""
116
+ return tls, None
prismdb/document.py ADDED
@@ -0,0 +1,50 @@
1
+ """The document tagged-binary codec.
2
+
3
+ Mirrors ``crates/prism-doc/src/value.rs`` (``Document::encode``/``decode``). A
4
+ document is ``[total:u32][count:u16]`` followed by, per field,
5
+ ``[tag:u8][nameLen:u16][name][value bytes]``. Field value bytes use the same
6
+ encoding as scalar values, except documents have no Binary type."""
7
+
8
+ from __future__ import annotations
9
+
10
+ from typing import Dict
11
+
12
+ from ._codec import Reader, Writer
13
+ from .errors import ProtocolError
14
+ from .value import TAG, Value, decode_untagged, encode_untagged, tag_of
15
+
16
+ # A document is a plain dict; field insertion order is preserved.
17
+ Document = Dict[str, Value]
18
+
19
+
20
+ def encode_document(doc: Document) -> bytes:
21
+ """Encode a document to its tagged-binary payload."""
22
+ body = Writer()
23
+ if len(doc) > 0xFFFF:
24
+ raise ProtocolError("too many document fields")
25
+ body.u16(len(doc))
26
+ for name, value in doc.items():
27
+ tag = tag_of(value)
28
+ if tag == TAG.BINARY:
29
+ raise ProtocolError(f'field "{name}": binary values are not supported in documents')
30
+ body.u8(tag)
31
+ body.str_u16(name)
32
+ encode_untagged(body, tag, value)
33
+ inner = body.out()
34
+ out = Writer()
35
+ out.u32(4 + len(inner)) # total length, including this u32
36
+ out.raw(inner)
37
+ return out.out()
38
+
39
+
40
+ def decode_document(raw: bytes) -> Document:
41
+ """Decode a document from its tagged-binary payload."""
42
+ r = Reader(raw)
43
+ r.u32() # total length (redundant with the frame's blob length)
44
+ count = r.u16()
45
+ doc: Document = {}
46
+ for _ in range(count):
47
+ tag = r.u8()
48
+ name = r.str_u16()
49
+ doc[name] = decode_untagged(r, tag)
50
+ return doc