embed-tree 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. embed_tree/__init__.py +121 -0
  2. embed_tree/cache/__init__.py +10 -0
  3. embed_tree/cache/json.py +7 -0
  4. embed_tree/cache/model.py +13 -0
  5. embed_tree/cache/sqlalchemy.py +7 -0
  6. embed_tree/config.py +100 -0
  7. embed_tree/embedders/__init__.py +7 -0
  8. embed_tree/embedders/huggingface.py +69 -0
  9. embed_tree/embedders/model.py +27 -0
  10. embed_tree/labelers/__init__.py +8 -0
  11. embed_tree/labelers/function.py +25 -0
  12. embed_tree/labelers/llm.py +26 -0
  13. embed_tree/labelers/model.py +38 -0
  14. embed_tree/loaders/__init__.py +17 -0
  15. embed_tree/loaders/filesystem.py +83 -0
  16. embed_tree/loaders/json.py +49 -0
  17. embed_tree/loaders/model.py +20 -0
  18. embed_tree/loaders/sqlalchemy.py +91 -0
  19. embed_tree/loaders/sqlalchemy_content.py +63 -0
  20. embed_tree/loaders/sqlite.py +21 -0
  21. embed_tree/persisters/__init__.py +15 -0
  22. embed_tree/persisters/filesystem.py +293 -0
  23. embed_tree/persisters/json.py +29 -0
  24. embed_tree/persisters/model.py +23 -0
  25. embed_tree/persisters/sqlalchemy.py +76 -0
  26. embed_tree/projectors/__init__.py +7 -0
  27. embed_tree/projectors/model.py +39 -0
  28. embed_tree/projectors/pca.py +57 -0
  29. embed_tree/providers/__init__.py +20 -0
  30. embed_tree/providers/base.py +104 -0
  31. embed_tree/providers/fake.py +26 -0
  32. embed_tree/providers/local.py +44 -0
  33. embed_tree/providers/openai.py +49 -0
  34. embed_tree/reconcilers/__init__.py +6 -0
  35. embed_tree/reconcilers/default.py +65 -0
  36. embed_tree/reconcilers/model.py +25 -0
  37. embed_tree/reducers.py +194 -0
  38. embed_tree/representation/__init__.py +27 -0
  39. embed_tree/representation/default.py +59 -0
  40. embed_tree/representation/model.py +87 -0
  41. embed_tree/store.py +5 -0
  42. embed_tree/stores/__init__.py +8 -0
  43. embed_tree/stores/file.py +32 -0
  44. embed_tree/stores/model.py +25 -0
  45. embed_tree/stores/null.py +16 -0
  46. embed_tree/taggers.py +132 -0
  47. embed_tree/tree.py +691 -0
  48. embed_tree-0.0.6.dist-info/METADATA +182 -0
  49. embed_tree-0.0.6.dist-info/RECORD +51 -0
  50. embed_tree-0.0.6.dist-info/WHEEL +5 -0
  51. embed_tree-0.0.6.dist-info/top_level.txt +1 -0
embed_tree/__init__.py ADDED
@@ -0,0 +1,121 @@
1
+ """embed-tree: an incremental hierarchical clustering tree over embeddings.
2
+
3
+ See DESIGN.md for the full design. Minimal usage:
4
+
5
+ from embed_tree import EmbedTree, TreeConfig, FileTreeStore
6
+
7
+ tree = EmbedTree(embedder=my_embed_fn, store=FileTreeStore("./tree.json"))
8
+ tree.add("some content")
9
+ hits = tree.query("similar content", k=5)
10
+ """
11
+
12
+ from .config import LLMConfig, RebalanceConfig, TreeConfig
13
+ from .embedders import HuggingFaceTextEmbedder, TextEmbedder, embed_texts
14
+ from .labelers import FunctionLabeler, LabelCandidate, Labeler, LabelRequest, LLMLabeler
15
+ from .loaders import (
16
+ FileSystemTreeLoader,
17
+ JsonTreeLoader,
18
+ SQLAlchemyContentLoader,
19
+ SQLAlchemyTreeLoader,
20
+ SQLiteTreeLoader,
21
+ TreeLoader,
22
+ )
23
+ from .persisters import (
24
+ FileSystemTreePersister,
25
+ FolderTreePersister,
26
+ JsonTreePersister,
27
+ MaterializedTreeState,
28
+ SQLAlchemyTreePersister,
29
+ TreePersister,
30
+ )
31
+ from .projectors import PCAConfig, PCAProjector, VectorProjector
32
+ from .providers import (
33
+ EmbeddingProvider,
34
+ FakeEmbeddingProvider,
35
+ OpenAIEmbeddingProvider,
36
+ SentenceTransformerProvider,
37
+ )
38
+ from .reducers import (
39
+ FreezePCAReducer,
40
+ IdentityReducer,
41
+ IncrementalPCAReducer,
42
+ Reducer,
43
+ )
44
+ from .reconcilers import DefaultTreeReconciler, TreeReconciler
45
+ from .representation import (
46
+ ContentNode,
47
+ DefaultTreeRepresentation,
48
+ KeyNode,
49
+ NodeAggregate,
50
+ NodeEmbedding,
51
+ NodeId,
52
+ PartialTree,
53
+ TreeEdge,
54
+ VectorData,
55
+ partial_tree_from_dict,
56
+ partial_tree_to_dict,
57
+ )
58
+ from .store import FileTreeStore, NullTreeStore, TreeState, TreeStore
59
+ from .taggers import KeywordTagger, LLMTagger, Tagger, make_tagger
60
+ from .tree import EmbedTree, Item, Node
61
+
62
+ __all__ = [
63
+ "EmbedTree",
64
+ "Item",
65
+ "Node",
66
+ "TreeConfig",
67
+ "RebalanceConfig",
68
+ "LLMConfig",
69
+ "TextEmbedder",
70
+ "HuggingFaceTextEmbedder",
71
+ "embed_texts",
72
+ "PCAConfig",
73
+ "PCAProjector",
74
+ "VectorProjector",
75
+ "LabelCandidate",
76
+ "LabelRequest",
77
+ "Labeler",
78
+ "FunctionLabeler",
79
+ "LLMLabeler",
80
+ "TreeLoader",
81
+ "FileSystemTreeLoader",
82
+ "JsonTreeLoader",
83
+ "SQLAlchemyContentLoader",
84
+ "SQLAlchemyTreeLoader",
85
+ "SQLiteTreeLoader",
86
+ "MaterializedTreeState",
87
+ "TreeReconciler",
88
+ "DefaultTreeReconciler",
89
+ "TreePersister",
90
+ "FolderTreePersister",
91
+ "FileSystemTreePersister",
92
+ "JsonTreePersister",
93
+ "SQLAlchemyTreePersister",
94
+ "DefaultTreeRepresentation",
95
+ "PartialTree",
96
+ "ContentNode",
97
+ "KeyNode",
98
+ "TreeEdge",
99
+ "NodeEmbedding",
100
+ "NodeAggregate",
101
+ "NodeId",
102
+ "VectorData",
103
+ "partial_tree_from_dict",
104
+ "partial_tree_to_dict",
105
+ "TreeState",
106
+ "TreeStore",
107
+ "FileTreeStore",
108
+ "NullTreeStore",
109
+ "EmbeddingProvider",
110
+ "FakeEmbeddingProvider",
111
+ "OpenAIEmbeddingProvider",
112
+ "SentenceTransformerProvider",
113
+ "Reducer",
114
+ "IdentityReducer",
115
+ "FreezePCAReducer",
116
+ "IncrementalPCAReducer",
117
+ "Tagger",
118
+ "KeywordTagger",
119
+ "LLMTagger",
120
+ "make_tagger",
121
+ ]
@@ -0,0 +1,10 @@
1
+ """Deprecated cache compatibility imports.
2
+
3
+ Use loaders plus persisters directly for new code.
4
+ """
5
+
6
+ from .json import JsonTreeCache
7
+ from .model import MaterializedTreeState, TreeCache
8
+ from .sqlalchemy import SQLAlchemyTreeCache
9
+
10
+ __all__ = ["MaterializedTreeState", "TreeCache", "JsonTreeCache", "SQLAlchemyTreeCache"]
@@ -0,0 +1,7 @@
1
+ """Compatibility import for JSON tree cache."""
2
+
3
+ from embed_tree.loaders.json import JsonTreeLoader
4
+
5
+
6
+ class JsonTreeCache(JsonTreeLoader):
7
+ """Backward-compatible name for JsonTreeLoader."""
@@ -0,0 +1,13 @@
1
+ """Deprecated cache compatibility contract."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Protocol, runtime_checkable
6
+
7
+ from embed_tree.loaders.model import TreeLoader
8
+ from embed_tree.persisters.model import MaterializedTreeState, TreePersister
9
+
10
+
11
+ @runtime_checkable
12
+ class TreeCache(TreeLoader, TreePersister, Protocol):
13
+ """Deprecated alias for TreeLoader + TreePersister."""
@@ -0,0 +1,7 @@
1
+ """Compatibility import for SQLAlchemy tree cache."""
2
+
3
+ from embed_tree.loaders.sqlalchemy import SQLAlchemyTreeLoader
4
+
5
+
6
+ class SQLAlchemyTreeCache(SQLAlchemyTreeLoader):
7
+ """Backward-compatible name for SQLAlchemyTreeLoader."""
embed_tree/config.py ADDED
@@ -0,0 +1,100 @@
1
+ """Configuration for embed-tree, as a single pydantic object.
2
+
3
+ `TreeConfig` is a plain pydantic `BaseModel` (NOT `BaseSettings`): it is
4
+ constructed explicitly and handed whole to `EmbedTree(config=...)`. It does
5
+ **not** read environment variables — every value must be passed in code, so the
6
+ configuration is always explicit and reproducible.
7
+
8
+ See DESIGN.md §4/§5.3. M1 adds PCA dimensionality reduction (pca_dims) in two
9
+ modes (freeze / incremental); see those sections for the rebalance contract.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ from typing import Literal
15
+
16
+ from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
17
+
18
+
19
+ class RebalanceConfig(BaseModel):
20
+ """When/whether to rebuild the whole tree from its leaves (DESIGN.md §4)."""
21
+
22
+ enabled: bool = True
23
+ every_n_inserts: int | None = 10_000 # auto-rebuild cadence; None disables
24
+ on_demand: bool = True # allow manual tree.rebalance()
25
+
26
+
27
+ class LLMConfig(BaseModel):
28
+ """How to auto-name taxonomy nodes (DESIGN.md §10). Provider "none" uses a
29
+ no-network keyword tagger; "openai" and "local" generate labels with an LLM.
30
+ """
31
+
32
+ provider: Literal["none", "openai", "local"] = "none"
33
+ model: str = "gpt-4o-mini" # OpenAI model id, or HF model id when local
34
+ api_key: str | None = None # OpenAI key (explicit; no env)
35
+ base_url: str | None = None # OpenAI-compatible endpoint (e.g. a local server)
36
+ max_samples: int = 15 # member texts shown to the LLM when naming a cluster
37
+ max_label_words: int = 6 # keep labels short and browsable
38
+
39
+
40
+ class TreeConfig(BaseModel):
41
+ """Top-level knobs. Defaults are tuned for the M0/M1 (<100k items) regime.
42
+
43
+ Constructed in code only — no environment-variable loading.
44
+ """
45
+
46
+ model_config = ConfigDict(protected_namespaces=()) # allow `model_args` name
47
+
48
+ # Defaults are tuned for a human-browsable taxonomy (DESIGN.md §10): a small
49
+ # fan-out and small leaves keep every level readable (<=5 sub-topics, <=10
50
+ # items per leaf). Raise both for a large-scale retrieval index instead.
51
+ max_branches: int = 5 # max sub-topics per level (k for KMeans)
52
+ leaf_capacity: int = 10 # max items in a leaf before it subdivides
53
+ split_algo: str = "kmeans" # M0: "kmeans" only
54
+
55
+ # Distance is always cosine: vectors are L2-normalized and compared with
56
+ # Euclidean (rank-equivalent on the unit sphere). Embeddings encode meaning
57
+ # in direction, not magnitude, so there is no separate distance knob.
58
+
59
+ # --- PCA dimensionality reduction (M1; see DESIGN.md §5.3) -------------
60
+ # Off by default: only worth it at scale (thousands+). At tens of items PCA
61
+ # is meaningless (too few samples) and never reaches pca_warmup anyway.
62
+ pca_dims: int | None = None # None = no reduction (operate in raw space)
63
+ pca_mode: Literal["freeze", "incremental"] = "freeze"
64
+ pca_warmup: int = 1000 # items buffered before the first PCA fit
65
+ pca_batch_size: int = 256 # incremental: partial_fit cadence
66
+
67
+ rebalance: RebalanceConfig = Field(default_factory=RebalanceConfig)
68
+ llm: LLMConfig = Field(default_factory=LLMConfig) # node auto-naming
69
+ model_args: dict = Field(default_factory=dict) # passed through to KMeans
70
+
71
+ @field_validator("max_branches")
72
+ @classmethod
73
+ def _min_branches(cls, v: int) -> int:
74
+ if v < 2:
75
+ raise ValueError("max_branches must be >= 2")
76
+ return v
77
+
78
+ @field_validator("split_algo")
79
+ @classmethod
80
+ def _supported_split(cls, v: str) -> str:
81
+ if v != "kmeans":
82
+ raise NotImplementedError(
83
+ f"split_algo={v!r} arrives in a later milestone; "
84
+ "M0 supports 'kmeans' only"
85
+ )
86
+ return v
87
+
88
+ @model_validator(mode="after")
89
+ def _cross_field(self) -> "TreeConfig":
90
+ if self.leaf_capacity < self.max_branches:
91
+ raise ValueError("leaf_capacity must be >= max_branches")
92
+ if self.pca_dims is not None:
93
+ if self.pca_dims < 2:
94
+ raise ValueError("pca_dims must be >= 2")
95
+ # PCA needs at least n_components samples to fit / partial_fit.
96
+ if self.pca_warmup < self.pca_dims:
97
+ raise ValueError("pca_warmup must be >= pca_dims")
98
+ if self.pca_batch_size < self.pca_dims:
99
+ raise ValueError("pca_batch_size must be >= pca_dims")
100
+ return self
@@ -0,0 +1,7 @@
1
+ """Embedding model integrations."""
2
+
3
+ from .huggingface import HuggingFaceTextEmbedder, embed_texts
4
+ from .model import TextEmbedder, Vector
5
+
6
+ __all__ = ["TextEmbedder", "Vector", "HuggingFaceTextEmbedder", "embed_texts"]
7
+
@@ -0,0 +1,69 @@
1
+ """Open-source local embeddings via Hugging Face sentence-transformers."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ import numpy as np
9
+
10
+ from embed_tree.providers.local import SentenceTransformerProvider
11
+
12
+
13
+ class HuggingFaceTextEmbedder(SentenceTransformerProvider):
14
+ """Sentence-transformers embedder with Mac-friendly device selection.
15
+
16
+ The model is downloaded from Hugging Face by sentence-transformers on first
17
+ use and then cached by that stack. On Apple Silicon, device="auto" prefers
18
+ MPS when PyTorch reports it as available; otherwise it falls back to CPU.
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ model: str = "BAAI/bge-small-en-v1.5",
24
+ *,
25
+ device: str | None = "auto",
26
+ cache_folder: str | Path | None = None,
27
+ model_obj: Any | None = None,
28
+ encode_kwargs: dict[str, Any] | None = None,
29
+ **kwargs: Any,
30
+ ) -> None:
31
+ resolved_device = _resolve_device(device)
32
+ self.model_name = model
33
+ self.device = resolved_device
34
+ self.cache_folder = None if cache_folder is None else str(cache_folder)
35
+ if model_obj is None:
36
+ try:
37
+ from sentence_transformers import SentenceTransformer
38
+ except ImportError as e: # pragma: no cover
39
+ raise ImportError(
40
+ 'HuggingFaceTextEmbedder needs the "local" extra: '
41
+ 'pip install "embed-tree[local]"'
42
+ ) from e
43
+ model_obj = SentenceTransformer(model, device=resolved_device, cache_folder=self.cache_folder)
44
+ super().__init__(model=model, device=resolved_device, model_obj=model_obj, encode_kwargs=encode_kwargs, **kwargs)
45
+ else:
46
+ super().__init__(model=model, device=resolved_device, model_obj=model_obj, encode_kwargs=encode_kwargs, **kwargs)
47
+
48
+
49
+ def _resolve_device(device: str | None) -> str | None:
50
+ if device != "auto":
51
+ return device
52
+ try:
53
+ import torch
54
+ except ImportError:
55
+ return None
56
+ if getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available():
57
+ return "mps"
58
+ if torch.cuda.is_available():
59
+ return "cuda"
60
+ return "cpu"
61
+
62
+
63
+ def embed_texts(embedder: Any, texts: list[str]) -> np.ndarray:
64
+ """Embed a batch through any callable or TextEmbedder-like object."""
65
+ batch_fn = getattr(embedder, "embed_batch", None)
66
+ if callable(batch_fn):
67
+ return np.asarray(batch_fn(texts))
68
+ return np.asarray([embedder(text) for text in texts])
69
+
@@ -0,0 +1,27 @@
1
+ """Embedding model contracts."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Protocol, Sequence, runtime_checkable
6
+
7
+ import numpy as np
8
+
9
+ Vector = np.ndarray
10
+
11
+
12
+ @runtime_checkable
13
+ class TextEmbedder(Protocol):
14
+ """Turn strings into embedding vectors."""
15
+
16
+ def embed(self, text: str) -> Vector:
17
+ """Embed one string."""
18
+ ...
19
+
20
+ def embed_batch(self, texts: Sequence[str]) -> np.ndarray:
21
+ """Embed strings in order."""
22
+ ...
23
+
24
+ def __call__(self, text: str) -> Vector:
25
+ """Embed one string."""
26
+ ...
27
+
@@ -0,0 +1,8 @@
1
+ """Labeling model integrations."""
2
+
3
+ from .function import FunctionLabeler
4
+ from .llm import LLMLabeler
5
+ from .model import LabelCandidate, LabelRequest, Labeler
6
+
7
+ __all__ = ["LabelCandidate", "LabelRequest", "Labeler", "FunctionLabeler", "LLMLabeler"]
8
+
@@ -0,0 +1,25 @@
1
+ """Function-backed labeler."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Callable, Iterable
6
+
7
+ from .model import LabelRequest
8
+
9
+
10
+ class FunctionLabeler:
11
+ """Adapt a cheap user function into the streaming labeler protocol."""
12
+
13
+ def __init__(self, fn: Callable[[LabelRequest], str | Iterable[str]]) -> None:
14
+ self.fn = fn
15
+
16
+ def stream(self, request: LabelRequest) -> Iterable[str]:
17
+ out = self.fn(request)
18
+ if isinstance(out, str):
19
+ yield out
20
+ else:
21
+ yield from out
22
+
23
+ def label(self, request: LabelRequest) -> str:
24
+ return "".join(self.stream(request)).strip()
25
+
@@ -0,0 +1,26 @@
1
+ """LLM-backed streaming labeler."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Iterable
6
+
7
+ from embed_tree.config import LLMConfig
8
+ from embed_tree.taggers import LLMTagger
9
+
10
+ from .model import LabelRequest
11
+
12
+
13
+ class LLMLabeler:
14
+ """Generate labels from nearby candidates using the existing LLM tagger."""
15
+
16
+ def __init__(self, config: LLMConfig | None = None, *, client: Any | None = None, pipeline: Any | None = None) -> None:
17
+ self.config = config or LLMConfig()
18
+ self.tagger = LLMTagger(self.config, client=client, pipeline=pipeline)
19
+
20
+ def stream(self, request: LabelRequest) -> Iterable[str]:
21
+ yield self.label(request)
22
+
23
+ def label(self, request: LabelRequest) -> str:
24
+ texts = [candidate.text for candidate in request.candidates]
25
+ return self.tagger(texts)
26
+
@@ -0,0 +1,38 @@
1
+ """Labeling model contracts."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+ from typing import Any, Iterable, Protocol, runtime_checkable
7
+
8
+
9
+ @dataclass(frozen=True)
10
+ class LabelCandidate:
11
+ """Nearby node or item used as context for a label."""
12
+
13
+ id: Any
14
+ text: str
15
+ distance: float | None = None
16
+ payload: Any = None
17
+
18
+
19
+ @dataclass(frozen=True)
20
+ class LabelRequest:
21
+ """Context for generating a label."""
22
+
23
+ candidates: list[LabelCandidate]
24
+ max_words: int = 6
25
+
26
+
27
+ @runtime_checkable
28
+ class Labeler(Protocol):
29
+ """Generate a label from nearby candidates."""
30
+
31
+ def stream(self, request: LabelRequest) -> Iterable[str]:
32
+ """Yield label chunks."""
33
+ ...
34
+
35
+ def label(self, request: LabelRequest) -> str:
36
+ """Return the full label."""
37
+ ...
38
+
@@ -0,0 +1,17 @@
1
+ """Tree loader contracts."""
2
+
3
+ from .filesystem import FileSystemTreeLoader
4
+ from .json import JsonTreeLoader
5
+ from .model import TreeLoader
6
+ from .sqlalchemy_content import SQLAlchemyContentLoader
7
+ from .sqlalchemy import SQLAlchemyTreeLoader
8
+ from .sqlite import SQLiteTreeLoader
9
+
10
+ __all__ = [
11
+ "TreeLoader",
12
+ "FileSystemTreeLoader",
13
+ "JsonTreeLoader",
14
+ "SQLAlchemyContentLoader",
15
+ "SQLAlchemyTreeLoader",
16
+ "SQLiteTreeLoader",
17
+ ]
@@ -0,0 +1,83 @@
1
+ """Filesystem-backed ground-truth loader."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import hashlib
6
+ from pathlib import Path
7
+ from typing import Iterable
8
+
9
+ from embed_tree.representation import ContentNode, KeyNode, PartialTree, TreeEdge
10
+
11
+
12
+ class FileSystemTreeLoader:
13
+ """Load files under a directory as content nodes.
14
+
15
+ Directory nodes are emitted as ``KeyNode`` records with edges to their child
16
+ directories/files. File node ids are MD5 hashes of their file bytes, so the
17
+ same file keeps its identity when it moves locally.
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ root: str | Path,
23
+ *,
24
+ include_suffixes: Iterable[str] | None = None,
25
+ encoding: str = "utf-8",
26
+ include_hidden: bool = False,
27
+ ) -> None:
28
+ self.root = Path(root)
29
+ self.include_suffixes = None if include_suffixes is None else {s.lower() for s in include_suffixes}
30
+ self.encoding = encoding
31
+ self.include_hidden = include_hidden
32
+
33
+ def load(self) -> PartialTree | None:
34
+ if not self.root.exists():
35
+ return None
36
+
37
+ tree = PartialTree(metadata={"source": "filesystem", "root": str(self.root)})
38
+ root_id = "."
39
+ tree.key_nodes.append(KeyNode(id=root_id, label=self.root.name or str(self.root)))
40
+
41
+ for path in sorted(self.root.rglob("*")):
42
+ rel = path.relative_to(self.root).as_posix()
43
+ if not self.include_hidden and any(part.startswith(".") for part in path.relative_to(self.root).parts):
44
+ continue
45
+ parent = path.parent.relative_to(self.root).as_posix() if path.parent != self.root else root_id
46
+ if path.is_dir():
47
+ tree.key_nodes.append(KeyNode(id=rel, label=path.name))
48
+ tree.edges.append(TreeEdge(parent_id=parent, child_id=rel))
49
+ continue
50
+ if not path.is_file() or not self._included(path):
51
+ continue
52
+ file_id = _file_md5(path)
53
+ try:
54
+ content = path.read_text(encoding=self.encoding)
55
+ except UnicodeDecodeError:
56
+ continue
57
+ tree.content_nodes.append(
58
+ ContentNode(
59
+ id=file_id,
60
+ content=content,
61
+ text=path.stem,
62
+ payload={
63
+ "path": str(path),
64
+ "relative_path": rel,
65
+ "filename": path.name,
66
+ },
67
+ version=file_id,
68
+ )
69
+ )
70
+ tree.edges.append(TreeEdge(parent_id=parent, child_id=file_id))
71
+
72
+ return tree
73
+
74
+ def _included(self, path: Path) -> bool:
75
+ return self.include_suffixes is None or path.suffix.lower() in self.include_suffixes
76
+
77
+
78
+ def _file_md5(path: Path) -> str:
79
+ digest = hashlib.md5()
80
+ with path.open("rb") as f:
81
+ for chunk in iter(lambda: f.read(1024 * 1024), b""):
82
+ digest.update(chunk)
83
+ return digest.hexdigest()
@@ -0,0 +1,49 @@
1
+ """JSON-backed tree loader."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import os
7
+ from pathlib import Path
8
+ from typing import Any
9
+
10
+ from embed_tree.persisters.model import MaterializedTreeState
11
+ from embed_tree.representation import PartialTree
12
+ from embed_tree.representation.default import partial_tree_from_dict, partial_tree_to_dict
13
+
14
+
15
+ class JsonTreeLoader:
16
+ """Load/save a PartialTree or materialized state as one JSON file."""
17
+
18
+ def __init__(self, path: str | Path) -> None:
19
+ self.path = Path(path)
20
+ self.post_init()
21
+
22
+ def post_init(self) -> None:
23
+ """Hook for implementations that need setup after construction."""
24
+ pass
25
+
26
+ def load(self) -> PartialTree | MaterializedTreeState | None:
27
+ if not self.path.exists():
28
+ return None
29
+ with self.path.open("r", encoding="utf-8") as f:
30
+ data = json.load(f)
31
+ if data.get("kind") == "partial_tree":
32
+ return partial_tree_from_dict(data["tree"])
33
+ if data.get("kind") == "materialized_tree_state":
34
+ return data["state"]
35
+ return data
36
+
37
+ def save(self, state: PartialTree | MaterializedTreeState) -> None:
38
+ if isinstance(state, PartialTree):
39
+ payload: dict[str, Any] = {"kind": "partial_tree", "tree": partial_tree_to_dict(state)}
40
+ else:
41
+ payload = {"kind": "materialized_tree_state", "state": state}
42
+
43
+ self.path.parent.mkdir(parents=True, exist_ok=True)
44
+ tmp = self.path.with_name(f"{self.path.name}.tmp.{os.getpid()}")
45
+ with tmp.open("w", encoding="utf-8") as f:
46
+ json.dump(payload, f)
47
+ f.flush()
48
+ os.fsync(f.fileno())
49
+ os.replace(tmp, self.path)
@@ -0,0 +1,20 @@
1
+ """Abstract loader contract."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Protocol, runtime_checkable
6
+
7
+ from embed_tree.representation import PartialTree
8
+
9
+
10
+ @runtime_checkable
11
+ class TreeLoader(Protocol):
12
+ """Load a partial tree from any source.
13
+
14
+ Ground truth and reusable-state inputs share this shape. Their semantics
15
+ come from the argument position where the loader is used.
16
+ """
17
+
18
+ def load(self) -> PartialTree | None:
19
+ """Return loaded tree data, or None if the source is empty."""
20
+ ...