embed-tree 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- embed_tree/__init__.py +121 -0
- embed_tree/cache/__init__.py +10 -0
- embed_tree/cache/json.py +7 -0
- embed_tree/cache/model.py +13 -0
- embed_tree/cache/sqlalchemy.py +7 -0
- embed_tree/config.py +100 -0
- embed_tree/embedders/__init__.py +7 -0
- embed_tree/embedders/huggingface.py +69 -0
- embed_tree/embedders/model.py +27 -0
- embed_tree/labelers/__init__.py +8 -0
- embed_tree/labelers/function.py +25 -0
- embed_tree/labelers/llm.py +26 -0
- embed_tree/labelers/model.py +38 -0
- embed_tree/loaders/__init__.py +17 -0
- embed_tree/loaders/filesystem.py +83 -0
- embed_tree/loaders/json.py +49 -0
- embed_tree/loaders/model.py +20 -0
- embed_tree/loaders/sqlalchemy.py +91 -0
- embed_tree/loaders/sqlalchemy_content.py +63 -0
- embed_tree/loaders/sqlite.py +21 -0
- embed_tree/persisters/__init__.py +15 -0
- embed_tree/persisters/filesystem.py +293 -0
- embed_tree/persisters/json.py +29 -0
- embed_tree/persisters/model.py +23 -0
- embed_tree/persisters/sqlalchemy.py +76 -0
- embed_tree/projectors/__init__.py +7 -0
- embed_tree/projectors/model.py +39 -0
- embed_tree/projectors/pca.py +57 -0
- embed_tree/providers/__init__.py +20 -0
- embed_tree/providers/base.py +104 -0
- embed_tree/providers/fake.py +26 -0
- embed_tree/providers/local.py +44 -0
- embed_tree/providers/openai.py +49 -0
- embed_tree/reconcilers/__init__.py +6 -0
- embed_tree/reconcilers/default.py +65 -0
- embed_tree/reconcilers/model.py +25 -0
- embed_tree/reducers.py +194 -0
- embed_tree/representation/__init__.py +27 -0
- embed_tree/representation/default.py +59 -0
- embed_tree/representation/model.py +87 -0
- embed_tree/store.py +5 -0
- embed_tree/stores/__init__.py +8 -0
- embed_tree/stores/file.py +32 -0
- embed_tree/stores/model.py +25 -0
- embed_tree/stores/null.py +16 -0
- embed_tree/taggers.py +132 -0
- embed_tree/tree.py +691 -0
- embed_tree-0.0.6.dist-info/METADATA +182 -0
- embed_tree-0.0.6.dist-info/RECORD +51 -0
- embed_tree-0.0.6.dist-info/WHEEL +5 -0
- embed_tree-0.0.6.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
"""SQLAlchemy-backed tree loader."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from embed_tree.persisters.model import MaterializedTreeState
|
|
8
|
+
from embed_tree.representation import PartialTree
|
|
9
|
+
from embed_tree.representation.default import partial_tree_from_dict, partial_tree_to_dict
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class SQLAlchemyTreeLoader:
|
|
13
|
+
"""Load/save tree representation from an existing SQL table.
|
|
14
|
+
|
|
15
|
+
This base implementation deliberately does not create tables or run
|
|
16
|
+
migrations. Database shape is an application decision. Subclasses can
|
|
17
|
+
override ``post_init`` for setup; ``SQLiteTreeLoader`` is the built-in
|
|
18
|
+
simple implementation that creates its table.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
engine_or_url: Any,
|
|
24
|
+
*,
|
|
25
|
+
table_name: str = "embed_tree_state",
|
|
26
|
+
cache_key: str = "default",
|
|
27
|
+
) -> None:
|
|
28
|
+
self.engine_or_url = engine_or_url
|
|
29
|
+
self.table_name = table_name
|
|
30
|
+
self.cache_key = cache_key
|
|
31
|
+
sa = _sqlalchemy()
|
|
32
|
+
self.engine = _engine(sa, self.engine_or_url)
|
|
33
|
+
self.post_init()
|
|
34
|
+
|
|
35
|
+
def post_init(self) -> None:
|
|
36
|
+
"""Hook for migrations or validation owned by the caller/subclass."""
|
|
37
|
+
pass
|
|
38
|
+
|
|
39
|
+
def load(self) -> PartialTree | MaterializedTreeState | None:
|
|
40
|
+
sa = _sqlalchemy()
|
|
41
|
+
table = self._table(sa)
|
|
42
|
+
stmt = sa.select(table.c.kind, table.c.payload).where(table.c.cache_key == self.cache_key)
|
|
43
|
+
with self.engine.connect() as conn:
|
|
44
|
+
row = conn.execute(stmt).mappings().first()
|
|
45
|
+
if row is None:
|
|
46
|
+
return None
|
|
47
|
+
if row["kind"] == "partial_tree":
|
|
48
|
+
return partial_tree_from_dict(row["payload"])
|
|
49
|
+
if row["kind"] == "materialized_tree_state":
|
|
50
|
+
return row["payload"]
|
|
51
|
+
return row["payload"]
|
|
52
|
+
|
|
53
|
+
def save(self, state: PartialTree | MaterializedTreeState) -> None:
|
|
54
|
+
sa = _sqlalchemy()
|
|
55
|
+
table = self._table(sa)
|
|
56
|
+
if isinstance(state, PartialTree):
|
|
57
|
+
kind = "partial_tree"
|
|
58
|
+
payload: dict[str, Any] = partial_tree_to_dict(state)
|
|
59
|
+
else:
|
|
60
|
+
kind = "materialized_tree_state"
|
|
61
|
+
payload = state
|
|
62
|
+
|
|
63
|
+
delete = table.delete().where(table.c.cache_key == self.cache_key)
|
|
64
|
+
insert = table.insert().values(cache_key=self.cache_key, kind=kind, payload=payload)
|
|
65
|
+
with self.engine.begin() as conn:
|
|
66
|
+
conn.execute(delete)
|
|
67
|
+
conn.execute(insert)
|
|
68
|
+
|
|
69
|
+
def _table(self, sa: Any) -> Any:
|
|
70
|
+
meta = sa.MetaData()
|
|
71
|
+
return sa.Table(
|
|
72
|
+
self.table_name,
|
|
73
|
+
meta,
|
|
74
|
+
sa.Column("cache_key", sa.String, primary_key=True),
|
|
75
|
+
sa.Column("kind", sa.String, nullable=False),
|
|
76
|
+
sa.Column("payload", sa.JSON, nullable=False),
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _sqlalchemy() -> Any:
|
|
81
|
+
try:
|
|
82
|
+
import sqlalchemy as sa
|
|
83
|
+
except ImportError as e: # pragma: no cover
|
|
84
|
+
raise ImportError('SQLAlchemyTreeLoader needs the "sql" extra: pip install "embed-tree[sql]"') from e
|
|
85
|
+
return sa
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _engine(sa: Any, engine_or_url: Any) -> Any:
|
|
89
|
+
if isinstance(engine_or_url, str):
|
|
90
|
+
return sa.create_engine(engine_or_url)
|
|
91
|
+
return engine_or_url
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
"""SQLAlchemy-backed content loader."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any, Iterable
|
|
6
|
+
|
|
7
|
+
from embed_tree.representation import ContentNode, PartialTree
|
|
8
|
+
|
|
9
|
+
from .sqlalchemy import _engine, _sqlalchemy
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class SQLAlchemyContentLoader:
|
|
13
|
+
"""Load content nodes from an existing SQL table via SQLAlchemy Core."""
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
engine_or_url: Any,
|
|
18
|
+
table_name: str,
|
|
19
|
+
*,
|
|
20
|
+
id_column: str = "id",
|
|
21
|
+
content_column: str = "content",
|
|
22
|
+
text_column: str | None = None,
|
|
23
|
+
payload_columns: Iterable[str] | None = None,
|
|
24
|
+
where: Any | None = None,
|
|
25
|
+
) -> None:
|
|
26
|
+
self.engine_or_url = engine_or_url
|
|
27
|
+
self.table_name = table_name
|
|
28
|
+
self.id_column = id_column
|
|
29
|
+
self.content_column = content_column
|
|
30
|
+
self.text_column = text_column
|
|
31
|
+
self.payload_columns = list(payload_columns or [])
|
|
32
|
+
self.where = where
|
|
33
|
+
self.post_init()
|
|
34
|
+
|
|
35
|
+
def post_init(self) -> None:
|
|
36
|
+
"""Hook for migrations or validation owned by the caller/subclass."""
|
|
37
|
+
pass
|
|
38
|
+
|
|
39
|
+
def load(self) -> PartialTree | None:
|
|
40
|
+
sa = _sqlalchemy()
|
|
41
|
+
engine = _engine(sa, self.engine_or_url)
|
|
42
|
+
meta = sa.MetaData()
|
|
43
|
+
table = sa.Table(self.table_name, meta, autoload_with=engine)
|
|
44
|
+
|
|
45
|
+
stmt = sa.select(table)
|
|
46
|
+
if self.where is not None:
|
|
47
|
+
stmt = stmt.where(self.where)
|
|
48
|
+
|
|
49
|
+
nodes: list[ContentNode] = []
|
|
50
|
+
with engine.connect() as conn:
|
|
51
|
+
for row in conn.execute(stmt).mappings():
|
|
52
|
+
payload = {name: row[name] for name in self.payload_columns}
|
|
53
|
+
nodes.append(
|
|
54
|
+
ContentNode(
|
|
55
|
+
id=row[self.id_column],
|
|
56
|
+
content=row[self.content_column],
|
|
57
|
+
text=None if self.text_column is None else row[self.text_column],
|
|
58
|
+
payload=payload or None,
|
|
59
|
+
)
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
return PartialTree(content_nodes=nodes, metadata={"source": "sqlalchemy", "table": self.table_name})
|
|
63
|
+
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
"""SQLite-specific tree loader."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
from .sqlalchemy import SQLAlchemyTreeLoader, _sqlalchemy
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class SQLiteTreeLoader(SQLAlchemyTreeLoader):
|
|
11
|
+
"""SQLite tree-state loader with built-in table creation."""
|
|
12
|
+
|
|
13
|
+
def __init__(self, path: str | Path, *, table_name: str = "embed_tree_state", cache_key: str = "default") -> None:
|
|
14
|
+
self.path = Path(path)
|
|
15
|
+
super().__init__(f"sqlite:///{self.path}", table_name=table_name, cache_key=cache_key)
|
|
16
|
+
|
|
17
|
+
def post_init(self) -> None:
|
|
18
|
+
sa = _sqlalchemy()
|
|
19
|
+
self.path.parent.mkdir(parents=True, exist_ok=True)
|
|
20
|
+
self._table(sa).metadata.create_all(self.engine)
|
|
21
|
+
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""Tree persister contracts and implementations."""
|
|
2
|
+
|
|
3
|
+
from .filesystem import FileSystemTreePersister, FolderTreePersister
|
|
4
|
+
from .json import JsonTreePersister
|
|
5
|
+
from .model import MaterializedTreeState, TreePersister
|
|
6
|
+
from .sqlalchemy import SQLAlchemyTreePersister
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"TreePersister",
|
|
10
|
+
"MaterializedTreeState",
|
|
11
|
+
"FolderTreePersister",
|
|
12
|
+
"FileSystemTreePersister",
|
|
13
|
+
"JsonTreePersister",
|
|
14
|
+
"SQLAlchemyTreePersister",
|
|
15
|
+
]
|
|
@@ -0,0 +1,293 @@
|
|
|
1
|
+
"""Folder-backed external persister."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import hashlib
|
|
6
|
+
import re
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
from embed_tree.representation import ContentNode, KeyNode, NodeId, PartialTree
|
|
12
|
+
|
|
13
|
+
_UNSAFE_PATH_CHARS = re.compile(r'[\\/:*?"<>|#\[\]^\n\r\t]+')
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class FolderTreePersister:
|
|
17
|
+
"""Materialize a live tree as folders and files.
|
|
18
|
+
|
|
19
|
+
The local folder is treated as mutable ground truth: every save reloads its
|
|
20
|
+
current file-md5 map, compares it with the target layout from the in-memory
|
|
21
|
+
tree, moves known files into place, creates missing files only when content
|
|
22
|
+
is available, prunes empty folders, and leaves unknown files untouched.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
root: str | Path,
|
|
28
|
+
*,
|
|
29
|
+
include_suffixes: set[str] | list[str] | tuple[str, ...] | None = None,
|
|
30
|
+
encoding: str = "utf-8",
|
|
31
|
+
include_hidden: bool = False,
|
|
32
|
+
) -> None:
|
|
33
|
+
self.root = Path(root)
|
|
34
|
+
self.include_suffixes = None if include_suffixes is None else {s.lower() for s in include_suffixes}
|
|
35
|
+
self.encoding = encoding
|
|
36
|
+
self.include_hidden = include_hidden
|
|
37
|
+
|
|
38
|
+
def save(self, state: Any) -> None:
|
|
39
|
+
self.root.mkdir(parents=True, exist_ok=True)
|
|
40
|
+
current = self._current_files_by_md5()
|
|
41
|
+
desired = self._desired_files(state, current)
|
|
42
|
+
|
|
43
|
+
reserved: set[Path] = set()
|
|
44
|
+
for file in desired:
|
|
45
|
+
target = _unique_path(self.root / file.folder / file.filename, reserved)
|
|
46
|
+
reserved.add(target)
|
|
47
|
+
existing = current.get(file.md5)
|
|
48
|
+
if existing is not None:
|
|
49
|
+
self._move_file(existing, target)
|
|
50
|
+
elif file.content is not None:
|
|
51
|
+
self._write_file(target, file.content)
|
|
52
|
+
|
|
53
|
+
self._prune_empty_dirs()
|
|
54
|
+
|
|
55
|
+
def _desired_files(self, state: Any, current: dict[str, Path]) -> list["_DesiredFile"]:
|
|
56
|
+
tree = state.get_tree() if hasattr(state, "get_tree") and callable(state.get_tree) else state
|
|
57
|
+
if isinstance(tree, PartialTree):
|
|
58
|
+
return self._desired_files_from_partial_tree(tree, current)
|
|
59
|
+
if _looks_like_live_node(tree):
|
|
60
|
+
return list(self._desired_files_from_live_node(tree, current))
|
|
61
|
+
raise TypeError("FolderTreePersister persists an EmbedTree, live Node, or PartialTree")
|
|
62
|
+
|
|
63
|
+
def _current_files_by_md5(self) -> dict[str, Path]:
|
|
64
|
+
files: dict[str, Path] = {}
|
|
65
|
+
for path in sorted(self.root.rglob("*")):
|
|
66
|
+
if not path.is_file() or not self._included(path):
|
|
67
|
+
continue
|
|
68
|
+
rel_parts = path.relative_to(self.root).parts
|
|
69
|
+
if not self.include_hidden and any(part.startswith(".") for part in rel_parts):
|
|
70
|
+
continue
|
|
71
|
+
files.setdefault(_file_md5(path), path)
|
|
72
|
+
return files
|
|
73
|
+
|
|
74
|
+
def _desired_files_from_partial_tree(
|
|
75
|
+
self,
|
|
76
|
+
tree: PartialTree,
|
|
77
|
+
current: dict[str, Path],
|
|
78
|
+
) -> list["_DesiredFile"]:
|
|
79
|
+
key_nodes = {node.id: node for node in tree.key_nodes}
|
|
80
|
+
parent_by_child = {edge.child_id: edge.parent_id for edge in tree.edges}
|
|
81
|
+
desired: list[_DesiredFile] = []
|
|
82
|
+
|
|
83
|
+
for node in tree.content_nodes:
|
|
84
|
+
folder = self._folder_for(node.id, key_nodes, parent_by_child)
|
|
85
|
+
filename = self._filename_for_node(node, current.get(str(node.id)))
|
|
86
|
+
content = None if node.content is None else str(node.content)
|
|
87
|
+
desired.append(_DesiredFile(str(node.id), folder, filename, content))
|
|
88
|
+
return desired
|
|
89
|
+
|
|
90
|
+
def _desired_files_from_live_node(
|
|
91
|
+
self,
|
|
92
|
+
root: Any,
|
|
93
|
+
current: dict[str, Path],
|
|
94
|
+
) -> list["_DesiredFile"]:
|
|
95
|
+
if root.is_leaf:
|
|
96
|
+
folder = Path(_safe_name(root.label or "topics"))
|
|
97
|
+
return [
|
|
98
|
+
file
|
|
99
|
+
for item in root.items or []
|
|
100
|
+
if (file := self._desired_file_for_item(item, folder, current)) is not None
|
|
101
|
+
]
|
|
102
|
+
|
|
103
|
+
desired: list[_DesiredFile] = []
|
|
104
|
+
self._collect_live_node_files(root, Path(), current, desired, include_self=False)
|
|
105
|
+
return desired
|
|
106
|
+
|
|
107
|
+
def _collect_live_node_files(
|
|
108
|
+
self,
|
|
109
|
+
node: Any,
|
|
110
|
+
prefix: Path,
|
|
111
|
+
current: dict[str, Path],
|
|
112
|
+
desired: list["_DesiredFile"],
|
|
113
|
+
*,
|
|
114
|
+
include_self: bool,
|
|
115
|
+
) -> None:
|
|
116
|
+
folder = prefix
|
|
117
|
+
if include_self:
|
|
118
|
+
folder = prefix / _safe_name(node.label or f"topic-{getattr(node, 'id', 'node')}")
|
|
119
|
+
if node.is_leaf:
|
|
120
|
+
for item in node.items or []:
|
|
121
|
+
file = self._desired_file_for_item(item, folder, current)
|
|
122
|
+
if file is not None:
|
|
123
|
+
desired.append(file)
|
|
124
|
+
return
|
|
125
|
+
|
|
126
|
+
used: set[str] = set()
|
|
127
|
+
for index, child in enumerate(node.children or []):
|
|
128
|
+
label = _dedupe_name(_safe_name(child.label or f"topic-{index + 1}"), used)
|
|
129
|
+
self._collect_live_node_files(child, folder / label, current, desired, include_self=False)
|
|
130
|
+
|
|
131
|
+
def _desired_file_for_item(
|
|
132
|
+
self,
|
|
133
|
+
item: Any,
|
|
134
|
+
folder: Path,
|
|
135
|
+
current: dict[str, Path],
|
|
136
|
+
) -> "_DesiredFile | None":
|
|
137
|
+
md5 = self._md5_for_item(item)
|
|
138
|
+
if md5 is None:
|
|
139
|
+
return None
|
|
140
|
+
existing = current.get(md5)
|
|
141
|
+
filename = self._filename_for_item(item, existing)
|
|
142
|
+
content = self._content_for_item(item)
|
|
143
|
+
return _DesiredFile(md5, folder, filename, content)
|
|
144
|
+
|
|
145
|
+
def _folder_for(
|
|
146
|
+
self,
|
|
147
|
+
node_id: NodeId,
|
|
148
|
+
key_nodes: dict[NodeId, KeyNode],
|
|
149
|
+
parent_by_child: dict[NodeId, NodeId],
|
|
150
|
+
) -> Path:
|
|
151
|
+
parts: list[str] = []
|
|
152
|
+
current = parent_by_child.get(node_id)
|
|
153
|
+
seen: set[NodeId] = set()
|
|
154
|
+
while current is not None and current not in seen:
|
|
155
|
+
seen.add(current)
|
|
156
|
+
if str(current) == ".":
|
|
157
|
+
break
|
|
158
|
+
key = key_nodes.get(current)
|
|
159
|
+
raw = key.label if key is not None and key.label else str(current)
|
|
160
|
+
part = _safe_name(raw)
|
|
161
|
+
if part:
|
|
162
|
+
parts.append(part)
|
|
163
|
+
current = parent_by_child.get(current)
|
|
164
|
+
return Path(*reversed(parts)) if parts else Path()
|
|
165
|
+
|
|
166
|
+
def _filename_for_node(self, node: ContentNode, existing: Path | None) -> str:
|
|
167
|
+
if existing is not None:
|
|
168
|
+
return existing.name
|
|
169
|
+
payload = node.payload if isinstance(node.payload, dict) else {}
|
|
170
|
+
for key in ("filename", "relative_path", "path"):
|
|
171
|
+
value = payload.get(key)
|
|
172
|
+
if value:
|
|
173
|
+
name = Path(str(value)).name
|
|
174
|
+
if name:
|
|
175
|
+
return _safe_name(name, fallback=f"{node.id}.txt")
|
|
176
|
+
text = node.text.strip() if isinstance(node.text, str) else ""
|
|
177
|
+
return _safe_name(text, fallback=f"{node.id}.txt")
|
|
178
|
+
|
|
179
|
+
def _md5_for_item(self, item: Any) -> str | None:
|
|
180
|
+
item_id = str(item.id)
|
|
181
|
+
if _is_md5(item_id):
|
|
182
|
+
return item_id
|
|
183
|
+
payload = item.payload if isinstance(item.payload, dict) else {}
|
|
184
|
+
for key in ("md5", "file_md5", "content_md5", "content_id", "id"):
|
|
185
|
+
value = payload.get(key)
|
|
186
|
+
if value is not None and _is_md5(str(value)):
|
|
187
|
+
return str(value)
|
|
188
|
+
return None
|
|
189
|
+
|
|
190
|
+
def _filename_for_item(self, item: Any, existing: Path | None) -> str:
|
|
191
|
+
if existing is not None:
|
|
192
|
+
return existing.name
|
|
193
|
+
payload = item.payload if isinstance(item.payload, dict) else {}
|
|
194
|
+
for key in ("filename", "relative_path", "output_path", "path"):
|
|
195
|
+
value = payload.get(key)
|
|
196
|
+
if value:
|
|
197
|
+
name = Path(str(value)).name
|
|
198
|
+
if name:
|
|
199
|
+
return _safe_name(name, fallback=f"{item.id}.txt")
|
|
200
|
+
text = item.text.strip() if isinstance(item.text, str) else ""
|
|
201
|
+
return _safe_name(text, fallback=f"{item.id}.txt")
|
|
202
|
+
|
|
203
|
+
def _content_for_item(self, item: Any) -> str | None:
|
|
204
|
+
payload = item.payload if isinstance(item.payload, dict) else {}
|
|
205
|
+
for key in ("content", "body", "text"):
|
|
206
|
+
value = payload.get(key)
|
|
207
|
+
if value is not None:
|
|
208
|
+
return str(value)
|
|
209
|
+
return None
|
|
210
|
+
|
|
211
|
+
def _move_file(self, source: Path, target: Path) -> None:
|
|
212
|
+
if source == target:
|
|
213
|
+
return
|
|
214
|
+
target.parent.mkdir(parents=True, exist_ok=True)
|
|
215
|
+
source.rename(target)
|
|
216
|
+
|
|
217
|
+
def _write_file(self, target: Path, content: str) -> None:
|
|
218
|
+
target.parent.mkdir(parents=True, exist_ok=True)
|
|
219
|
+
target.write_text(content, encoding=self.encoding)
|
|
220
|
+
|
|
221
|
+
def _prune_empty_dirs(self) -> None:
|
|
222
|
+
for directory in sorted(
|
|
223
|
+
(path for path in self.root.rglob("*") if path.is_dir()),
|
|
224
|
+
key=lambda path: len(path.parts),
|
|
225
|
+
reverse=True,
|
|
226
|
+
):
|
|
227
|
+
rel_parts = directory.relative_to(self.root).parts
|
|
228
|
+
if not self.include_hidden and any(part.startswith(".") for part in rel_parts):
|
|
229
|
+
continue
|
|
230
|
+
try:
|
|
231
|
+
next(directory.iterdir())
|
|
232
|
+
except StopIteration:
|
|
233
|
+
directory.rmdir()
|
|
234
|
+
|
|
235
|
+
def _included(self, path: Path) -> bool:
|
|
236
|
+
return self.include_suffixes is None or path.suffix.lower() in self.include_suffixes
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
FileSystemTreePersister = FolderTreePersister
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
@dataclass(frozen=True)
|
|
243
|
+
class _DesiredFile:
|
|
244
|
+
md5: str
|
|
245
|
+
folder: Path
|
|
246
|
+
filename: str
|
|
247
|
+
content: str | None
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def _file_md5(path: Path) -> str:
|
|
251
|
+
digest = hashlib.md5()
|
|
252
|
+
with path.open("rb") as f:
|
|
253
|
+
for chunk in iter(lambda: f.read(1024 * 1024), b""):
|
|
254
|
+
digest.update(chunk)
|
|
255
|
+
return digest.hexdigest()
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def _safe_name(value: str, *, fallback: str = "untitled") -> str:
|
|
259
|
+
cleaned = _UNSAFE_PATH_CHARS.sub(" ", value)
|
|
260
|
+
cleaned = " ".join(cleaned.split()).strip(". ")
|
|
261
|
+
return cleaned or fallback
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def _dedupe_name(value: str, used: set[str]) -> str:
|
|
265
|
+
candidate = value
|
|
266
|
+
i = 2
|
|
267
|
+
while candidate in used:
|
|
268
|
+
candidate = f"{value}-{i}"
|
|
269
|
+
i += 1
|
|
270
|
+
used.add(candidate)
|
|
271
|
+
return candidate
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
def _is_md5(value: str) -> bool:
|
|
275
|
+
return len(value) == 32 and all(c in "0123456789abcdefABCDEF" for c in value)
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def _looks_like_live_node(value: Any) -> bool:
|
|
279
|
+
return hasattr(value, "is_leaf") and (hasattr(value, "items") or hasattr(value, "children"))
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def _unique_path(target: Path, reserved: set[Path]) -> Path:
|
|
283
|
+
if target not in reserved and not target.exists():
|
|
284
|
+
return target
|
|
285
|
+
stem = target.stem
|
|
286
|
+
suffix = target.suffix
|
|
287
|
+
parent = target.parent
|
|
288
|
+
i = 2
|
|
289
|
+
while True:
|
|
290
|
+
candidate = parent / f"{stem}-{i}{suffix}"
|
|
291
|
+
if candidate not in reserved and not candidate.exists():
|
|
292
|
+
return candidate
|
|
293
|
+
i += 1
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
"""JSON external persister."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import os
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
from embed_tree.representation import PartialTree
|
|
11
|
+
from embed_tree.representation.default import partial_tree_to_dict
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class JsonTreePersister:
|
|
15
|
+
"""Persist an exported tree artifact to JSON."""
|
|
16
|
+
|
|
17
|
+
def __init__(self, path: str | Path) -> None:
|
|
18
|
+
self.path = Path(path)
|
|
19
|
+
|
|
20
|
+
def save(self, state: Any) -> None:
|
|
21
|
+
payload = partial_tree_to_dict(state) if isinstance(state, PartialTree) else state
|
|
22
|
+
self.path.parent.mkdir(parents=True, exist_ok=True)
|
|
23
|
+
tmp = self.path.with_name(f"{self.path.name}.tmp.{os.getpid()}")
|
|
24
|
+
with tmp.open("w", encoding="utf-8") as f:
|
|
25
|
+
json.dump(payload, f)
|
|
26
|
+
f.flush()
|
|
27
|
+
os.fsync(f.fileno())
|
|
28
|
+
os.replace(tmp, self.path)
|
|
29
|
+
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
"""Abstract persister contract."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any, Protocol, runtime_checkable
|
|
6
|
+
|
|
7
|
+
from embed_tree.representation import PartialTree
|
|
8
|
+
|
|
9
|
+
MaterializedTreeState = dict[str, Any]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@runtime_checkable
|
|
13
|
+
class TreePersister(Protocol):
|
|
14
|
+
"""Persist a tree representation.
|
|
15
|
+
|
|
16
|
+
The same protocol covers durable internal state, reusable state for future
|
|
17
|
+
builds, and user-facing exports. The role comes from where the persister is
|
|
18
|
+
passed, not from a separate cache type.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def save(self, state: PartialTree | MaterializedTreeState | Any) -> None:
|
|
22
|
+
"""Persist a tree representation or export artifact."""
|
|
23
|
+
...
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
"""SQLAlchemy external persister."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from embed_tree.representation import PartialTree
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class SQLAlchemyTreePersister:
|
|
11
|
+
"""Persist PartialTree content nodes to a SQL table via SQLAlchemy Core."""
|
|
12
|
+
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
engine_or_url: Any,
|
|
16
|
+
table_name: str,
|
|
17
|
+
*,
|
|
18
|
+
id_column: str = "id",
|
|
19
|
+
content_column: str = "content",
|
|
20
|
+
text_column: str = "text",
|
|
21
|
+
payload_column: str = "payload",
|
|
22
|
+
) -> None:
|
|
23
|
+
self.engine_or_url = engine_or_url
|
|
24
|
+
self.table_name = table_name
|
|
25
|
+
self.id_column = id_column
|
|
26
|
+
self.content_column = content_column
|
|
27
|
+
self.text_column = text_column
|
|
28
|
+
self.payload_column = payload_column
|
|
29
|
+
|
|
30
|
+
def save(self, state: Any) -> None:
|
|
31
|
+
if not isinstance(state, PartialTree):
|
|
32
|
+
raise TypeError("SQLAlchemyTreePersister only persists PartialTree instances")
|
|
33
|
+
|
|
34
|
+
sa = _sqlalchemy()
|
|
35
|
+
engine = _engine(sa, self.engine_or_url)
|
|
36
|
+
table = self._table(sa)
|
|
37
|
+
table.metadata.create_all(engine)
|
|
38
|
+
rows = [
|
|
39
|
+
{
|
|
40
|
+
self.id_column: node.id,
|
|
41
|
+
self.content_column: node.content,
|
|
42
|
+
self.text_column: node.text,
|
|
43
|
+
self.payload_column: node.payload,
|
|
44
|
+
}
|
|
45
|
+
for node in state.content_nodes
|
|
46
|
+
]
|
|
47
|
+
with engine.begin() as conn:
|
|
48
|
+
conn.execute(table.delete())
|
|
49
|
+
if rows:
|
|
50
|
+
conn.execute(table.insert(), rows)
|
|
51
|
+
|
|
52
|
+
def _table(self, sa: Any) -> Any:
|
|
53
|
+
meta = sa.MetaData()
|
|
54
|
+
return sa.Table(
|
|
55
|
+
self.table_name,
|
|
56
|
+
meta,
|
|
57
|
+
sa.Column(self.id_column, sa.String, primary_key=True),
|
|
58
|
+
sa.Column(self.content_column, sa.Text, nullable=False),
|
|
59
|
+
sa.Column(self.text_column, sa.Text),
|
|
60
|
+
sa.Column(self.payload_column, sa.JSON),
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _sqlalchemy() -> Any:
|
|
65
|
+
try:
|
|
66
|
+
import sqlalchemy as sa
|
|
67
|
+
except ImportError as e: # pragma: no cover
|
|
68
|
+
raise ImportError('SQLAlchemyTreePersister needs the "sql" extra: pip install "embed-tree[sql]"') from e
|
|
69
|
+
return sa
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _engine(sa: Any, engine_or_url: Any) -> Any:
|
|
73
|
+
if isinstance(engine_or_url, str):
|
|
74
|
+
return sa.create_engine(engine_or_url)
|
|
75
|
+
return engine_or_url
|
|
76
|
+
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
"""Vector projection contracts."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Literal, Protocol, runtime_checkable
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
from pydantic import BaseModel
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class PCAConfig(BaseModel):
|
|
12
|
+
"""Configuration for PCA-based projection."""
|
|
13
|
+
|
|
14
|
+
dims: int | None = None
|
|
15
|
+
mode: Literal["freeze", "incremental"] = "freeze"
|
|
16
|
+
warmup: int = 1000
|
|
17
|
+
batch_size: int = 256
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@runtime_checkable
|
|
21
|
+
class VectorProjector(Protocol):
|
|
22
|
+
"""Map raw vectors to operational vectors."""
|
|
23
|
+
|
|
24
|
+
@property
|
|
25
|
+
def is_fitted(self) -> bool:
|
|
26
|
+
"""Whether this projector can transform vectors."""
|
|
27
|
+
...
|
|
28
|
+
|
|
29
|
+
def fit(self, vectors: np.ndarray) -> None:
|
|
30
|
+
"""Fit projector state from vectors."""
|
|
31
|
+
...
|
|
32
|
+
|
|
33
|
+
def transform(self, vectors: np.ndarray) -> np.ndarray:
|
|
34
|
+
"""Project vectors."""
|
|
35
|
+
...
|
|
36
|
+
|
|
37
|
+
def __call__(self, vectors: np.ndarray) -> np.ndarray:
|
|
38
|
+
"""Project vectors."""
|
|
39
|
+
...
|