astreum 0.1.16__py3-none-any.whl → 0.1.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of astreum might be problematic. Click here for more details.

@@ -1,734 +1,224 @@
1
- """
2
- General purpose Merkle tree implementation for the Astreum blockchain.
3
-
4
- This module provides a flexible Merkle tree implementation that can be used
5
- across the Astreum codebase, integrated with the existing storage system.
6
- Supports efficient binary search and resolvers for querying data.
7
- """
8
-
9
- import threading
10
- from dataclasses import dataclass
11
- from typing import List, Dict, Optional, Tuple, Set, Any, Callable, Union, TypeVar, Generic
12
- from enum import Enum, auto
13
-
14
- from ..utils import hash_data
15
-
16
-
17
- class MerkleNodeType(Enum):
18
- """Types of Merkle tree nodes."""
19
- LEAF = auto() # Contains actual data
20
- BRANCH = auto() # Internal node with children
21
-
22
-
23
- @dataclass
24
- class MerkleNode:
25
- """
26
- Represents a node in the Merkle tree.
27
-
28
- Attributes:
29
- node_type: Type of the node (LEAF or BRANCH)
30
- hash: The hash of this node
31
- data: The data stored in this node (for leaf nodes)
32
- left_child: Hash of the left child (for branch nodes)
33
- right_child: Hash of the right child (for branch nodes)
34
- """
35
- node_type: MerkleNodeType
36
- hash: bytes
37
- data: Optional[bytes] = None
38
- left_child: Optional[bytes] = None
39
- right_child: Optional[bytes] = None
40
-
41
- def serialize(self) -> bytes:
42
- """Serialize the node to bytes for storage."""
43
- if self.node_type == MerkleNodeType.LEAF:
44
- # Format: [1-byte type][data]
45
- type_byte = b'\x00'
46
- return type_byte + self.data
47
- else: # BRANCH
48
- # Format: [1-byte type][32-byte left child hash][32-byte right child hash]
49
- type_byte = b'\x01'
50
- return type_byte + self.left_child + self.right_child
51
-
52
- @classmethod
53
- def deserialize(cls, data: bytes) -> 'MerkleNode':
54
- """Deserialize bytes into a MerkleNode object."""
55
- type_byte = data[0]
56
- if type_byte == 0: # LEAF
57
- node_data = data[1:]
58
- node_hash = hash_data(data)
59
- return cls(
60
- node_type=MerkleNodeType.LEAF,
61
- hash=node_hash,
62
- data=node_data
63
- )
64
- elif type_byte == 1: # BRANCH
65
- left_child = data[1:33]
66
- right_child = data[33:65]
67
- node_hash = hash_data(data)
68
- return cls(
69
- node_type=MerkleNodeType.BRANCH,
70
- hash=node_hash,
71
- left_child=left_child,
72
- right_child=right_child
73
- )
74
- else:
75
- raise ValueError(f"Unknown node type: {type_byte}")
76
-
77
-
78
- @dataclass
79
- class MerkleProof:
80
- """
81
- Represents a Merkle inclusion proof.
82
-
83
- A proof consists of the original data and a series of sibling hashes
84
- that allow verification without having the entire tree.
85
-
86
- Attributes:
87
- leaf_hash: Hash of the leaf node being proven
88
- siblings: List of sibling hashes needed for verification
89
- path: Bit array indicating left (0) or right (1) at each level
90
- """
91
- leaf_hash: bytes
92
- siblings: List[bytes]
93
- path: List[bool] # False=left, True=right
94
-
95
- def verify(self, root_hash: bytes) -> bool:
96
- """
97
- Verify this proof against a root hash.
98
-
99
- Args:
100
- root_hash: The expected root hash of the Merkle tree
101
-
102
- Returns:
103
- True if the proof is valid, False otherwise
104
- """
105
- if not self.siblings:
106
- # Special case: single node tree
107
- return self.leaf_hash == root_hash
108
-
109
- current_hash = self.leaf_hash
110
-
111
- for i, sibling in enumerate(self.siblings):
112
- # At each level, hash the current hash with the sibling
113
- if self.path[i]: # right
114
- current_hash = hash_data(b'\x01' + sibling + current_hash)
115
- else: # left
116
- current_hash = hash_data(b'\x01' + current_hash + sibling)
117
-
118
- return current_hash == root_hash
119
-
120
- def serialize(self) -> bytes:
121
- """Serialize the proof to bytes."""
122
- # Convert path to bytes (1 bit per boolean)
123
- path_bytes = bytearray()
124
- for i in range(0, len(self.path), 8):
125
- byte = 0
126
- for j in range(8):
127
- if i + j < len(self.path) and self.path[i + j]:
128
- byte |= (1 << j)
129
- path_bytes.append(byte)
130
-
131
- # Format: [32-byte leaf hash][1-byte path length][path bits][siblings...]
132
- result = self.leaf_hash + len(self.path).to_bytes(1, 'big') + bytes(path_bytes)
133
-
134
- for sibling in self.siblings:
135
- result += sibling
136
-
137
- return result
138
-
139
- @classmethod
140
- def deserialize(cls, data: bytes) -> 'MerkleProof':
141
- """Deserialize bytes into a MerkleProof object."""
142
- leaf_hash = data[:32]
143
- path_length = data[32]
144
-
145
- # Path bytes length (rounded up to nearest byte)
146
- path_bytes_length = (path_length + 7) // 8
147
- path_bytes = data[33:33+path_bytes_length]
148
-
149
- # Convert path bytes to boolean list
150
- path = []
151
- for i in range(path_length):
152
- byte_index = i // 8
153
- bit_index = i % 8
154
- if byte_index < len(path_bytes):
155
- path.append(bool(path_bytes[byte_index] & (1 << bit_index)))
156
- else:
157
- path.append(False)
158
-
159
- # Calculate number of siblings from path length
160
- siblings = []
161
- siblings_offset = 33 + path_bytes_length
162
- for i in range(path_length):
163
- if siblings_offset + 32 <= len(data):
164
- siblings.append(data[siblings_offset:siblings_offset+32])
165
- siblings_offset += 32
166
-
167
- return cls(leaf_hash=leaf_hash, siblings=siblings, path=path)
168
-
169
-
170
- T = TypeVar('T')
171
-
172
-
173
- def find_first(storage, root_hash: bytes, predicate: Callable[[bytes], bool]) -> Optional[bytes]:
174
- """
175
- Find the first leaf node that matches the predicate.
176
-
177
- Args:
178
- storage: The storage instance
179
- root_hash: The Merkle tree root hash
180
- predicate: Function that takes leaf data and returns True/False
181
-
182
- Returns:
183
- The data of the first matching leaf node, or None if not found
184
- """
185
- node = _get_node(storage, root_hash)
186
- if not node:
187
- return None
188
-
189
- if node.node_type == MerkleNodeType.LEAF:
190
- if predicate(node.data):
191
- return node.data
192
- return None
193
-
194
- # Recursively search through branch nodes, left first
195
- left_result = find_first(storage, node.left_child, predicate)
196
- if left_result:
197
- return left_result
198
-
199
- right_result = find_first(storage, node.right_child, predicate)
200
- return right_result
201
-
202
-
203
- def find_all(storage, root_hash: bytes, predicate: Callable[[bytes], bool]) -> List[bytes]:
204
- """
205
- Find all leaf nodes that match the predicate.
206
-
207
- Args:
208
- storage: The storage instance
209
- root_hash: The Merkle tree root hash
210
- predicate: Function that takes leaf data and returns True/False
211
-
212
- Returns:
213
- List of data from all matching leaf nodes
214
- """
215
- results = []
216
- _find_all_recursive(storage, root_hash, predicate, results)
217
- return results
218
-
219
-
220
- def _find_all_recursive(storage, node_hash: bytes, predicate: Callable[[bytes], bool],
221
- results: List[bytes]) -> None:
222
- """
223
- Recursively find all leaf nodes that match the predicate.
224
-
225
- Args:
226
- storage: The storage instance
227
- node_hash: The current node hash
228
- predicate: Function that takes leaf data and returns True/False
229
- results: List to collect matching leaf data
230
- """
231
- node = _get_node(storage, node_hash)
232
- if not node:
233
- return
234
-
235
- if node.node_type == MerkleNodeType.LEAF:
236
- if predicate(node.data):
237
- results.append(node.data)
238
- return
239
-
240
- # Branch node, recursively search both children
241
- _find_all_recursive(storage, node.left_child, predicate, results)
242
- _find_all_recursive(storage, node.right_child, predicate, results)
243
-
244
-
245
- def map(storage, root_hash: bytes, transform: Callable[[bytes], T]) -> List[T]:
246
- """
247
- Apply a transform function to all leaf nodes and return the results.
248
-
249
- Args:
250
- storage: The storage instance
251
- root_hash: The Merkle tree root hash
252
- transform: Function that takes leaf data and returns transformed value
253
-
254
- Returns:
255
- List of transformed values from all leaf nodes
256
- """
257
- results = []
258
- _map_recursive(storage, root_hash, transform, results)
259
- return results
260
-
261
-
262
- def _map_recursive(storage, node_hash: bytes, transform: Callable[[bytes], T],
263
- results: List[T]) -> None:
264
- """
265
- Recursively apply a transform function to all leaf nodes.
266
-
267
- Args:
268
- storage: The storage instance
269
- node_hash: The current node hash
270
- transform: Function that takes leaf data and returns transformed value
271
- results: List to collect transformed values
272
- """
273
- node = _get_node(storage, node_hash)
274
- if not node:
275
- return
276
-
277
- if node.node_type == MerkleNodeType.LEAF:
278
- results.append(transform(node.data))
279
- return
280
-
281
- # Branch node, recursively apply to both children
282
- _map_recursive(storage, node.left_child, transform, results)
283
- _map_recursive(storage, node.right_child, transform, results)
284
-
285
-
286
- def binary_search(storage, root_hash: bytes, compare: Callable[[bytes], int]) -> Optional[bytes]:
287
- """
288
- Perform a binary search on a sorted Merkle tree.
289
-
290
- The tree must be sorted for this to work correctly. The compare function
291
- should return:
292
- - 0 if the data matches the target
293
- - 1 if the data is less than the target
294
- - -1 if the data is greater than the target
295
-
296
- Args:
297
- storage: The storage instance
298
- root_hash: The Merkle tree root hash
299
- compare: Function that takes data and returns -1, 0, or 1
300
-
301
- Returns:
302
- The matching data if found, None otherwise
303
- """
304
- return _binary_search_recursive(storage, root_hash, compare)
305
-
306
-
307
- def _binary_search_recursive(storage, node_hash: bytes,
308
- compare: Callable[[bytes], int]) -> Optional[bytes]:
309
- """
310
- Recursively perform a binary search on a sorted Merkle tree.
311
-
312
- Args:
313
- storage: The storage instance
314
- node_hash: The current node hash
315
- compare: Comparison function
316
-
317
- Returns:
318
- The matching data if found, None otherwise
319
- """
320
- node = _get_node(storage, node_hash)
321
- if not node:
322
- return None
323
-
324
- if node.node_type == MerkleNodeType.LEAF:
325
- # Leaf node, compare the data
326
- comparison = compare(node.data)
327
- if comparison == 0:
328
- return node.data
329
- return None
330
-
331
- # For a branch node, we need to decide which side to search
332
- # In a sorted tree, leftmost leaf < all leaves in right subtree
333
- # So we check the rightmost leaf in the left subtree
334
- leftmost_leaf = _find_rightmost_leaf(storage, node.left_child)
335
- if not leftmost_leaf:
336
- # If left subtree is empty, search right subtree
337
- return _binary_search_recursive(storage, node.right_child, compare)
338
-
339
- comparison = compare(leftmost_leaf)
340
-
341
- if comparison >= 0:
342
- # Target <= leftmost_leaf, search left subtree
343
- return _binary_search_recursive(storage, node.left_child, compare)
344
- else:
345
- # Target > leftmost_leaf, search right subtree
346
- return _binary_search_recursive(storage, node.right_child, compare)
347
-
348
-
349
- def _find_rightmost_leaf(storage, node_hash: bytes) -> Optional[bytes]:
350
- """
351
- Find the rightmost leaf in a subtree.
352
-
353
- Args:
354
- storage: The storage instance
355
- node_hash: The subtree root hash
356
-
357
- Returns:
358
- Data of the rightmost leaf, or None if the tree is empty
359
- """
360
- node = _get_node(storage, node_hash)
361
- if not node:
362
- return None
363
-
364
- if node.node_type == MerkleNodeType.LEAF:
365
- return node.data
366
-
367
- # Branch node, prioritize right
368
- right_result = _find_rightmost_leaf(storage, node.right_child)
369
- if right_result:
370
- return right_result
371
-
372
- # No right leaf, try left
373
- return _find_rightmost_leaf(storage, node.left_child)
374
-
375
-
376
- def _get_node(storage, node_hash: bytes) -> Optional[MerkleNode]:
377
- """
378
- Get a node from storage.
379
-
380
- Args:
381
- storage: The storage instance
382
- node_hash: The node hash
383
-
384
- Returns:
385
- The MerkleNode, or None if not found
386
- """
387
- node_data = storage.get(node_hash)
388
- if not node_data:
389
- return None
390
- return MerkleNode.deserialize(node_data)
391
-
392
-
393
- class MerkleTree:
394
- """
395
- A general purpose Merkle tree implementation.
396
-
397
- This class builds and manages a Merkle tree, with support for
398
- generating and verifying inclusion proofs. It integrates with
399
- the node's storage system for persistent trees.
400
- """
401
- def __init__(self, storage=None):
402
- """
403
- Initialize a new Merkle tree.
404
-
405
- Args:
406
- storage: The storage object to use for persisting nodes
407
- """
408
- self.storage = storage
409
- self.root_hash = None
410
- self.nodes = {} # In-memory cache of nodes
411
- self.lock = threading.Lock()
412
-
413
- def add(self, data: Union[bytes, List[bytes]]) -> bytes:
414
- """
415
- Add data to the Merkle tree and return the root hash.
416
-
417
- If a list is provided, a balanced tree is built from all items.
418
- If a single item is provided, it's added to the existing tree.
419
-
420
- Args:
421
- data: The data to add (single bytes object or list of bytes)
422
-
423
- Returns:
424
- The root hash of the Merkle tree after adding the data
425
- """
426
- with self.lock:
427
- if isinstance(data, list):
428
- return self._build_tree(data)
429
- elif isinstance(data, bytes):
430
- if self.root_hash is None:
431
- # First leaf
432
- return self._build_tree([data])
433
- else:
434
- # Add to existing tree
435
- return self._add_leaf(data)
436
- else:
437
- raise TypeError("Data must be bytes or list of bytes")
438
-
439
- def add_sorted(self, data: List[bytes]) -> bytes:
440
- """
441
- Add a sorted list of data to create a balanced, ordered Merkle tree.
442
-
443
- This is particularly useful for binary search operations.
444
-
445
- Args:
446
- data: Sorted list of byte arrays
447
-
448
- Returns:
449
- The root hash of the Merkle tree
450
- """
451
- with self.lock:
452
- return self._build_tree(sorted(data))
453
-
454
- def _build_tree(self, items: List[bytes]) -> bytes:
455
- """
456
- Build a balanced Merkle tree from a list of items.
457
-
458
- Args:
459
- items: List of byte arrays to include in the tree
460
-
461
- Returns:
462
- The root hash of the new tree
463
- """
464
- if not items:
465
- return None
466
-
467
- # Create leaf nodes
468
- leaf_nodes = []
469
- for item in items:
470
- leaf_node = MerkleNode(
471
- node_type=MerkleNodeType.LEAF,
472
- hash=hash_data(b'\x00' + item),
473
- data=item
474
- )
475
- self.nodes[leaf_node.hash] = leaf_node
476
- if self.storage:
477
- self.storage.put(leaf_node.hash, leaf_node.serialize())
478
- leaf_nodes.append(leaf_node)
479
-
480
- # Build tree bottom-up
481
- return self._build_tree_level(leaf_nodes)
482
-
483
- def _build_tree_level(self, nodes: List[MerkleNode]) -> bytes:
484
- """
485
- Build a tree level by pairing nodes and creating parent nodes.
486
-
487
- Args:
488
- nodes: List of nodes at the current level
489
-
490
- Returns:
491
- The root hash (if we've reached the root) or None
492
- """
493
- if not nodes:
494
- return None
495
-
496
- if len(nodes) == 1:
497
- # We've reached the root
498
- self.root_hash = nodes[0].hash
499
- return self.root_hash
500
-
501
- # Pair up nodes to create the next level
502
- next_level = []
503
-
504
- for i in range(0, len(nodes), 2):
505
- if i + 1 < len(nodes):
506
- # Create a branch node with two children
507
- branch_node = MerkleNode(
508
- node_type=MerkleNodeType.BRANCH,
509
- hash=hash_data(b'\x01' + nodes[i].hash + nodes[i+1].hash),
510
- left_child=nodes[i].hash,
511
- right_child=nodes[i+1].hash
512
- )
513
- else:
514
- # Odd number of nodes, duplicate the last one
515
- branch_node = MerkleNode(
516
- node_type=MerkleNodeType.BRANCH,
517
- hash=hash_data(b'\x01' + nodes[i].hash + nodes[i].hash),
518
- left_child=nodes[i].hash,
519
- right_child=nodes[i].hash
520
- )
521
-
522
- self.nodes[branch_node.hash] = branch_node
523
- if self.storage:
524
- self.storage.put(branch_node.hash, branch_node.serialize())
525
- next_level.append(branch_node)
526
-
527
- # Continue building up the tree
528
- return self._build_tree_level(next_level)
529
-
530
- def _add_leaf(self, data: bytes) -> bytes:
531
- """
532
- Add a single leaf to an existing tree.
533
-
534
- This is more complex and requires tree restructuring.
535
- For now, we'll rebuild the tree with the new item.
536
- A more efficient implementation would be to track all
537
- leaves and only rebuild the affected branches.
538
-
539
- Args:
540
- data: The data to add
541
-
542
- Returns:
543
- The new root hash
544
- """
545
- # Get all existing leaves
546
- leaves = self._get_all_leaves()
547
-
548
- # Add the new leaf
549
- leaves.append(data)
550
-
551
- # Rebuild the tree
552
- return self._build_tree(leaves)
553
-
554
- def _get_all_leaves(self) -> List[bytes]:
555
- """
556
- Get all leaf data from the current tree.
557
-
558
- Returns:
559
- List of leaf data
560
- """
561
- if not self.root_hash:
562
- return []
563
-
564
- leaves = []
565
- self._collect_leaves(self.root_hash, leaves)
566
- return leaves
567
-
568
- def _collect_leaves(self, node_hash: bytes, leaves: List[bytes]) -> None:
569
- """
570
- Recursively collect leaf data starting from the given node.
571
-
572
- Args:
573
- node_hash: Hash of the starting node
574
- leaves: List to collect leaf data
575
- """
576
- # Get the node
577
- node = self._get_node(node_hash)
578
- if not node:
579
- return
580
-
581
- if node.node_type == MerkleNodeType.LEAF:
582
- leaves.append(node.data)
583
- elif node.node_type == MerkleNodeType.BRANCH:
584
- self._collect_leaves(node.left_child, leaves)
585
- # Only collect from right child if it's different from left
586
- if node.right_child != node.left_child:
587
- self._collect_leaves(node.right_child, leaves)
588
-
589
- def _get_node(self, node_hash: bytes) -> Optional[MerkleNode]:
590
- """
591
- Get a node by its hash, from memory or storage.
592
-
593
- Args:
594
- node_hash: The hash of the node to get
595
-
596
- Returns:
597
- The node if found, None otherwise
598
- """
599
- # Check memory cache first
600
- if node_hash in self.nodes:
601
- return self.nodes[node_hash]
602
-
603
- # Then check storage if available
604
- if self.storage:
605
- node_data = self.storage.get(node_hash)
606
- if node_data:
607
- node = MerkleNode.deserialize(node_data)
608
- self.nodes[node_hash] = node
609
- return node
610
-
611
- return None
612
-
613
- def generate_proof(self, data: bytes) -> Optional[MerkleProof]:
614
- """
615
- Generate a Merkle proof for the given data.
616
-
617
- Args:
618
- data: The data to generate a proof for
619
-
620
- Returns:
621
- A MerkleProof object if the data is in the tree, None otherwise
622
- """
623
- if not self.root_hash:
624
- return None
625
-
626
- # Find the leaf hash for this data
627
- leaf_hash = hash_data(b'\x00' + data)
628
-
629
- # Find path from root to leaf
630
- path = []
631
- siblings = []
632
-
633
- if self._build_proof(self.root_hash, leaf_hash, path, siblings):
634
- return MerkleProof(leaf_hash=leaf_hash, siblings=siblings, path=path)
635
- else:
636
- return None
637
-
638
- def _build_proof(self, current_hash: bytes, target_hash: bytes,
639
- path: List[bool], siblings: List[bytes]) -> bool:
640
- """
641
- Recursively build a proof from the current node to the target leaf.
642
-
643
- Args:
644
- current_hash: Hash of the current node
645
- target_hash: Hash of the target leaf
646
- path: List to collect path directions (left=False, right=True)
647
- siblings: List to collect sibling hashes
648
-
649
- Returns:
650
- True if the path to the target was found, False otherwise
651
- """
652
- if current_hash == target_hash:
653
- # Found the target
654
- return True
655
-
656
- node = self._get_node(current_hash)
657
- if not node or node.node_type != MerkleNodeType.BRANCH:
658
- # Not a branch node or not found
659
- return False
660
-
661
- # Try left branch
662
- left_result = self._build_proof(node.left_child, target_hash, path, siblings)
663
- if left_result:
664
- path.append(False) # Left direction
665
- siblings.append(node.right_child)
666
- return True
667
-
668
- # Try right branch
669
- right_result = self._build_proof(node.right_child, target_hash, path, siblings)
670
- if right_result:
671
- path.append(True) # Right direction
672
- siblings.append(node.left_child)
673
- return True
674
-
675
- return False
676
-
677
- def verify_proof(self, proof: MerkleProof) -> bool:
678
- """
679
- Verify a Merkle proof against this tree's root.
680
-
681
- Args:
682
- proof: The MerkleProof to verify
683
-
684
- Returns:
685
- True if the proof is valid, False otherwise
686
- """
687
- if not self.root_hash:
688
- return False
689
-
690
- return proof.verify(self.root_hash)
691
-
692
- def create_resolver(self) -> MerkleResolver:
693
- """
694
- Create a resolver for this tree.
695
-
696
- Returns:
697
- A MerkleResolver for querying data
698
- """
699
- if not self.root_hash:
700
- raise ValueError("Cannot create resolver for empty tree")
701
-
702
- return MerkleResolver(self.storage, self.root_hash)
703
-
704
- @classmethod
705
- def load_from_storage(cls, storage, root_hash: bytes) -> 'MerkleTree':
706
- """
707
- Load a Merkle tree from storage using its root hash.
708
-
709
- Args:
710
- storage: The storage object to load from
711
- root_hash: The root hash of the tree to load
712
-
713
- Returns:
714
- A MerkleTree object initialized with the loaded tree
715
- """
716
- tree = cls(storage)
717
- tree.root_hash = root_hash
718
-
719
- # Cache the root node
720
- root_data = storage.get(root_hash)
721
- if root_data:
722
- root_node = MerkleNode.deserialize(root_data)
723
- tree.nodes[root_hash] = root_node
724
-
725
- return tree
726
-
727
- def get_root_hash(self) -> Optional[bytes]:
728
- """
729
- Get the root hash of this tree.
730
-
731
- Returns:
732
- The root hash, or None if the tree is empty
733
- """
734
- return self.root_hash
1
+ import blake3
2
+ from .storage import Storage
3
+ import astreum.utils.bytes_format as bytes_format bytes_format.decode, bytes_format.encode
4
+
5
+
6
+
7
+ class MerkleTree:
8
+ def __init__(self, storage: Storage, root_hash: bytes = None, leaves: list[bytes] = None):
9
+ """
10
+ Initialize a Merkle tree from an existing root hash or by constructing a new tree from leaf data.
11
+
12
+ If a list of leaf data is provided, the tree will be built from the bottom up,
13
+ every node will be stored in the provided storage, and the computed root hash
14
+ will be used as the tree's identifier.
15
+
16
+ :param storage: A Storage instance used for storing and retrieving tree nodes.
17
+ :param root_hash: An optional existing root hash of a Merkle tree.
18
+ :param leaves: An optional list of leaf data (each as bytes). If provided, a new tree is built.
19
+ :raises ValueError: If neither root_hash nor leaves is provided.
20
+ """
21
+ self.storage = storage
22
+ if leaves is not None:
23
+ self.root_hash = self.build_tree_from_leaves(leaves)
24
+ elif root_hash is not None:
25
+ self.root_hash = root_hash
26
+ else:
27
+ raise ValueError("Either root_hash or leaves must be provided.")
28
+
29
+ def build_tree_from_leaves(self, leaves: list[bytes]) -> bytes:
30
+ """
31
+ Construct a Merkle tree from a list of leaf data and store each node in storage.
32
+
33
+ Each leaf data entry is wrapped in a MerkleNode (with leaf=True) and stored.
34
+ Then, nodes are paired (duplicating the last node if needed when the count is odd)
35
+ to form parent nodes. For each parent node, the data is the concatenation of its
36
+ two child hashes. This process repeats until a single root hash remains.
37
+
38
+ :param leaves: A list of bytes objects, each representing leaf data.
39
+ :return: The computed root hash of the newly built tree.
40
+ """
41
+ # Create leaf nodes and store them.
42
+ current_level = []
43
+ for leaf_data in leaves:
44
+ leaf_node = MerkleNode(True, leaf_data)
45
+ leaf_hash = leaf_node.hash()
46
+ self.storage.put(leaf_hash, leaf_node.to_bytes())
47
+ current_level.append(leaf_hash)
48
+
49
+ # Build the tree upward until one node remains.
50
+ while len(current_level) > 1:
51
+ next_level = []
52
+ # If an odd number of nodes, duplicate the last node.
53
+ if len(current_level) % 2 == 1:
54
+ current_level.append(current_level[-1])
55
+ for i in range(0, len(current_level), 2):
56
+ left_hash = current_level[i]
57
+ right_hash = current_level[i + 1]
58
+ # Create a parent node from the concatenated child hashes.
59
+ parent_node = MerkleNode(False, left_hash + right_hash)
60
+ parent_hash = parent_node.hash()
61
+ self.storage.put(parent_hash, parent_node.to_bytes())
62
+ next_level.append(parent_hash)
63
+ current_level = next_level
64
+
65
+ # The remaining hash is the root of the tree.
66
+ return current_level[0]
67
+
68
+ def get(self, index: int, level: int = 0) -> bytes:
69
+ """
70
+ Retrieve the data stored in the leaf at a given index.
71
+
72
+ The method traverses the tree from the root, using the binary representation
73
+ of the index to choose which branch to follow at each level. It assumes that
74
+ non-leaf nodes store two child hashes concatenated together (each 32 bytes).
75
+
76
+ :param index: The index of the leaf to retrieve. The bits of this number determine the path.
77
+ :param level: The current tree level (used internally for recursion).
78
+ :return: The data stored in the target leaf node, or None if not found.
79
+ """
80
+ current_node = MerkleNode.from_storage(self.storage, self.root_hash)
81
+ if current_node is None:
82
+ return None
83
+
84
+ # If at a leaf node, return its data.
85
+ if current_node.leaf:
86
+ return current_node.data
87
+
88
+ # For non-leaf nodes, extract the left and right child hashes.
89
+ left_hash = current_node.data[:32]
90
+ right_hash = current_node.data[32:64]
91
+
92
+ # Use the bit at position `level` in the index to select the branch:
93
+ # 0 selects the left branch, 1 selects the right branch.
94
+ bit = (index >> level) & 1
95
+ next_hash = left_hash if bit == 0 else right_hash
96
+
97
+ # Recursively traverse the tree.
98
+ return MerkleTree(self.storage, root_hash=next_hash).get(index, level + 1)
99
+
100
+ def set(self, index: int, new_data: bytes) -> None:
101
+ """
102
+ Update the leaf at the specified index with new data, rebuilding all affected nodes.
103
+
104
+ The update process recursively creates new nodes for the branch from the updated leaf
105
+ back to the root. At each step, the old node is removed from storage and replaced with
106
+ a new node that reflects the updated hash.
107
+
108
+ :param index: The index of the leaf node to update.
109
+ :param new_data: The new data (as bytes) to store in the leaf.
110
+ """
111
+ self.root_hash = self._update(self.root_hash, index, 0, new_data)
112
+
113
+ def _update(self, node_hash: bytes, index: int, level: int, new_data: bytes) -> bytes:
114
+ """
115
+ Recursive helper function to update a node on the path to the target leaf.
116
+
117
+ For a leaf node, a new node is created with the updated data.
118
+ For an internal node, the correct branch (determined by the index and level) is updated,
119
+ and a new parent node is constructed from the updated child hash and the unchanged sibling hash.
120
+
121
+ :param node_hash: The hash of the current node to update.
122
+ :param index: The target leaf index whose path is being updated.
123
+ :param level: The current depth in the tree.
124
+ :param new_data: The new data to set at the target leaf.
125
+ :return: The hash of the newly constructed node replacing the current node.
126
+ :raises Exception: If the node is not found in storage.
127
+ """
128
+ current_node = MerkleNode.from_storage(self.storage, node_hash)
129
+ if current_node is None:
130
+ raise Exception("Node not found in storage")
131
+
132
+ if current_node.leaf:
133
+ # At the leaf, create a new node with updated data.
134
+ new_leaf = MerkleNode(True, new_data)
135
+ new_hash = new_leaf.hash()
136
+ self.storage.put(new_hash, new_leaf.to_bytes())
137
+ self.storage.delete(node_hash) # Remove the outdated node.
138
+ return new_hash
139
+ else:
140
+ # For non-leaf nodes, update the correct branch.
141
+ left_hash = current_node.data[:32]
142
+ right_hash = current_node.data[32:64]
143
+ bit = (index >> level) & 1
144
+
145
+ if bit == 0:
146
+ new_left_hash = self._update(left_hash, index, level + 1, new_data)
147
+ new_right_hash = right_hash
148
+ else:
149
+ new_left_hash = left_hash
150
+ new_right_hash = self._update(right_hash, index, level + 1, new_data)
151
+
152
+ # Create a new parent node with updated child hashes.
153
+ updated_node_data = new_left_hash + new_right_hash
154
+ new_node = MerkleNode(False, updated_node_data)
155
+ new_node_hash = new_node.hash()
156
+ self.storage.put(new_node_hash, new_node.to_bytes())
157
+ self.storage.delete(node_hash) # Remove the outdated parent node.
158
+ return new_node_hash
159
+
160
+
161
+ class MerkleNode:
162
+ def __init__(self, leaf: bool, data: bytes):
163
+ """
164
+ Initialize a Merkle node.
165
+
166
+ For a leaf node, `data` is the actual content to be stored.
167
+ For an internal node, `data` should be the concatenation of the two child hashes.
168
+
169
+ :param leaf: A boolean flag indicating whether this node is a leaf node (True) or an internal node (False).
170
+ :param data: The node's data. For leaves, the stored data; for internal nodes, concatenated child hashes.
171
+ """
172
+ self.leaf = leaf
173
+ self.data = data
174
+ self._hash = None # Cached hash value to avoid recomputation.
175
+
176
+ @classmethod
177
+ def from_bytes(cls, data: bytes) -> 'MerkleNode':
178
+ """
179
+ Deserialize a MerkleNode from its byte representation.
180
+
181
+ The input bytes are expected to be in the Astreum format, containing a leaf flag and node data.
182
+
183
+ :param data: The serialized node data.
184
+ :return: A new MerkleNode instance.
185
+ """
186
+ leaf_flag, node_data = bytes_format.decode(data)
187
+ return cls(True if leaf_flag == 1 else False, node_data)
188
+
189
+ @classmethod
190
+ def from_storage(cls, storage: Storage, hash_value: bytes) -> 'MerkleNode' or None:
191
+ """
192
+ Retrieve and deserialize a MerkleNode from storage using its hash.
193
+
194
+ :param storage: The Storage instance used to retrieve the node.
195
+ :param hash_value: The hash key under which the node is stored.
196
+ :return: A MerkleNode instance if found, otherwise None.
197
+ """
198
+ node_bytes = storage.get(hash_value)
199
+ if node_bytes is None:
200
+ return None
201
+ return cls.from_bytes(node_bytes)
202
+
203
+ def to_bytes(self) -> bytes:
204
+ """
205
+ Serialize the MerkleNode into bytes using the Astreum format.
206
+
207
+ The format encodes a list containing the leaf flag and the node data.
208
+
209
+ :return: The serialized bytes representing the node.
210
+ """
211
+ return bytes_format.encode([1 if self.leaf else 0, self.data])
212
+
213
+ def hash(self) -> bytes:
214
+ """
215
+ Compute (or retrieve a cached) hash of the node using the Blake3 algorithm.
216
+
217
+ For leaf nodes, the hash is computed over the actual data.
218
+ For internal nodes, the hash is computed over the concatenated child hashes.
219
+
220
+ :return: The Blake3 digest of the node's data.
221
+ """
222
+ if self._hash is None:
223
+ self._hash = blake3.blake3(self.data).digest()
224
+ return self._hash