asebytes 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,48 @@
1
+ Metadata-Version: 2.3
2
+ Name: asebytes
3
+ Version: 0.1.0
4
+ Summary: Add your description here
5
+ Requires-Dist: ase>=3.26.0
6
+ Requires-Dist: lmdb>=1.7.5
7
+ Requires-Dist: msgpack>=1.1.2
8
+ Requires-Dist: msgpack-numpy>=0.4.8
9
+ Requires-Python: >=3.11
10
+ Description-Content-Type: text/markdown
11
+
12
+ # asebytes
13
+
14
+ Efficient serialization and storage for ASE Atoms objects using LMDB.
15
+
16
+ ## API
17
+
18
+ - **`encode(atoms)`** - Encode an ASE Atoms object to a dict of bytes
19
+ - **`decode(data)`** - Decode bytes back into an ASE Atoms object
20
+ - **`BytesIO(file, prefix)`** - LMDB-backed list-like storage for bytes dictionaries
21
+ - **`ASEIO(file, prefix)`** - LMDB-backed list-like storage for ASE Atoms objects
22
+
23
+ ## Examples
24
+
25
+ ```python
26
+ from asebytes import ASEIO, BytesIO, encode, decode
27
+ import molify
28
+
29
+ # Generate conformers from SMILES
30
+ ethanol = molify.smiles2conformers("CCO", numConfs=100)
31
+
32
+ # Serialize/deserialize single molecule
33
+ data = encode(ethanol[0])
34
+ atoms_restored = decode(data)
35
+
36
+ # High-level: Store Atoms objects directly
37
+ db = ASEIO("conformers.lmdb")
38
+ db.extend(ethanol) # Add all conformers
39
+ mol = db[0] # Returns ase.Atoms
40
+
41
+ # Low-level: BytesIO stores serialized data
42
+ bytes_db = BytesIO("conformers.lmdb")
43
+ bytes_db.append(encode(ethanol[0])) # Manual serialization
44
+ data = bytes_db[0] # Returns dict[bytes, bytes]
45
+ mol = decode(data) # Manual deserialization
46
+
47
+ # ASEIO = BytesIO + automatic encode/decode
48
+ ```
@@ -0,0 +1,37 @@
1
+ # asebytes
2
+
3
+ Efficient serialization and storage for ASE Atoms objects using LMDB.
4
+
5
+ ## API
6
+
7
+ - **`encode(atoms)`** - Encode an ASE Atoms object to a dict of bytes
8
+ - **`decode(data)`** - Decode bytes back into an ASE Atoms object
9
+ - **`BytesIO(file, prefix)`** - LMDB-backed list-like storage for bytes dictionaries
10
+ - **`ASEIO(file, prefix)`** - LMDB-backed list-like storage for ASE Atoms objects
11
+
12
+ ## Examples
13
+
14
+ ```python
15
+ from asebytes import ASEIO, BytesIO, encode, decode
16
+ import molify
17
+
18
+ # Generate conformers from SMILES
19
+ ethanol = molify.smiles2conformers("CCO", numConfs=100)
20
+
21
+ # Serialize/deserialize single molecule
22
+ data = encode(ethanol[0])
23
+ atoms_restored = decode(data)
24
+
25
+ # High-level: Store Atoms objects directly
26
+ db = ASEIO("conformers.lmdb")
27
+ db.extend(ethanol) # Add all conformers
28
+ mol = db[0] # Returns ase.Atoms
29
+
30
+ # Low-level: BytesIO stores serialized data
31
+ bytes_db = BytesIO("conformers.lmdb")
32
+ bytes_db.append(encode(ethanol[0])) # Manual serialization
33
+ data = bytes_db[0] # Returns dict[bytes, bytes]
34
+ mol = decode(data) # Manual deserialization
35
+
36
+ # ASEIO = BytesIO + automatic encode/decode
37
+ ```
@@ -0,0 +1,38 @@
1
+ [project]
2
+ name = "asebytes"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.11"
7
+ dependencies = [
8
+ "ase>=3.26.0",
9
+ "lmdb>=1.7.5",
10
+ "msgpack>=1.1.2",
11
+ "msgpack-numpy>=0.4.8",
12
+ ]
13
+
14
+ [build-system]
15
+ requires = ["uv_build>=0.9.6,<0.10.0"]
16
+ build-backend = "uv_build"
17
+
18
+ [dependency-groups]
19
+ dev = [
20
+ "ase-db-backends>=0.10.0",
21
+ "ipykernel>=7.1.0",
22
+ "matplotlib>=3.10.7",
23
+ "molify>=0.0.1a0",
24
+ "pandas>=2.3.3",
25
+ "pytest>=8.4.2",
26
+ "pytest-benchmark>=5.2.1",
27
+ ]
28
+
29
+ [tool.pytest.ini_options]
30
+ testpaths = ["tests"]
31
+ python_files = ["test_*.py"]
32
+ python_classes = ["Test*"]
33
+ python_functions = ["test_*"]
34
+ addopts = ["-v", "--strict-markers", "-m", "not benchmark"]
35
+
36
+ markers = [
37
+ "benchmark: marks tests as benchmark tests (deselect with '-m \"not benchmark\"')",
38
+ ]
@@ -0,0 +1,10 @@
1
+ import importlib.metadata
2
+
3
+ from .decode import decode
4
+ from .io import ASEIO, BytesIO
5
+ from .metadata import get_metadata
6
+ from .encode import encode
7
+
8
+ __all__ = ["encode", "decode", "BytesIO", "ASEIO", "get_metadata"]
9
+
10
+ __version__ = importlib.metadata.version("asebytes")
@@ -0,0 +1,102 @@
1
+ import ase
2
+ import msgpack
3
+ import msgpack_numpy as m
4
+ import numpy as np
5
+ from ase.calculators.singlepoint import SinglePointCalculator
6
+ from ase.cell import Cell
7
+
8
+
9
+ def decode(data: dict[bytes, bytes], fast: bool = True) -> ase.Atoms:
10
+ """
11
+ Deserialize bytes into an ASE Atoms object.
12
+
13
+ Parameters
14
+ ----------
15
+ data : dict[bytes, bytes]
16
+ Dictionary with byte keys and msgpack-serialized byte values.
17
+ fast : bool, default=True
18
+ If True, use optimized direct attribute assignment (6x faster).
19
+ If False, use standard Atoms constructor (safer but slower).
20
+
21
+ Returns
22
+ -------
23
+ ase.Atoms
24
+ Reconstructed Atoms object.
25
+
26
+ Raises
27
+ ------
28
+ ValueError
29
+ If unknown keys are present in data.
30
+ KeyError
31
+ If required key 'arrays.numbers' is missing.
32
+ """
33
+ if b"arrays.numbers" in data:
34
+ numbers_array = msgpack.unpackb(data[b"arrays.numbers"], object_hook=m.decode)
35
+ else:
36
+ numbers_array = np.array([], dtype=int)
37
+
38
+ # Extract optional parameters with defaults
39
+ if b"cell" in data:
40
+ cell_array = msgpack.unpackb(data[b"cell"], object_hook=m.decode)
41
+ else:
42
+ cell_array = None
43
+
44
+ if b"pbc" in data:
45
+ pbc_array = msgpack.unpackb(data[b"pbc"], object_hook=m.decode)
46
+ else:
47
+ pbc_array = np.array([False, False, False], dtype=bool)
48
+
49
+ if fast:
50
+ # Skip Atoms.__init__() and directly assign attributes for better performance
51
+ atoms = ase.Atoms.__new__(ase.Atoms)
52
+
53
+ # Set cell - use provided cell or default empty cell
54
+ if cell_array is not None:
55
+ atoms._cellobj = Cell(cell_array)
56
+ else:
57
+ atoms._cellobj = Cell(np.zeros((3, 3)))
58
+
59
+ atoms._pbc = pbc_array
60
+ atoms.arrays = {"numbers": numbers_array}
61
+
62
+ # Initialize positions if not provided
63
+ if b"arrays.positions" not in data:
64
+ # Create default positions (zeros) based on number of atoms
65
+ n_atoms = len(numbers_array)
66
+ atoms.arrays["positions"] = np.zeros((n_atoms, 3))
67
+
68
+ atoms.info = {}
69
+ atoms.constraints = []
70
+ atoms._celldisp = np.zeros(3)
71
+ atoms._calc = None
72
+ else:
73
+ # Use standard Atoms constructor
74
+ atoms = ase.Atoms(numbers=numbers_array, cell=cell_array, pbc=pbc_array)
75
+
76
+ for key in data:
77
+ if key in [b"cell", b"pbc", b"arrays.numbers"]:
78
+ continue
79
+ if key.startswith(b"arrays."):
80
+ array_data = msgpack.unpackb(data[key], object_hook=m.decode)
81
+ atoms.arrays[key.decode().split("arrays.")[1]] = array_data
82
+ elif key.startswith(b"info."):
83
+ info_key = key.decode().split("info.")[1]
84
+ info_array = msgpack.unpackb(data[key], object_hook=m.decode)
85
+ atoms.info[info_key] = info_array
86
+ elif key.startswith(b"calc."):
87
+ if not hasattr(atoms, "calc") or atoms.calc is None:
88
+ atoms.calc = SinglePointCalculator(atoms)
89
+ calc_key = key.decode().split("calc.")[1]
90
+ calc_array = msgpack.unpackb(data[key], object_hook=m.decode)
91
+ atoms.calc.results[calc_key] = calc_array
92
+ elif key == b"constraints":
93
+ constraints_data = msgpack.unpackb(data[key], object_hook=m.decode)
94
+ constraints = []
95
+ for constraint_dict in constraints_data:
96
+ constraint = ase.constraints.dict2constraint(constraint_dict)
97
+ constraints.append(constraint)
98
+ atoms.set_constraint(constraints)
99
+ else:
100
+ raise ValueError(f"Unknown key in data: {key}")
101
+
102
+ return atoms
@@ -0,0 +1,68 @@
1
+ import ase
2
+ import msgpack
3
+ import msgpack_numpy as m
4
+ import numpy as np
5
+
6
+
7
+ def encode(atoms: ase.Atoms) -> dict[bytes, bytes]:
8
+ """
9
+ Serialize an ASE Atoms object into a dictionary of bytes.
10
+
11
+ Parameters
12
+ ----------
13
+ atoms : ase.Atoms
14
+ Atoms object to serialize.
15
+
16
+ Returns
17
+ -------
18
+ dict[bytes, bytes]
19
+ Dictionary with byte keys and msgpack-serialized byte values.
20
+
21
+ Raises
22
+ ------
23
+ TypeError
24
+ If input is not an ase.Atoms object.
25
+ ValueError
26
+ If any key in atoms.arrays, atoms.info, or atoms.calc.results contains a dot.
27
+ """
28
+ if not isinstance(atoms, ase.Atoms):
29
+ raise TypeError("Input must be an ase.Atoms object.")
30
+ data: dict[bytes, bytes] = {}
31
+ cell: np.ndarray = atoms.get_cell().array
32
+ data[b"cell"] = msgpack.packb(cell, default=m.encode)
33
+ data[b"pbc"] = msgpack.packb(atoms.get_pbc(), default=m.encode)
34
+
35
+ for key in atoms.arrays:
36
+ if "." in key:
37
+ raise ValueError(
38
+ f"Key '{key}' in atoms.arrays contains a dot (.), which is not allowed as it is used as a path separator."
39
+ )
40
+ data[f"arrays.{key}".encode()] = msgpack.packb(
41
+ atoms.arrays[key], default=m.encode
42
+ )
43
+ for key in atoms.info:
44
+ if "." in key:
45
+ raise ValueError(
46
+ f"Key '{key}' in atoms.info contains a dot (.), which is not allowed as it is used as a path separator."
47
+ )
48
+ value = atoms.info[key]
49
+ data[f"info.{key}".encode()] = msgpack.packb(value, default=m.encode)
50
+ if atoms.calc is not None:
51
+ for key in atoms.calc.results:
52
+ if "." in key:
53
+ raise ValueError(
54
+ f"Key '{key}' in atoms.calc.results contains a dot (.), which is not allowed as it is used as a path separator."
55
+ )
56
+ value = atoms.calc.results[key]
57
+ data[f"calc.{key}".encode()] = msgpack.packb(value, default=m.encode)
58
+
59
+ # Serialize constraints
60
+ if atoms.constraints:
61
+ constraints_data = []
62
+ for constraint in atoms.constraints:
63
+ if isinstance(constraint, ase.constraints.FixConstraint):
64
+ constraints_data.append(constraint.todict())
65
+ if constraints_data:
66
+ data[b"constraints"] = msgpack.packb(constraints_data, default=m.encode)
67
+
68
+ return data
@@ -0,0 +1,558 @@
1
+ from collections.abc import MutableSequence
2
+ from typing import Iterator
3
+
4
+ import ase
5
+ import lmdb
6
+
7
+ from asebytes.decode import decode
8
+ from asebytes.encode import encode
9
+
10
+
11
+ class ASEIO(MutableSequence):
12
+ """
13
+ LMDB-backed mutable sequence for ASE Atoms objects.
14
+
15
+ Parameters
16
+ ----------
17
+ file : str
18
+ Path to LMDB database file.
19
+ prefix : bytes, default=b""
20
+ Key prefix for namespacing entries.
21
+ """
22
+
23
+ def __init__(self, file: str, prefix: bytes = b""):
24
+ self.io = BytesIO(file, prefix)
25
+
26
+ def __getitem__(self, index: int) -> ase.Atoms:
27
+ data = self.io[index]
28
+ return decode(data)
29
+
30
+ def __setitem__(self, index: int, value: ase.Atoms) -> None:
31
+ data = encode(value)
32
+ self.io[index] = data
33
+
34
+ def __delitem__(self, index: int) -> None:
35
+ del self.io[index]
36
+
37
+ def insert(self, index: int, value: ase.Atoms) -> None:
38
+ data = encode(value)
39
+ self.io.insert(index, data)
40
+
41
+ def extend(self, values: list[ase.Atoms]) -> None:
42
+ """
43
+ Efficiently extend with multiple Atoms objects using bulk operations.
44
+
45
+ Serializes all Atoms objects first, then performs a single bulk transaction.
46
+ Much faster than calling append() in a loop.
47
+
48
+ Parameters
49
+ ----------
50
+ values : list[ase.Atoms]
51
+ Atoms objects to append.
52
+ """
53
+ # Serialize all atoms objects first
54
+ serialized_data = [encode(atoms) for atoms in values]
55
+ # Use BytesIO's bulk extend (single transaction)
56
+ self.io.extend(serialized_data)
57
+
58
+ def __len__(self) -> int:
59
+ return len(self.io)
60
+
61
+ def __iter__(self) -> Iterator:
62
+ for i in range(len(self)):
63
+ yield self[i]
64
+
65
+ def get_available_keys(self, index: int) -> list[bytes]:
66
+ """
67
+ Get all available keys for a given index.
68
+
69
+ Parameters
70
+ ----------
71
+ index : int
72
+ Logical index to query.
73
+
74
+ Returns
75
+ -------
76
+ list[bytes]
77
+ Available keys at the index.
78
+
79
+ Raises
80
+ ------
81
+ KeyError
82
+ If the index does not exist.
83
+ """
84
+ return self.io.get_available_keys(index)
85
+
86
+ def get(self, index: int, keys: list[bytes] | None = None) -> ase.Atoms:
87
+ """
88
+ Get Atoms object at index, optionally filtering to specific keys.
89
+
90
+ Parameters
91
+ ----------
92
+ index : int
93
+ Logical index to retrieve.
94
+ keys : list[bytes], optional
95
+ Keys to retrieve (e.g., b"arrays.positions", b"info.smiles", b"calc.energy").
96
+ If None, returns all data.
97
+
98
+ Returns
99
+ -------
100
+ ase.Atoms
101
+ Atoms object reconstructed from the requested keys.
102
+
103
+ Raises
104
+ ------
105
+ KeyError
106
+ If the index does not exist.
107
+ """
108
+ data = self.io.get(index, keys=keys)
109
+ return decode(data)
110
+
111
+
112
+ class BytesIO(MutableSequence):
113
+ """
114
+ LMDB-backed mutable sequence for byte dictionaries.
115
+
116
+ Parameters
117
+ ----------
118
+ file : str
119
+ Path to LMDB database file.
120
+ prefix : bytes, default=b""
121
+ Key prefix for namespacing entries.
122
+ """
123
+
124
+ def __init__(self, file: str, prefix: bytes = b""):
125
+ self.file = file
126
+ self.prefix = prefix
127
+ self.env = lmdb.open(
128
+ file,
129
+ # map_size=1099511627776,
130
+ # subdir=False,
131
+ # readonly=False,
132
+ # lock=True,
133
+ # readahead=True,
134
+ # meminit=False,
135
+ )
136
+
137
+ # Metadata helpers
138
+ def _get_count(self, txn) -> int:
139
+ """Get the current count from metadata (returns 0 if not set)."""
140
+ count_key = self.prefix + b"__meta__count"
141
+ count_bytes = txn.get(count_key)
142
+ if count_bytes is None:
143
+ return 0
144
+ return int(count_bytes.decode())
145
+
146
+ def _set_count(self, txn, count: int) -> None:
147
+ """Set the count in metadata."""
148
+ count_key = self.prefix + b"__meta__count"
149
+ txn.put(count_key, str(count).encode())
150
+
151
+ def _get_next_sort_key(self, txn) -> int:
152
+ """Get the next available sort key counter (returns 0 if not set)."""
153
+ key = self.prefix + b"__meta__next_sort_key"
154
+ value = txn.get(key)
155
+ if value is None:
156
+ return 0
157
+ return int(value.decode())
158
+
159
+ def _set_next_sort_key(self, txn, value: int) -> None:
160
+ """Set the next available sort key counter."""
161
+ key = self.prefix + b"__meta__next_sort_key"
162
+ txn.put(key, str(value).encode())
163
+
164
+ # Mapping helpers (logical_index → sort_key)
165
+ def _get_mapping(self, txn, logical_index: int) -> int | None:
166
+ """Get sort_key for a logical index (returns None if not found)."""
167
+ mapping_key = self.prefix + b"__idx__" + str(logical_index).encode()
168
+ sort_key_bytes = txn.get(mapping_key)
169
+ if sort_key_bytes is None:
170
+ return None
171
+ return int(sort_key_bytes.decode())
172
+
173
+ def _set_mapping(self, txn, logical_index: int, sort_key: int) -> None:
174
+ """Set the mapping from logical_index to sort_key."""
175
+ mapping_key = self.prefix + b"__idx__" + str(logical_index).encode()
176
+ txn.put(mapping_key, str(sort_key).encode())
177
+
178
+ def _delete_mapping(self, txn, logical_index: int) -> None:
179
+ """Delete the mapping for a logical index."""
180
+ mapping_key = self.prefix + b"__idx__" + str(logical_index).encode()
181
+ txn.delete(mapping_key)
182
+
183
+ def _allocate_sort_key(self, txn) -> int:
184
+ """Allocate a new unique sort key by incrementing the counter."""
185
+ next_key = self._get_next_sort_key(txn)
186
+ self._set_next_sort_key(txn, next_key + 1)
187
+ return next_key
188
+
189
+ # Metadata helpers for field keys
190
+ def _get_field_keys_metadata(self, txn, sort_key: int) -> list[bytes] | None:
191
+ """
192
+ Get field keys for a sort key from metadata.
193
+
194
+ Parameters
195
+ ----------
196
+ txn : lmdb.Transaction
197
+ LMDB transaction.
198
+ sort_key : int
199
+ Sort key to query.
200
+
201
+ Returns
202
+ -------
203
+ list[bytes] or None
204
+ Field keys (without prefix) or None if not found.
205
+ """
206
+ metadata_key = self.prefix + b"__keys__" + str(sort_key).encode()
207
+ metadata_bytes = txn.get(metadata_key)
208
+ if metadata_bytes is None:
209
+ return None
210
+ # Deserialize the list of keys (stored as newline-separated bytes)
211
+ return metadata_bytes.split(b"\n") if metadata_bytes else []
212
+
213
+ def _set_field_keys_metadata(
214
+ self, txn, sort_key: int, field_keys: list[bytes]
215
+ ) -> None:
216
+ """
217
+ Store field keys for a sort key in metadata.
218
+
219
+ Parameters
220
+ ----------
221
+ txn : lmdb.Transaction
222
+ LMDB transaction.
223
+ sort_key : int
224
+ Sort key.
225
+ field_keys : list[bytes]
226
+ Field keys (without prefix).
227
+ """
228
+ metadata_key = self.prefix + b"__keys__" + str(sort_key).encode()
229
+ # Serialize as newline-separated bytes
230
+ metadata_bytes = b"\n".join(field_keys)
231
+ txn.put(metadata_key, metadata_bytes)
232
+
233
+ def _delete_field_keys_metadata(self, txn, sort_key: int) -> None:
234
+ """
235
+ Delete field keys metadata for a sort key.
236
+
237
+ Parameters
238
+ ----------
239
+ txn : lmdb.Transaction
240
+ LMDB transaction.
241
+ sort_key : int
242
+ Sort key.
243
+ """
244
+ metadata_key = self.prefix + b"__keys__" + str(sort_key).encode()
245
+ txn.delete(metadata_key)
246
+
247
+ def __setitem__(self, index: int, data: dict[bytes, bytes]) -> None:
248
+ with self.env.begin(write=True) as txn:
249
+ current_count = self._get_count(txn)
250
+
251
+ # Get or allocate sort key for this index
252
+ sort_key = self._get_mapping(txn, index)
253
+ is_new_index = sort_key is None
254
+
255
+ if is_new_index:
256
+ # Allocate new unique sort key
257
+ sort_key = self._allocate_sort_key(txn)
258
+ self._set_mapping(txn, index, sort_key)
259
+ else:
260
+ # Delete existing data keys if overwriting
261
+ try:
262
+ _, _, keys_to_delete = self._get_full_keys(txn, index)
263
+ for key in keys_to_delete:
264
+ txn.delete(key)
265
+ except KeyError:
266
+ # No existing data, continue
267
+ pass
268
+
269
+ # Write new data with sort key prefix using putmulti
270
+ sort_key_str = str(sort_key).encode()
271
+ items_to_insert = [
272
+ (self.prefix + sort_key_str + b"-" + key, value)
273
+ for key, value in data.items()
274
+ ]
275
+ if items_to_insert:
276
+ cursor = txn.cursor()
277
+ cursor.putmulti(items_to_insert, dupdata=False)
278
+
279
+ # Store metadata for field keys
280
+ field_keys = list(data.keys())
281
+ self._set_field_keys_metadata(txn, sort_key, field_keys)
282
+
283
+ # Update count if needed (when index == current_count, we're appending)
284
+ if is_new_index and index >= current_count:
285
+ self._set_count(txn, index + 1)
286
+
287
+ def _get_full_keys(self, txn, index: int) -> tuple[int, bytes, list[bytes]]:
288
+ """
289
+ Get sort key, prefix, and all full keys for an index.
290
+
291
+ Parameters
292
+ ----------
293
+ txn : lmdb.Transaction
294
+ LMDB transaction.
295
+ index : int
296
+ Logical index to query.
297
+
298
+ Returns
299
+ -------
300
+ tuple[int, bytes, list[bytes]]
301
+ Tuple of (sort_key, prefix, full keys including prefix).
302
+
303
+ Raises
304
+ ------
305
+ KeyError
306
+ If the index does not exist.
307
+ """
308
+ # Look up the sort key for this logical index
309
+ sort_key = self._get_mapping(txn, index)
310
+
311
+ if sort_key is None:
312
+ raise KeyError(f"Index {index} not found")
313
+
314
+ # Build prefix
315
+ sort_key_str = str(sort_key).encode()
316
+ prefix = self.prefix + sort_key_str + b"-"
317
+
318
+ # Get field keys from metadata
319
+ field_keys = self._get_field_keys_metadata(txn, sort_key)
320
+ if field_keys is None:
321
+ raise KeyError(
322
+ f"Metadata not found for index {index} (sort_key {sort_key})"
323
+ )
324
+
325
+ # Build full keys with prefix
326
+ keys_to_fetch = [prefix + field_key for field_key in field_keys]
327
+
328
+ return sort_key, prefix, keys_to_fetch
329
+
330
+ def __getitem__(self, index: int) -> dict[bytes, bytes]:
331
+ with self.env.begin() as txn:
332
+ _, prefix, keys_to_fetch = self._get_full_keys(txn, index)
333
+
334
+ # Use getmulti for efficient batch retrieval
335
+ result = {}
336
+ if keys_to_fetch:
337
+ cursor = txn.cursor()
338
+ for key, value in cursor.getmulti(keys_to_fetch):
339
+ # Extract the field name after the sort_key prefix
340
+ field_name = key[len(prefix) :]
341
+ result[field_name] = value
342
+
343
+ return result
344
+
345
+ def get_available_keys(self, index: int) -> list[bytes]:
346
+ """
347
+ Get all available keys for a given index.
348
+
349
+ Parameters
350
+ ----------
351
+ index : int
352
+ Logical index to query.
353
+
354
+ Returns
355
+ -------
356
+ list[bytes]
357
+ Available keys at the index.
358
+
359
+ Raises
360
+ ------
361
+ KeyError
362
+ If the index does not exist.
363
+ """
364
+ with self.env.begin() as txn:
365
+ _, prefix, keys_to_fetch = self._get_full_keys(txn, index)
366
+
367
+ # Extract field names from full keys
368
+ return [key[len(prefix) :] for key in keys_to_fetch]
369
+
370
+ def get(self, index: int, keys: list[bytes] | None = None) -> dict[bytes, bytes]:
371
+ """
372
+ Get data at index, optionally filtering to specific keys.
373
+
374
+ Parameters
375
+ ----------
376
+ index : int
377
+ Logical index to retrieve.
378
+ keys : list[bytes], optional
379
+ Keys to retrieve. If None, returns all keys.
380
+
381
+ Returns
382
+ -------
383
+ dict[bytes, bytes]
384
+ Key-value pairs. If keys provided, only existing keys are returned.
385
+
386
+ Raises
387
+ ------
388
+ KeyError
389
+ If the index does not exist.
390
+ """
391
+ with self.env.begin() as txn:
392
+ _, prefix, keys_to_fetch = self._get_full_keys(txn, index)
393
+
394
+ # Filter keys if requested
395
+ if keys is not None:
396
+ keys_set = set(keys)
397
+ # Filter to only the requested keys
398
+ keys_to_fetch = [
399
+ k for k in keys_to_fetch if k[len(prefix) :] in keys_set
400
+ ]
401
+
402
+ # Use getmulti for efficient batch retrieval
403
+ result = {}
404
+ if keys_to_fetch:
405
+ cursor = txn.cursor()
406
+ for key, value in cursor.getmulti(keys_to_fetch):
407
+ # Extract the field name after the sort_key prefix
408
+ field_name = key[len(prefix) :]
409
+ result[field_name] = value
410
+
411
+ return result
412
+
413
+ def __delitem__(self, key: int) -> None:
414
+ with self.env.begin(write=True) as txn:
415
+ current_count = self._get_count(txn)
416
+
417
+ if key < 0 or key >= current_count:
418
+ raise IndexError(f"Index {key} out of range [0, {current_count})")
419
+
420
+ # Get the sort key for this index and data keys before deleting mapping
421
+ sort_key = self._get_mapping(txn, key)
422
+ if sort_key is None:
423
+ raise KeyError(f"Index {key} not found")
424
+
425
+ # Get the data keys to delete before modifying mappings
426
+ _, _, keys_to_delete = self._get_full_keys(txn, key)
427
+
428
+ # Collect all mappings that need to be shifted
429
+ # We need to shift indices [key+1, key+2, ..., count-1] down by 1
430
+ mappings_to_shift = []
431
+ for i in range(key + 1, current_count):
432
+ sk = self._get_mapping(txn, i)
433
+ if sk is not None:
434
+ mappings_to_shift.append((i, sk))
435
+
436
+ # Delete the mapping for the deleted index
437
+ self._delete_mapping(txn, key)
438
+
439
+ # Shift all subsequent mappings down by 1
440
+ # Delete old mappings first, then write new ones
441
+ for old_index, sk in mappings_to_shift:
442
+ self._delete_mapping(txn, old_index)
443
+
444
+ for old_index, sk in mappings_to_shift:
445
+ new_index = old_index - 1
446
+ self._set_mapping(txn, new_index, sk)
447
+
448
+ # Delete the data keys
449
+ for k in keys_to_delete:
450
+ txn.delete(k)
451
+
452
+ # Delete metadata for field keys
453
+ self._delete_field_keys_metadata(txn, sort_key)
454
+
455
+ # Update count
456
+ self._set_count(txn, current_count - 1)
457
+
458
+ def insert(self, index: int, input: dict[bytes, bytes]) -> None:
459
+ with self.env.begin(write=True) as txn:
460
+ current_count = self._get_count(txn)
461
+
462
+ # Clamp index to valid range [0, count]
463
+ if index < 0:
464
+ index = 0
465
+ if index > current_count:
466
+ index = current_count
467
+
468
+ # Collect all mappings that need to be shifted right
469
+ # We need to shift indices [index, index+1, ..., count-1] up by 1
470
+ mappings_to_shift = []
471
+ for i in range(index, current_count):
472
+ sk = self._get_mapping(txn, i)
473
+ if sk is not None:
474
+ mappings_to_shift.append((i, sk))
475
+
476
+ # Shift all mappings up by 1
477
+ # Do this in reverse order to avoid conflicts
478
+ # Delete old mappings first, then write new ones
479
+ for old_index, sk in mappings_to_shift:
480
+ self._delete_mapping(txn, old_index)
481
+
482
+ for old_index, sk in reversed(mappings_to_shift):
483
+ new_index = old_index + 1
484
+ self._set_mapping(txn, new_index, sk)
485
+
486
+ # Allocate a new sort key for the new item
487
+ sort_key = self._allocate_sort_key(txn)
488
+ self._set_mapping(txn, index, sort_key)
489
+
490
+ # Write the new data with sort key prefix using putmulti
491
+ sort_key_str = str(sort_key).encode()
492
+ items_to_insert = [
493
+ (self.prefix + sort_key_str + b"-" + key, value)
494
+ for key, value in input.items()
495
+ ]
496
+ if items_to_insert:
497
+ cursor = txn.cursor()
498
+ cursor.putmulti(items_to_insert, dupdata=False)
499
+
500
+ # Store metadata for field keys
501
+ field_keys = list(input.keys())
502
+ self._set_field_keys_metadata(txn, sort_key, field_keys)
503
+
504
+ # Update count
505
+ self._set_count(txn, current_count + 1)
506
+
507
+ def extend(self, items: list[dict[bytes, bytes]]) -> None:
508
+ """
509
+ Efficiently extend the sequence with multiple items using bulk operations.
510
+
511
+ Parameters
512
+ ----------
513
+ items : list[dict[bytes, bytes]]
514
+ Dictionaries to append.
515
+ """
516
+ if not items:
517
+ return
518
+
519
+ with self.env.begin(write=True) as txn:
520
+ current_count = self._get_count(txn)
521
+
522
+ # Prepare all items with their mappings, data keys, and metadata
523
+ items_to_insert = []
524
+
525
+ for idx, item in enumerate(items):
526
+ logical_index = current_count + idx
527
+ sort_key = self._allocate_sort_key(txn)
528
+ sort_key_str = str(sort_key).encode()
529
+
530
+ # Add mapping entry
531
+ mapping_key = self.prefix + b"__idx__" + str(logical_index).encode()
532
+ items_to_insert.append((mapping_key, sort_key_str))
533
+
534
+ # Collect field keys and add data entries
535
+ field_keys = list(item.keys())
536
+ for field_key, field_value in item.items():
537
+ data_key = self.prefix + sort_key_str + b"-" + field_key
538
+ items_to_insert.append((data_key, field_value))
539
+
540
+ # Add metadata entry (inline with other inserts for single putmulti)
541
+ metadata_key = self.prefix + b"__keys__" + sort_key_str
542
+ metadata_value = b"\n".join(field_keys)
543
+ items_to_insert.append((metadata_key, metadata_value))
544
+
545
+ # Bulk insert all items (mappings + data + metadata) in one call
546
+ cursor = txn.cursor()
547
+ cursor.putmulti(items_to_insert, dupdata=False)
548
+
549
+ # Update count
550
+ self._set_count(txn, current_count + len(items))
551
+
552
+ def __iter__(self):
553
+ for i in range(len(self)):
554
+ yield self[i]
555
+
556
+ def __len__(self) -> int:
557
+ with self.env.begin() as txn:
558
+ return self._get_count(txn)
@@ -0,0 +1,88 @@
1
+ import msgpack
2
+ import msgpack_numpy as m
3
+ import numpy as np
4
+
5
+
6
+ def get_metadata(data: dict[bytes, bytes]) -> dict[str, dict]:
7
+ """
8
+ Extract type, shape, and dtype information from serialized data.
9
+
10
+ Parameters
11
+ ----------
12
+ data : dict[bytes, bytes]
13
+ Dictionary with byte keys and msgpack-serialized byte values.
14
+
15
+ Returns
16
+ -------
17
+ dict[str, dict]
18
+ Mapping of decoded string keys to metadata dictionaries.
19
+ Each metadata dict contains:
20
+
21
+ - For ndarrays: {"type": "ndarray", "dtype": str, "shape": tuple}
22
+ - For numpy scalars: {"type": "numpy_scalar", "dtype": str}
23
+ - For Python types: {"type": typename} where typename is one of
24
+ "str", "int", "float", "bool", "NoneType", "list", "dict"
25
+ """
26
+ metadata = {}
27
+
28
+ for key_bytes, value_bytes in data.items():
29
+ # Decode the key from bytes to string
30
+ key = key_bytes.decode("utf-8")
31
+
32
+ # Deserialize the value
33
+ value = msgpack.unpackb(value_bytes, object_hook=m.decode)
34
+
35
+ # Determine type and extract metadata
36
+ metadata[key] = _get_value_metadata(value)
37
+
38
+ return metadata
39
+
40
+
41
+ def _get_value_metadata(value) -> dict:
42
+ """
43
+ Extract metadata for a single value.
44
+
45
+ Parameters
46
+ ----------
47
+ value : Any
48
+ Deserialized value.
49
+
50
+ Returns
51
+ -------
52
+ dict
53
+ Type information and additional metadata.
54
+ """
55
+ # Check for NumPy array
56
+ if isinstance(value, np.ndarray):
57
+ return {
58
+ "type": "ndarray",
59
+ "dtype": str(value.dtype),
60
+ "shape": value.shape,
61
+ }
62
+
63
+ # Check for NumPy scalar types
64
+ if isinstance(value, np.generic):
65
+ return {
66
+ "type": "numpy_scalar",
67
+ "dtype": value.dtype.name,
68
+ }
69
+
70
+ if isinstance(value, bytes):
71
+ return {"type": "bytes"}
72
+ elif value is None:
73
+ return {"type": "NoneType"}
74
+ elif isinstance(value, bool):
75
+ return {"type": "bool"}
76
+ elif isinstance(value, int):
77
+ return {"type": "int"}
78
+ elif isinstance(value, float):
79
+ return {"type": "float"}
80
+ elif isinstance(value, str):
81
+ return {"type": "str"}
82
+ elif isinstance(value, list):
83
+ return {"type": "list"}
84
+ elif isinstance(value, dict):
85
+ return {"type": "dict"}
86
+ else:
87
+ # Fallback for unknown types
88
+ return {"type": type(value).__name__}