embed-tree 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. embed_tree/__init__.py +121 -0
  2. embed_tree/cache/__init__.py +10 -0
  3. embed_tree/cache/json.py +7 -0
  4. embed_tree/cache/model.py +13 -0
  5. embed_tree/cache/sqlalchemy.py +7 -0
  6. embed_tree/config.py +100 -0
  7. embed_tree/embedders/__init__.py +7 -0
  8. embed_tree/embedders/huggingface.py +69 -0
  9. embed_tree/embedders/model.py +27 -0
  10. embed_tree/labelers/__init__.py +8 -0
  11. embed_tree/labelers/function.py +25 -0
  12. embed_tree/labelers/llm.py +26 -0
  13. embed_tree/labelers/model.py +38 -0
  14. embed_tree/loaders/__init__.py +17 -0
  15. embed_tree/loaders/filesystem.py +83 -0
  16. embed_tree/loaders/json.py +49 -0
  17. embed_tree/loaders/model.py +20 -0
  18. embed_tree/loaders/sqlalchemy.py +91 -0
  19. embed_tree/loaders/sqlalchemy_content.py +63 -0
  20. embed_tree/loaders/sqlite.py +21 -0
  21. embed_tree/persisters/__init__.py +15 -0
  22. embed_tree/persisters/filesystem.py +293 -0
  23. embed_tree/persisters/json.py +29 -0
  24. embed_tree/persisters/model.py +23 -0
  25. embed_tree/persisters/sqlalchemy.py +76 -0
  26. embed_tree/projectors/__init__.py +7 -0
  27. embed_tree/projectors/model.py +39 -0
  28. embed_tree/projectors/pca.py +57 -0
  29. embed_tree/providers/__init__.py +20 -0
  30. embed_tree/providers/base.py +104 -0
  31. embed_tree/providers/fake.py +26 -0
  32. embed_tree/providers/local.py +44 -0
  33. embed_tree/providers/openai.py +49 -0
  34. embed_tree/reconcilers/__init__.py +6 -0
  35. embed_tree/reconcilers/default.py +65 -0
  36. embed_tree/reconcilers/model.py +25 -0
  37. embed_tree/reducers.py +194 -0
  38. embed_tree/representation/__init__.py +27 -0
  39. embed_tree/representation/default.py +59 -0
  40. embed_tree/representation/model.py +87 -0
  41. embed_tree/store.py +5 -0
  42. embed_tree/stores/__init__.py +8 -0
  43. embed_tree/stores/file.py +32 -0
  44. embed_tree/stores/model.py +25 -0
  45. embed_tree/stores/null.py +16 -0
  46. embed_tree/taggers.py +132 -0
  47. embed_tree/tree.py +691 -0
  48. embed_tree-0.0.6.dist-info/METADATA +182 -0
  49. embed_tree-0.0.6.dist-info/RECORD +51 -0
  50. embed_tree-0.0.6.dist-info/WHEEL +5 -0
  51. embed_tree-0.0.6.dist-info/top_level.txt +1 -0
@@ -0,0 +1,91 @@
1
+ """SQLAlchemy-backed tree loader."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
7
+ from embed_tree.persisters.model import MaterializedTreeState
8
+ from embed_tree.representation import PartialTree
9
+ from embed_tree.representation.default import partial_tree_from_dict, partial_tree_to_dict
10
+
11
+
12
+ class SQLAlchemyTreeLoader:
13
+ """Load/save tree representation from an existing SQL table.
14
+
15
+ This base implementation deliberately does not create tables or run
16
+ migrations. Database shape is an application decision. Subclasses can
17
+ override ``post_init`` for setup; ``SQLiteTreeLoader`` is the built-in
18
+ simple implementation that creates its table.
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ engine_or_url: Any,
24
+ *,
25
+ table_name: str = "embed_tree_state",
26
+ cache_key: str = "default",
27
+ ) -> None:
28
+ self.engine_or_url = engine_or_url
29
+ self.table_name = table_name
30
+ self.cache_key = cache_key
31
+ sa = _sqlalchemy()
32
+ self.engine = _engine(sa, self.engine_or_url)
33
+ self.post_init()
34
+
35
+ def post_init(self) -> None:
36
+ """Hook for migrations or validation owned by the caller/subclass."""
37
+ pass
38
+
39
+ def load(self) -> PartialTree | MaterializedTreeState | None:
40
+ sa = _sqlalchemy()
41
+ table = self._table(sa)
42
+ stmt = sa.select(table.c.kind, table.c.payload).where(table.c.cache_key == self.cache_key)
43
+ with self.engine.connect() as conn:
44
+ row = conn.execute(stmt).mappings().first()
45
+ if row is None:
46
+ return None
47
+ if row["kind"] == "partial_tree":
48
+ return partial_tree_from_dict(row["payload"])
49
+ if row["kind"] == "materialized_tree_state":
50
+ return row["payload"]
51
+ return row["payload"]
52
+
53
+ def save(self, state: PartialTree | MaterializedTreeState) -> None:
54
+ sa = _sqlalchemy()
55
+ table = self._table(sa)
56
+ if isinstance(state, PartialTree):
57
+ kind = "partial_tree"
58
+ payload: dict[str, Any] = partial_tree_to_dict(state)
59
+ else:
60
+ kind = "materialized_tree_state"
61
+ payload = state
62
+
63
+ delete = table.delete().where(table.c.cache_key == self.cache_key)
64
+ insert = table.insert().values(cache_key=self.cache_key, kind=kind, payload=payload)
65
+ with self.engine.begin() as conn:
66
+ conn.execute(delete)
67
+ conn.execute(insert)
68
+
69
+ def _table(self, sa: Any) -> Any:
70
+ meta = sa.MetaData()
71
+ return sa.Table(
72
+ self.table_name,
73
+ meta,
74
+ sa.Column("cache_key", sa.String, primary_key=True),
75
+ sa.Column("kind", sa.String, nullable=False),
76
+ sa.Column("payload", sa.JSON, nullable=False),
77
+ )
78
+
79
+
80
+ def _sqlalchemy() -> Any:
81
+ try:
82
+ import sqlalchemy as sa
83
+ except ImportError as e: # pragma: no cover
84
+ raise ImportError('SQLAlchemyTreeLoader needs the "sql" extra: pip install "embed-tree[sql]"') from e
85
+ return sa
86
+
87
+
88
+ def _engine(sa: Any, engine_or_url: Any) -> Any:
89
+ if isinstance(engine_or_url, str):
90
+ return sa.create_engine(engine_or_url)
91
+ return engine_or_url
@@ -0,0 +1,63 @@
1
+ """SQLAlchemy-backed content loader."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Iterable
6
+
7
+ from embed_tree.representation import ContentNode, PartialTree
8
+
9
+ from .sqlalchemy import _engine, _sqlalchemy
10
+
11
+
12
+ class SQLAlchemyContentLoader:
13
+ """Load content nodes from an existing SQL table via SQLAlchemy Core."""
14
+
15
+ def __init__(
16
+ self,
17
+ engine_or_url: Any,
18
+ table_name: str,
19
+ *,
20
+ id_column: str = "id",
21
+ content_column: str = "content",
22
+ text_column: str | None = None,
23
+ payload_columns: Iterable[str] | None = None,
24
+ where: Any | None = None,
25
+ ) -> None:
26
+ self.engine_or_url = engine_or_url
27
+ self.table_name = table_name
28
+ self.id_column = id_column
29
+ self.content_column = content_column
30
+ self.text_column = text_column
31
+ self.payload_columns = list(payload_columns or [])
32
+ self.where = where
33
+ self.post_init()
34
+
35
+ def post_init(self) -> None:
36
+ """Hook for migrations or validation owned by the caller/subclass."""
37
+ pass
38
+
39
+ def load(self) -> PartialTree | None:
40
+ sa = _sqlalchemy()
41
+ engine = _engine(sa, self.engine_or_url)
42
+ meta = sa.MetaData()
43
+ table = sa.Table(self.table_name, meta, autoload_with=engine)
44
+
45
+ stmt = sa.select(table)
46
+ if self.where is not None:
47
+ stmt = stmt.where(self.where)
48
+
49
+ nodes: list[ContentNode] = []
50
+ with engine.connect() as conn:
51
+ for row in conn.execute(stmt).mappings():
52
+ payload = {name: row[name] for name in self.payload_columns}
53
+ nodes.append(
54
+ ContentNode(
55
+ id=row[self.id_column],
56
+ content=row[self.content_column],
57
+ text=None if self.text_column is None else row[self.text_column],
58
+ payload=payload or None,
59
+ )
60
+ )
61
+
62
+ return PartialTree(content_nodes=nodes, metadata={"source": "sqlalchemy", "table": self.table_name})
63
+
@@ -0,0 +1,21 @@
1
+ """SQLite-specific tree loader."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+
7
+ from .sqlalchemy import SQLAlchemyTreeLoader, _sqlalchemy
8
+
9
+
10
+ class SQLiteTreeLoader(SQLAlchemyTreeLoader):
11
+ """SQLite tree-state loader with built-in table creation."""
12
+
13
+ def __init__(self, path: str | Path, *, table_name: str = "embed_tree_state", cache_key: str = "default") -> None:
14
+ self.path = Path(path)
15
+ super().__init__(f"sqlite:///{self.path}", table_name=table_name, cache_key=cache_key)
16
+
17
+ def post_init(self) -> None:
18
+ sa = _sqlalchemy()
19
+ self.path.parent.mkdir(parents=True, exist_ok=True)
20
+ self._table(sa).metadata.create_all(self.engine)
21
+
@@ -0,0 +1,15 @@
1
+ """Tree persister contracts and implementations."""
2
+
3
+ from .filesystem import FileSystemTreePersister, FolderTreePersister
4
+ from .json import JsonTreePersister
5
+ from .model import MaterializedTreeState, TreePersister
6
+ from .sqlalchemy import SQLAlchemyTreePersister
7
+
8
+ __all__ = [
9
+ "TreePersister",
10
+ "MaterializedTreeState",
11
+ "FolderTreePersister",
12
+ "FileSystemTreePersister",
13
+ "JsonTreePersister",
14
+ "SQLAlchemyTreePersister",
15
+ ]
@@ -0,0 +1,293 @@
1
+ """Folder-backed external persister."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import hashlib
6
+ import re
7
+ from dataclasses import dataclass
8
+ from pathlib import Path
9
+ from typing import Any
10
+
11
+ from embed_tree.representation import ContentNode, KeyNode, NodeId, PartialTree
12
+
13
+ _UNSAFE_PATH_CHARS = re.compile(r'[\\/:*?"<>|#\[\]^\n\r\t]+')
14
+
15
+
16
+ class FolderTreePersister:
17
+ """Materialize a live tree as folders and files.
18
+
19
+ The local folder is treated as mutable ground truth: every save reloads its
20
+ current file-md5 map, compares it with the target layout from the in-memory
21
+ tree, moves known files into place, creates missing files only when content
22
+ is available, prunes empty folders, and leaves unknown files untouched.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ root: str | Path,
28
+ *,
29
+ include_suffixes: set[str] | list[str] | tuple[str, ...] | None = None,
30
+ encoding: str = "utf-8",
31
+ include_hidden: bool = False,
32
+ ) -> None:
33
+ self.root = Path(root)
34
+ self.include_suffixes = None if include_suffixes is None else {s.lower() for s in include_suffixes}
35
+ self.encoding = encoding
36
+ self.include_hidden = include_hidden
37
+
38
+ def save(self, state: Any) -> None:
39
+ self.root.mkdir(parents=True, exist_ok=True)
40
+ current = self._current_files_by_md5()
41
+ desired = self._desired_files(state, current)
42
+
43
+ reserved: set[Path] = set()
44
+ for file in desired:
45
+ target = _unique_path(self.root / file.folder / file.filename, reserved)
46
+ reserved.add(target)
47
+ existing = current.get(file.md5)
48
+ if existing is not None:
49
+ self._move_file(existing, target)
50
+ elif file.content is not None:
51
+ self._write_file(target, file.content)
52
+
53
+ self._prune_empty_dirs()
54
+
55
+ def _desired_files(self, state: Any, current: dict[str, Path]) -> list["_DesiredFile"]:
56
+ tree = state.get_tree() if hasattr(state, "get_tree") and callable(state.get_tree) else state
57
+ if isinstance(tree, PartialTree):
58
+ return self._desired_files_from_partial_tree(tree, current)
59
+ if _looks_like_live_node(tree):
60
+ return list(self._desired_files_from_live_node(tree, current))
61
+ raise TypeError("FolderTreePersister persists an EmbedTree, live Node, or PartialTree")
62
+
63
+ def _current_files_by_md5(self) -> dict[str, Path]:
64
+ files: dict[str, Path] = {}
65
+ for path in sorted(self.root.rglob("*")):
66
+ if not path.is_file() or not self._included(path):
67
+ continue
68
+ rel_parts = path.relative_to(self.root).parts
69
+ if not self.include_hidden and any(part.startswith(".") for part in rel_parts):
70
+ continue
71
+ files.setdefault(_file_md5(path), path)
72
+ return files
73
+
74
+ def _desired_files_from_partial_tree(
75
+ self,
76
+ tree: PartialTree,
77
+ current: dict[str, Path],
78
+ ) -> list["_DesiredFile"]:
79
+ key_nodes = {node.id: node for node in tree.key_nodes}
80
+ parent_by_child = {edge.child_id: edge.parent_id for edge in tree.edges}
81
+ desired: list[_DesiredFile] = []
82
+
83
+ for node in tree.content_nodes:
84
+ folder = self._folder_for(node.id, key_nodes, parent_by_child)
85
+ filename = self._filename_for_node(node, current.get(str(node.id)))
86
+ content = None if node.content is None else str(node.content)
87
+ desired.append(_DesiredFile(str(node.id), folder, filename, content))
88
+ return desired
89
+
90
+ def _desired_files_from_live_node(
91
+ self,
92
+ root: Any,
93
+ current: dict[str, Path],
94
+ ) -> list["_DesiredFile"]:
95
+ if root.is_leaf:
96
+ folder = Path(_safe_name(root.label or "topics"))
97
+ return [
98
+ file
99
+ for item in root.items or []
100
+ if (file := self._desired_file_for_item(item, folder, current)) is not None
101
+ ]
102
+
103
+ desired: list[_DesiredFile] = []
104
+ self._collect_live_node_files(root, Path(), current, desired, include_self=False)
105
+ return desired
106
+
107
+ def _collect_live_node_files(
108
+ self,
109
+ node: Any,
110
+ prefix: Path,
111
+ current: dict[str, Path],
112
+ desired: list["_DesiredFile"],
113
+ *,
114
+ include_self: bool,
115
+ ) -> None:
116
+ folder = prefix
117
+ if include_self:
118
+ folder = prefix / _safe_name(node.label or f"topic-{getattr(node, 'id', 'node')}")
119
+ if node.is_leaf:
120
+ for item in node.items or []:
121
+ file = self._desired_file_for_item(item, folder, current)
122
+ if file is not None:
123
+ desired.append(file)
124
+ return
125
+
126
+ used: set[str] = set()
127
+ for index, child in enumerate(node.children or []):
128
+ label = _dedupe_name(_safe_name(child.label or f"topic-{index + 1}"), used)
129
+ self._collect_live_node_files(child, folder / label, current, desired, include_self=False)
130
+
131
+ def _desired_file_for_item(
132
+ self,
133
+ item: Any,
134
+ folder: Path,
135
+ current: dict[str, Path],
136
+ ) -> "_DesiredFile | None":
137
+ md5 = self._md5_for_item(item)
138
+ if md5 is None:
139
+ return None
140
+ existing = current.get(md5)
141
+ filename = self._filename_for_item(item, existing)
142
+ content = self._content_for_item(item)
143
+ return _DesiredFile(md5, folder, filename, content)
144
+
145
+ def _folder_for(
146
+ self,
147
+ node_id: NodeId,
148
+ key_nodes: dict[NodeId, KeyNode],
149
+ parent_by_child: dict[NodeId, NodeId],
150
+ ) -> Path:
151
+ parts: list[str] = []
152
+ current = parent_by_child.get(node_id)
153
+ seen: set[NodeId] = set()
154
+ while current is not None and current not in seen:
155
+ seen.add(current)
156
+ if str(current) == ".":
157
+ break
158
+ key = key_nodes.get(current)
159
+ raw = key.label if key is not None and key.label else str(current)
160
+ part = _safe_name(raw)
161
+ if part:
162
+ parts.append(part)
163
+ current = parent_by_child.get(current)
164
+ return Path(*reversed(parts)) if parts else Path()
165
+
166
+ def _filename_for_node(self, node: ContentNode, existing: Path | None) -> str:
167
+ if existing is not None:
168
+ return existing.name
169
+ payload = node.payload if isinstance(node.payload, dict) else {}
170
+ for key in ("filename", "relative_path", "path"):
171
+ value = payload.get(key)
172
+ if value:
173
+ name = Path(str(value)).name
174
+ if name:
175
+ return _safe_name(name, fallback=f"{node.id}.txt")
176
+ text = node.text.strip() if isinstance(node.text, str) else ""
177
+ return _safe_name(text, fallback=f"{node.id}.txt")
178
+
179
+ def _md5_for_item(self, item: Any) -> str | None:
180
+ item_id = str(item.id)
181
+ if _is_md5(item_id):
182
+ return item_id
183
+ payload = item.payload if isinstance(item.payload, dict) else {}
184
+ for key in ("md5", "file_md5", "content_md5", "content_id", "id"):
185
+ value = payload.get(key)
186
+ if value is not None and _is_md5(str(value)):
187
+ return str(value)
188
+ return None
189
+
190
+ def _filename_for_item(self, item: Any, existing: Path | None) -> str:
191
+ if existing is not None:
192
+ return existing.name
193
+ payload = item.payload if isinstance(item.payload, dict) else {}
194
+ for key in ("filename", "relative_path", "output_path", "path"):
195
+ value = payload.get(key)
196
+ if value:
197
+ name = Path(str(value)).name
198
+ if name:
199
+ return _safe_name(name, fallback=f"{item.id}.txt")
200
+ text = item.text.strip() if isinstance(item.text, str) else ""
201
+ return _safe_name(text, fallback=f"{item.id}.txt")
202
+
203
+ def _content_for_item(self, item: Any) -> str | None:
204
+ payload = item.payload if isinstance(item.payload, dict) else {}
205
+ for key in ("content", "body", "text"):
206
+ value = payload.get(key)
207
+ if value is not None:
208
+ return str(value)
209
+ return None
210
+
211
+ def _move_file(self, source: Path, target: Path) -> None:
212
+ if source == target:
213
+ return
214
+ target.parent.mkdir(parents=True, exist_ok=True)
215
+ source.rename(target)
216
+
217
+ def _write_file(self, target: Path, content: str) -> None:
218
+ target.parent.mkdir(parents=True, exist_ok=True)
219
+ target.write_text(content, encoding=self.encoding)
220
+
221
+ def _prune_empty_dirs(self) -> None:
222
+ for directory in sorted(
223
+ (path for path in self.root.rglob("*") if path.is_dir()),
224
+ key=lambda path: len(path.parts),
225
+ reverse=True,
226
+ ):
227
+ rel_parts = directory.relative_to(self.root).parts
228
+ if not self.include_hidden and any(part.startswith(".") for part in rel_parts):
229
+ continue
230
+ try:
231
+ next(directory.iterdir())
232
+ except StopIteration:
233
+ directory.rmdir()
234
+
235
+ def _included(self, path: Path) -> bool:
236
+ return self.include_suffixes is None or path.suffix.lower() in self.include_suffixes
237
+
238
+
239
+ FileSystemTreePersister = FolderTreePersister
240
+
241
+
242
+ @dataclass(frozen=True)
243
+ class _DesiredFile:
244
+ md5: str
245
+ folder: Path
246
+ filename: str
247
+ content: str | None
248
+
249
+
250
+ def _file_md5(path: Path) -> str:
251
+ digest = hashlib.md5()
252
+ with path.open("rb") as f:
253
+ for chunk in iter(lambda: f.read(1024 * 1024), b""):
254
+ digest.update(chunk)
255
+ return digest.hexdigest()
256
+
257
+
258
+ def _safe_name(value: str, *, fallback: str = "untitled") -> str:
259
+ cleaned = _UNSAFE_PATH_CHARS.sub(" ", value)
260
+ cleaned = " ".join(cleaned.split()).strip(". ")
261
+ return cleaned or fallback
262
+
263
+
264
+ def _dedupe_name(value: str, used: set[str]) -> str:
265
+ candidate = value
266
+ i = 2
267
+ while candidate in used:
268
+ candidate = f"{value}-{i}"
269
+ i += 1
270
+ used.add(candidate)
271
+ return candidate
272
+
273
+
274
+ def _is_md5(value: str) -> bool:
275
+ return len(value) == 32 and all(c in "0123456789abcdefABCDEF" for c in value)
276
+
277
+
278
+ def _looks_like_live_node(value: Any) -> bool:
279
+ return hasattr(value, "is_leaf") and (hasattr(value, "items") or hasattr(value, "children"))
280
+
281
+
282
+ def _unique_path(target: Path, reserved: set[Path]) -> Path:
283
+ if target not in reserved and not target.exists():
284
+ return target
285
+ stem = target.stem
286
+ suffix = target.suffix
287
+ parent = target.parent
288
+ i = 2
289
+ while True:
290
+ candidate = parent / f"{stem}-{i}{suffix}"
291
+ if candidate not in reserved and not candidate.exists():
292
+ return candidate
293
+ i += 1
@@ -0,0 +1,29 @@
1
+ """JSON external persister."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import os
7
+ from pathlib import Path
8
+ from typing import Any
9
+
10
+ from embed_tree.representation import PartialTree
11
+ from embed_tree.representation.default import partial_tree_to_dict
12
+
13
+
14
+ class JsonTreePersister:
15
+ """Persist an exported tree artifact to JSON."""
16
+
17
+ def __init__(self, path: str | Path) -> None:
18
+ self.path = Path(path)
19
+
20
+ def save(self, state: Any) -> None:
21
+ payload = partial_tree_to_dict(state) if isinstance(state, PartialTree) else state
22
+ self.path.parent.mkdir(parents=True, exist_ok=True)
23
+ tmp = self.path.with_name(f"{self.path.name}.tmp.{os.getpid()}")
24
+ with tmp.open("w", encoding="utf-8") as f:
25
+ json.dump(payload, f)
26
+ f.flush()
27
+ os.fsync(f.fileno())
28
+ os.replace(tmp, self.path)
29
+
@@ -0,0 +1,23 @@
1
+ """Abstract persister contract."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Protocol, runtime_checkable
6
+
7
+ from embed_tree.representation import PartialTree
8
+
9
+ MaterializedTreeState = dict[str, Any]
10
+
11
+
12
+ @runtime_checkable
13
+ class TreePersister(Protocol):
14
+ """Persist a tree representation.
15
+
16
+ The same protocol covers durable internal state, reusable state for future
17
+ builds, and user-facing exports. The role comes from where the persister is
18
+ passed, not from a separate cache type.
19
+ """
20
+
21
+ def save(self, state: PartialTree | MaterializedTreeState | Any) -> None:
22
+ """Persist a tree representation or export artifact."""
23
+ ...
@@ -0,0 +1,76 @@
1
+ """SQLAlchemy external persister."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
7
+ from embed_tree.representation import PartialTree
8
+
9
+
10
+ class SQLAlchemyTreePersister:
11
+ """Persist PartialTree content nodes to a SQL table via SQLAlchemy Core."""
12
+
13
+ def __init__(
14
+ self,
15
+ engine_or_url: Any,
16
+ table_name: str,
17
+ *,
18
+ id_column: str = "id",
19
+ content_column: str = "content",
20
+ text_column: str = "text",
21
+ payload_column: str = "payload",
22
+ ) -> None:
23
+ self.engine_or_url = engine_or_url
24
+ self.table_name = table_name
25
+ self.id_column = id_column
26
+ self.content_column = content_column
27
+ self.text_column = text_column
28
+ self.payload_column = payload_column
29
+
30
+ def save(self, state: Any) -> None:
31
+ if not isinstance(state, PartialTree):
32
+ raise TypeError("SQLAlchemyTreePersister only persists PartialTree instances")
33
+
34
+ sa = _sqlalchemy()
35
+ engine = _engine(sa, self.engine_or_url)
36
+ table = self._table(sa)
37
+ table.metadata.create_all(engine)
38
+ rows = [
39
+ {
40
+ self.id_column: node.id,
41
+ self.content_column: node.content,
42
+ self.text_column: node.text,
43
+ self.payload_column: node.payload,
44
+ }
45
+ for node in state.content_nodes
46
+ ]
47
+ with engine.begin() as conn:
48
+ conn.execute(table.delete())
49
+ if rows:
50
+ conn.execute(table.insert(), rows)
51
+
52
+ def _table(self, sa: Any) -> Any:
53
+ meta = sa.MetaData()
54
+ return sa.Table(
55
+ self.table_name,
56
+ meta,
57
+ sa.Column(self.id_column, sa.String, primary_key=True),
58
+ sa.Column(self.content_column, sa.Text, nullable=False),
59
+ sa.Column(self.text_column, sa.Text),
60
+ sa.Column(self.payload_column, sa.JSON),
61
+ )
62
+
63
+
64
+ def _sqlalchemy() -> Any:
65
+ try:
66
+ import sqlalchemy as sa
67
+ except ImportError as e: # pragma: no cover
68
+ raise ImportError('SQLAlchemyTreePersister needs the "sql" extra: pip install "embed-tree[sql]"') from e
69
+ return sa
70
+
71
+
72
+ def _engine(sa: Any, engine_or_url: Any) -> Any:
73
+ if isinstance(engine_or_url, str):
74
+ return sa.create_engine(engine_or_url)
75
+ return engine_or_url
76
+
@@ -0,0 +1,7 @@
1
+ """Vector projection integrations."""
2
+
3
+ from .model import PCAConfig, VectorProjector
4
+ from .pca import PCAProjector
5
+
6
+ __all__ = ["PCAConfig", "VectorProjector", "PCAProjector"]
7
+
@@ -0,0 +1,39 @@
1
+ """Vector projection contracts."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Literal, Protocol, runtime_checkable
6
+
7
+ import numpy as np
8
+ from pydantic import BaseModel
9
+
10
+
11
+ class PCAConfig(BaseModel):
12
+ """Configuration for PCA-based projection."""
13
+
14
+ dims: int | None = None
15
+ mode: Literal["freeze", "incremental"] = "freeze"
16
+ warmup: int = 1000
17
+ batch_size: int = 256
18
+
19
+
20
+ @runtime_checkable
21
+ class VectorProjector(Protocol):
22
+ """Map raw vectors to operational vectors."""
23
+
24
+ @property
25
+ def is_fitted(self) -> bool:
26
+ """Whether this projector can transform vectors."""
27
+ ...
28
+
29
+ def fit(self, vectors: np.ndarray) -> None:
30
+ """Fit projector state from vectors."""
31
+ ...
32
+
33
+ def transform(self, vectors: np.ndarray) -> np.ndarray:
34
+ """Project vectors."""
35
+ ...
36
+
37
+ def __call__(self, vectors: np.ndarray) -> np.ndarray:
38
+ """Project vectors."""
39
+ ...